diff --git a/src/pystencils_reco/resampling.py b/src/pystencils_reco/resampling.py
index 5ac543f367f44bba403997d620f950be62cde486..be067291a2973c5fe53411c386172f9d7eb308f2 100644
--- a/src/pystencils_reco/resampling.py
+++ b/src/pystencils_reco/resampling.py
@@ -13,7 +13,9 @@ from collections.abc import Iterable
 import sympy
 
 import pystencils
+from pystencils.autodiff import AdjointField
 from pystencils_reco import AssignmentCollection, crazy
+import types
 
 
 @crazy
@@ -23,13 +25,28 @@ def generic_spatial_matrix_transform(input_field, output_field, transform_matrix
     if inverse_matrix is None:
         inverse_matrix = transform_matrix.inv()
 
+    # output_coordinate = input_field.coordinate_transform.inv() @ (
+        # inverse_matrix @ output_field.physical_coordinates_staggered) - input_field.coordinate_origin
+    output_coordinate = input_field.physical_to_index(
+        inverse_matrix @ output_field.physical_coordinates_staggered, staggered=False)
+
     assignments = AssignmentCollection({
         output_field.center():
-        texture.at(input_field.coordinate_transform.inv() @
-                   (inverse_matrix @ output_field.physical_coordinates_staggered) - input_field.coordinate_origin)
+        texture.at(output_coordinate)
     })
-    assignments.transform_matrix = transform_matrix
 
+    def create_autodiff(self, constant_fields=None):
+        assignments.transform_matrix = transform_matrix
+
+        texture = pystencils.astnodes.TextureCachedField(AdjointField(output_field))
+        output_coordinate = output_field.physical_to_index(
+            transform_matrix @ input_field.physical_coordinates_staggered, staggered=True)
+        backward_assignments = AssignmentCollection({
+            AdjointField(input_field).center(): texture.at(output_coordinate)
+        })
+        self._autodiff = pystencils.autodiff.AutoDiffOp(assignments, "", backward_assignments=backward_assignments)
+
+    assignments._create_autodiff = types.MethodType(create_autodiff, assignments)
     return assignments
 
 
@@ -54,7 +71,7 @@ def scale_transform(input_field, output_field, scaling_factor):
 def rotation_transform(input_field, output_field, rotation_angle, rotation_axis=None):
     if input_field.spatial_dimensions == 3:
         assert rotation_axis is not None, "You must specify a rotation_axis for 3d rotations!"
-        transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis+1))(rotation_angle)
+        transform_matrix = getattr(sympy, 'rot_axis%i' % (rotation_axis + 1))(rotation_angle)
     elif input_field.spatial_dimensions == 2:
         # 2d rotation is 3d rotation around 3rd axis
         transform_matrix = sympy.rot_axis3(rotation_angle)[:2, :2]
diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py
index 983468173d57d920788fa74fb63942e45632cb73..8c1e376125c97e824cf52588948fabc1187b771b 100644
--- a/tests/test_pytorch.py
+++ b/tests/test_pytorch.py
@@ -8,6 +8,7 @@
 
 """
 import pystencils
+import pystencils_reco.resampling
 from pystencils.autodiff import torch_tensor_from_field
 from pystencils_reco.filters import mean_filter
 from pystencils_reco.stencils import BallStencil
@@ -54,9 +55,22 @@ def test_pytorch_from_tensors():
     print(torch_op)
 
 
+def test_texture():
+
+    x, y = pystencils.fields('x,y: float32[100,100]')
+    assignments = pystencils_reco.resampling.scale_transform(x, y, 2)
+
+    x_tensor = torch_tensor_from_field(x, requires_grad=True, cuda=True)
+    y_tensor = torch_tensor_from_field(y, cuda=True)
+    kernel = assignments.create_pytorch_op(x=x_tensor, y=y_tensor)
+    print(assignments)
+    print(kernel)
+
+
 def main():
     # test_pytorch()
-    test_pytorch_from_tensors()
+    # test_pytorch_from_tensors()
+    test_texture()
 
 
 if __name__ == '__main__':
diff --git a/tests/test_resampling.py b/tests/test_resampling.py
index e1fc0f9c5196ef61fb0bab374da4c64a9812feb5..28cb0f656aab48d3d57fcb32588fea3e5e1f3116 100644
--- a/tests/test_resampling.py
+++ b/tests/test_resampling.py
@@ -21,7 +21,7 @@ from pystencils_reco.resampling import rotation_transform, scale_transform
 def test_scaling():
 
     for ndim in range(1, 5):
-        for scale in (0.5, [(s+1)*0.1 for s in range(ndim)]):
+        for scale in (0.5, [(s + 1) * 0.1 for s in range(ndim)]):
             x, y = pystencils.fields('x,y: float32[%id]' % ndim)
             transform = scale_transform(x, y, scale)
             print(transform)
@@ -40,7 +40,7 @@ def test_rotation():
 def test_scaling_compilation():
 
     for ndim in range(1, 4):
-        for scale in (0.5, [(s+1)*0.1 for s in range(ndim)]):
+        for scale in (0.5, [(s + 1) * 0.1 for s in range(ndim)]):
             x, y = pystencils.fields('x,y: float32[%id]' % ndim)
             scale_transform(x, y, scale).compile('gpu')
 
@@ -63,7 +63,7 @@ def test_scaling_visualize():
     s = pystencils.data_types.TypedSymbol('s', 'float32')
     transform = scale_transform(x, y, s).compile('gpu')
 
-    test_image = 1-skimage.io.imread(join(dirname(__file__), "test_data",  "test_vessel2d_mask.png"), as_gray=True)
+    test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
     test_image = np.ascontiguousarray(test_image, np.float32)
     test_image = to_gpu(test_image)
     tmp = zeros_like(test_image)
@@ -82,11 +82,10 @@ def test_rotation_visualize():
     s = pystencils.data_types.TypedSymbol('s', 'float32')
     transform = rotation_transform(x, y, s).compile('gpu')
 
-    test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data",  "test_vessel2d_mask.png"), as_gray=True)
+    test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
     test_image = np.ascontiguousarray(test_image, np.float32)
     test_image = to_gpu(test_image)
     tmp = zeros_like(test_image)
-    print(transform.code)
 
     for s in (0.2, 0.5, 0.7, 1, 2):
         transform(x=test_image, y=tmp, s=s)
@@ -98,7 +97,7 @@ def test_rotation_around_center_visualize():
     import pyconrad.autoinit
     from pycuda.gpuarray import to_gpu, zeros_like
 
-    test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data",  "test_vessel2d_mask.png"), as_gray=True)
+    test_image = 1 - skimage.io.imread(join(dirname(__file__), "test_data", "test_vessel2d_mask.png"), as_gray=True)
     test_image = np.ascontiguousarray(test_image, np.float32)
     test_image = to_gpu(test_image)
 
@@ -109,7 +108,6 @@ def test_rotation_around_center_visualize():
     print(x.coordinate_origin)
     s = pystencils.data_types.TypedSymbol('s', 'float32')
     transform = rotation_transform(x, y, s).compile('gpu')
-    print(transform.code)
 
     for s in (0, 0.2, 0.5, 0.7, 1, 2):
         transform(x=test_image, y=tmp, s=s)