From 803bbff814b4d9de219df2a957818bec4eb8ed79 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Fri, 13 Sep 2019 16:08:12 +0200
Subject: [PATCH] Add TorchModule.compile

---
 src/pystencils_autodiff/backends/astnodes.py  | 24 +++++++++++++++++--
 .../backends/test_torch_native_compilation.py | 10 ++++++++
 2 files changed, 32 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 1d3d6a2..bd6c781 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -8,16 +8,19 @@
 
 """
 
+import os
 from collections.abc import Iterable
-from os.path import dirname, join
+from os.path import dirname, exists, join
 
+import pystencils
 from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
-from pystencils_autodiff._file_io import read_template_from_file
+from pystencils_autodiff._file_io import read_template_from_file, write_file
 from pystencils_autodiff.backends.python_bindings import (
     PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
     TensorflowPythonBindings, TorchPythonBindings)
 from pystencils_autodiff.framework_integration.astnodes import (
     DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
+from pystencils_autodiff.tensorflow_jit import _hash
 
 
 class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
@@ -73,6 +76,7 @@ class TorchModule(JinjaCppFile):
         if not isinstance(kernel_asts, Iterable):
             kernel_asts = [kernel_asts]
         wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
+        self.module_name = module_name
 
         ast_dict = {
             'kernels': kernel_asts,
@@ -88,6 +92,22 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
+    def compile(self):
+        from torch.utils.cpp_extension import load
+        file_extension = '.cu' if self.is_cuda else '.cpp'
+        source_code = str(self)
+        hash = _hash(source_code.encode()).hexdigest()
+        try:
+            os.mkdir(join(pystencils.cache.cache_dir, hash))
+        except Exception:
+            pass
+        file_name = join(pystencils.cache.cache_dir, hash, f'{hash}{file_extension}')
+
+        if not exists(file_name):
+            write_file(file_name, source_code)
+        torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
+        return torch_extension
+
 
 class TensorflowModule(TorchModule):
     DESTRUCTURING_CLASS = TensorflowTensorDestructuring
diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py
index 4e88075..40a5e23 100644
--- a/tests/backends/test_torch_native_compilation.py
+++ b/tests/backends/test_torch_native_compilation.py
@@ -73,6 +73,11 @@ def test_torch_native_compilation_cpu():
     assert 'call_forward' in dir(torch_extension)
     assert 'call_backward' in dir(torch_extension)
 
+    torch_extension = module.compile()
+    assert torch_extension is not None
+    assert 'call_forward' in dir(torch_extension)
+    assert 'call_backward' in dir(torch_extension)
+
 
 @pytest.mark.skipif("TRAVIS" in os.environ, reason="nvcc compilation currently not working on TRAVIS")
 def test_torch_native_compilation_gpu():
@@ -106,6 +111,11 @@ def test_torch_native_compilation_gpu():
     assert 'call_forward' in dir(torch_extension)
     assert 'call_backward' in dir(torch_extension)
 
+    torch_extension = module.compile()
+    assert torch_extension is not None
+    assert 'call_forward' in dir(torch_extension)
+    assert 'call_backward' in dir(torch_extension)
+
 
 @pytest.mark.skipif(True or 'NO_GPU_EXECUTION' in os.environ, reason='Skip GPU execution tests')
 def test_execute_torch_gpu():
-- 
GitLab