diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index 9c3d0ac461b288e24fdeda1983b1e5839abe1556..c1af6fc255a196881856c233629fcd46194c1892 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -26,6 +26,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
     """
 
     forward_assignments = self._forward_assignments
+
     if hasattr(forward_assignments, 'new_without_subexpressions'):
         forward_assignments = forward_assignments.new_without_subexpressions()
     if hasattr(forward_assignments, 'main_assignments'):
@@ -249,7 +250,6 @@ Backward:
             self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x)))
             self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x)))
             self._backward_assignments = backward_assignments
-            self._backward_field_map = None
             self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x)))
             self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x)))
         else:
@@ -261,22 +261,16 @@ Backward:
                 self._backward_field_map = None
                 backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix)
                 self._backward_assignments = backward_assignments
-                if self._backward_field_map:
-                    self._backward_input_fields = [
-                        self._backward_field_map[f] for f in self._forward_output_fields]
-                    self._backward_output_fields = [
-                        self._backward_field_map[f] for f in self._forward_input_fields]
-                else:
-                    self._forward_assignments = forward_assignments
-                    self._forward_read_accesses = None
-                    self._forward_write_accesses = None
-                    self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x)))
-                    self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x)))
-                    self._backward_assignments = backward_assignments
-                    self._backward_field_map = None
-                    self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x)))
-                    self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x)))
-                    self._backward_field_map = None
+
+                self._forward_assignments = forward_assignments
+                self._forward_read_accesses = None
+                self._forward_write_accesses = None
+                self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x)))
+                self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x)))
+                self._backward_assignments = backward_assignments
+                self._backward_field_map = None
+                self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x)))
+                self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x)))
             else:
                 raise NotImplementedError()
             # else:
@@ -528,10 +522,6 @@ Backward:
             self._backward_kernel_gpu = self.backward_ast_gpu.compile()
         return self._backward_kernel_gpu
 
-    @property
-    def backward_fields_map(self):
-        return self._backward_field_map
-
     @property
     def backward_input_fields(self):
         return self._backward_input_fields