From 1a1d6d59d4d0551033e76f4edc49c795208fb71f Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 17 Sep 2019 17:48:40 +0200
Subject: [PATCH] Add test_execute_torch

---
 .../backends/test_torch_native_compilation.py | 60 ++++++++++++++++++-
 1 file changed, 57 insertions(+), 3 deletions(-)

diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index 0d867b7..64b30a7 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -8,6 +8,7 @@ import subprocess
 import tempfile
 from os.path import dirname, isfile, join
 
+import numpy as np
 import pytest
 import sympy
 
@@ -148,6 +149,59 @@ def test_torch_native_compilation_gpu():
     assert 'call_backward' in dir(torch_extension)
 
 
-@pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
-def test_execute_torch_gpu():
-    pass
+@pytest.mark.parametrize('target', ('gpu', 'cpu'))
+def test_execute_torch(target):
+    module_name = "Ololol" + target
+
+    z, y, x = pystencils.fields("z, y, x: [20,40]")
+    a = sympy.Symbol('a')
+
+    forward_assignments = pystencils.AssignmentCollection({
+        z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
+    })
+
+    # backward_assignments = create_backward_assignments(forward_assignments)
+
+    if target == 'cpu':
+        x = np.random.rand(20, 40)
+        y = np.random.rand(20, 40)
+        z = np.zeros((20, 40))
+    else:
+        gpuarray = pytest.importorskip('pycuda.gpuarray')
+        x = gpuarray.to_gpu(np.random.rand(20, 40))
+        y = gpuarray.to_gpu(np.random.rand(20, 40))
+        z = gpuarray.zeros((20, 40), np.float64)
+
+    kernel = pystencils.create_kernel(forward_assignments, target=target)
+    kernel.function_name = 'forward'
+
+    torch_module = TorchModule(module_name, [kernel]).compile()
+    pystencils_module = kernel.compile()
+
+    pystencils_module(x=x, y=y, z=z, a=5.)
+    if target == 'gpu':
+        x = x.get()
+        y = y.get()
+        z = z.get()
+
+    z_pystencils = np.copy(z)
+    import torch
+    x = torch.Tensor(x)
+    y = torch.Tensor(y)
+    z = torch.Tensor(z)
+    if target == 'gpu':
+        x = x.double().cuda()
+        y = y.double().cuda()
+        z = z.double().cuda()
+    else:
+        x = x.double()
+        y = y.double()
+        z = z.double()
+
+    torch_module.call_forward(x=x, y=y, z=z, a=5.)
+    if target == 'gpu':
+        z = z.cpu()
+
+    z_torch = np.copy(z)
+
+    assert np.allclose(z_torch[1:-1, 1:-1], z_pystencils[1:-1, 1:-1], atol=1e-6)
-- 
GitLab