diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index f9c39529ae2519f2fbab45fc60d4c522b7b09ec2..b3a9e48aaf678189dbcbb514b5e61e3c115fbaca 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -183,10 +183,12 @@ class CpuJitKernelWrapper(KernelWrapper):
     def __init__(self, kernel: Kernel, jit_module: ModuleType):
         super().__init__(kernel)
         self._module = jit_module
-        self._wrapper_func = getattr(jit_module, kernel.function_name)
+        self._check_params = getattr(jit_module, "check_params")
+        self._invoke = getattr(jit_module, "invoke")
 
     def __call__(self, **kwargs) -> None:
-        return self._wrapper_func(**kwargs)
+        self._check_params(**kwargs)
+        return self._invoke(**kwargs)
 
 
 class ExtensionModuleBuilderBase(ABC):
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index aee2e9f996d3147fd06435aaac60bbcb728083a2..b68ed9c29ef7899f8a731a95ebfb6d5cf5580e66 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -37,12 +37,14 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
         self._actual_field_types: dict[Field, PsType]
         self._param_binds: list[str]
         self._public_params: list[str]
+        self._param_check_lines: list[str]
         self._extraction_lines: list[str]
 
     def __call__(self, kernel: Kernel, module_name: str) -> str:
         self._actual_field_types = dict()
         self._param_binds = []
         self._public_params = []
+        self._param_check_lines = []
         self._extraction_lines = []
 
         self._handle_params(kernel.parameters)
@@ -61,6 +63,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
             kernel_name=kernel.function_name,
             param_binds=", ".join(self._param_binds),
             public_params=", ".join(self._public_params),
+            param_check_lines=indent("\n".join(self._param_check_lines), prefix="    "),
             extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
             kernel_args=", ".join(kernel_args),
             kernel_definition=kernel_def,
@@ -93,6 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
         kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
         self._public_params.append(kernel_param)
 
+        for coord, size in enumerate(field.shape):
+            if isinstance(size, int):
+                self._param_check_lines.append(
+                    f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});"
+                )
+
+        for coord, stride in enumerate(field.strides):
+            if isinstance(stride, int):
+                self._param_check_lines.append(
+                    f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});"
+                )
+
     def _add_scalar_param(self, sc_param: Parameter):
         param_bind = f'py::arg("{sc_param.name}")'
         if self._strict_scalar_types:
diff --git a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
index 3ee5c6973a761a6acbabf7864df947de38e4289c..acc9d05dea9f7df70f97e8ce3ce8c95fa745470b 100644
--- a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
+++ b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
@@ -1,6 +1,10 @@
 #include "pybind11/pybind11.h"
 #include "pybind11/numpy.h"
 
+#include <array>
+#include <string>
+#include <sstream>
+
 ${includes}
 
 namespace py = pybind11;
@@ -13,11 +17,50 @@ ${kernel_definition}
 
 }
 
-void callwrapper_${kernel_name} (${public_params}) {
+template< typename T >
+void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
+    auto panic = [&](){
+        std::stringstream err;
+        err << "Invalid shape of argument " << fieldName;
+        throw py::value_error{ err.str() };
+    };
+    
+    if(arr.ndim() <= coord){
+        panic();
+    }
+
+    if(arr.shape(coord) != desired){
+        panic();
+    }
+}
+
+template< typename T >
+void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
+    auto panic = [&](){
+        std::stringstream err;
+        err << "Invalid strides of argument " << fieldName;
+        throw py::value_error{ err.str() };
+    };
+    
+    if(arr.ndim() <= coord){
+        panic();
+    }
+
+    if(arr.strides(coord) / sizeof(T) != desired){
+        panic();
+    }
+}
+
+void check_params_${kernel_name} (${public_params}) {
+${param_check_lines}
+}
+
+void run_${kernel_name} (${public_params}) {
 ${extraction_lines}
     internal::${kernel_name}(${kernel_args});
 }
 
 PYBIND11_MODULE(${module_name}, m) {
-    m.def("${kernel_name}", &callwrapper_${kernel_name}, py::kw_only(), ${param_binds});
+    m.def("check_params", &check_params_${kernel_name}, py::kw_only(), ${param_binds});
+    m.def("invoke", &run_${kernel_name}, py::kw_only(), ${param_binds});
 }
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index 3d269d837f5cab2a12cfeb8b3b0be141763f9490..c8e75f8625129891e12ea592d7efd8bbac02636b 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -2,13 +2,13 @@ import pytest
 
 import sympy as sp
 import numpy as np
-from pystencils import create_kernel, Assignment, fields
+from pystencils import create_kernel, Assignment, fields, Field
 from pystencils.jit import CpuJit
 
 
 @pytest.fixture
 def cpu_jit(tmp_path) -> CpuJit:
-    return CpuJit.create(objcache=tmp_path)
+    return CpuJit.create(objcache=".jit")
 
 
 def test_basic_cpu_kernel(cpu_jit):
@@ -48,3 +48,23 @@ def test_argument_type_error(cpu_jit):
 
     #   Wrong scalar types are OK, though
     kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0))
+
+
+def test_fixed_shape(cpu_jit):
+    a = np.zeros((12, 23), dtype="float64")
+    b = np.zeros((13, 21), dtype="float64")
+    
+    f = Field.create_from_numpy_array("f", a)
+    g = Field.create_from_numpy_array("g", a)
+
+    asm = Assignment(f.center(), 2.0 * g.center())
+    ker = create_kernel(asm)
+    kfunc = cpu_jit.compile(ker)
+
+    kfunc(f=a, g=a)
+
+    with pytest.raises(ValueError):
+        kfunc(f=b, g=a)
+
+    with pytest.raises(ValueError):
+        kfunc(f=a, g=b)