Skip to content
Snippets Groups Projects
Commit 803bbff8 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add TorchModule.compile

parent 352ffff9
No related branches found
No related tags found
No related merge requests found
Pipeline #18000 passed
......@@ -8,16 +8,19 @@
"""
import os
from collections.abc import Iterable
from os.path import dirname, join
from os.path import dirname, exists, join
import pystencils
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff._file_io import read_template_from_file
from pystencils_autodiff._file_io import read_template_from_file, write_file
from pystencils_autodiff.backends.python_bindings import (
PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
TensorflowPythonBindings, TorchPythonBindings)
from pystencils_autodiff.framework_integration.astnodes import (
DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
from pystencils_autodiff.tensorflow_jit import _hash
class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
......@@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile):
if not isinstance(kernel_asts, Iterable):
kernel_asts = [kernel_asts]
wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
self.module_name = module_name
ast_dict = {
'kernels': kernel_asts,
......@@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile):
return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name)
def compile(self):
from torch.utils.cpp_extension import load
file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self)
hash = _hash(source_code.encode()).hexdigest()
try:
os.mkdir(join(pystencils.cache.cache_dir, hash))
except Exception:
pass
file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}')
if not exists(file_name):
write_file(file_name, source_code)
torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
return torch_extension
class TensorflowModule(TorchModule):
DESTRUCTURING_CLASS = TensorflowTensorDestructuring
......
......@@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu():
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
torch_extension = module.compile()
assert torch_extension is not None
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
def test_torch_native_compilation_gpu():
......@@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu():
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
torch_extension = module.compile()
assert torch_extension is not None
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
@pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
def test_execute_torch_gpu():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment