From 8b7bdb9645645ccec4037b98832947da5ab65d02 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 2 Dec 2019 14:00:47 +0100
Subject: [PATCH] Make all fields == not time_constant_fields default

---
 src/pystencils_autodiff/_autodiff.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index a15afbd..f072e5e 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -85,7 +85,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
                                             w in zip(ra.offsets, forward_assignment.lhs.offsets))
                     diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
                         diff_write_field[inverted_offset](*diff_write_index)
-                if forward_read_field in self._time_constant_fields:
+                if forward_read_field in self._time_constant_fields and self.time_constant_fields is not None:
                     # Accumulate in case of time_constant_fields
                     assignment = ps.Assignment(
                         diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
@@ -202,7 +202,7 @@ Backward:
                  forward_assignments: List[ps.Assignment],
                  op_name: str = "autodiffop",
                  boundary_handling: AutoDiffBoundaryHandling = None,
-                 time_constant_fields: List[ps.Field] = [],
+                 time_constant_fields: List[ps.Field] = None,
                  constant_fields: List[ps.Field] = [],
                  diff_fields_prefix='diff',  # TODO: remove!
                  do_common_subexpression_elimination=True,
@@ -241,6 +241,7 @@ Backward:
         self._backward_kernel_gpu = None
         self._do_common_subexpression_elimination = do_common_subexpression_elimination
         self._boundary_handling = boundary_handling
+
         if backward_assignments:
             self._forward_assignments = forward_assignments
             self._forward_read_accesses = None
@@ -367,7 +368,7 @@ Backward:
             rhs = rhs[0, 0]
 
             # if field is constant over we time we can accumulate in assignment
-            if read_access.field in self._time_constant_fields:
+            if read_access.field in self._time_constant_fields and self.time_constant_fields is not None:
                 backward_assignments.append(ps.Assignment(lhs, lhs + rhs))
             else:
                 backward_assignments.append(ps.Assignment(lhs, rhs))
-- 
GitLab