diff --git a/src/pystencils_autodiff/_backport.py b/src/pystencils_autodiff/_backport.py
index e9ba2790bac9a743e3175a464526a579eb878b00..c49103ab7c0f21ca5b08cf150671dc4a9ebe52ef 100644
--- a/src/pystencils_autodiff/_backport.py
+++ b/src/pystencils_autodiff/_backport.py
@@ -10,8 +10,6 @@
 
 import itertools
 
-import sympy as sp
-
 import pystencils
 from pystencils.astnodes import KernelFunction, ResolvedFieldAccess, SympyAssignment
 from pystencils.interpolation_astnodes import InterpolatorAccess
@@ -44,7 +42,6 @@ def compatibility_hacks():
     pystencils.fields = fields
     KernelFunction.fields_read = property(fields_read)
     KernelFunction.fields_written = property(fields_written)
-    sp.Expr.undefined_symbols = sp.Expr.free_symbols
 
 
 compatibility_hacks()
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index f724079dc11c6975c20aab97cfc57ce2490a3ed6..815738611dd89a8d8b30ebdbb2ca177f4c7f144a 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -174,7 +174,7 @@ class JinjaCppFile(Node):
     @property
     def args(self):
         """Returns all arguments/children of this node."""
-        ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, str))]
+        ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr, str))]
         iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
                                   and not isinstance(a, str)]
         return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
@@ -184,14 +184,14 @@ class JinjaCppFile(Node):
         """Set of symbols which are defined by this node."""
         return set(itertools.chain.from_iterable(a.symbols_defined
                                                  for a in self.args
-                                                 if hasattr(a, 'symbols_defined')))
+                                                 if isinstance(a, Node)))
 
     @property
     def undefined_symbols(self):
         """Symbols which are used but are not defined inside this node."""
-        return set(itertools.chain.from_iterable(a.undefined_symbols
+        return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols
                                                  for a in self.args
-                                                 if hasattr(a, 'undefined_symbols'))) - self.symbols_defined
+                                                 if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
 
     def _print(self, node):
         if isinstance(node, Node):
@@ -296,3 +296,25 @@ class CudaErrorCheck(CustomCodeNode):
     err_check_function = CudaErrorCheckDefinition()
     required_global_declarations = [err_check_function]
     headers = ['<cuda.h>']
+
+
+class DynamicFunction(sp.Function):
+    """
+    Function that is passed as an argument to a kernel.
+    Can be printed for example as `std::function` or as a functor template.
+    """
+
+    def __new__(cls, typed_function_symbol, return_dtype, *args):
+        return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
+
+    @property
+    def function_dtype(self):
+        return self.args[0].dtype
+
+    @property
+    def dtype(self):
+        return self.args[1].dtype
+
+    @property
+    def name(self):
+        return self.args[0].name
diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py
index cf94e9584c891d59e504d92b6ea387f830ab027a..35ac8895ebb558603b055c9744b5904a68db91be 100644
--- a/src/pystencils_autodiff/framework_integration/printer.py
+++ b/src/pystencils_autodiff/framework_integration/printer.py
@@ -1,7 +1,9 @@
 import sympy as sp
 
 import pystencils.backends.cbackend
+from pystencils.data_types import TypedSymbol
 from pystencils.kernelparameters import FieldPointerSymbol
+from pystencils_autodiff.framework_integration.types import TemplateType
 
 
 class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
@@ -15,12 +17,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
 
     def __init__(self):
         super().__init__(dialect='c')
+        self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction
 
     def _print(self, node):
         from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
         if isinstance(node, JinjaCppFile):
             node.printer = self
-
         if isinstance(node, sp.Expr):
             return self.sympy_printer._print(node)
         else:
@@ -40,6 +42,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
         else:
             prefix = '#define FUNC_PREFIX static\n'
             kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c', with_globals=False)
+        template_types = sorted([x.dtype for x in node.atoms(TypedSymbol)
+                                 if isinstance(x.dtype, TemplateType)], key=str)
+        template_types = list(map(lambda x: 'class ' + str(x), template_types))
+        if template_types:
+            prefix = f'{prefix}template <{",".join(template_types)}>\n'
+
         return prefix + kernel_code
 
     def _print_FunctionCall(self, node):
@@ -83,7 +91,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
                                                                 if hasattr(u, 'field_name')
                                                                 else u.field_names[0]]),
                                        field_name=(u.field_name if hasattr(u, "field_name") else ""),
-                                       dim=("" if type(u) == FieldPointerSymbol else u.coordinate)
+                                       dim=("" if type(u) == FieldPointerSymbol else u.coordinate),
+                                       dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate])
                                    )
                                    )
                                   for u in undefined_field_symbols
@@ -105,14 +114,28 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
     def _print_SwapBuffer(self, node):
         return f"""std::swap({node.first_array}, {node.second_array});"""
 
+    def _print_DynamicFunction(self, expr):
+        name = expr.name
+        arg_str = ', '.join(self._print(a) for a in expr.args[2:])
+        return f'{name}({arg_str})'
+
 
 class DebugFrameworkPrinter(FrameworkIntegrationPrinter):
+    """
+    Printer with information on nodes inlined in code as comments.
+
+    Should not be used in production, will modify your SymPy printer, destroy your whole life!
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.sympy_printer._old_print = self.sympy_printer._print
+        self.sympy_printer.__class__._print = self._print
 
     def _print(self, node):
         if isinstance(node, sp.Expr):
-            return self.sympy_printer._print(node)
+            return self.sympy_printer._old_print(node) + f'/* {node.__class__.__name__}: free_symbols: {node.free_symbols} */'  # noqa
         elif isinstance(node, pystencils.astnodes.Node):
             return super()._print(node) + f'/* {node.__class__.__name__} symbols_undefined: {node.undefined_symbols}, symbols_defined: {node.symbols_defined}, args {[a if isinstance(a,str) else a.__class__.__name__ for a in node.args]} */'  # noqa
-
         else:
             return super()._print(node)
diff --git a/src/pystencils_autodiff/framework_integration/types.py b/src/pystencils_autodiff/framework_integration/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3929348fb06bcc1f711233def968161fe428aba
--- /dev/null
+++ b/src/pystencils_autodiff/framework_integration/types.py
@@ -0,0 +1,18 @@
+#
+# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de>
+#
+# Distributed under terms of the GPLv3 license.
+
+"""
+
+"""
+from pystencils.data_types import Type
+
+
+class TemplateType(Type):
+
+    def __init__(self, name):
+        self._name = name
+
+    def _sympystr(self, *args, **kwargs):
+        return str(self._name)
diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..0423116585c4801ccf13a483a24c72580961d5be
--- /dev/null
+++ b/tests/test_dynamic_function.py
@@ -0,0 +1,35 @@
+import sympy as sp
+
+import pystencils
+from pystencils.data_types import TypedSymbol, create_type
+from pystencils_autodiff.framework_integration.astnodes import DynamicFunction
+from pystencils_autodiff.framework_integration.printer import (
+    DebugFrameworkPrinter, FrameworkIntegrationPrinter)
+from pystencils_autodiff.framework_integration.types import TemplateType
+
+
+def test_dynamic_function():
+    x, y = pystencils.fields('x, y:  float32[3d]')
+
+    a = sp.symbols('a')
+
+    my_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                              'std::function<double(double)>'), create_type('double'), a)
+
+    assignments = pystencils.AssignmentCollection({
+        y.center: x.center + my_fun_call
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+
+    template_fun_call = DynamicFunction(TypedSymbol('my_fun',
+                                                    TemplateType('Functor_T')), create_type('double'), a, x.center)
+
+    assignments = pystencils.AssignmentCollection({
+        y.center: x.center + template_fun_call
+    })
+
+    ast = pystencils.create_kernel(assignments)
+    pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
+    pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())