diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index 716178571e3e161427a2795798830de64a2f4c5b..29677281fe9d9ff0a642445edd35244056fb5ba0 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -9,9 +9,13 @@ import sympy as sp
 import pystencils as ps
 import pystencils.cache
 import pystencils_autodiff._layout_fixer
+from pystencils.interpolation_astnodes import InterpolatorAccess
+from pystencils.math_optimizations import ReplaceOptim, optimize_assignments
 from pystencils_autodiff.backends import AVAILABLE_BACKENDS
 from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling
 
+REMOVE_CASTS = ReplaceOptim(lambda x: isinstance(x, pystencils.data_types.cast_func), lambda x: x.args[0])
+
 
 @pystencils.cache.disk_cache_no_fallback
 def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
@@ -35,25 +39,28 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
     if not hasattr(forward_assignments, 'free_symbols'):
         forward_assignments = ps.AssignmentCollection(
             forward_assignments, [])
+    forward_assignments = ps.AssignmentCollection(optimize_assignments(forward_assignments, [
+        REMOVE_CASTS
+    ]))
 
     read_field_accesses = sorted([
-        a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
+        a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)]
+        + list(forward_assignments.atoms(InterpolatorAccess)), key=str)
     write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
     read_fields = {s.field for s in read_field_accesses}
     write_fields = {s.field for s in write_field_accesses}
 
     self._forward_read_accesses = read_field_accesses
     self._forward_write_accesses = write_field_accesses
-    self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
-    self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
+    self._forward_input_fields = sorted(list(read_fields), key=str)
+    self._forward_output_fields = sorted(list(write_fields), key=str)
 
-    read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
     write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
     assert write_field_accesses, "No write accesses found"
 
     # for every field create a corresponding diff field
     diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
-                        for f in read_fields}
+                        for f in read_fields if f not in self._constant_fields}
     diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
                          for f in write_fields}
 
@@ -81,11 +88,23 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
                     if ra.field != forward_read_field:
                         continue  # ignore constant fields in differentiation
 
-                    # TF-MAD requires flipped stencils
-                    inverted_offset = tuple(-v - w for v,
-                                            w in zip(ra.offsets, forward_assignment.lhs.offsets))
-                    diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
-                        diff_write_field[inverted_offset](*diff_write_index)
+                    if isinstance(ra, InterpolatorAccess):
+                        out_symbols = sp.symbols(f'o:{len(ra.offsets)}')
+                        in_symbols = pystencils.x_vector(len(ra.offsets))
+                        maybe_solution = sp.solve(sp.Matrix(ra.offsets) - sp.Matrix(out_symbols), in_symbols)
+                        assert maybe_solution, f'Could not solve for {in_symbols} when trying to derive for interpolator'  # noqa
+                        inverted_offset = tuple(maybe_solution.values())
+
+                        inverted_offset = [foo.subs({o: s for o, s in zip(out_symbols, in_symbols)})
+                                           for foo in inverted_offset]
+                        diff_read_field_sum += (sp.diff(forward_assignment.rhs, ra) *
+                                                diff_write_field.interpolated_access(inverted_offset))
+                    else:
+                        # TF-MAD requires flipped stencils
+                        inverted_offset = tuple(-v - w for v,
+                                                w in zip(ra.offsets, forward_assignment.lhs.offsets))
+                        diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
+                            diff_write_field[inverted_offset](*diff_write_index)
                 if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields:
                     # Accumulate in case of time_constant_fields
                     assignment = ps.Assignment(
@@ -111,8 +130,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
                     # TF-MAD requires flipped stencils
                     inverted_offset = tuple(-v - w for v,
                                             w in zip(ra.offsets, write_field_accesses[0].offsets))
-                    diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
-                        diff_write_field[inverted_offset]
+                    diff_read_field_sum[ra.index[0]
+                                        ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
 
                 for index in range(diff_read_field.index_shape[0]):
                     if forward_read_field in self._time_constant_fields:
@@ -135,7 +154,6 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
     backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
 
     try:
-
         if self._do_common_subexpression_elimination:
             backward_assignments = ps.simp.sympy_cse_on_assignment_list(
                 backward_assignments)
@@ -282,7 +300,7 @@ Backward:
                 # do_common_subexpression_elimination=do_common_subexpression_elimination)
 
     def __hash__(self):
-        return hash((str(self.forward_assignments), str(self.backward_assignments)))
+        return hash((str(self.forward_assignments), str(self.backward_assignments), str(self.constant_fields)))
 
     def __repr__(self):
         return self._REPR_TEMPLATE.render(forward_assignments=str(self.forward_assignments),
@@ -774,3 +792,6 @@ def get_jacobian_of_assignments(assignments, diff_variables):
 
     rhs = sp.Matrix([e.rhs for e in assignments])
     return rhs.jacobian(diff_variables)
+
+    rhs = sp.Matrix([e.rhs for e in assignments])
+    return rhs.jacobian(diff_variables)