diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index fdedf522ba4e89b2f6152618f936862063d89242..35287b873b0cac49cc677b31ab2d8a0ac9f89e4f 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -70,6 +70,10 @@ class TorchModule(JinjaCppFile): PYTHON_BINDINGS_CLASS = TorchPythonBindings PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping + @property + def backend(self): + return 'gpucuda' if self.is_cuda else 'c' + def __init__(self, module_name, kernel_asts): """Create a C++ module with forward and optional backward_kernels diff --git a/tests/test_datahandling.py b/tests/test_datahandling.py index b08cb8b76b6af6e627ae8a82356bcbba67214891..f552d99818a0a086dbbb20b6c7845e25b376a639 100644 --- a/tests/test_datahandling.py +++ b/tests/test_datahandling.py @@ -5,13 +5,28 @@ """ """ +import pytest +import sympy + +import pystencils from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling +pystencils_reco = pytest.importorskip('pystencils_reco') + def test_datahandling(): dh = PyTorchDataHandling((20, 30)) - dh.add_array('foo') + dh.add_array('x') + dh.add_array('y') + dh.add_array('z') + a = sympy.Symbol('a') + + z, y, x = pystencils.fields("z, y, x: [20,40]") + forward_assignments = pystencils_reco.AssignmentCollection({ + z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0]) + }) + kernel = forward_assignments.create_pytorch_op() -test_datahandling() + dh.run_kernel(kernel, a=3)