diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index e853a4e765c8db24e4401b6ffb33207e6fbc8e17..db753059432a78a0aaff49b14d3bc96cddd95dfe 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -137,7 +137,7 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
-    def compile(self):
+    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
         from torch.utils.cpp_extension import load
         file_extension = '.cu' if self.is_cuda else '.cpp'
         source_code = str(self)
@@ -155,10 +155,12 @@ class TorchModule(JinjaCppFile):
         os.environ['CXX'] = get_compiler_config()['command']
 
         torch_extension = load(hash,
-                               [file_name],
-                               with_cuda=self.is_cuda,
-                               extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
-                               extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
+                               [file_name] + extra_source_files,
+                               with_cuda=self.is_cuda or with_cuda,
+                               extra_cflags=['--std=c++14', get_compiler_config()
+                                             ['flags'].replace('--std=c++11', '')],
+                               extra_cuda_cflags=['-std=c++14', '-ccbin',
+                                                  get_compiler_config()['command']] + extra_cuda_flags,
                                build_directory=build_dir,
                                extra_include_paths=[get_pycuda_include_path(),
                                                     get_pystencils_include_path(),
diff --git a/src/pystencils_autodiff/backends/module.tmpl.cpp b/src/pystencils_autodiff/backends/module.tmpl.cpp
index f7f7ebfcb15500be81095634fb1cc1342a9978d4..9ea13c3c20ae7c06b6f23deb562559fd18a77ca6 100644
--- a/src/pystencils_autodiff/backends/module.tmpl.cpp
+++ b/src/pystencils_autodiff/backends/module.tmpl.cpp
@@ -1,9 +1,9 @@
 #define RESTRICT __restrict__
 
-#if GOOGLE_CUDA
-#define EIGEN_USE_GPU
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#endif
+//#if GOOGLE_CUDA
+//#define EIGEN_USE_GPU
+//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+//#endif
 
 {% for header in headers -%}
 #include {{ header }}
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 8e9cbba59a2be98672575bdd41846b6387ef8765..75dd5a1ca6b1d0484533d3264a11b94f2928081d 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -369,14 +369,18 @@ class CustomFunctionDeclaration(JinjaCppFile):
 class CustomFunctionCall(JinjaCppFile):
     TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
 
-    def __init__(self, function_name, *args, fields_accessed=[]):
+    def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None):
         ast_dict = {
             'function_name': function_name,
             'args': args,
             'fields_accessed': [f.center for f in fields_accessed]
         }
         super().__init__(ast_dict)
-        self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)]
+        if custom_signature:
+            self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())]
+        else:
+            self.required_global_declarations = [CustomFunctionDeclaration(
+                self.ast_dict.function_name, self.ast_dict.args)]
 
     @property
     def symbols_defined(self):