From d39a683175f1ee17af4f00f67afbe41e4a07b713 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 18 Nov 2019 16:25:00 +0100
Subject: [PATCH] Add backend property to TorchModule

---
 src/pystencils_autodiff/backends/astnodes.py |  4 ++++
 tests/test_datahandling.py                   | 19 +++++++++++++++++--
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index fdedf52..35287b8 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -70,6 +70,10 @@ class TorchModule(JinjaCppFile):
     PYTHON_BINDINGS_CLASS = TorchPythonBindings
     PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping
 
+    @property
+    def backend(self):
+        return 'gpucuda' if self.is_cuda else 'c'
+
     def __init__(self, module_name, kernel_asts):
         """Create a C++ module with forward and optional backward_kernels
 
diff --git a/tests/test_datahandling.py b/tests/test_datahandling.py
index b08cb8b..f552d99 100644
--- a/tests/test_datahandling.py
+++ b/tests/test_datahandling.py
@@ -5,13 +5,28 @@
 """
 
 """
+import pytest
+import sympy
+
+import pystencils
 from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling
 
+pystencils_reco = pytest.importorskip('pystencils_reco')
+
 
 def test_datahandling():
     dh = PyTorchDataHandling((20, 30))
 
-    dh.add_array('foo')
+    dh.add_array('x')
+    dh.add_array('y')
+    dh.add_array('z')
+    a = sympy.Symbol('a')
+
+    z, y, x = pystencils.fields("z, y, x: [20,40]")
+    forward_assignments = pystencils_reco.AssignmentCollection({
+        z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
+    })
 
+    kernel = forward_assignments.create_pytorch_op()
 
-test_datahandling()
+    dh.run_kernel(kernel, a=3)
-- 
GitLab