From 26df2dfb6f6c5a6b2cc869a085e0499c7889786f Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 24 Feb 2020 15:12:15 +0100
Subject: [PATCH] Add some extra flags to TorchModule.compile and remove
 GOOGLE_CUDA from Template

---
 src/pystencils_autodiff/backends/astnodes.py         | 12 +++++++-----
 src/pystencils_autodiff/backends/module.tmpl.cpp     |  8 ++++----
 .../framework_integration/astnodes.py                |  8 ++++++--
 3 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index e853a4e..db75305 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -137,7 +137,7 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
-    def compile(self):
+    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
         from torch.utils.cpp_extension import load
         file_extension = '.cu' if self.is_cuda else '.cpp'
         source_code = str(self)
@@ -155,10 +155,12 @@ class TorchModule(JinjaCppFile):
         os.environ['CXX'] = get_compiler_config()['command']
 
         torch_extension = load(hash,
-                               [file_name],
-                               with_cuda=self.is_cuda,
-                               extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
-                               extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
+                               [file_name] + extra_source_files,
+                               with_cuda=self.is_cuda or with_cuda,
+                               extra_cflags=['--std=c++14', get_compiler_config()
+                                             ['flags'].replace('--std=c++11', '')],
+                               extra_cuda_cflags=['-std=c++14', '-ccbin',
+                                                  get_compiler_config()['command']] + extra_cuda_flags,
                                build_directory=build_dir,
                                extra_include_paths=[get_pycuda_include_path(),
                                                     get_pystencils_include_path(),
diff --git a/src/pystencils_autodiff/backends/module.tmpl.cpp b/src/pystencils_autodiff/backends/module.tmpl.cpp
index f7f7ebf..9ea13c3 100644
--- a/src/pystencils_autodiff/backends/module.tmpl.cpp
+++ b/src/pystencils_autodiff/backends/module.tmpl.cpp
@@ -1,9 +1,9 @@
 #define RESTRICT __restrict__
 
-#if GOOGLE_CUDA
-#define EIGEN_USE_GPU
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#endif
+//#if GOOGLE_CUDA
+//#define EIGEN_USE_GPU
+//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+//#endif
 
 {% for header in headers -%}
 #include {{ header }}
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 8e9cbba..75dd5a1 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -369,14 +369,18 @@ class CustomFunctionDeclaration(JinjaCppFile):
 class CustomFunctionCall(JinjaCppFile):
     TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
 
-    def __init__(self, function_name, *args, fields_accessed=[]):
+    def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None):
         ast_dict = {
             'function_name': function_name,
             'args': args,
             'fields_accessed': [f.center for f in fields_accessed]
         }
         super().__init__(ast_dict)
-        self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)]
+        if custom_signature:
+            self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())]
+        else:
+            self.required_global_declarations = [CustomFunctionDeclaration(
+                self.ast_dict.function_name, self.ast_dict.args)]
 
     @property
     def symbols_defined(self):
-- 
GitLab