Skip to content
Snippets Groups Projects
Commit 4e7993e8 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Avoid subdir for torch compilation / load pybind11 module with cppimport

parent 5bb54590
Branches
Tags
No related merge requests found
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
""" """
import os import os
import sys
from collections.abc import Iterable from collections.abc import Iterable
from os.path import dirname, exists, join from os.path import dirname, exists, join
...@@ -97,14 +98,11 @@ class TorchModule(JinjaCppFile): ...@@ -97,14 +98,11 @@ class TorchModule(JinjaCppFile):
file_extension = '.cu' if self.is_cuda else '.cpp' file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self) source_code = str(self)
hash = _hash(source_code.encode()).hexdigest() hash = _hash(source_code.encode()).hexdigest()
try: file_name = join(pystencils.cache.cache_dir, f'{hash}{file_extension}')
os.mkdir(join(pystencils.cache.cache_dir, hash))
except Exception:
pass
file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}')
if not exists(file_name): if not exists(file_name):
write_file(file_name, source_code) write_file(file_name, source_code)
# TODO: propagate extra headers
torch_extension = load(hash, [file_name], with_cuda=self.is_cuda) torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
return torch_extension return torch_extension
...@@ -134,3 +132,27 @@ class TensorflowModule(TorchModule): ...@@ -134,3 +132,27 @@ class TensorflowModule(TorchModule):
class PybindModule(TorchModule): class PybindModule(TorchModule):
DESTRUCTURING_CLASS = PybindArrayDestructuring DESTRUCTURING_CLASS = PybindArrayDestructuring
PYTHON_BINDINGS_CLASS = PybindPythonBindings PYTHON_BINDINGS_CLASS = PybindPythonBindings
CPP_IMPORT_PREFIX = """/*
<%
setup_pybind11(cfg)
%>
*/
"""
def compile(self):
import cppimport
assert not self.is_cuda
source_code = str(self)
file_name = join(pystencils.cache.cache_dir, f'{self.module_name}.cpp')
if not exists(file_name):
write_file(file_name, source_code)
# TODO: propagate extra headers
cache_dir = pystencils.cache.cache_dir
if cache_dir not in sys.path:
sys.path.append(cache_dir)
torch_extension = cppimport.imp(f'{self.module_name}')
return torch_extension
...@@ -14,7 +14,7 @@ import sympy ...@@ -14,7 +14,7 @@ import sympy
import pystencils import pystencils
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import write_file from pystencils_autodiff._file_io import write_file
from pystencils_autodiff.backends.astnodes import TorchModule from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
torch = pytest.importorskip('torch') torch = pytest.importorskip('torch')
pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0, pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0,
...@@ -79,6 +79,34 @@ def test_torch_native_compilation_cpu(): ...@@ -79,6 +79,34 @@ def test_torch_native_compilation_cpu():
assert 'call_backward' in dir(torch_extension) assert 'call_backward' in dir(torch_extension)
def test_pybind11_compilation_cpu():
module_name = "Olololsada"
target = 'cpu'
z, y, x = pystencils.fields("z, y, x: [20,40]")
a = sympy.Symbol('a')
forward_assignments = pystencils.AssignmentCollection({
z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
})
backward_assignments = create_backward_assignments(forward_assignments)
forward_ast = pystencils.create_kernel(forward_assignments, target)
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
module = PybindModule(module_name, [forward_ast, backward_ast])
print(module)
pybind_extension = module.compile()
assert pybind_extension is not None
assert 'call_forward' in dir(pybind_extension)
assert 'call_backward' in dir(pybind_extension)
@pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
def test_torch_native_compilation_gpu(): def test_torch_native_compilation_gpu():
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment