diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index 00bad34ae1c07517efe5512b8a1d1abf3739dbf6..181201c215964fb3b6cde974473a1d494a5d9eb9 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -39,9 +39,9 @@ def test_jit(): assert isfile(cpp_file) assert isfile(cuda_file) - from torch.utils.cpp_extension import load + from torch.utils.cpp_extension import CUDAExtension - lltm_cuda = load( + lltm_cuda = CUDAExtension( join(dirname(__file__), 'lltm_cuda'), [cpp_file, cuda_file], verbose=True, extra_cuda_cflags=[]) assert lltm_cuda is not None print('hallo') @@ -60,7 +60,7 @@ def test_torch_native_compilation(): print(backward_assignments) template_string = read_file(join(dirname(__file__), - '../../pystencils/autodiff/backends/torch_native_cuda.tmpl.cpp')) + '../../src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp')) template = jinja2.Template(template_string) print(template_string) @@ -87,27 +87,31 @@ def test_torch_native_compilation(): print(output) template_string = read_file(join(dirname(__file__), - '../../pystencils/autodiff/backends/torch_native_cuda.tmpl.cu')) + '../../src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu')) template = jinja2.Template(template_string) print(template_string) output = template.render( - forward_tensors=[f.name for f in autodiff.forward_fields], - forward_input_tensors=[f.name for f in autodiff.forward_input_fields], - forward_output_tensors=[f.name for f in autodiff.forward_output_fields], - backward_tensors=[f.name for f in autodiff.backward_fields + autodiff.forward_input_fields], - backward_input_tensors=[f.name for f in autodiff.backward_input_fields], - backward_output_tensors=[f.name for f in autodiff.backward_output_fields], + forward_tensors=[f for f in autodiff.forward_fields], + forward_input_tensors=[f for f in autodiff.forward_input_fields], + forward_output_tensors=[f for f in autodiff.forward_output_fields], + backward_tensors=[f for f in autodiff.backward_fields + autodiff.forward_input_fields], + backward_input_tensors=[f for f in autodiff.backward_input_fields], + backward_output_tensors=[f for f in autodiff.backward_output_fields], forward_kernel=forward_code, backward_kernel=backward_code, + backward_blocks=str({1, 1, 1}), + backward_threads=str({1, 1, 1}), + forward_blocks=str({1, 1, 1}), + forward_threads=str({1, 1, 1}), kernel_name="square", dimensions=range(2) ) print(output) template_string = read_file(join(dirname(__file__), - '../../pystencils/autodiff/backends/torch_native_cpu.tmpl.cpp')) + '../../src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp')) template = jinja2.Template(template_string) print(template_string) @@ -181,10 +185,10 @@ def test_execute_torch_gpu(): def main(): - # test_jit() - # test_torch_native_compilation() + test_jit() + test_torch_native_compilation() # test_generate_torch() - test_execute_torch() + # test_execute_torch() main()