Skip to content
Snippets Groups Projects
Commit 8b7bdb96 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make all fields == not time_constant_fields default

parent 1e141dfb
No related branches found
No related tags found
No related merge requests found
Pipeline #20179 failed
...@@ -85,7 +85,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -85,7 +85,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
w in zip(ra.offsets, forward_assignment.lhs.offsets)) w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index) diff_write_field[inverted_offset](*diff_write_index)
if forward_read_field in self._time_constant_fields: if forward_read_field in self._time_constant_fields and self.time_constant_fields is not None:
# Accumulate in case of time_constant_fields # Accumulate in case of time_constant_fields
assignment = ps.Assignment( assignment = ps.Assignment(
diff_read_field.center(), diff_read_field.center() + diff_read_field_sum) diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
...@@ -202,7 +202,7 @@ Backward: ...@@ -202,7 +202,7 @@ Backward:
forward_assignments: List[ps.Assignment], forward_assignments: List[ps.Assignment],
op_name: str = "autodiffop", op_name: str = "autodiffop",
boundary_handling: AutoDiffBoundaryHandling = None, boundary_handling: AutoDiffBoundaryHandling = None,
time_constant_fields: List[ps.Field] = [], time_constant_fields: List[ps.Field] = None,
constant_fields: List[ps.Field] = [], constant_fields: List[ps.Field] = [],
diff_fields_prefix='diff', # TODO: remove! diff_fields_prefix='diff', # TODO: remove!
do_common_subexpression_elimination=True, do_common_subexpression_elimination=True,
...@@ -241,6 +241,7 @@ Backward: ...@@ -241,6 +241,7 @@ Backward:
self._backward_kernel_gpu = None self._backward_kernel_gpu = None
self._do_common_subexpression_elimination = do_common_subexpression_elimination self._do_common_subexpression_elimination = do_common_subexpression_elimination
self._boundary_handling = boundary_handling self._boundary_handling = boundary_handling
if backward_assignments: if backward_assignments:
self._forward_assignments = forward_assignments self._forward_assignments = forward_assignments
self._forward_read_accesses = None self._forward_read_accesses = None
...@@ -367,7 +368,7 @@ Backward: ...@@ -367,7 +368,7 @@ Backward:
rhs = rhs[0, 0] rhs = rhs[0, 0]
# if field is constant over we time we can accumulate in assignment # if field is constant over we time we can accumulate in assignment
if read_access.field in self._time_constant_fields: if read_access.field in self._time_constant_fields and self.time_constant_fields is not None:
backward_assignments.append(ps.Assignment(lhs, lhs + rhs)) backward_assignments.append(ps.Assignment(lhs, lhs + rhs))
else: else:
backward_assignments.append(ps.Assignment(lhs, rhs)) backward_assignments.append(ps.Assignment(lhs, rhs))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment