diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 3a7647907b82a4ee1ddbb72c2c700e42c7547f69..209df1b0e81e3893d69e532dbcd978203b2907d1 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, TypeGuard
 
 from warnings import warn
 from collections.abc import Collection
@@ -28,7 +28,10 @@ class PsOptionsError(Exception):
     """Indicates an option clash in the `CreateKernelConfig`."""
 
 
-class _AUTO_TYPE: ...  # noqa: E701
+class _AUTO_TYPE:
+    @staticmethod
+    def is_auto(val: Any) -> TypeGuard[_AUTO_TYPE]:
+        return val == AUTO
 
 
 AUTO = _AUTO_TYPE()
diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py
index 335ec52d4fae2b744eb0cb17cf4c9c42bcdda94e..5e4dfcbf70e89c543b6da6053709c75bfe421e25 100644
--- a/src/pystencils/jit/cpu/__init__.py
+++ b/src/pystencils/jit/cpu/__init__.py
@@ -1,5 +1,5 @@
 from .compiler_info import GccInfo, ClangInfo
-from .cpu_pybind11 import CpuJit
+from .cpujit import CpuJit
 
 __all__ = [
     "GccInfo",
diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py
deleted file mode 100644
index 0229c5c4d671a412fe3ba3c26711c703459a9d5d..0000000000000000000000000000000000000000
--- a/src/pystencils/jit/cpu/cpu_pybind11.py
+++ /dev/null
@@ -1,276 +0,0 @@
-from __future__ import annotations
-
-from typing import Sequence, cast
-from types import ModuleType
-from pathlib import Path
-from textwrap import indent
-import subprocess
-from copy import copy
-
-from ...types import PsPointerType, PsType
-from ...field import Field
-from ...sympyextensions import DynamicType
-from ..jit import KernelWrapper
-from ...codegen import Kernel, Parameter
-from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
-
-from ..jit import JitError, JitBase
-from .compiler_info import CompilerInfo, GccInfo
-
-
-_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp"
-
-
-class CpuJit(JitBase):
-    """Just-in-time compiler for CPU kernels."""
-
-    def __init__(
-        self,
-        compiler_info: CompilerInfo | None = None,
-        objcache: str | Path | None = None,
-        strict_scalar_types: bool = False,
-    ):
-        self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo()
-        self._strict_scalar_types = strict_scalar_types
-
-        if objcache is None:
-            from appdirs import AppDirs
-
-            dirs = AppDirs(appname="pystencils")
-            self._objcache = Path(dirs.user_cache_dir) / "cpujit"
-        else:
-            self._objcache = Path(objcache)
-
-        #   Include Directories
-        import pybind11 as pb11
-
-        pybind11_include = pb11.get_include()
-
-        import sysconfig
-
-        python_include = sysconfig.get_path("include")
-
-        from ...include import get_pystencils_include_path
-
-        pystencils_include = get_pystencils_include_path()
-
-        self._cxx_fixed_flags = [
-            "-shared",
-            "-fPIC",
-            f"-I{python_include}",
-            f"-I{pybind11_include}",
-            f"-I{pystencils_include}"
-        ]
-
-    @property
-    def objcache(self) -> Path:
-        return self._objcache
-
-    @property
-    def compiler_info(self) -> CompilerInfo:
-        return self._compiler_info
-    
-    @compiler_info.setter
-    def compiler_info(self, info: CompilerInfo):
-        self._compiler_info = info
-
-    @property
-    def strict_scalar_types(self) -> bool:
-        """Enable or disable implicit type casts for scalar parameters.
-
-        If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
-        """
-        return self._strict_scalar_types
-
-    @strict_scalar_types.setter
-    def strict_scalar_types(self, v: bool):
-        self._strict_scalar_types = v
-
-    @property
-    def restrict_qualifier(self) -> str:
-        return self._restrict_qualifier
-    
-    @restrict_qualifier.setter
-    def restrict_qualifier(self, qual: str):
-        self._restrict_qualifier = qual
-
-    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
-        #   Get the Code
-        module_name = f"{kernel.function_name}_jit"
-        modbuilder = KernelModuleBuilder(self, module_name)
-        cpp_code = modbuilder(kernel)
-
-        #   Get compiler information
-        import sysconfig
-
-        so_abi = sysconfig.get_config_var("SOABI")
-        lib_suffix = f"{so_abi}.so"
-
-        #   Compute Code Hash
-        code_utf8 = cpp_code.encode("utf-8")
-        import hashlib
-
-        code_hash = hashlib.sha256(code_utf8)
-        module_stem = f"module_{code_hash.hexdigest()}"
-
-        module_dir = self._objcache
-
-        #   Lock module
-        import fasteners
-
-        lockfile = module_dir / f"{module_stem}.lock"
-        with fasteners.InterProcessLock(lockfile):
-            cpp_file = module_dir / f"{module_stem}.cpp"
-            if not cpp_file.exists():
-                cpp_file.write_bytes(code_utf8)
-
-            lib_file = module_dir / f"{module_stem}.{lib_suffix}"
-            if not lib_file.exists():
-                self._compile_extension_module(cpp_file, lib_file)
-
-            module = self._load_extension_module(module_name, lib_file)
-
-        return CpuJitKernelWrapper(kernel, module)
-
-    def _compile_extension_module(self, src_file: Path, libfile: Path):
-        args = (
-            [self._compiler_info.cxx()]
-            + self._cxx_fixed_flags
-            + self._compiler_info.cxxflags()
-            + ["-o", str(libfile), str(src_file)]
-        )
-
-        result = subprocess.run(args, capture_output=True)
-        if result.returncode != 0:
-            raise JitError(
-                "Compilation failed: C++ compiler terminated with an error.\n"
-                + result.stderr.decode()
-            )
-
-    def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
-        from importlib import util as iutil
-
-        spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
-        if spec is None:
-            raise JitError("Unable to load kernel extension module -- this is probably a bug.")
-        mod = iutil.module_from_spec(spec)
-        spec.loader.exec_module(mod)  # type: ignore
-        return mod
-
-
-class CpuJitKernelWrapper(KernelWrapper):
-    def __init__(self, kernel: Kernel, jit_module: ModuleType):
-        super().__init__(kernel)
-        self._module = jit_module
-        self._wrapper_func = getattr(jit_module, kernel.function_name)
-
-    def __call__(self, **kwargs) -> None:
-        return self._wrapper_func(**kwargs)
-
-
-class KernelModuleBuilder:
-    def __init__(self, jit: CpuJit, module_name: str):
-        self._jit = jit
-        self._module_name = module_name
-
-        self._actual_field_types: dict[Field, PsType] = dict()
-        self._param_binds: list[str] = []
-        self._public_params: list[str] = []
-        self._extraction_lines: list[str] = []
-
-    def __call__(self, kernel: Kernel) -> str:
-        self._handle_params(kernel.parameters)
-        
-        kernel_def = self._get_kernel_definition(kernel)
-        kernel_args = [param.name for param in kernel.parameters]
-        includes = [f"#include {h}" for h in kernel.required_headers]
-
-        from string import Template
-
-        templ = Template(_module_template.read_text())
-        code_str = templ.substitute(
-            includes="\n".join(includes),
-            restrict_qualifier=self._jit.compiler_info.restrict_qualifier(),
-            module_name=self._module_name,
-            kernel_name=kernel.function_name,
-            param_binds=", ".join(self._param_binds),
-            public_params=", ".join(self._public_params),
-            extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
-            kernel_args=", ".join(kernel_args),
-            kernel_definition=kernel_def,
-        )
-        return code_str
-
-    def _get_kernel_definition(self, kernel: Kernel) -> str:
-        from ...backend.emission import CAstPrinter
-        printer = CAstPrinter(func_prefix="inline")
-
-        return printer(kernel)
-
-    def _add_field_param(self, ptr_param: Parameter):
-        field: Field = ptr_param.fields[0]
-
-        ptr_type = ptr_param.dtype
-        assert isinstance(ptr_type, PsPointerType)
-
-        if isinstance(field.dtype, DynamicType):
-            elem_type = ptr_type.base_type
-        else:
-            elem_type = field.dtype
-
-        self._actual_field_types[field] = elem_type
-
-        param_bind = f'py::arg("{field.name}").noconvert()'
-        self._param_binds.append(param_bind)
-
-        kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
-        self._public_params.append(kernel_param)
-
-    def _add_scalar_param(self, sc_param: Parameter):
-        param_bind = f'py::arg("{sc_param.name}")'
-        if self._jit.strict_scalar_types:
-            param_bind += ".noconvert()"
-        self._param_binds.append(param_bind)
-
-        kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
-        self._public_params.append(kernel_param)
-
-    def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
-        field_name = ptr_prop.field.name
-        assert isinstance(ptr_param.dtype, PsPointerType)
-        data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
-        extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
-        self._extraction_lines.append(extraction)
-
-    def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
-        field_name = shape_prop.field.name
-        coord = shape_prop.coordinate
-        extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
-        self._extraction_lines.append(extraction)
-
-    def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
-        field = stride_prop.field
-        field_name = field.name
-        coord = stride_prop.coordinate
-        field_type = self._actual_field_types[field]
-        assert field_type.itemsize is not None
-        extraction = (
-            f"{stride_param.dtype.c_string()} {stride_param.name} "
-            f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
-        )
-        self._extraction_lines.append(extraction)
-
-    def _handle_params(self, parameters: Sequence[Parameter]):
-        for param in parameters:
-            if param.get_properties(FieldBasePtr):
-                self._add_field_param(param)
-
-        for param in parameters:
-            if ptr_props := param.get_properties(FieldBasePtr):
-                self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
-            elif shape_props := param.get_properties(FieldShape):
-                self._extract_shape(param, cast(FieldShape, shape_props.pop()))
-            elif stride_props := param.get_properties(FieldStride):
-                self._extract_stride(param, cast(FieldStride, stride_props.pop()))
-            else:
-                self._add_scalar_param(param)
diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9c39529ae2519f2fbab45fc60d4c522b7b09ec2
--- /dev/null
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -0,0 +1,198 @@
+from __future__ import annotations
+
+from types import ModuleType
+from pathlib import Path
+import subprocess
+from copy import copy
+from abc import ABC, abstractmethod
+
+from ...codegen.config import _AUTO_TYPE, AUTO
+
+from ..jit import JitError, JitBase, KernelWrapper
+from ...codegen import Kernel
+from .compiler_info import CompilerInfo, GccInfo
+
+
+class CpuJit(JitBase):
+    """Just-in-time compiler for CPU kernels.
+
+    :Creation:
+    In most cases, objects of this class should be instantiated using the `create` factory method.
+    
+    :Implementation Details:
+
+    The `CpuJit` combines two separate components:
+    - The *extension module builder* produces the code of the dynamically built extension module containing
+      the kernel
+    - The *compiler info* describes the system compiler used to compile and link that extension module.
+
+    Both can be dynamically exchanged.
+    """
+
+    @staticmethod
+    def create(
+        compiler_info: CompilerInfo | None = None,
+        objcache: str | Path | _AUTO_TYPE | None = AUTO,
+    ):
+        if objcache is AUTO:
+            from appdirs import AppDirs
+
+            dirs = AppDirs(appname="pystencils")
+            objcache = Path(dirs.user_cache_dir) / "cpujit"
+        elif objcache is not None:
+            assert not isinstance(objcache, _AUTO_TYPE)
+            objcache = Path(objcache)
+
+        if compiler_info is None:
+            compiler_info = GccInfo()
+
+        from .cpujit_pybind11 import Pybind11KernelModuleBuilder
+
+        modbuilder = Pybind11KernelModuleBuilder(compiler_info)
+
+        return CpuJit(compiler_info, modbuilder, objcache)
+
+    def __init__(
+        self,
+        compiler_info: CompilerInfo,
+        ext_module_builder: ExtensionModuleBuilderBase,
+        objcache: Path | None,
+    ):
+        self._compiler_info = copy(compiler_info)
+        self._ext_module_builder = ext_module_builder
+
+        self._objcache = objcache
+
+        #   Include Directories
+
+        import sysconfig
+
+        python_include = sysconfig.get_path("include")
+
+        from ...include import get_pystencils_include_path
+
+        pystencils_include = get_pystencils_include_path()
+
+        self._cxx_fixed_flags = [
+            "-shared",
+            "-fPIC",
+            f"-I{python_include}",
+            f"-I{pystencils_include}",
+        ] + self._ext_module_builder.include_flags()
+
+    @property
+    def objcache(self) -> Path | None:
+        return self._objcache
+
+    @property
+    def compiler_info(self) -> CompilerInfo:
+        return self._compiler_info
+
+    @compiler_info.setter
+    def compiler_info(self, info: CompilerInfo):
+        self._compiler_info = info
+
+    @property
+    def strict_scalar_types(self) -> bool:
+        """Enable or disable implicit type casts for scalar parameters.
+
+        If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
+        """
+        return self._strict_scalar_types
+
+    @strict_scalar_types.setter
+    def strict_scalar_types(self, v: bool):
+        self._strict_scalar_types = v
+
+    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
+        #   Get the Code
+        module_name = f"{kernel.function_name}_jit"
+        cpp_code = self._ext_module_builder(kernel, module_name)
+
+        #   Get compiler information
+        import sysconfig
+
+        so_abi = sysconfig.get_config_var("SOABI")
+        lib_suffix = f"{so_abi}.so"
+
+        #   Compute Code Hash
+        code_utf8 = cpp_code.encode("utf-8")
+        import hashlib
+
+        code_hash = hashlib.sha256(code_utf8)
+        module_stem = f"module_{code_hash.hexdigest()}"
+
+        def compile_and_load(module_dir: Path):
+            cpp_file = module_dir / f"{module_stem}.cpp"
+            if not cpp_file.exists():
+                cpp_file.write_bytes(code_utf8)
+
+            lib_file = module_dir / f"{module_stem}.{lib_suffix}"
+            if not lib_file.exists():
+                self._compile_extension_module(cpp_file, lib_file)
+
+            module = self._load_extension_module(module_name, lib_file)
+            return module
+
+        if self._objcache is not None:
+            module_dir = self._objcache
+            #   Lock module
+            import fasteners
+
+            lockfile = module_dir / f"{module_stem}.lock"
+            with fasteners.InterProcessLock(lockfile):
+                module = compile_and_load(module_dir)
+        else:
+            from tempfile import TemporaryDirectory
+
+            with TemporaryDirectory() as tmpdir:
+                module_dir = Path(tmpdir)
+                module = compile_and_load(module_dir)
+
+        return CpuJitKernelWrapper(kernel, module)
+
+    def _compile_extension_module(self, src_file: Path, libfile: Path):
+        args = (
+            [self._compiler_info.cxx()]
+            + self._cxx_fixed_flags
+            + self._compiler_info.cxxflags()
+            + ["-o", str(libfile), str(src_file)]
+        )
+
+        result = subprocess.run(args, capture_output=True)
+        if result.returncode != 0:
+            raise JitError(
+                "Compilation failed: C++ compiler terminated with an error.\n"
+                + result.stderr.decode()
+            )
+
+    def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
+        from importlib import util as iutil
+
+        spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
+        if spec is None:
+            raise JitError(
+                "Unable to load kernel extension module -- this is probably a bug."
+            )
+        mod = iutil.module_from_spec(spec)
+        spec.loader.exec_module(mod)  # type: ignore
+        return mod
+
+
+class CpuJitKernelWrapper(KernelWrapper):
+    def __init__(self, kernel: Kernel, jit_module: ModuleType):
+        super().__init__(kernel)
+        self._module = jit_module
+        self._wrapper_func = getattr(jit_module, kernel.function_name)
+
+    def __call__(self, **kwargs) -> None:
+        return self._wrapper_func(**kwargs)
+
+
+class ExtensionModuleBuilderBase(ABC):
+    @staticmethod
+    @abstractmethod
+    def include_flags() -> list[str]: ...
+
+    @abstractmethod
+    def __call__(self, kernel: Kernel, module_name: str) -> str: ...
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
new file mode 100644
index 0000000000000000000000000000000000000000..aee2e9f996d3147fd06435aaac60bbcb728083a2
--- /dev/null
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+from typing import Sequence, cast
+from pathlib import Path
+from textwrap import indent
+
+from ...types import PsPointerType, PsType
+from ...field import Field
+from ...sympyextensions import DynamicType
+from ...codegen import Kernel, Parameter
+from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
+
+from .compiler_info import CompilerInfo
+from .cpujit import ExtensionModuleBuilderBase
+
+
+_module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp"
+
+
+class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
+    @staticmethod
+    def include_flags() -> list[str]:
+        import pybind11 as pb11
+
+        pybind11_include = pb11.get_include()
+        return [f"-I{pybind11_include}"]
+
+    def __init__(
+        self,
+        compiler_info: CompilerInfo,
+        strict_scalar_types: bool = False,
+    ):
+        self._compiler_info = compiler_info
+
+        self._strict_scalar_types = strict_scalar_types
+
+        self._actual_field_types: dict[Field, PsType]
+        self._param_binds: list[str]
+        self._public_params: list[str]
+        self._extraction_lines: list[str]
+
+    def __call__(self, kernel: Kernel, module_name: str) -> str:
+        self._actual_field_types = dict()
+        self._param_binds = []
+        self._public_params = []
+        self._extraction_lines = []
+
+        self._handle_params(kernel.parameters)
+
+        kernel_def = self._get_kernel_definition(kernel)
+        kernel_args = [param.name for param in kernel.parameters]
+        includes = [f"#include {h}" for h in kernel.required_headers]
+
+        from string import Template
+
+        templ = Template(_module_template.read_text())
+        code_str = templ.substitute(
+            includes="\n".join(includes),
+            restrict_qualifier=self._compiler_info.restrict_qualifier(),
+            module_name=module_name,
+            kernel_name=kernel.function_name,
+            param_binds=", ".join(self._param_binds),
+            public_params=", ".join(self._public_params),
+            extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
+            kernel_args=", ".join(kernel_args),
+            kernel_definition=kernel_def,
+        )
+        return code_str
+
+    def _get_kernel_definition(self, kernel: Kernel) -> str:
+        from ...backend.emission import CAstPrinter
+
+        printer = CAstPrinter(func_prefix="inline")
+
+        return printer(kernel)
+
+    def _add_field_param(self, ptr_param: Parameter):
+        field: Field = ptr_param.fields[0]
+
+        ptr_type = ptr_param.dtype
+        assert isinstance(ptr_type, PsPointerType)
+
+        if isinstance(field.dtype, DynamicType):
+            elem_type = ptr_type.base_type
+        else:
+            elem_type = field.dtype
+
+        self._actual_field_types[field] = elem_type
+
+        param_bind = f'py::arg("{field.name}").noconvert()'
+        self._param_binds.append(param_bind)
+
+        kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
+        self._public_params.append(kernel_param)
+
+    def _add_scalar_param(self, sc_param: Parameter):
+        param_bind = f'py::arg("{sc_param.name}")'
+        if self._strict_scalar_types:
+            param_bind += ".noconvert()"
+        self._param_binds.append(param_bind)
+
+        kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
+        self._public_params.append(kernel_param)
+
+    def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
+        field_name = ptr_prop.field.name
+        assert isinstance(ptr_param.dtype, PsPointerType)
+        data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
+        extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
+        self._extraction_lines.append(extraction)
+
+    def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
+        field_name = shape_prop.field.name
+        coord = shape_prop.coordinate
+        extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
+        self._extraction_lines.append(extraction)
+
+    def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
+        field = stride_prop.field
+        field_name = field.name
+        coord = stride_prop.coordinate
+        field_type = self._actual_field_types[field]
+        assert field_type.itemsize is not None
+        extraction = (
+            f"{stride_param.dtype.c_string()} {stride_param.name} "
+            f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
+        )
+        self._extraction_lines.append(extraction)
+
+    def _handle_params(self, parameters: Sequence[Parameter]):
+        for param in parameters:
+            if param.get_properties(FieldBasePtr):
+                self._add_field_param(param)
+
+        for param in parameters:
+            if ptr_props := param.get_properties(FieldBasePtr):
+                self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
+            elif shape_props := param.get_properties(FieldShape):
+                self._extract_shape(param, cast(FieldShape, shape_props.pop()))
+            elif stride_props := param.get_properties(FieldStride):
+                self._extract_stride(param, cast(FieldStride, stride_props.pop()))
+            else:
+                self._add_scalar_param(param)
diff --git a/src/pystencils/jit/cpu/kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
similarity index 100%
rename from src/pystencils/jit/cpu/kernel_module.tmpl.cpp
rename to src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index aa4f50f3fca04e0d7177f8d623a44f1eb7f50329..0e17fa6a14c66d104eea1e3a40366e8664040f59 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit
 
 
 def test_basic_cpu_kernel(tmp_path):
-    jit = CpuJit(objcache=tmp_path)
+    jit = CpuJit.create(objcache=tmp_path)
 
     f, g = fields("f, g: [2D]")
     asm = Assignment(f.center(), 2.0 * g.center())