diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index d3068c92d8e32d32fbeee0dd1f7510740fe7d535..5f1f073506a4293102b850ca9b3387c44a45e07b 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -16,7 +16,9 @@ import sympy import pystencils import pystencils_reco.transforms +from pystencils.data_types import create_type from pystencils_reco import crazy +from pystencils_reco._projective_matrix import ProjectiveMatrix from pystencils_reco.filters import gauss_filter from pystencils_reco.resampling import ( downsample, resample, resample_to_shape, scale_transform, translate) @@ -46,7 +48,37 @@ def test_torch_simple(): pytest.importorskip("torch") import torch - x, y = torch.zeros((20, 20)), torch.zeros((20, 20)) + x, y = pystencils.fields('x,y: float32[2d]') + + @crazy + def move(x, y): + h = pystencils.fields('h(8): float32[2d]') + A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)], + [h.center(3), h.center(4), h.center(5)], + [h.center(6), h.center(7), 1]]) + return { + y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2)) + + } + + kernel = move(x, y).create_pytorch_op() + pystencils.autodiff.show_code(kernel.ast) + + x = torch.ones((10, 40)).cuda() + h = torch.ones((10, 40, 8)).cuda() + + kernel().forward(h, x) + # kernel().forward(*([1]*9), x, y) + + +def test_torch_matrix(): + + import pytest + pytest.importorskip("torch") + import torch + + # x, y = torch.zeros((20, 20)), torch.zeros((20, 20)) + x, y = pystencils.fields('x,y: float32[2d]') a = sympy.Symbol('a') @crazy @@ -55,9 +87,8 @@ def test_torch_simple(): y.center: x.interpolated_access((pystencils.x_, pystencils.y_ + a)) } - kernel = move(x, y, a).compile() + kernel = move(x, y, a).create_pytorch_op() pystencils.autodiff.show_code(kernel.ast) - kernel().forward(x, y, 3) def test_downsample():