From 8b7bdb9645645ccec4037b98832947da5ab65d02 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 2 Dec 2019 14:00:47 +0100 Subject: [PATCH] Make all fields == not time_constant_fields default --- src/pystencils_autodiff/_autodiff.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index a15afbd..f072e5e 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -85,7 +85,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): w in zip(ra.offsets, forward_assignment.lhs.offsets)) diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ diff_write_field[inverted_offset](*diff_write_index) - if forward_read_field in self._time_constant_fields: + if forward_read_field in self._time_constant_fields and self.time_constant_fields is not None: # Accumulate in case of time_constant_fields assignment = ps.Assignment( diff_read_field.center(), diff_read_field.center() + diff_read_field_sum) @@ -202,7 +202,7 @@ Backward: forward_assignments: List[ps.Assignment], op_name: str = "autodiffop", boundary_handling: AutoDiffBoundaryHandling = None, - time_constant_fields: List[ps.Field] = [], + time_constant_fields: List[ps.Field] = None, constant_fields: List[ps.Field] = [], diff_fields_prefix='diff', # TODO: remove! do_common_subexpression_elimination=True, @@ -241,6 +241,7 @@ Backward: self._backward_kernel_gpu = None self._do_common_subexpression_elimination = do_common_subexpression_elimination self._boundary_handling = boundary_handling + if backward_assignments: self._forward_assignments = forward_assignments self._forward_read_accesses = None @@ -367,7 +368,7 @@ Backward: rhs = rhs[0, 0] # if field is constant over we time we can accumulate in assignment - if read_access.field in self._time_constant_fields: + if read_access.field in self._time_constant_fields and self.time_constant_fields is not None: backward_assignments.append(ps.Assignment(lhs, lhs + rhs)) else: backward_assignments.append(ps.Assignment(lhs, rhs)) -- GitLab