Skip to content
Snippets Groups Projects
Commit 803bbff8 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add TorchModule.compile

parent 352ffff9
No related branches found
No related tags found
No related merge requests found
Pipeline #18000 passed
...@@ -8,16 +8,19 @@ ...@@ -8,16 +8,19 @@
""" """
import os
from collections.abc import Iterable from collections.abc import Iterable
from os.path import dirname, join from os.path import dirname, exists, join
import pystencils
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff._file_io import read_template_from_file from pystencils_autodiff._file_io import read_template_from_file, write_file
from pystencils_autodiff.backends.python_bindings import ( from pystencils_autodiff.backends.python_bindings import (
PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping, PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
TensorflowPythonBindings, TorchPythonBindings) TensorflowPythonBindings, TorchPythonBindings)
from pystencils_autodiff.framework_integration.astnodes import ( from pystencils_autodiff.framework_integration.astnodes import (
DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call) DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
from pystencils_autodiff.tensorflow_jit import _hash
class TorchTensorDestructuring(DestructuringBindingsForFieldClass): class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
...@@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile): ...@@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile):
if not isinstance(kernel_asts, Iterable): if not isinstance(kernel_asts, Iterable):
kernel_asts = [kernel_asts] kernel_asts = [kernel_asts]
wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts] wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
self.module_name = module_name
ast_dict = { ast_dict = {
'kernels': kernel_asts, 'kernels': kernel_asts,
...@@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile): ...@@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile):
return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name) function_name='call_' + kernel_ast.function_name)
def compile(self):
from torch.utils.cpp_extension import load
file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self)
hash = _hash(source_code.encode()).hexdigest()
try:
os.mkdir(join(pystencils.cache.cache_dir, hash))
except Exception:
pass
file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}')
if not exists(file_name):
write_file(file_name, source_code)
torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
return torch_extension
class TensorflowModule(TorchModule): class TensorflowModule(TorchModule):
DESTRUCTURING_CLASS = TensorflowTensorDestructuring DESTRUCTURING_CLASS = TensorflowTensorDestructuring
......
...@@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu(): ...@@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu():
assert 'call_forward' in dir(torch_extension) assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension) assert 'call_backward' in dir(torch_extension)
torch_extension = module.compile()
assert torch_extension is not None
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
def test_torch_native_compilation_gpu(): def test_torch_native_compilation_gpu():
...@@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu(): ...@@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu():
assert 'call_forward' in dir(torch_extension) assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension) assert 'call_backward' in dir(torch_extension)
torch_extension = module.compile()
assert torch_extension is not None
assert 'call_forward' in dir(torch_extension)
assert 'call_backward' in dir(torch_extension)
@pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests') @pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
def test_execute_torch_gpu(): def test_execute_torch_gpu():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment