From c7f0ad481e2da8a5a48b3dff8df89ebd28851d1d Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 2 Mar 2020 00:20:44 +0100
Subject: [PATCH] Extend test_homography to support also constant H

---
 tests/test_superresolution.py | 22 ++++++++++------------
 1 file changed, 10 insertions(+), 12 deletions(-)

diff --git a/tests/test_superresolution.py b/tests/test_superresolution.py
index 258f34d..4cf20e9 100644
--- a/tests/test_superresolution.py
+++ b/tests/test_superresolution.py
@@ -16,7 +16,6 @@ import sympy
 
 import pystencils
 import pystencils_reco.transforms
-from pystencils.data_types import create_type
 from pystencils_reco import crazy
 from pystencils_reco._projective_matrix import ProjectiveMatrix
 from pystencils_reco.filters import gauss_filter
@@ -42,7 +41,8 @@ def test_superresolution():
     pyconrad.show_everything()
 
 
-def test_torch_simple():
+@pytest.mark.parametrize('constant_h', ('constant_h', False))
+def test_torch_simple(constant_h):
 
     import pytest
     pytest.importorskip("torch")
@@ -50,18 +50,21 @@ def test_torch_simple():
 
     x, y = pystencils.fields('x,y: float32[2d]')
 
+    h = pystencils.fields('h0,h1,h2,h3,h4,h5,h6,h7: float32[2d]')
     @crazy
     def move(x, y):
-        h = pystencils.fields('h(8): float32[2d]')
-        A = sympy.Matrix([[h.center(0), h.center(1), h.center(2)],
-                          [h.center(3), h.center(4), h.center(5)],
-                          [h.center(6), h.center(7), 1]])
+        A = sympy.Matrix([[h[0].center, h[1].center, h[2].center],
+                          [h[3].center, h[4].center, h[5].center],
+                          [h[6].center, h[7].center, 1]])
         return {
             y.center: x.interpolated_access(ProjectiveMatrix(A) @ pystencils.x_vector(2))
 
         }
 
-    kernel = move(x, y).create_pytorch_op()
+    if constant_h:
+        kernel = move(x, y).create_pytorch_op(constant_fields=h)
+    else:
+        kernel = move(x, y).create_pytorch_op()
     pystencils.autodiff.show_code(kernel.ast)
 
     x = torch.ones((10, 40)).cuda()
@@ -83,11 +86,6 @@ def test_torch_simple():
 
 
 def test_torch_matrix():
-
-    import pytest
-    pytest.importorskip("torch")
-    import torch
-
     # x, y = torch.zeros((20, 20)), torch.zeros((20, 20))
     x, y = pystencils.fields('x,y: float32[2d]')
     a = sympy.Symbol('a')
-- 
GitLab