diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py
index 258f34dc1a01c5ea3180e550f0c4896f5e231eff..4cf20e921779bd41e9b764594e93b2818d71411c 100644
--- a/tests/test_superresolution.py
+++ b/tests/test_superresolution.py
@@ -16,7 +16,6 @@ import sympy
 
 import pystencils
 import pystencils_reco.transforms
-from pystencils.data_types import create_type
 from pystencils_reco import crazy
 from pystencils_reco._projective_matrix import ProjectiveMatrix
 from pystencils_reco.filters import gauss_filter
@@ -42,7 +41,8 @@ def test_superresolution():
     pyconrad.show_everything()
 
 
-def test_torch_simple():
+@pytest.mark.parametrize('constant_h', ('constant_h', False))
+def test_torch_simple(constant_h):
 
     import pytest
     pytest.importorskip("torch")
@@ -50,18 +50,21 @@ def test_torch_simple():
 
     x, y = pystencils.fields('x,y: float32[2d]')
 
+    h = pystencils.fields('h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]')
     @crazy
     def move(x, y):
-        h = pystencils.fields('h(8): float32[2d]')
-        A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)],
-                          [h.center(3), h.center(4), h.center(5)],
-                          [h.center(6), h.center(7), 1]])
+        A = sympy.Matrix([[h[0].center, h[1].center, h[2].center],
+                          [h[3].center, h[4].center, h[5].center],
+                          [h[6].center, h[7].center, 1]])
         return {
             y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2))
 
         }
 
-    kernel = move(x, y).create_pytorch_op()
+    if constant_h:
+        kernel = move(x, y).create_pytorch_op(constant_fields=h)
+    else:
+        kernel = move(x, y).create_pytorch_op()
     pystencils.autodiff.show_code(kernel.ast)
 
     x = torch.ones((10, 40)).cuda()
@@ -83,11 +86,6 @@ def test_torch_simple():
 
 
 def test_torch_matrix():
-
-    import pytest
-    pytest.importorskip("torch")
-    import torch
-
     # x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
     x, y = pystencils.fields('x,y: float32[2d]')
     a = sympy.Symbol('a')