From 397becd085184c7680ebe5c8f8b4b6e0e99cd7be Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 7 Aug 2019 14:55:03 +0200 Subject: [PATCH] Add definition of RESTRICT to Torch CPU template --- .../backends/torch_native_cpu.tmpl.cpp | 1 + tests/backends/test_torch_native_compilation.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp index 0f02323..d7ee774 100644 --- a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp +++ b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp @@ -7,6 +7,7 @@ using namespace pybind11::literals; using scalar_t = {{ dtype }}; +#define RESTRICT __restrict std::vector<at::Tensor> {{ kernel_name }}_forward( {%- for tensor in forward_tensors -%} diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index bf72109..f23fd21 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -132,10 +132,9 @@ def test_torch_native_compilation(): print(output) -def test_generate_torch(): +def test_generate_torch_gpu(): x, y = pystencils.fields('x, y: float32[2d]') - os.environ['CUDA_HOME'] = "/usr/local/cuda-10.0" assignments = pystencils.AssignmentCollection({ y.center(): x.center()**2 }, {}) @@ -143,6 +142,16 @@ def test_generate_torch(): op_cuda = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=True, dtype=np.float32) assert op_cuda is not None + + +def test_generate_torch_cpu(): + x, y = pystencils.fields('x, y: float32[2d]') + + assignments = pystencils.AssignmentCollection({ + y.center(): x.center()**2 + }, {}) + autodiff = pystencils_autodiff.AutoDiffOp(assignments) + op_cpp = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=False, dtype=np.float32) assert op_cpp is not None @@ -165,7 +174,6 @@ def test_execute_torch(): @pytest.mark.skipif('NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests') - def test_execute_torch_gpu(): x, y = pystencils.fields('x, y: float64[32,32]') -- GitLab