From 397becd085184c7680ebe5c8f8b4b6e0e99cd7be Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 7 Aug 2019 14:55:03 +0200
Subject: [PATCH] Add definition of RESTRICT to Torch CPU template

---
 .../backends/torch_native_cpu.tmpl.cpp             |  1 +
 tests/backends/test_torch_native_compilation.py    | 14 +++++++++++---
 2 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
index 0f02323..d7ee774 100644
--- a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
+++ b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
@@ -7,6 +7,7 @@ using namespace pybind11::literals;
 using scalar_t = {{ dtype }};
 
 
+#define RESTRICT __restrict
 
 std::vector<at::Tensor> {{ kernel_name }}_forward(
 {%- for tensor in forward_tensors -%}
diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index bf72109..f23fd21 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -132,10 +132,9 @@ def test_torch_native_compilation():
     print(output)
 
 
-def test_generate_torch():
+def test_generate_torch_gpu():
     x, y = pystencils.fields('x, y: float32[2d]')
 
-    os.environ['CUDA_HOME'] = "/usr/local/cuda-10.0"
     assignments = pystencils.AssignmentCollection({
         y.center(): x.center()**2
     }, {})
@@ -143,6 +142,16 @@ def test_generate_torch():
 
     op_cuda = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=True, dtype=np.float32)
     assert op_cuda is not None
+
+
+def test_generate_torch_cpu():
+    x, y = pystencils.fields('x, y: float32[2d]')
+
+    assignments = pystencils.AssignmentCollection({
+        y.center(): x.center()**2
+    }, {})
+    autodiff = pystencils_autodiff.AutoDiffOp(assignments)
+
     op_cpp = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=False, dtype=np.float32)
     assert op_cpp is not None
 
@@ -165,7 +174,6 @@ def test_execute_torch():
 
 
 @pytest.mark.skipif('NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
-
 def test_execute_torch_gpu():
     x, y = pystencils.fields('x, y: float64[32,32]')
 
-- 
GitLab