Skip to content
Snippets Groups Projects
Commit 803bbff8 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add TorchModule.compile

parent 352ffff9
Branches
Tags
No related merge requests found
Pipeline #18000 passed
......@@ -8,16 +8,19 @@
"""
import os
from collections.abc import Iterable
from os.path import dirname, join
from os.path import dirname, exists, join
import pystencils
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff._file_io import read_template_from_file
from pystencils_autodiff._file_io import read_template_from_file, write_file
from pystencils_autodiff.backends.python_bindings import (
PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
TensorflowPythonBindings, TorchPythonBindings)
from pystencils_autodiff.framework_integration.astnodes import (
DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
from pystencils_autodiff.tensorflow_jit import _hash
class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
......@@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile):
if not isinstance(kernel_asts, Iterable):
kernel_asts = [kernel_asts]
wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
self.module_name = module_name
ast_dict = {
'kernels': kernel_asts,
......@@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile):
return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name)
def compile(self):
from torch.utils.cpp_extension import load
file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self)
hash = _hash(source_code.encode()).hexdigest()
try:
os.mkdir(join(pystencils.cache.cache_dir, hash))
except Exception:
pass
file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}')
if not exists(file_name):
write_file(file_name, source_code)
torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
return torch_extension
class TensorflowModule(TorchModule):
DESTRUCTURING_CLASS = TensorflowTensorDestructuring
......
......@@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu():
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
torch_extension = module.compile()
assert torch_extension is not None
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
def test_torch_native_compilation_gpu():
......@@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu():
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
torch_extension = module.compile()
assert torch_extension is not None
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
@pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
def test_execute_torch_gpu():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment