diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index a879abb91111da927632b64840c9f8b854075774..dc602d7dc5b2b2597cd9882694d1e203b10d1616 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -16,6 +16,7 @@ import sympy import pystencils import pystencils_reco.transforms +from pystencils_reco import crazy from pystencils_reco.filters import gauss_filter from pystencils_reco.resampling import ( downsample, resample, resample_to_shape, scale_transform, translate) @@ -39,6 +40,25 @@ def test_superresolution(): pyconrad.show_everything() +def test_torch_simple(): + + import pytest + pytest.importorskip("torch") + import torch + + x, y = torch.zeros((20, 20)), torch.zeros((20, 20)) + a = sympy.Symbol('a') + + @crazy + def move(x, y, a): + return { + y.center: x.interpolated_access((pystencils.x_, pystencils.y_ + a)) + } + + kernel = move(x, y, a).compile() + pystencils.autodiff.show_code(kernel.ast) + + def test_downsample(): shape = (20, 10)