From 98eb75fe0cdd90d97cac3b62f8890c508363d1ee Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 26 Oct 2020 17:05:15 +0100 Subject: [PATCH] Automatically don't differentiate for indexVector --- src/pystencils_autodiff/_autodiff.py | 8 +++++--- src/pystencils_autodiff/backends/python_bindings.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index d746679..b182280 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -60,7 +60,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): # for every field create a corresponding diff field diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) - for f in read_fields if f not in self._constant_fields} + for f in read_fields if (f not in self._constant_fields + and f.name not in self._constant_fields)} diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in write_fields} @@ -77,7 +78,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): # TODO: simplify implementation. use matrix notation like in 'transposed' mode for forward_read_field in self._forward_input_fields: - if forward_read_field in self._constant_fields: + if forward_read_field in self._constant_fields or forward_read_field.name in self._constant_fields: continue diff_read_field = diff_read_fields[forward_read_field] @@ -247,6 +248,7 @@ Backward: self._forward_assignments = forward_assignments self._backward_assignments = None self._constant_fields = constant_fields + self._constant_fields += ['indexVector'] self._time_constant_fields = time_constant_fields self._kwargs = kwargs self.op_name = op_name @@ -393,7 +395,7 @@ Backward: backward_assignments = [] for lhs, read_access in zip(diff_read_field_accesses, read_field_accesses): # don't differentiate for constant fields - if read_access.field in self._constant_fields: + if read_access.field in self._constant_fields or read_access.field.name in self._constant_fields: continue rhs = sp.Matrix(sp.Matrix([e.rhs for e in forward_assignments])).diff( diff --git a/src/pystencils_autodiff/backends/python_bindings.py b/src/pystencils_autodiff/backends/python_bindings.py index e2f56f8..a459b5e 100644 --- a/src/pystencils_autodiff/backends/python_bindings.py +++ b/src/pystencils_autodiff/backends/python_bindings.py @@ -171,5 +171,5 @@ class PybindFunctionWrapping(JinjaCppFile): super().__init__({'python_name': function_node.function_name, 'cpp_name': function_node.function_name, 'parameters': [p.symbol.name for p in function_node.get_parameters() - if hasattr(p.symbol, 'dtype') and not 'meshFunctor' in p.symbol.name] + if hasattr(p.symbol, 'dtype') and 'meshFunctor' not in p.symbol.name] }) -- GitLab