diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index a15afbd01a194df7f5e2597b485b5ac205b0cd1c..f072e5e0acadbe7a5b4f4914a72fc478a9084f12 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -85,7 +85,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): w in zip(ra.offsets, forward_assignment.lhs.offsets)) diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ diff_write_field[inverted_offset](*diff_write_index) - if forward_read_field in self._time_constant_fields: + if forward_read_field in self._time_constant_fields and self.time_constant_fields is not None: # Accumulate in case of time_constant_fields assignment = ps.Assignment( diff_read_field.center(), diff_read_field.center() + diff_read_field_sum) @@ -202,7 +202,7 @@ Backward: forward_assignments: List[ps.Assignment], op_name: str = "autodiffop", boundary_handling: AutoDiffBoundaryHandling = None, - time_constant_fields: List[ps.Field] = [], + time_constant_fields: List[ps.Field] = None, constant_fields: List[ps.Field] = [], diff_fields_prefix='diff', # TODO: remove! do_common_subexpression_elimination=True, @@ -241,6 +241,7 @@ Backward: self._backward_kernel_gpu = None self._do_common_subexpression_elimination = do_common_subexpression_elimination self._boundary_handling = boundary_handling + if backward_assignments: self._forward_assignments = forward_assignments self._forward_read_accesses = None @@ -367,7 +368,7 @@ Backward: rhs = rhs[0, 0] # if field is constant over we time we can accumulate in assignment - if read_access.field in self._time_constant_fields: + if read_access.field in self._time_constant_fields and self.time_constant_fields is not None: backward_assignments.append(ps.Assignment(lhs, lhs + rhs)) else: backward_assignments.append(ps.Assignment(lhs, rhs))