From 02c99c28f23966f373f7963e4cf531ca9429ddad Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 6 Aug 2019 19:44:52 +0200 Subject: [PATCH] Fix test in test_torch_native_compilation --- .../backends/test_torch_native_compilation.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index 00bad34..181201c 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -39,9 +39,9 @@ def test_jit(): assert isfile(cpp_file) assert isfile(cuda_file) - from torch.utils.cpp_extension import load + from torch.utils.cpp_extension import CUDAExtension - lltm_cuda = load( + lltm_cuda = CUDAExtension( join(dirname(__file__), 'lltm_cuda'), [cpp_file, cuda_file], verbose=True, extra_cuda_cflags=[]) assert lltm_cuda is not None print('hallo') @@ -60,7 +60,7 @@ def test_torch_native_compilation(): print(backward_assignments) template_string = read_file(join(dirname(__file__), - '../../pystencils/autodiff/backends/torch_native_cuda.tmpl.cpp')) + '../../src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp')) template = jinja2.Template(template_string) print(template_string) @@ -87,27 +87,31 @@ def test_torch_native_compilation(): print(output) template_string = read_file(join(dirname(__file__), - '../../pystencils/autodiff/backends/torch_native_cuda.tmpl.cu')) + '../../src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu')) template = jinja2.Template(template_string) print(template_string) output = template.render( - forward_tensors=[f.name for f in autodiff.forward_fields], - forward_input_tensors=[f.name for f in autodiff.forward_input_fields], - forward_output_tensors=[f.name for f in autodiff.forward_output_fields], - backward_tensors=[f.name for f in autodiff.backward_fields + autodiff.forward_input_fields], - backward_input_tensors=[f.name for f in autodiff.backward_input_fields], - backward_output_tensors=[f.name for f in autodiff.backward_output_fields], + forward_tensors=[f for f in autodiff.forward_fields], + forward_input_tensors=[f for f in autodiff.forward_input_fields], + forward_output_tensors=[f for f in autodiff.forward_output_fields], + backward_tensors=[f for f in autodiff.backward_fields + autodiff.forward_input_fields], + backward_input_tensors=[f for f in autodiff.backward_input_fields], + backward_output_tensors=[f for f in autodiff.backward_output_fields], forward_kernel=forward_code, backward_kernel=backward_code, + backward_blocks=str({1, 1, 1}), + backward_threads=str({1, 1, 1}), + forward_blocks=str({1, 1, 1}), + forward_threads=str({1, 1, 1}), kernel_name="square", dimensions=range(2) ) print(output) template_string = read_file(join(dirname(__file__), - '../../pystencils/autodiff/backends/torch_native_cpu.tmpl.cpp')) + '../../src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp')) template = jinja2.Template(template_string) print(template_string) @@ -181,10 +185,10 @@ def test_execute_torch_gpu(): def main(): - # test_jit() - # test_torch_native_compilation() + test_jit() + test_torch_native_compilation() # test_generate_torch() - test_execute_torch() + # test_execute_torch() main() -- GitLab