diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index e35d1454d537eeb129be7a57f2ab02c952d95cfc..10228c8b776ed9b157c3d0e2712b9235ffe2253b 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -108,12 +108,14 @@ class TorchModule(JinjaCppFile):
         :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
         :param backward_kernel_ast:
         """
+        from pystencils_autodiff.framework_integration.astnodes import CustomFunctionCall
+
         if not isinstance(kernel_asts, Iterable):
             kernel_asts = [kernel_asts]
         wrapper_functions = [self.generate_wrapper_function(k)
-                             if not isinstance(k, WrapperFunction) or wrap_wrapper_functions
+                             if not isinstance(k, WrapperFunction)
                              else k for k in kernel_asts]
-        kernel_asts = [k for k in kernel_asts if not isinstance(k, WrapperFunction)]
+        kernel_asts = [k for k in kernel_asts if not isinstance(k, (WrapperFunction, CustomFunctionCall))]
         self.module_name = module_name
         self.compiled_file = None
 
@@ -137,7 +139,12 @@ class TorchModule(JinjaCppFile):
         return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                                function_name='call_' + kernel_ast.function_name)
 
-    def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None, build_dir=None):
+    def compile(self,
+                extra_source_files=[],
+                extra_cuda_flags=[],
+                with_cuda=None,
+                build_dir=None,
+                compile_module_name=None):
         from torch.utils.cpp_extension import load
         file_extension = '.cu' if self.is_cuda else '.cpp'
         source_code = str(self)
@@ -147,7 +154,7 @@ class TorchModule(JinjaCppFile):
         os.makedirs(build_dir, exist_ok=True)
         file_name = join(build_dir, f'{hash}{file_extension}')
 
-        self.compiled_file = file_name
+        self.compiled_file = (join(build_dir, compile_module_name) or file_name).replace('.cpp', '') + '.so'
 
         if not exists(file_name):
             write_file(file_name, source_code)
@@ -155,7 +162,7 @@ class TorchModule(JinjaCppFile):
         # Torch regards CXX
         os.environ['CXX'] = get_compiler_config()['command']
 
-        torch_extension = load(hash,
+        torch_extension = load(compile_module_name or hash,
                                [file_name] + extra_source_files,
                                with_cuda=self.is_cuda or with_cuda,
                                extra_cflags=['--std=c++14', get_compiler_config()
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 75dd5a1ca6b1d0484533d3264a11b94f2928081d..b7c78dd7d0ee2167021381c0ab53130a0fefde8e 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -43,10 +43,11 @@ class DestructuringBindingsForFieldClass(Node):
         """Set of Field instances: fields which are accessed inside this kernel function"""
 
         from pystencils.interpolation_astnodes import InterpolatorAccess
-        return (set(o.field for o in self.atoms(ResolvedFieldAccess)
-                    | self.atoms(InterpolatorAccess))
-                | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed
-                                                     for k in self.atoms(FunctionCall)))))
+        return set(itertools.chain.from_iterable(((a.field for a in self.atoms(pystencils.Field.Access)),
+                                                  (a.field for a in self.atoms(InterpolatorAccess)),
+                                                  (a.field for a in self.atoms(ResolvedFieldAccess))),)) \
+            | set(itertools.chain.from_iterable((k.kernel_function.fields_accessed
+                                                 for k in self.atoms(FunctionCall))))
 
     def __init__(self, body):
         super(DestructuringBindingsForFieldClass, self).__init__()
@@ -133,6 +134,9 @@ class WrapperFunction(pystencils.astnodes.KernelFunction):
 
 
 def generate_kernel_call(kernel_function):
+    if isinstance(kernel_function, CustomFunctionCall):
+        return kernel_function
+
     from pystencils.interpolation_astnodes import InterpolatorAccess
     from pystencils.kernelparameters import FieldPointerSymbol
 
@@ -152,6 +156,7 @@ def generate_kernel_call(kernel_function):
             FunctionCall(kernel_function),
             CudaErrorCheck(),
         ])
+
     elif kernel_function.backend == 'gpucuda':
         return pystencils.astnodes.Block([CudaErrorCheck(),
                                           FunctionCall(kernel_function),
@@ -385,3 +390,11 @@ class CustomFunctionCall(JinjaCppFile):
     @property
     def symbols_defined(self):
         return set(self.ast_dict.fields_accessed)
+
+    @property
+    def function_name(self):
+        return self.ast_dict.function_name
+
+    @property
+    def undefined_symbols(self):
+        return set(self.ast_dict.args)