diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index d746679726c2a4b0b5f31634ea5c2c17b134e8f0..b182280ae67dc82a2b7aeebd5991d1ad1804631f 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -60,7 +60,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): # for every field create a corresponding diff field diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) - for f in read_fields if f not in self._constant_fields} + for f in read_fields if (f not in self._constant_fields + and f.name not in self._constant_fields)} diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in write_fields} @@ -77,7 +78,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): # TODO: simplify implementation. use matrix notation like in 'transposed' mode for forward_read_field in self._forward_input_fields: - if forward_read_field in self._constant_fields: + if forward_read_field in self._constant_fields or forward_read_field.name in self._constant_fields: continue diff_read_field = diff_read_fields[forward_read_field] @@ -247,6 +248,7 @@ Backward: self._forward_assignments = forward_assignments self._backward_assignments = None self._constant_fields = constant_fields + self._constant_fields += ['indexVector'] self._time_constant_fields = time_constant_fields self._kwargs = kwargs self.op_name = op_name @@ -393,7 +395,7 @@ Backward: backward_assignments = [] for lhs, read_access in zip(diff_read_field_accesses, read_field_accesses): # don't differentiate for constant fields - if read_access.field in self._constant_fields: + if read_access.field in self._constant_fields or read_access.field.name in self._constant_fields: continue rhs = sp.Matrix(sp.Matrix([e.rhs for e in forward_assignments])).diff( diff --git a/src/pystencils_autodiff/backends/python_bindings.py b/src/pystencils_autodiff/backends/python_bindings.py index e2f56f8209e9026e8c19cec21f9114548f7930d9..a459b5eae5939248a0ffe1a133756ac5558a162f 100644 --- a/src/pystencils_autodiff/backends/python_bindings.py +++ b/src/pystencils_autodiff/backends/python_bindings.py @@ -171,5 +171,5 @@ class PybindFunctionWrapping(JinjaCppFile): super().__init__({'python_name': function_node.function_name, 'cpp_name': function_node.function_name, 'parameters': [p.symbol.name for p in function_node.get_parameters() - if hasattr(p.symbol, 'dtype') and not 'meshFunctor' in p.symbol.name] + if hasattr(p.symbol, 'dtype') and 'meshFunctor' not in p.symbol.name] })