From c7f0ad481e2da8a5a48b3dff8df89ebd28851d1d Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 2 Mar 2020 00:20:44 +0100 Subject: [PATCH] Extend test_homography to support also constant H --- tests/test_superresolution.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py index 258f34d..4cf20e9 100644 --- a/tests/test_superresolution.py +++ b/tests/test_superresolution.py @@ -16,7 +16,6 @@ import sympy import pystencils import pystencils_reco.transforms -from pystencils.data_types import create_type from pystencils_reco import crazy from pystencils_reco._projective_matrix import ProjectiveMatrix from pystencils_reco.filters import gauss_filter @@ -42,7 +41,8 @@ def test_superresolution(): pyconrad.show_everything() -def test_torch_simple(): +@pytest.mark.parametrize('constant_h', ('constant_h', False)) +def test_torch_simple(constant_h): import pytest pytest.importorskip("torch") @@ -50,18 +50,21 @@ def test_torch_simple(): x, y = pystencils.fields('x,y: float32[2d]') + h = pystencils.fields('h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]') @crazy def move(x, y): - h = pystencils.fields('h(8): float32[2d]') - A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)], - [h.center(3), h.center(4), h.center(5)], - [h.center(6), h.center(7), 1]]) + A = sympy.Matrix([[h[0].center, h[1].center, h[2].center], + [h[3].center, h[4].center, h[5].center], + [h[6].center, h[7].center, 1]]) return { y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2)) } - kernel = move(x, y).create_pytorch_op() + if constant_h: + kernel = move(x, y).create_pytorch_op(constant_fields=h) + else: + kernel = move(x, y).create_pytorch_op() pystencils.autodiff.show_code(kernel.ast) x = torch.ones((10, 40)).cuda() @@ -83,11 +86,6 @@ def test_torch_simple(): def test_torch_matrix(): - - import pytest - pytest.importorskip("torch") - import torch - # x, y = torch.zeros((20, 20)), torch.zeros((20, 20)) x, y = pystencils.fields('x,y: float32[2d]') a = sympy.Symbol('a') -- GitLab