From 0cd0283c6901f932d4f3196ea7c2e355d691188b Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 11 Sep 2020 15:21:26 +0200
Subject: [PATCH] Support symbolic index assignments

---
 src/pystencils_autodiff/_autodiff.py | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index a003cd5..d746679 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -122,7 +122,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
 
             elif diff_read_field.index_dimensions == 1:
 
-                diff_read_field_sum = [0] * diff_read_field.index_shape[0]
+                diff_read_field_sum = {}
                 for ra in read_field_accesses:
                     if ra.field != forward_read_field:
                         continue  # ignore constant fields in differentiation
@@ -130,19 +130,19 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
                     # TF-MAD requires flipped stencils
                     inverted_offset = tuple(-v - w for v,
                                             w in zip(ra.offsets, write_field_accesses[0].offsets))
-                    diff_read_field_sum[ra.index[0]
-                                        ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
+                    diff_read_field_sum[ra.index[0]] = diff_read_field_sum.get(
+                        ra.index[0], 0) + sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
 
-                for index in range(diff_read_field.index_shape[0]):
+                for index in diff_read_field_sum.keys():
                     if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields:
                         # Accumulate in case of time_constant_fields
                         assignment = ps.Assignment(
-                            diff_read_field.center_vector[index],
-                            diff_read_field.center_vector[index] + diff_read_field_sum[index])
+                            diff_read_field.center.at_index(index),
+                            diff_read_field.center.at_index(index) + diff_read_field_sum[index])
                     else:
                         # If time dependent, we just need to assign the sum for the current time step
                         assignment = ps.Assignment(
-                            diff_read_field.center_vector[index], diff_read_field_sum[index])
+                            diff_read_field.center.at_index(index), diff_read_field_sum[index])
 
                 if assignment.lhs in backward_assignment_dict:
                     backward_assignment_dict[assignment.lhs].append(assignment.rhs)
-- 
GitLab