Skip to content
Snippets Groups Projects
Commit d39a6831 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add backend property to TorchModule

parent b55c70a2
No related branches found
No related tags found
No related merge requests found
Pipeline #19873 failed
......@@ -70,6 +70,10 @@ class TorchModule(JinjaCppFile):
PYTHON_BINDINGS_CLASS = TorchPythonBindings
PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping
@property
def backend(self):
return 'gpucuda' if self.is_cuda else 'c'
def __init__(self, module_name, kernel_asts):
"""Create a C++ module with forward and optional backward_kernels
......
......@@ -5,13 +5,28 @@
"""
"""
import pytest
import sympy
import pystencils
from pystencils_autodiff.framework_integration.datahandling import PyTorchDataHandling
pystencils_reco = pytest.importorskip('pystencils_reco')
def test_datahandling():
dh = PyTorchDataHandling((20, 30))
dh.add_array('foo')
dh.add_array('x')
dh.add_array('y')
dh.add_array('z')
a = sympy.Symbol('a')
z, y, x = pystencils.fields("z, y, x: [20,40]")
forward_assignments = pystencils_reco.AssignmentCollection({
z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
})
kernel = forward_assignments.create_pytorch_op()
test_datahandling()
dh.run_kernel(kernel, a=3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment