Skip to content
Snippets Groups Projects
Commit 397becd0 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add definition of RESTRICT to Torch CPU template

parent f02c6825
Branches
Tags
No related merge requests found
...@@ -7,6 +7,7 @@ using namespace pybind11::literals; ...@@ -7,6 +7,7 @@ using namespace pybind11::literals;
using scalar_t = {{ dtype }}; using scalar_t = {{ dtype }};
#define RESTRICT __restrict
std::vector<at::Tensor> {{ kernel_name }}_forward( std::vector<at::Tensor> {{ kernel_name }}_forward(
{%- for tensor in forward_tensors -%} {%- for tensor in forward_tensors -%}
......
...@@ -132,10 +132,9 @@ def test_torch_native_compilation(): ...@@ -132,10 +132,9 @@ def test_torch_native_compilation():
print(output) print(output)
def test_generate_torch(): def test_generate_torch_gpu():
x, y = pystencils.fields('x, y: float32[2d]') x, y = pystencils.fields('x, y: float32[2d]')
os.environ['CUDA_HOME'] = "/usr/local/cuda-10.0"
assignments = pystencils.AssignmentCollection({ assignments = pystencils.AssignmentCollection({
y.center(): x.center()**2 y.center(): x.center()**2
}, {}) }, {})
...@@ -143,6 +142,16 @@ def test_generate_torch(): ...@@ -143,6 +142,16 @@ def test_generate_torch():
op_cuda = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=True, dtype=np.float32) op_cuda = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=True, dtype=np.float32)
assert op_cuda is not None assert op_cuda is not None
def test_generate_torch_cpu():
x, y = pystencils.fields('x, y: float32[2d]')
assignments = pystencils.AssignmentCollection({
y.center(): x.center()**2
}, {})
autodiff = pystencils_autodiff.AutoDiffOp(assignments)
op_cpp = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=False, dtype=np.float32) op_cpp = generate_torch(appdirs.user_cache_dir('pystencils'), autodiff, is_cuda=False, dtype=np.float32)
assert op_cpp is not None assert op_cpp is not None
...@@ -165,7 +174,6 @@ def test_execute_torch(): ...@@ -165,7 +174,6 @@ def test_execute_torch():
@pytest.mark.skipif('NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests') @pytest.mark.skipif('NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
def test_execute_torch_gpu(): def test_execute_torch_gpu():
x, y = pystencils.fields('x, y: float64[32,32]') x, y = pystencils.fields('x, y: float64[32,32]')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment