diff --git a/pystencils/nbackend/arrays.py b/pystencils/nbackend/arrays.py
index 23f0ff7788e9464c8ed3860e78993b3ff4527d0f..47394d2ac633897ddfd62985a0e1b17f1c49b05b 100644
--- a/pystencils/nbackend/arrays.py
+++ b/pystencils/nbackend/arrays.py
@@ -118,33 +118,67 @@ class PsArrayAssocVar(PsTypedVariable, ABC):
     to a particular array.
     """
 
+    init_arg_names: tuple[str, ...] = ("name", "dtype", "array")
+    __match_args__ = ("name", "dtype", "array")
+
     def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
         super().__init__(name, dtype)
         self._array = array
 
+    def __getinitargs__(self):
+        return self.name, self.dtype, self.array
+
     @property
     def array(self) -> PsLinearizedArray:
         return self._array
 
 
 class PsArrayBasePointer(PsArrayAssocVar):
+    init_arg_names: tuple[str, ...] = ("name", "array")
+    __match_args__ = ("name", "array")
+
     def __init__(self, name: str, array: PsLinearizedArray):
         dtype = PsPointerType(array.element_type)
         super().__init__(name, dtype, array)
 
         self._array = array
 
+    def __getinitargs__(self):
+        return self.name, self.array
+
 
 class PsArrayShapeVar(PsArrayAssocVar):
-    def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType):
-        name = f"{array}_size{dimension}"
+    init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
+    __match_args__ = ("array", "coordinate", "dtype")
+
+    def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
+        name = f"{array}_size{coordinate}"
         super().__init__(name, dtype, array)
+        self._coordinate = coordinate
+
+    @property
+    def coordinate(self) -> int:
+        return self._coordinate
+
+    def __getinitargs__(self):
+        return self.array, self.coordinate, self.dtype
 
 
 class PsArrayStrideVar(PsArrayAssocVar):
-    def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType):
-        name = f"{array}_size{dimension}"
+    init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype")
+    __match_args__ = ("array", "coordinate", "dtype")
+
+    def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
+        name = f"{array}_size{coordinate}"
         super().__init__(name, dtype, array)
+        self._coordinate = coordinate
+
+    @property
+    def coordinate(self) -> int:
+        return self._coordinate
+
+    def __getinitargs__(self):
+        return self.array, self.coordinate, self.dtype
 
 
 class PsArrayAccess(pb.Subscript):
diff --git a/pystencils/nbackend/ast/constraints.py b/pystencils/nbackend/ast/constraints.py
index 68cbe347a3acc7c5716da86277f952323e26f8ae..d11fe1195a0df3c1f12cfb159ce1763a3dbf7c8f 100644
--- a/pystencils/nbackend/ast/constraints.py
+++ b/pystencils/nbackend/ast/constraints.py
@@ -2,6 +2,9 @@ from dataclasses import dataclass
 
 import pymbolic.primitives as pb
 from pymbolic.mapper.c_code import CCodeMapper
+from pymbolic.mapper.dependency import DependencyMapper
+
+from ..typed_expressions import PsTypedVariable
 
 
 @dataclass
@@ -9,5 +12,11 @@ class PsParamConstraint:
     condition: pb.Comparison
     message: str = ""
 
-    def print(self):
+    def print_c_condition(self):
         return CCodeMapper()(self.condition)
+    
+    def get_variables(self) -> set[PsTypedVariable]:
+        return DependencyMapper(False, False, False, False)(self.condition)
+    
+    def __str__(self) -> str:
+        return f"{self.message} [{self.condition}]"
diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py
index fccd4f12a18292a53566ad3a7d93eef6fef637f4..e2f0ac8e678e3d794b01536afc76dac03245f9ab 100644
--- a/pystencils/nbackend/ast/kernelfunction.py
+++ b/pystencils/nbackend/ast/kernelfunction.py
@@ -126,3 +126,6 @@ class PsKernelFunction(PsAstNode):
 
         arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
         return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints))
+    
+    def get_required_headers(self) -> set[str]:
+        raise NotImplementedError()
diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/emission.py
similarity index 94%
rename from pystencils/nbackend/c_printer.py
rename to pystencils/nbackend/emission.py
index 58a61d579991f6fa2c51ee2300b7941bf584ed0e..b315b172e4173aeaa2d2308928b94c52535e7556 100644
--- a/pystencils/nbackend/c_printer.py
+++ b/pystencils/nbackend/emission.py
@@ -6,6 +6,12 @@ from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, P
 from .ast.kernelfunction import PsKernelFunction
 
 
+def emit_code(kernel: PsKernelFunction):
+    #   TODO: Specialize for different targets
+    printer = CPrinter()
+    return printer.print(kernel)    
+
+
 class CPrinter:
     def __init__(self, indent_width=3):
         self._indent_width = indent_width
diff --git a/pystencils/nbackend/jit/cpu_extension_module.py b/pystencils/nbackend/jit/cpu_extension_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d021c5ddbb3e952a5fa2c6877f934262e590aa
--- /dev/null
+++ b/pystencils/nbackend/jit/cpu_extension_module.py
@@ -0,0 +1,286 @@
+from __future__ import annotations
+
+from typing import Any
+
+from os import path
+import hashlib
+
+from itertools import chain
+
+from ..exceptions import PsInternalCompilerError
+from ..ast import PsKernelFunction
+from ..ast.constraints import PsParamConstraint
+from ..typed_expressions import PsTypedVariable
+from ..arrays import (
+    PsLinearizedArray,
+    PsArrayAssocVar,
+    PsArrayBasePointer,
+    PsArrayShapeVar,
+    PsArrayStrideVar,
+)
+from ..types import PsAbstractType
+from ..types.quick import Fp, SInt, UInt
+from ..emission import emit_code
+
+
+class PsKernelExtensioNModule:
+    """Replacement for `pystencils.cpu.cpujit.ExtensionModuleCode`.
+    Conforms to its interface for plug-in to `compile_and_load`.
+    """
+
+    def __init__(
+        self, module_name: str = "generated", custom_backend: Any = None
+    ) -> None:
+        self._module_name = module_name
+
+        if custom_backend is not None:
+            raise PsInternalCompilerError(
+                "The `custom_backend` parameter exists only for interface compatibility and cannot be set."
+            )
+
+        self._kernels: dict[str, PsKernelFunction] = dict()
+        self._code_string: str | None = None
+        self._code_hash: str | None = None
+
+    @property
+    def module_name(self) -> str:
+        return self._module_name
+
+    def add_function(self, kernel_function: PsKernelFunction, name: str | None = None):
+        if name is None:
+            name = kernel_function.name
+
+        self._kernels[name] = kernel_function
+
+    def create_code_string(self, restrict_qualifier: str, function_prefix: str):
+        code = ""
+
+        #   Collect headers
+        headers = {"<math.h>", "<stdint.h>"}
+        for kernel in self._kernels.values():
+            headers |= kernel.get_required_headers()
+
+        header_list = sorted(headers)
+        header_list.insert(0, '"Python.h"')
+
+        from pystencils.include import get_pystencils_include_path
+
+        ps_incl_path = get_pystencils_include_path()
+
+        ps_headers = []
+        for header in header_list:
+            header = header[1:-1]
+            header_path = path.join(ps_incl_path, header)
+            if path.exists(header_path):
+                ps_headers.append(header_path)
+
+        header_hash = b"".join(
+            [hashlib.sha256(open(h, "rb").read()).digest() for h in ps_headers]
+        )
+
+        #   Prelude: Includes and definitions
+
+        includes = "\n".join(f"#include {header}" for header in header_list)
+
+        code += includes
+        code += "\n"
+        code += f"#define RESTRICT {restrict_qualifier}\n"
+        code += f"#define FUNC_PREFIX {function_prefix}\n"
+        code += "\n"
+
+        #   Kernels and call wrappers
+
+        for name, kernel in self._kernels.items():
+            old_name = kernel.name
+            kernel.name = f"kernel_{name}"
+
+            code += emit_code(kernel)
+            code += "\n"
+            code += emit_call_wrapper(name, kernel)
+            code += "\n"
+
+            kernel.name = old_name
+
+        self._code_hash = (
+            "mod_" + hashlib.sha256(code.encode() + header_hash).hexdigest()
+        )
+
+        from ...cpu.cpujit import create_module_boilerplate_code
+
+        code += create_module_boilerplate_code(self._code_hash, self._kernels.keys())
+
+    def get_hash_of_code(self):
+        assert self._code_string is not None, "The code must be generated first"
+        return self._code_hash
+
+    def write_to_file(self, file):
+        assert self._code_string is not None, "The code must be generated first"
+        print(self._code_string, file=file)
+
+
+def emit_call_wrapper(function_name: str, kernel: PsKernelFunction) -> str:
+    builder = CallWrapperBuilder()
+    params_spec = kernel.get_parameters()
+
+    for p in params_spec.params:
+        builder.extract_parameter(p)
+
+    for c in params_spec.constraints:
+        builder.check_constraint(c)
+
+    builder.call(kernel, params_spec.params)
+
+    return builder.resolve(function_name)
+
+
+class CallWrapperBuilder:
+    TMPL_EXTRACT_SCALAR = """
+PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
+if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
+{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
+if( PyErr_Occurred() ) {{ return NULL; }}
+"""
+
+    TMPL_EXTRACT_ARRAY = """
+PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
+if( obj_{name} == NULL) {{  PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
+Py_buffer buffer_{name};
+int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
+if (buffer_{name}_res == -1) {{ return NULL; }}
+"""
+
+    KWCHECK = """
+if( !kwargs || !PyDict_Check(kwargs) ) {{ 
+        PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); 
+        return NULL; 
+    }}
+"""
+
+    def __init__(self) -> None:
+        self._array_buffers: dict[PsLinearizedArray, str] = dict()
+        self._array_extractions: dict[PsLinearizedArray, str] = dict()
+        self._array_frees: dict[PsLinearizedArray, str] = dict()
+
+        self._array_assoc_var_extractions: dict[PsArrayAssocVar, str] = dict()
+        self._scalar_extractions: dict[PsTypedVariable, str] = dict()
+
+        self._constraint_checks: list[str] = []
+
+        self._call: str | None = None
+
+    def _scalar_extractor(self, dtype: PsAbstractType) -> str:
+        match dtype:
+            case Fp(32) | Fp(64):
+                return "PyFloat_AsDouble"
+            case SInt():
+                return "PyLong_AsLong"
+            case UInt():
+                return "PyLong_AsUnsignedLong"
+
+            case _:
+                raise PsInternalCompilerError(
+                    f"Don't know how to cast Python objects to {dtype}"
+                )
+
+    def extract_array(self, arr: PsLinearizedArray) -> str:
+        """Adds an array, and returns the name of the underlying Py_Buffer."""
+        if arr not in self._array_extractions:
+            extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=arr.name)
+            self._array_buffers[arr] = f"buffer_{arr.name}"
+            self._array_extractions[arr] = extraction_code
+
+            release_code = f"PyBuffer_Release(&buffer_{arr.name});"
+            self._array_frees[arr] = release_code
+
+        return self._array_buffers[arr]
+
+    def extract_scalar(self, variable: PsTypedVariable) -> str:
+        if variable not in self._scalar_extractions:
+            self.TMPL_EXTRACT_SCALAR.format()
+
+            extract_func = self._scalar_extractor(variable.dtype)
+            code = self.TMPL_EXTRACT_SCALAR.format(
+                name=variable.name,
+                target_type=str(variable.dtype),
+                extract_function=extract_func,
+            )
+            self._scalar_extractions[variable] = code
+
+        return variable.name
+
+    def extract_array_assoc_var(self, variable: PsArrayAssocVar) -> str:
+        if variable not in self._array_assoc_var_extractions:
+            arr = variable.array
+            buffer = self.extract_array(arr)
+            match variable:
+                case PsArrayBasePointer():
+                    code = f"{variable.dtype} {variable.name} = ({variable.dtype}) {buffer}.buf;"
+                case PsArrayShapeVar():
+                    coord = variable.coordinate
+                    code = (
+                        f"{variable.dtype} {variable.name} = "
+                        f"{buffer}.shape[{coord}] / {arr.element_type.itemsize};"
+                    )
+                case PsArrayStrideVar():
+                    coord = variable.coordinate
+                    code = (
+                        f"{variable.dtype} {variable.name} = "
+                        f"{buffer}.strides[{coord}] / {arr.element_type.itemsize};"
+                    )
+                case _:
+                    assert False, "unreachable code"
+
+            self._array_assoc_var_extractions[variable] = code
+
+        return variable.name
+
+    def extract_parameter(self, variable: PsTypedVariable):
+        match variable:
+            case PsArrayAssocVar():
+                self.extract_array_assoc_var(variable)
+            case PsTypedVariable():
+                self.extract_scalar(variable)
+            case _:
+                assert False, "Invalid variable encountered."
+
+    def check_constraint(self, constraint: PsParamConstraint):
+        variables = constraint.get_variables()
+
+        for var in variables:
+            self.extract_parameter(var)
+
+        cond = constraint.print_c_condition()
+
+        code = f"""
+if(!({cond}))
+{{
+    PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); 
+    return NULL;
+}}
+"""
+
+        self._constraint_checks.append(code)
+
+    def call(self, kernel: PsKernelFunction, params: tuple[PsTypedVariable, ...]):
+        param_list = ", ".join(p.name for p in params)
+        self._call = f"{kernel.name} ({param_list});"
+
+    def resolve(self, function_name) -> str:
+        assert self._call is not None
+
+        body = "\n".join(
+            chain(
+                [self.KWCHECK],
+                self._scalar_extractions.values(),
+                self._array_extractions.values(),
+                self._array_assoc_var_extractions.values(),
+                [self._call],
+                self._array_frees.values(),
+                ["Py_RETURN_NONE;"],
+            )
+        )
+
+        code = f"static PyObject * {function_name}(PyObject * self, PyObject * args, PyObject * kwargs)\n"
+        code += "{\n" + body + "\n}\n"
+
+        return code
diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py
index e878fbfcee0b03f4444382672a624fad10f5679e..657eaf4f0845f16a784418769597b07382d6c0c8 100644
--- a/pystencils/nbackend/typed_expressions.py
+++ b/pystencils/nbackend/typed_expressions.py
@@ -13,10 +13,18 @@ from .types import (
 
 
 class PsTypedVariable(pb.Variable):
+
+    init_arg_names: tuple[str, ...] = ("name", "dtype")
+
+    __match_args__ = ("name", "dtype")
+
     def __init__(self, name: str, dtype: PsAbstractType):
         super(PsTypedVariable, self).__init__(name)
         self._dtype = dtype
 
+    def __getinitargs__(self):
+        return self.name, self._dtype
+
     @property
     def dtype(self) -> PsAbstractType:
         return self._dtype
diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py
index 412189968ce7dcdc90db507a2e2fad5b91c3514a..45deb88a001b754338b21b46d44cfce131aeed59 100644
--- a/pystencils/nbackend/types/basic_types.py
+++ b/pystencils/nbackend/types/basic_types.py
@@ -185,6 +185,11 @@ class PsScalarType(PsNumericType, ABC):
 
     def is_float(self) -> bool:
         return isinstance(self, PsIeeeFloatType)
+    
+    @property
+    @abstractmethod
+    def itemsize(self) -> int:
+        """Size of this type's elements in bytes."""
 
 
 class PsIntegerType(PsScalarType, ABC):
@@ -216,6 +221,10 @@ class PsIntegerType(PsScalarType, ABC):
     @property
     def signed(self) -> bool:
         return self._signed
+    
+    @property
+    def itemsize(self) -> int:
+        return self.width // 8
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsIntegerType):
@@ -320,6 +329,10 @@ class PsIeeeFloatType(PsScalarType):
     @property
     def width(self) -> int:
         return self._width
+    
+    @property
+    def itemsize(self) -> int:
+        return self.width // 8
 
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
diff --git a/pystencils_tests/nbackend/test_basic_printing.py b/pystencils_tests/nbackend/test_basic_printing.py
index 8679211146633fdd48bd8aadf0128cfe1fb9d012..4031c12feaf542d562dcd6e18e93de3e0b6dd6c3 100644
--- a/pystencils_tests/nbackend/test_basic_printing.py
+++ b/pystencils_tests/nbackend/test_basic_printing.py
@@ -5,7 +5,7 @@ from pystencils import Target
 from pystencils.nbackend.ast import *
 from pystencils.nbackend.typed_expressions import *
 from pystencils.nbackend.types.quick import *
-from pystencils.nbackend.c_printer import CPrinter
+from pystencils.nbackend.emission import CPrinter
 
 def test_basic_kernel():