diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 241145bd251dd3e5461c5725cbf7cee9e7603c53..84d184c15bb30334c702a5cf785eb8e0020a4a9e 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -8,18 +8,15 @@
 
 """
 
+from collections.abc import Iterable
 from os.path import dirname, join
 
-import jinja2
-
 from pystencils.astnodes import (
-    DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol,
-    FieldStrideSymbol, Node)
-from pystencils.backends.cbackend import get_headers
-from pystencils.framework_intergration_astnodes import (
-    FrameworkIntegrationPrinter, WrapperFunction, generate_kernel_call)
+    DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol)
 from pystencils_autodiff._file_io import _read_template_from_file
 from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
+from pystencils_autodiff.framework_integration.astnodes import (
+    WrapperFunction, generate_kernel_call)
 
 # Torch
 
@@ -50,17 +47,34 @@ class TensorflowTensorDestructuring(DestructuringBindingsForFieldClass):
                ]
 
 
-class TorchCudaModule(JinjaCppFile):
-    TEMPLATE = _read_template_from_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu'))
+class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
+    CLASS_TO_MEMBER_DICT = {
+        FieldPointerSymbol: "mutable_data()",
+        FieldShapeSymbol: "shape({dim})",
+        FieldStrideSymbol: "strides({dim})"
+    }
+
+    CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>"
+
+    headers = ["<pybind11/numpy.h>"]
+
+
+class TorchModule(JinjaCppFile):
+    TEMPLATE = _read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp'))
     DESTRUCTURING_CLASS = TorchTensorDestructuring
 
-    def __init__(self, forward_kernel_ast, backward_kernel_ast):
+    def __init__(self, kernel_asts):
+        """Create a C++ module with forward and optional backward_kernels
+
+        :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
+        :param backward_kernel_ast:
+        """
+        if not isinstance(kernel_asts, Iterable):
+            kernel_asts = [kernel_asts]
 
         ast_dict = {
-            'forward_kernel': forward_kernel_ast,
-            'backward_kernel': backward_kernel_ast,
-            'forward_wrapper': self.generate_wrapper_function(forward_kernel_ast),
-            'backward_wrapper': self.generate_wrapper_function(backward_kernel_ast),
+            'kernels': kernel_asts,
+            'kernel_wrappers': [self.generate_wrapper_function(k) for k in kernel_asts],
         }
 
         super().__init__(ast_dict)
@@ -70,9 +84,9 @@ class TorchCudaModule(JinjaCppFile):
                                function_name='call_' + kernel_ast.function_name)
 
 
-class TensorCudaModule(TorchCudaModule):
+class TensorflowModule(TorchModule):
     DESTRUCTURING_CLASS = TensorflowTensorDestructuring
 
 
-# Generic
-class MachineLearningBackend(FrameworkIntegrationPrinter):
+class PybindModule(TorchModule):
+    DESTRUCTURING_CLASS = PybindArrayDestructuring
diff --git a/src/pystencils_autodiff/backends/module.tmpl.cpp b/src/pystencils_autodiff/backends/module.tmpl.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2e9c2f0d0e14a48b2a9b672c6490ec1fab6425cc
--- /dev/null
+++ b/src/pystencils_autodiff/backends/module.tmpl.cpp
@@ -0,0 +1,24 @@
+#include <cuda.h>
+#include <vector>
+
+// Most compilers don't care whether it's __restrict or __restrict__
+#define RESTRICT __restrict__
+
+{% for header in headers -%}
+#include {{ header }}
+{% endfor %}
+
+{% for global in globals -%}
+{{ global }}
+{% endfor %}
+
+{% for kernel in kernels %}
+{{ kernel }}
+{% endfor %}
+
+{% for wrapper in kernel_wrappers %}
+{{ wrapper }}
+{% endfor %}
+
+{{ python_bindings }}
+
diff --git a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
deleted file mode 100644
index 74eec9ee424e3955453dc52b3516d8cdab82d8da..0000000000000000000000000000000000000000
--- a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
+++ /dev/null
@@ -1,79 +0,0 @@
-#include <torch/extension.h>
-
-#include <vector>
-
-using namespace pybind11::literals;
-
-using scalar_t = {{ dtype }};
-
-
-#define RESTRICT __restrict
-
-std::vector<at::Tensor> {{ kernel_name }}_forward(
-{%- for tensor in forward_tensors -%}
-    at::Tensor {{ tensor }} {{- ", " if not loop.last -}}
-{%- endfor %})
-{
-    //{% for tensor in forward_output_tensors -%}
-    //auto {{tensor}} = at::zeros_like({{ forward_input_tensors[0] }});
-    //{% endfor %}
-
-    {% for tensor in forward_tensors -%}
-    {%- set last = loop.last -%}
-    scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>();
-    {% for i in dimensions -%}
-    int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}];
-    {% endfor -%}
-    {% for i in dimensions -%}
-    int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
-    {% endfor -%}
-    {% endfor -%}
-
-    {{forward_kernel}}
-
-    return {
-    {%- for tensor in forward_output_tensors -%}
-    {{ tensor }} {{- "," if not loop.last -}}
-    {% endfor -%}
-    };
-}
-
-std::vector<at::Tensor> {{ kernel_name }}_backward(
-{%- for tensor in backward_tensors -%}
-    at::Tensor {{ tensor }} {{- ", " if not loop.last -}}
-{% endfor %})
-{
-    //{% for tensor in backward_output_tensors -%}
-    //auto {{tensor}} = at::zeros_like({{ backward_input_tensors[0] }});
-    //{% endfor %}
-
-    {% for tensor in backward_tensors -%}
-    {%- set last = loop.last -%}
-    scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>();
-    {% for i in dimensions -%}
-    int _stride_{{ tensor }}_{{i}} = {{ tensor }}.strides()[{{ i }}];
-    {% endfor -%}
-    {% for i in dimensions -%}
-    int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
-    {% endfor -%}
-    {% endfor -%}
-
-    {{backward_kernel}}
-
-    return {
-    {%- for tensor in backward_output_tensors -%}
-    {{ tensor }} {{- "," if not loop.last -}}
-    {% endfor -%}
-    };
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("forward", &{{ kernel_name }}_forward, "{{ kernel_name }} forward (CPU)",
-{%- for tensor in forward_tensors -%}
-    "{{ tensor }}"_a {{ ", " if not loop.last }}  
-{%- endfor -%} );
-  m.def("backward", &{{ kernel_name }}_backward, "{{ kernel_name }} backward (CPU)",
-{%- for tensor in backward_tensors -%}
-    "{{ tensor }}"_a {{ ", " if not loop.last }}  
-{%- endfor -%} );
-}
diff --git a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp
deleted file mode 100644
index 0662090d3320b045d41ac6ab11fb15c11c4d8d74..0000000000000000000000000000000000000000
--- a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp
+++ /dev/null
@@ -1,79 +0,0 @@
-#include <torch/extension.h>
-#include <vector>
-
-// CUDA forward declarations
-using namespace pybind11::literals;
-
-void {{ kernel_name }}_cuda_forward(
-{%- for tensor in forward_tensors %}
-    at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
-{% endfor %});
-
-std::vector<at::Tensor> {{ kernel_name }}_cuda_backward(
-{%- for tensor in backward_tensors -%}
-    at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
-{% endfor %});
-
-// C++ interface
-
-// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
-#define CHECK_CUDA(x)                                                          \
-  AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
-#define CHECK_CONTIGUOUS(x)                                                    \
-  //AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
-#define CHECK_INPUT(x) CHECK_CUDA(x);       
-  //CHECK_CONTIGUOUS(x)
-
-std::vector<at::Tensor> {{ kernel_name }}_forward(
-{%- for tensor in forward_tensors -%}
-    at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
-{%- endfor %})
-{
-    {% for tensor in forward_tensors -%}
-    CHECK_INPUT({{ tensor.name }});
-    {% endfor %}
-
-    {{ kernel_name }}_cuda_forward(
-        {%- for tensor in forward_tensors %}
-        {{ tensor.name }} {{- ", " if not loop.last }}
-        {%- endfor %});
-
-    return std::vector<at::Tensor>{
-        {%- for tensor in forward_output_tensors %}
-        {{ tensor.name }} {{- ", " if not loop.last }}
-        {%- endfor %}
-    }
-        ;
-}
-
-std::vector<at::Tensor> {{ kernel_name }}_backward(
-{%- for tensor in backward_tensors -%}
-    at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
-{% endfor %})
-{
-    {%- for tensor in forward_input_tensors + backward_input_tensors -%}
-    CHECK_INPUT({{ tensor }});
-    {% endfor %}
-    {{ kernel_name }}_cuda_backward(
-        {%- for tensor in backward_tensors -%}
-        {{ tensor.name }} {{- ", " if not loop.last }}
-        {%- endfor %});
-
-    return std::vector<at::Tensor>{
-        {%- for tensor in backward_output_tensors %}
-        {{ tensor.name }} {{- ", " if not loop.last }}
-        {%- endfor %}
-    }
-        ;
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-  m.def("forward", &{{ kernel_name }}_forward, "{{ kernel_name }} forward (CUDA)",
-{%- for tensor in forward_tensors -%}
-    "{{ tensor.name }}"_a {{ ", " if not loop.last }}  
-{%- endfor -%} );
-  m.def("backward", &{{ kernel_name }}_backward, "{{ kernel_name }} backward (CUDA)",
-{%- for tensor in backward_tensors -%}
-    "{{ tensor.name }}"_a {{ ", " if not loop.last }}  
-{%- endfor -%} );
-}
diff --git a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu b/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu
deleted file mode 100644
index 4d706ea2d3b9c1c6d7fbb316678ada2771d6d7db..0000000000000000000000000000000000000000
--- a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu
+++ /dev/null
@@ -1,21 +0,0 @@
-#include <cuda.h>
-#include <vector>
-
-
-{% for header in headers -%}
-#include {{ header }}
-{% endfor %}
-
-#define RESTRICT __restrict__
-#define FUNC_PREFIX __global__
-
-
-{{ forward_kernel }}
-
-{{ backward_kernel }}
-
-
-#define RESTRICT
-{{ forward_wrapper }}
-
-{{ backward_wrapper }}
diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py
index 1e7d12827010efc584013b2a9c5d0131dc3141cf..db85b65031baa23153c2b2a9952289ebbe36ebc8 100644
--- a/src/pystencils_autodiff/framework_integration/astnodes.py
+++ b/src/pystencils_autodiff/framework_integration/astnodes.py
@@ -9,6 +9,7 @@ Astnodes useful for the generations of C modules for frameworks (apart from waLB
 
 waLBerla currently uses `pystencils-walberla <https://pypi.org/project/pystencils-walberla/>`_.
 """
+from collections.abc import Iterable
 from typing import Any, List, Set
 
 import jinja2
@@ -19,6 +20,7 @@ import pystencils
 from pystencils.astnodes import Node, NodeOrExpr, ResolvedFieldAccess
 from pystencils.data_types import TypedSymbol
 from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
+from pystencils_autodiff.framework_integration.printer import FrameworkIntegrationPrinter
 
 
 class DestructuringBindingsForFieldClass(Node):
@@ -221,6 +223,7 @@ class JinjaCppFile(Node):
 
     def __init__(self, ast_dict):
         self.ast_dict = ast_dict
+        self.printer = FrameworkIntegrationPrinter()
         Node.__init__(self)
 
     @property
@@ -246,8 +249,9 @@ class JinjaCppFile(Node):
 
     def __str__(self):
         assert self.TEMPLATE, f"Template of {self.__class__} must be set"
-        render_dict = {k: self._print(v) for k, v in self.ast_dict.items()}
-        render_dict.update({"headers": pystencils.backend.cbackend.get_headers(self)})
+        render_dict = {k: self._print(v) if not isinstance(v, Iterable) else [self._print(a) for a in v]
+                       for k, v in self.ast_dict.items()}
+        render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
 
         return self.TEMPLATE.render(render_dict)
 
diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py
index 81a182667500ee5de1c9a33a3ee8477c6383f7d2..3442ffaace5b8ff3ebbd5e567c2a7bd868d865b8 100644
--- a/src/pystencils_autodiff/framework_integration/printer.py
+++ b/src/pystencils_autodiff/framework_integration/printer.py
@@ -20,7 +20,13 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
         return super_result.replace('FUNC_PREFIX ', '')
 
     def _print_KernelFunction(self, node):
-        return pystencils.backends.cbackend.generate_c(node, dialect=node.backend)
+        if node.backend == 'cuda':
+            prefix = '#define FUNC_PREFIX __global__\n'
+            kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda')
+        else:
+            prefix = '#define FUNC_PREFIX\n"'
+            kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c')
+        return prefix + kernel_code
 
     def _print_KernelFunctionCall(self, node):
 
diff --git a/tests/test_module_printing.py b/tests/test_module_printing.py
index a2d253e5b97097426719d8c1c96a0f6782c33e73..db0005fe5fc44ace94a553ed82a895e475966931 100644
--- a/tests/test_module_printing.py
+++ b/tests/test_module_printing.py
@@ -11,7 +11,7 @@ import sympy
 
 import pystencils
 from pystencils_autodiff import create_backward_assignments
-from pystencils_autodiff.backends.astnodes import TensorCudaModule, TorchCudaModule
+from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule
 
 TARGET_TO_DIALECT = {
     'cpu': 'c',
@@ -34,12 +34,18 @@ def test_module_printing():
         forward_ast.function_name = 'forward'
         backward_ast = pystencils.create_kernel(backward_assignments, target)
         backward_ast.function_name = 'backward'
-        module = TorchCudaModule(forward_ast, backward_ast)
+        module = TorchModule([forward_ast, backward_ast])
         print(module)
 
-        module = TensorCudaModule(forward_ast, backward_ast)
+        module = TensorflowModule({forward_ast: backward_ast})
         print(module)
 
+        if target == 'cpu':
+            module = PybindModule([forward_ast, backward_ast])
+            print(module)
+            module = PybindModule(forward_ast)
+            print(module)
+
 
 def test_module_printing_parameter():
     for target in ('cpu', 'gpu'):
@@ -57,16 +63,22 @@ def test_module_printing_parameter():
         forward_ast.function_name = 'forward'
         backward_ast = pystencils.create_kernel(backward_assignments, target)
         backward_ast.function_name = 'backward'
-        module = TorchCudaModule(forward_ast, backward_ast)
+        module = TorchModule([forward_ast, backward_ast])
         print(module)
 
-        module = TensorCudaModule(forward_ast, backward_ast)
+        module = TensorflowModule({forward_ast: backward_ast})
         print(module)
 
+        if target == 'cpu':
+            module = PybindModule([forward_ast, backward_ast])
+            print(module)
+            module = PybindModule(forward_ast)
+            print(module)
+
 
 def main():
     test_module_printing()
-    test_module_printing_parameter()
+    # test_module_printing_parameter()
 
 
 if __name__ == '__main__':