Skip to content
Snippets Groups Projects
Commit 02c99c28 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix test in test_torch_native_compilation

parent c8801997
No related branches found
No related tags found
No related merge requests found
......@@ -39,9 +39,9 @@ def test_jit():
assert isfile(cpp_file)
assert isfile(cuda_file)
from torch.utils.cpp_extension import load
from torch.utils.cpp_extension import CUDAExtension
lltm_cuda = load(
lltm_cuda = CUDAExtension(
join(dirname(__file__), 'lltm_cuda'), [cpp_file, cuda_file], verbose=True, extra_cuda_cflags=[])
assert lltm_cuda is not None
print('hallo')
......@@ -60,7 +60,7 @@ def test_torch_native_compilation():
print(backward_assignments)
template_string = read_file(join(dirname(__file__),
'../../pystencils/autodiff/backends/torch_native_cuda.tmpl.cpp'))
'../../src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp'))
template = jinja2.Template(template_string)
print(template_string)
......@@ -87,27 +87,31 @@ def test_torch_native_compilation():
print(output)
template_string = read_file(join(dirname(__file__),
'../../pystencils/autodiff/backends/torch_native_cuda.tmpl.cu'))
'../../src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu'))
template = jinja2.Template(template_string)
print(template_string)
output = template.render(
forward_tensors=[f.name for f in autodiff.forward_fields],
forward_input_tensors=[f.name for f in autodiff.forward_input_fields],
forward_output_tensors=[f.name for f in autodiff.forward_output_fields],
backward_tensors=[f.name for f in autodiff.backward_fields + autodiff.forward_input_fields],
backward_input_tensors=[f.name for f in autodiff.backward_input_fields],
backward_output_tensors=[f.name for f in autodiff.backward_output_fields],
forward_tensors=[f for f in autodiff.forward_fields],
forward_input_tensors=[f for f in autodiff.forward_input_fields],
forward_output_tensors=[f for f in autodiff.forward_output_fields],
backward_tensors=[f for f in autodiff.backward_fields + autodiff.forward_input_fields],
backward_input_tensors=[f for f in autodiff.backward_input_fields],
backward_output_tensors=[f for f in autodiff.backward_output_fields],
forward_kernel=forward_code,
backward_kernel=backward_code,
backward_blocks=str({1, 1, 1}),
backward_threads=str({1, 1, 1}),
forward_blocks=str({1, 1, 1}),
forward_threads=str({1, 1, 1}),
kernel_name="square",
dimensions=range(2)
)
print(output)
template_string = read_file(join(dirname(__file__),
'../../pystencils/autodiff/backends/torch_native_cpu.tmpl.cpp'))
'../../src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp'))
template = jinja2.Template(template_string)
print(template_string)
......@@ -181,10 +185,10 @@ def test_execute_torch_gpu():
def main():
# test_jit()
# test_torch_native_compilation()
test_jit()
test_torch_native_compilation()
# test_generate_torch()
test_execute_torch()
# test_execute_torch()
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment