From 803bbff814b4d9de219df2a957818bec4eb8ed79 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Fri, 13 Sep 2019 16:08:12 +0200 Subject: [PATCH] Add TorchModule.compile --- src/pystencils_autodiff/backends/astnodes.py | 24 +++++++++++++++++-- .../backends/test_torch_native_compilation.py | 10 ++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 1d3d6a2..bd6c781 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -8,16 +8,19 @@ """ +import os from collections.abc import Iterable -from os.path import dirname, join +from os.path import dirname, exists, join +import pystencils from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol -from pystencils_autodiff._file_io import read_template_from_file +from pystencils_autodiff._file_io import read_template_from_file, write_file from pystencils_autodiff.backends.python_bindings import ( PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping, TensorflowPythonBindings, TorchPythonBindings) from pystencils_autodiff.framework_integration.astnodes import ( DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call) +from pystencils_autodiff.tensorflow_jit import _hash class TorchTensorDestructuring(DestructuringBindingsForFieldClass): @@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile): if not isinstance(kernel_asts, Iterable): kernel_asts = [kernel_asts] wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts] + self.module_name = module_name ast_dict = { 'kernels': kernel_asts, @@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile): return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), function_name='call_' + kernel_ast.function_name) + def compile(self): + from torch.utils.cpp_extension import load + file_extension = '.cu' if self.is_cuda else '.cpp' + source_code = str(self) + hash = _hash(source_code.encode()).hexdigest() + try: + os.mkdir(join(pystencils.cache.cache_dir, hash)) + except Exception: + pass + file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}') + + if not exists(file_name): + write_file(file_name, source_code) + torch_extension = load(hash, [file_name], with_cuda=self.is_cuda) + return torch_extension + class TensorflowModule(TorchModule): DESTRUCTURING_CLASS = TensorflowTensorDestructuring diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index 4e88075..40a5e23 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu(): assert 'call_forward' in dir(torch_extension) assert 'call_backward' in dir(torch_extension) + torch_extension = module.compile() + assert torch_extension is not None + assert 'call_forward' in dir(torch_extension) + assert 'call_backward' in dir(torch_extension) + @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS") def test_torch_native_compilation_gpu(): @@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu(): assert 'call_forward' in dir(torch_extension) assert 'call_backward' in dir(torch_extension) + torch_extension = module.compile() + assert torch_extension is not None + assert 'call_forward' in dir(torch_extension) + assert 'call_backward' in dir(torch_extension) + @pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests') def test_execute_torch_gpu(): -- GitLab