diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index 2e4ea6c160e4e135b3273bbb0200d0348bb83b14..a15afbd01a194df7f5e2597b485b5ac205b0cd1c 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -7,11 +7,151 @@ import numpy as np
 import sympy as sp
 
 import pystencils as ps
+import pystencils.cache
 import pystencils_autodiff._layout_fixer
 from pystencils_autodiff.backends import AVAILABLE_BACKENDS
 from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling
 
 
+@pystencils.cache.disk_cache_no_fallback
+def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
+    """
+    Performs the automatic backward differentiation in a more fancy way with write accesses
+    like in the forward pass (only flipped).
+    It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD).
+
+    See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that
+    paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true
+    for more information
+    """
+
+    forward_assignments = self._forward_assignments
+    if hasattr(forward_assignments, 'new_without_subexpressions'):
+        forward_assignments = forward_assignments.new_without_subexpressions()
+    if hasattr(forward_assignments, 'main_assignments'):
+        forward_assignments = forward_assignments.main_assignments
+
+    if not hasattr(forward_assignments, 'free_symbols'):
+        forward_assignments = ps.AssignmentCollection(
+            forward_assignments, [])
+
+    read_field_accesses = sorted([
+        a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
+    write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
+    read_fields = {s.field for s in read_field_accesses}
+    write_fields = {s.field for s in write_field_accesses}
+
+    self._forward_read_accesses = read_field_accesses
+    self._forward_write_accesses = write_field_accesses
+    self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
+    self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
+
+    read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
+    write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
+    assert write_field_accesses, "No write accesses found"
+
+    # for every field create a corresponding diff field
+    diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
+                        for f in read_fields}
+    diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
+                         for f in write_fields}
+
+    assert all(isinstance(w, ps.Field.Access)
+               for w in write_field_accesses), \
+        "Please check if your assignments are a AssignmentCollection or main_assignments only"
+
+    backward_assignment_dict = collections.OrderedDict()
+    # for each output of forward operation
+    for _, forward_assignment in enumerate(forward_assignments.main_assignments):
+        # we have only one assignment
+        diff_write_field = diff_write_fields[forward_assignment.lhs.field]
+        diff_write_index = forward_assignment.lhs.index
+
+        # TODO: simplify implementation. use matrix notation like in 'transposed' mode
+        for forward_read_field in self._forward_input_fields:
+            if forward_read_field in self._constant_fields:
+                continue
+            diff_read_field = diff_read_fields[forward_read_field]
+
+            if diff_read_field.index_dimensions == 0:
+
+                diff_read_field_sum = 0
+                for ra in read_field_accesses:
+                    if ra.field != forward_read_field:
+                        continue  # ignore constant fields in differentiation
+
+                    # TF-MAD requires flipped stencils
+                    inverted_offset = tuple(-v - w for v,
+                                            w in zip(ra.offsets, forward_assignment.lhs.offsets))
+                    diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
+                        diff_write_field[inverted_offset](*diff_write_index)
+                if forward_read_field in self._time_constant_fields:
+                    # Accumulate in case of time_constant_fields
+                    assignment = ps.Assignment(
+                        diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
+                else:
+                    # If time dependent, we just need to assign the sum for the current time step
+                    assignment = ps.Assignment(
+                        diff_read_field.center(), diff_read_field_sum)
+
+                # We can have contributions from multiple forward assignments
+                if assignment.lhs in backward_assignment_dict:
+                    backward_assignment_dict[assignment.lhs].append(assignment.rhs)
+                else:
+                    backward_assignment_dict[assignment.lhs] = [assignment.rhs]
+
+            elif diff_read_field.index_dimensions == 1:
+
+                diff_read_field_sum = [0] * diff_read_field.index_shape[0]
+                for ra in read_field_accesses:
+                    if ra.field != forward_read_field:
+                        continue  # ignore constant fields in differentiation
+
+                    # TF-MAD requires flipped stencils
+                    inverted_offset = tuple(-v - w for v,
+                                            w in zip(ra.offsets, write_field_accesses[0].offsets))
+                    diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
+                        diff_write_field[inverted_offset]
+
+                for index in range(diff_read_field.index_shape[0]):
+                    if forward_read_field in self._time_constant_fields:
+                        # Accumulate in case of time_constant_fields
+                        assignment = ps.Assignment(
+                            diff_read_field.center_vector[index],
+                            diff_read_field.center_vector[index] + diff_read_field_sum[index])
+                    else:
+                        # If time dependent, we just need to assign the sum for the current time step
+                        assignment = ps.Assignment(
+                            diff_read_field.center_vector[index], diff_read_field_sum[index])
+
+                if assignment.lhs in backward_assignment_dict:
+                    backward_assignment_dict[assignment.lhs].append(assignment.rhs)
+                else:
+                    backward_assignment_dict[assignment.lhs] = [assignment.rhs]
+            else:
+                raise NotImplementedError()
+
+    backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
+
+    try:
+
+        if self._do_common_subexpression_elimination:
+            backward_assignments = ps.simp.sympy_cse_on_assignment_list(
+                backward_assignments)
+    except Exception:
+        pass
+        # print("Common subexpression elimination failed")
+        # print(err)
+    main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)]
+    subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)]
+    backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions)
+
+    assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!"
+    self._backward_field_map = {**diff_read_fields, **diff_write_fields}
+
+    return backward_assignments
+
+
 class AutoDiffBoundaryHandling(str, Enum):
     """
     Strategies for in-kernel boundary handling for AutoDiffOp
@@ -112,12 +252,39 @@ Backward:
             self._backward_input_fields = list(backward_assignments.free_fields)
             self._backward_output_fields = list(backward_assignments.bound_fields)
         else:
+            # if no_caching:
             if diff_mode == 'transposed':
                 self._create_backward_assignments(diff_fields_prefix)
             elif diff_mode == 'transposed-forward':
-                self._create_backward_assignments_tf_mad(diff_fields_prefix)
+                self._backward_assignments = None
+                self._backward_field_map = None
+                backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix)
+                self._backward_assignments = backward_assignments
+                if self._backward_field_map:
+                    self._backward_input_fields = [
+                        self._backward_field_map[f] for f in self._forward_output_fields]
+                    self._backward_output_fields = [
+                        self._backward_field_map[f] for f in self._forward_input_fields]
+                else:
+                    self._forward_assignments = forward_assignments
+                    self._forward_read_accesses = None
+                    self._forward_write_accesses = None
+                    self._forward_input_fields = list(forward_assignments.free_fields)
+                    self._forward_output_fields = list(forward_assignments.bound_fields)
+                    self._backward_assignments = backward_assignments
+                    self._backward_field_map = None
+                    self._backward_input_fields = list(backward_assignments.free_fields)
+                    self._backward_output_fields = list(backward_assignments.bound_fields)
+                    self._backward_field_map = None
             else:
                 raise NotImplementedError()
+            # else:
+            #   # self.backward_assignments = create_backward_assignments(forward_assignments,
+            #                                                           # diff_fields_prefix,
+            #                                                           # time_constant_fields,
+            #                                                           # constant_fields,
+            #                                                           # diff_mode=diff_mode,
+                # do_common_subexpression_elimination=do_common_subexpression_elimination)
 
     def __hash__(self):
         return hash((str(self.forward_assignments), str(self.backward_assignments)))
@@ -129,6 +296,23 @@ Backward:
     def __str__(self):
         return self.__repr__()
 
+    def __setstate__(self, state):
+        forward_assignments = state['forward_assignments']
+        backward_assignments = state['backward_assignments']
+        self._forward_assignments = forward_assignments
+        self._forward_read_accesses = None
+        self._forward_write_accesses = None
+        self._forward_input_fields = list(forward_assignments.free_fields)
+        self._forward_output_fields = list(forward_assignments.bound_fields)
+        self._backward_assignments = backward_assignments
+        self._backward_field_map = None
+        self._backward_input_fields = list(backward_assignments.free_fields)
+        self._backward_output_fields = list(backward_assignments.bound_fields)
+        self._backward_field_map = None
+
+    def __getstate__(self):
+        return {'forward_assignments': self._forward_assignments, 'backward_assignments': self.backward_assignments}
+
     def _create_backward_assignments(self, diff_fields_prefix):
         """
         Performs automatic differentiation in the traditional adjoint/tangent way.
@@ -214,147 +398,6 @@ Backward:
         self._backward_output_fields = [
             self._backward_field_map[f] for f in self._forward_input_fields]
 
-    def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
-        """
-        Performs the automatic backward differentiation in a more fancy way with write accesses
-        like in the forward pass (only flipped).
-        It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD).
-
-        See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that
-        paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true
-        for more information
-        """
-
-        forward_assignments = self._forward_assignments
-        if hasattr(forward_assignments, 'new_without_subexpressions'):
-            forward_assignments = forward_assignments.new_without_subexpressions()
-        if hasattr(forward_assignments, 'main_assignments'):
-            forward_assignments = forward_assignments.main_assignments
-
-        if not hasattr(forward_assignments, 'free_symbols'):
-            forward_assignments = ps.AssignmentCollection(
-                forward_assignments, [])
-
-        read_field_accesses = sorted([
-            a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
-        write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
-        read_fields = {s.field for s in read_field_accesses}
-        write_fields = {s.field for s in write_field_accesses}
-
-        self._forward_read_accesses = read_field_accesses
-        self._forward_write_accesses = write_field_accesses
-        self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
-        self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
-
-        read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
-        write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
-        assert write_field_accesses, "No write accesses found"
-
-        # for every field create a corresponding diff field
-        diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
-                            for f in read_fields}
-        diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
-                             for f in write_fields}
-
-        assert all(isinstance(w, ps.Field.Access)
-                   for w in write_field_accesses), \
-            "Please check if your assignments are a AssignmentCollection or main_assignments only"
-
-        backward_assignment_dict = collections.OrderedDict()
-        # for each output of forward operation
-        for _, forward_assignment in enumerate(forward_assignments.main_assignments):
-            # we have only one assignment
-            diff_write_field = diff_write_fields[forward_assignment.lhs.field]
-            diff_write_index = forward_assignment.lhs.index
-
-            # TODO: simplify implementation. use matrix notation like in 'transposed' mode
-            for forward_read_field in self._forward_input_fields:
-                if forward_read_field in self._constant_fields:
-                    continue
-                diff_read_field = diff_read_fields[forward_read_field]
-
-                if diff_read_field.index_dimensions == 0:
-
-                    diff_read_field_sum = 0
-                    for ra in read_field_accesses:
-                        if ra.field != forward_read_field:
-                            continue  # ignore constant fields in differentiation
-
-                        # TF-MAD requires flipped stencils
-                        inverted_offset = tuple(-v - w for v,
-                                                w in zip(ra.offsets, forward_assignment.lhs.offsets))
-                        diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
-                            diff_write_field[inverted_offset](*diff_write_index)
-                    if forward_read_field in self._time_constant_fields:
-                        # Accumulate in case of time_constant_fields
-                        assignment = ps.Assignment(
-                            diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
-                    else:
-                        # If time dependent, we just need to assign the sum for the current time step
-                        assignment = ps.Assignment(
-                            diff_read_field.center(), diff_read_field_sum)
-
-                    # We can have contributions from multiple forward assignments
-                    if assignment.lhs in backward_assignment_dict:
-                        backward_assignment_dict[assignment.lhs].append(assignment.rhs)
-                    else:
-                        backward_assignment_dict[assignment.lhs] = [assignment.rhs]
-
-                elif diff_read_field.index_dimensions == 1:
-
-                    diff_read_field_sum = [0] * diff_read_field.index_shape[0]
-                    for ra in read_field_accesses:
-                        if ra.field != forward_read_field:
-                            continue  # ignore constant fields in differentiation
-
-                        # TF-MAD requires flipped stencils
-                        inverted_offset = tuple(-v - w for v,
-                                                w in zip(ra.offsets, write_field_accesses[0].offsets))
-                        diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
-                            diff_write_field[inverted_offset]
-
-                    for index in range(diff_read_field.index_shape[0]):
-                        if forward_read_field in self._time_constant_fields:
-                            # Accumulate in case of time_constant_fields
-                            assignment = ps.Assignment(
-                                diff_read_field.center_vector[index],
-                                diff_read_field.center_vector[index] + diff_read_field_sum[index])
-                        else:
-                            # If time dependent, we just need to assign the sum for the current time step
-                            assignment = ps.Assignment(
-                                diff_read_field.center_vector[index], diff_read_field_sum[index])
-
-                    if assignment.lhs in backward_assignment_dict:
-                        backward_assignment_dict[assignment.lhs].append(assignment.rhs)
-                    else:
-                        backward_assignment_dict[assignment.lhs] = [assignment.rhs]
-                else:
-                    raise NotImplementedError()
-
-        backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
-
-        try:
-
-            if self._do_common_subexpression_elimination:
-                backward_assignments = ps.simp.sympy_cse_on_assignment_list(
-                    backward_assignments)
-        except Exception:
-            pass
-            # print("Common subexpression elimination failed")
-            # print(err)
-        main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)]
-        subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)]
-        backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions)
-
-        assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!"
-
-        self._backward_assignments = backward_assignments
-        self._backward_field_map = {**diff_read_fields, **diff_write_fields}
-        self._backward_input_fields = [
-            self._backward_field_map[f] for f in self._forward_output_fields]
-        self._backward_output_fields = [
-            self._backward_field_map[f] for f in self._forward_input_fields]
-
     @property
     def forward_assignments(self):
         return self._forward_assignments
@@ -629,6 +672,7 @@ Backward:
             return op
 
 
+@pystencils.cache.disk_cache_no_fallback
 def create_backward_assignments(forward_assignments,
                                 diff_fields_prefix="diff",
                                 time_constant_fields=[],
@@ -640,7 +684,9 @@ def create_backward_assignments(forward_assignments,
                            time_constant_fields=time_constant_fields,
                            constant_fields=constant_fields,
                            diff_mode=diff_mode,
-                           do_common_subexpression_elimination=do_common_sub_expression_elimination)
+                           do_common_subexpression_elimination=do_common_sub_expression_elimination,
+                           no_chaching=True
+                           )
     backward_assignments = auto_diff.backward_assignments
 
     return backward_assignments