Skip to content
Snippets Groups Projects
Commit 0cd0283c authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Support symbolic index assignments

parent e0a3abbd
No related branches found
No related tags found
No related merge requests found
...@@ -122,7 +122,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -122,7 +122,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
elif diff_read_field.index_dimensions == 1: elif diff_read_field.index_dimensions == 1:
diff_read_field_sum = [0] * diff_read_field.index_shape[0] diff_read_field_sum = {}
for ra in read_field_accesses: for ra in read_field_accesses:
if ra.field != forward_read_field: if ra.field != forward_read_field:
continue # ignore constant fields in differentiation continue # ignore constant fields in differentiation
...@@ -130,19 +130,19 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -130,19 +130,19 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
# TF-MAD requires flipped stencils # TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v, inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, write_field_accesses[0].offsets)) w in zip(ra.offsets, write_field_accesses[0].offsets))
diff_read_field_sum[ra.index[0] diff_read_field_sum[ra.index[0]] = diff_read_field_sum.get(
] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset] ra.index[0], 0) + sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]): for index in diff_read_field_sum.keys():
if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields: if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields # Accumulate in case of time_constant_fields
assignment = ps.Assignment( assignment = ps.Assignment(
diff_read_field.center_vector[index], diff_read_field.center.at_index(index),
diff_read_field.center_vector[index] + diff_read_field_sum[index]) diff_read_field.center.at_index(index) + diff_read_field_sum[index])
else: else:
# If time dependent, we just need to assign the sum for the current time step # If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment( assignment = ps.Assignment(
diff_read_field.center_vector[index], diff_read_field_sum[index]) diff_read_field.center.at_index(index), diff_read_field_sum[index])
if assignment.lhs in backward_assignment_dict: if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs) backward_assignment_dict[assignment.lhs].append(assignment.rhs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment