diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index b3a9e48aaf678189dbcbb514b5e61e3c115fbaca..eb1b60abec079b55a52d11fed1ed85ac1e85f017 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -104,7 +104,7 @@ class CpuJit(JitBase):
     def strict_scalar_types(self, v: bool):
         self._strict_scalar_types = v
 
-    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
+    def compile(self, kernel: Kernel) -> KernelWrapper:
         #   Get the Code
         module_name = f"{kernel.function_name}_jit"
         cpp_code = self._ext_module_builder(kernel, module_name)
@@ -149,7 +149,7 @@ class CpuJit(JitBase):
                 module_dir = Path(tmpdir)
                 module = compile_and_load(module_dir)
 
-        return CpuJitKernelWrapper(kernel, module)
+        return self._ext_module_builder.get_wrapper(kernel, module)
 
     def _compile_extension_module(self, src_file: Path, libfile: Path):
         args = (
@@ -179,18 +179,6 @@ class CpuJit(JitBase):
         return mod
 
 
-class CpuJitKernelWrapper(KernelWrapper):
-    def __init__(self, kernel: Kernel, jit_module: ModuleType):
-        super().__init__(kernel)
-        self._module = jit_module
-        self._check_params = getattr(jit_module, "check_params")
-        self._invoke = getattr(jit_module, "invoke")
-
-    def __call__(self, **kwargs) -> None:
-        self._check_params(**kwargs)
-        return self._invoke(**kwargs)
-
-
 class ExtensionModuleBuilderBase(ABC):
     @staticmethod
     @abstractmethod
@@ -198,3 +186,6 @@ class ExtensionModuleBuilderBase(ABC):
 
     @abstractmethod
     def __call__(self, kernel: Kernel, module_name: str) -> str: ...
+
+    @abstractmethod
+    def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ...
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index eff3a061f2f4fef4d03a56183553ebedd33d0c6d..ba9065c8a552c4ad70ce6179ce5339daef637253 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -1,9 +1,12 @@
 from __future__ import annotations
 
+from types import ModuleType
 from typing import Sequence, cast
 from pathlib import Path
 from textwrap import indent
 
+from pystencils.jit.jit import KernelWrapper
+
 from ...types import PsPointerType, PsType
 from ...field import Field
 from ...sympyextensions import DynamicType
@@ -69,6 +72,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
             kernel_definition=kernel_def,
         )
         return code_str
+    
+    def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper:
+        return Pybind11KernelWrapper(kernel, extension_module)
 
     def _get_kernel_definition(self, kernel: Kernel) -> str:
         from ...backend.emission import CAstPrinter
@@ -158,3 +164,15 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
                 self._extract_stride(param, cast(FieldStride, stride_props.pop()))
             else:
                 self._add_scalar_param(param)
+
+
+class Pybind11KernelWrapper(KernelWrapper):
+    def __init__(self, kernel: Kernel, jit_module: ModuleType):
+        super().__init__(kernel)
+        self._module = jit_module
+        self._check_params = getattr(jit_module, "check_params")
+        self._invoke = getattr(jit_module, "invoke")
+
+    def __call__(self, **kwargs) -> None:
+        self._check_params(**kwargs)
+        return self._invoke(**kwargs)