From 0cd0283c6901f932d4f3196ea7c2e355d691188b Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 11 Sep 2020 15:21:26 +0200 Subject: [PATCH] Support symbolic index assignments --- src/pystencils_autodiff/_autodiff.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index a003cd5..d746679 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -122,7 +122,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): elif diff_read_field.index_dimensions == 1: - diff_read_field_sum = [0] * diff_read_field.index_shape[0] + diff_read_field_sum = {} for ra in read_field_accesses: if ra.field != forward_read_field: continue # ignore constant fields in differentiation @@ -130,19 +130,19 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): # TF-MAD requires flipped stencils inverted_offset = tuple(-v - w for v, w in zip(ra.offsets, write_field_accesses[0].offsets)) - diff_read_field_sum[ra.index[0] - ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset] + diff_read_field_sum[ra.index[0]] = diff_read_field_sum.get( + ra.index[0], 0) + sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset] - for index in range(diff_read_field.index_shape[0]): + for index in diff_read_field_sum.keys(): if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields: # Accumulate in case of time_constant_fields assignment = ps.Assignment( - diff_read_field.center_vector[index], - diff_read_field.center_vector[index] + diff_read_field_sum[index]) + diff_read_field.center.at_index(index), + diff_read_field.center.at_index(index) + diff_read_field_sum[index]) else: # If time dependent, we just need to assign the sum for the current time step assignment = ps.Assignment( - diff_read_field.center_vector[index], diff_read_field_sum[index]) + diff_read_field.center.at_index(index), diff_read_field_sum[index]) if assignment.lhs in backward_assignment_dict: backward_assignment_dict[assignment.lhs].append(assignment.rhs) -- GitLab