diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py
index d3068c92d8e32d32fbeee0dd1f7510740fe7d535..5f1f073506a4293102b850ca9b3387c44a45e07b 100644
--- a/tests/test_superresolution.py
+++ b/tests/test_superresolution.py
@@ -16,7 +16,9 @@ import sympy
 
 import pystencils
 import pystencils_reco.transforms
+from pystencils.data_types import create_type
 from pystencils_reco import crazy
+from pystencils_reco._projective_matrix import ProjectiveMatrix
 from pystencils_reco.filters import gauss_filter
 from pystencils_reco.resampling import (
     downsample, resample, resample_to_shape, scale_transform, translate)
@@ -46,7 +48,37 @@ def test_torch_simple():
     pytest.importorskip("torch")
     import torch
 
-    x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
+    x, y = pystencils.fields('x,y: float32[2d]')
+
+    @crazy
+    def move(x, y):
+        h = pystencils.fields('h(8): float32[2d]')
+        A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)],
+                          [h.center(3), h.center(4), h.center(5)],
+                          [h.center(6), h.center(7), 1]])
+        return {
+            y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2))
+
+        }
+
+    kernel = move(x, y).create_pytorch_op()
+    pystencils.autodiff.show_code(kernel.ast)
+
+    x = torch.ones((10, 40)).cuda()
+    h = torch.ones((10, 40, 8)).cuda()
+
+    kernel().forward(h, x)
+    # kernel().forward(*([1]*9), x, y)
+
+
+def test_torch_matrix():
+
+    import pytest
+    pytest.importorskip("torch")
+    import torch
+
+    # x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
+    x, y = pystencils.fields('x,y: float32[2d]')
     a = sympy.Symbol('a')
 
     @crazy
@@ -55,9 +87,8 @@ def test_torch_simple():
             y.center: x.interpolated_access((pystencils.x_, pystencils.y_ + a))
         }
 
-    kernel = move(x, y, a).compile()
+    kernel = move(x, y, a).create_pytorch_op()
     pystencils.autodiff.show_code(kernel.ast)
-    kernel().forward(x, y, 3)
 
 
 def test_downsample():