diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 258f34dc1a01c5ea3180e550f0c4896f5e231eff..4cf20e921779bd41e9b764594e93b2818d71411c 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -16,7 +16,6 @@ import sympy import pystencils import pystencils_reco.transforms -from pystencils.data_types import create_type from pystencils_reco import crazy from pystencils_reco._projective_matrix import ProjectiveMatrix from pystencils_reco.filters import gauss_filter @@ -42,7 +41,8 @@ def test_superresolution(): pyconrad.show_everything() -def test_torch_simple(): +@pytest.mark.parametrize('constant_h', ('constant_h', False)) +def test_torch_simple(constant_h): import pytest pytest.importorskip("torch") @@ -50,18 +50,21 @@ def test_torch_simple(): x, y = pystencils.fields('x,y: float32[2d]') + h = pystencils.fields('h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]') @crazy def move(x, y): - h = pystencils.fields('h(8): float32[2d]') - A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)], - [h.center(3), h.center(4), h.center(5)], - [h.center(6), h.center(7), 1]]) + A = sympy.Matrix([[h[0].center, h[1].center, h[2].center], + [h[3].center, h[4].center, h[5].center], + [h[6].center, h[7].center, 1]]) return { y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2)) } - kernel = move(x, y).create_pytorch_op() + if constant_h: + kernel = move(x, y).create_pytorch_op(constant_fields=h) + else: + kernel = move(x, y).create_pytorch_op() pystencils.autodiff.show_code(kernel.ast) x = torch.ones((10, 40)).cuda() @@ -83,11 +86,6 @@ def test_torch_simple(): def test_torch_matrix(): - - import pytest - pytest.importorskip("torch") - import torch - # x, y = torch.zeros((20, 20)), torch.zeros((20, 20)) x, y = pystencils.fields('x,y: float32[2d]') a = sympy.Symbol('a')