diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index 9c3d0ac461b288e24fdeda1983b1e5839abe1556..c1af6fc255a196881856c233629fcd46194c1892 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -26,6 +26,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): """ forward_assignments = self._forward_assignments + if hasattr(forward_assignments, 'new_without_subexpressions'): forward_assignments = forward_assignments.new_without_subexpressions() if hasattr(forward_assignments, 'main_assignments'): @@ -249,7 +250,6 @@ Backward: self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x))) self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x))) self._backward_assignments = backward_assignments - self._backward_field_map = None self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x))) self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x))) else: @@ -261,22 +261,16 @@ Backward: self._backward_field_map = None backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix) self._backward_assignments = backward_assignments - if self._backward_field_map: - self._backward_input_fields = [ - self._backward_field_map[f] for f in self._forward_output_fields] - self._backward_output_fields = [ - self._backward_field_map[f] for f in self._forward_input_fields] - else: - self._forward_assignments = forward_assignments - self._forward_read_accesses = None - self._forward_write_accesses = None - self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x))) - self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x))) - self._backward_assignments = backward_assignments - self._backward_field_map = None - self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x))) - self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x))) - self._backward_field_map = None + + self._forward_assignments = forward_assignments + self._forward_read_accesses = None + self._forward_write_accesses = None + self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x))) + self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x))) + self._backward_assignments = backward_assignments + self._backward_field_map = None + self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x))) + self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x))) else: raise NotImplementedError() # else: @@ -528,10 +522,6 @@ Backward: self._backward_kernel_gpu = self.backward_ast_gpu.compile() return self._backward_kernel_gpu - @property - def backward_fields_map(self): - return self._backward_field_map - @property def backward_input_fields(self): return self._backward_input_fields