diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py
index a003cd5eaed998eca92e6b58da696df5807a12cd..d746679726c2a4b0b5f31634ea5c2c17b134e8f0 100644
--- a/src/pystencils_autodiff/_autodiff.py
+++ b/src/pystencils_autodiff/_autodiff.py
@@ -122,7 +122,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
 
             elif diff_read_field.index_dimensions == 1:
 
-                diff_read_field_sum = [0] * diff_read_field.index_shape[0]
+                diff_read_field_sum = {}
                 for ra in read_field_accesses:
                     if ra.field != forward_read_field:
                         continue  # ignore constant fields in differentiation
@@ -130,19 +130,19 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
                     # TF-MAD requires flipped stencils
                     inverted_offset = tuple(-v - w for v,
                                             w in zip(ra.offsets, write_field_accesses[0].offsets))
-                    diff_read_field_sum[ra.index[0]
-                                        ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
+                    diff_read_field_sum[ra.index[0]] = diff_read_field_sum.get(
+                        ra.index[0], 0) + sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
 
-                for index in range(diff_read_field.index_shape[0]):
+                for index in diff_read_field_sum.keys():
                     if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields:
                         # Accumulate in case of time_constant_fields
                         assignment = ps.Assignment(
-                            diff_read_field.center_vector[index],
-                            diff_read_field.center_vector[index] + diff_read_field_sum[index])
+                            diff_read_field.center.at_index(index),
+                            diff_read_field.center.at_index(index) + diff_read_field_sum[index])
                     else:
                         # If time dependent, we just need to assign the sum for the current time step
                         assignment = ps.Assignment(
-                            diff_read_field.center_vector[index], diff_read_field_sum[index])
+                            diff_read_field.center.at_index(index), diff_read_field_sum[index])
 
                 if assignment.lhs in backward_assignment_dict:
                     backward_assignment_dict[assignment.lhs].append(assignment.rhs)