From d39a683175f1ee17af4f00f67afbe41e4a07b713 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 18 Nov 2019 16:25:00 +0100 Subject: [PATCH] Add backend property to TorchModule --- src/pystencils_autodiff/backends/astnodes.py | 4 ++++ tests/test_datahandling.py | 19 +++++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index fdedf52..35287b8 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -70,6 +70,10 @@ class TorchModule(JinjaCppFile): PYTHON_BINDINGS_CLASS = TorchPythonBindings PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping + @property + def backend(self): + return 'gpucuda' if self.is_cuda else 'c' + def __init__(self, module_name, kernel_asts): """Create a C++ module with forward and optional backward_kernels diff --git a/tests/test_datahandling.py b/tests/test_datahandling.py index b08cb8b..f552d99 100644 --- a/tests/test_datahandling.py +++ b/tests/test_datahandling.py @@ -5,13 +5,28 @@ """ """ +import pytest +import sympy + +import pystencils from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling +pystencils_reco = pytest.importorskip('pystencils_reco') + def test_datahandling(): dh = PyTorchDataHandling((20, 30)) - dh.add_array('foo') + dh.add_array('x') + dh.add_array('y') + dh.add_array('z') + a = sympy.Symbol('a') + + z, y, x = pystencils.fields("z, y, x: [20,40]") + forward_assignments = pystencils_reco.AssignmentCollection({ + z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0]) + }) + kernel = forward_assignments.create_pytorch_op() -test_datahandling() + dh.run_kernel(kernel, a=3) -- GitLab