diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index bd6c78131af9833e6deff4c645130aadd80d152c..864973ef785d74b7626a3b87d4a2daa0ea15afcf 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -9,6 +9,7 @@ """ import os +import sys from collections.abc import Iterable from os.path import dirname, exists, join @@ -97,14 +98,11 @@ class TorchModule(JinjaCppFile): file_extension = '.cu' if self.is_cuda else '.cpp' source_code = str(self) hash = _hash(source_code.encode()).hexdigest() - try: - os.mkdir(join(pystencils.cache.cache_dir, hash)) - except Exception: - pass - file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}') + file_name = join(pystencils.cache.cache_dir, f'{hash}{file_extension}') if not exists(file_name): write_file(file_name, source_code) + # TODO: propagate extra headers torch_extension = load(hash, [file_name], with_cuda=self.is_cuda) return torch_extension @@ -134,3 +132,27 @@ class TensorflowModule(TorchModule): class PybindModule(TorchModule): DESTRUCTURING_CLASS = PybindArrayDestructuring PYTHON_BINDINGS_CLASS = PybindPythonBindings + + CPP_IMPORT_PREFIX = """/* +<% +setup_pybind11(cfg) +%> +*/ +""" + + def compile(self): + import cppimport + + assert not self.is_cuda + + source_code = str(self) + file_name = join(pystencils.cache.cache_dir, f'{self.module_name}.cpp') + + if not exists(file_name): + write_file(file_name, source_code) + # TODO: propagate extra headers + cache_dir = pystencils.cache.cache_dir + if cache_dir not in sys.path: + sys.path.append(cache_dir) + torch_extension = cppimport.imp(f'{self.module_name}') + return torch_extension diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index 40a5e23f16c817f3314865ea9a48f453b2ec6a5c..03d44063e929e3cc8c0a202c98d2e4c38ee4358c 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -14,7 +14,7 @@ import sympy import pystencils from pystencils_autodiff import create_backward_assignments from pystencils_autodiff._file_io import write_file -from pystencils_autodiff.backends.astnodes import TorchModule +from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule torch = pytest.importorskip('torch') pytestmark = pytest.mark.skipif(subprocess.call(['ninja', '--v']) != 0, @@ -79,6 +79,34 @@ def test_torch_native_compilation_cpu(): assert 'call_backward' in dir(torch_extension) +def test_pybind11_compilation_cpu(): + + module_name = "Olololsada" + + target = 'cpu' + + z, y, x = pystencils.fields("z, y, x: [20,40]") + a = sympy.Symbol('a') + + forward_assignments = pystencils.AssignmentCollection({ + z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0]) + }) + + backward_assignments = create_backward_assignments(forward_assignments) + + forward_ast = pystencils.create_kernel(forward_assignments, target) + forward_ast.function_name = 'forward' + backward_ast = pystencils.create_kernel(backward_assignments, target) + backward_ast.function_name = 'backward' + module = PybindModule(module_name, [forward_ast, backward_ast]) + print(module) + + pybind_extension = module.compile() + assert pybind_extension is not None + assert 'call_forward' in dir(pybind_extension) + assert 'call_backward' in dir(pybind_extension) + + @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") def test_torch_native_compilation_gpu(): from torch.utils.cpp_extension import load