From 6efa0439757e01a87f402e956353036b6d5d47b0 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 27 Jan 2025 10:48:37 +0100
Subject: [PATCH 01/17] Implement pybind11-based CPU JIT compiler + basic
 testing.

---
 mypy.ini                                      |   3 +
 pyproject.toml                                |   3 +-
 .../backend/emission/base_printer.py          |  18 +-
 src/pystencils/backend/emission/c_printer.py  |   4 -
 src/pystencils/jit/__init__.py                |   2 +
 src/pystencils/jit/cpu/__init__.py            |   5 +
 src/pystencils/jit/cpu/cpu_pybind11.py        | 281 ++++++++++++++++++
 src/pystencils/jit/cpu/kernel_module.tmpl.cpp |  23 ++
 src/pystencils/sympyextensions/__init__.py    |   5 +-
 tests/jit/test_cpujit.py                      |  21 ++
 10 files changed, 345 insertions(+), 20 deletions(-)
 create mode 100644 src/pystencils/jit/cpu/__init__.py
 create mode 100644 src/pystencils/jit/cpu/cpu_pybind11.py
 create mode 100644 src/pystencils/jit/cpu/kernel_module.tmpl.cpp
 create mode 100644 tests/jit/test_cpujit.py

diff --git a/mypy.ini b/mypy.ini
index cc23a503a..a0533a60a 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -31,3 +31,6 @@ ignore_missing_imports=true
 
 [mypy-cpuinfo.*]
 ignore_missing_imports=true
+
+[mypy-fasteners.*]
+ignore_missing_imports=true
diff --git a/pyproject.toml b/pyproject.toml
index 59e71b8db..b3c6b1c02 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@ authors = [
 ]
 license = { file = "COPYING.txt" }
 requires-python = ">=3.10"
-dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
+dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml", "pybind11", "fasteners"]
 classifiers = [
     "Development Status :: 4 - Beta",
     "Framework :: Jupyter",
@@ -90,6 +90,7 @@ build-backend = "setuptools.build_meta"
 [tool.setuptools.package-data]
 pystencils = [
     "include/*.h",
+    "jit/cpu/*.tmpl.cpp",
     "boundaries/createindexlistcython.pyx"
 ]
 
diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py
index a4358bbf3..adb9c232b 100644
--- a/src/pystencils/backend/emission/base_printer.py
+++ b/src/pystencils/backend/emission/base_printer.py
@@ -3,8 +3,6 @@ from enum import Enum
 from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING
 
-from ...codegen import Target
-
 from ..ast.structural import (
     PsAstNode,
     PsBlock,
@@ -171,8 +169,9 @@ class BasePrinter(ABC):
     and in `IRAstPrinter` for debug-printing the entire IR.
     """
 
-    def __init__(self, indent_width=3):
+    def __init__(self, indent_width=3, func_prefix: str | None = None):
         self._indent_width = indent_width
+        self._func_prefix = func_prefix
 
     def __call__(self, obj: PsAstNode | Kernel) -> str:
         from ...codegen import Kernel
@@ -376,21 +375,14 @@ class BasePrinter(ABC):
                 )
 
     def print_signature(self, func: Kernel) -> str:
-        prefix = self._func_prefix(func)
+        prefix = self._func_prefix
         params_str = ", ".join(
             f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters
         )
-        signature = " ".join([prefix, "void", func.name, f"({params_str})"])
+        sig_parts = ([prefix] if prefix is not None else []) + ["void", func.name, f"({params_str})"]
+        signature = " ".join(sig_parts)
         return signature
 
-    def _func_prefix(self, func: Kernel):
-        from ...codegen import GpuKernel
-
-        if isinstance(func, GpuKernel) and func.target == Target.CUDA:
-            return "__global__"
-        else:
-            return "FUNC_PREFIX"
-
     @abstractmethod
     def _symbol_decl(self, symb: PsSymbol) -> str:
         pass
diff --git a/src/pystencils/backend/emission/c_printer.py b/src/pystencils/backend/emission/c_printer.py
index 90a7e54e2..40cd69283 100644
--- a/src/pystencils/backend/emission/c_printer.py
+++ b/src/pystencils/backend/emission/c_printer.py
@@ -21,10 +21,6 @@ def emit_code(ast: PsAstNode | Kernel):
 
 
 class CAstPrinter(BasePrinter):
-
-    def __init__(self, indent_width=3):
-        super().__init__(indent_width)
-
     def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
         match node:
             case PsVecMemAcc():
diff --git a/src/pystencils/jit/__init__.py b/src/pystencils/jit/__init__.py
index 1ef8378d3..3ae63fa72 100644
--- a/src/pystencils/jit/__init__.py
+++ b/src/pystencils/jit/__init__.py
@@ -24,6 +24,7 @@ It is due to be replaced in the near future.
 
 from .jit import JitBase, NoJit, KernelWrapper
 from .legacy_cpu import LegacyCpuJit
+from .cpu import CpuJit
 from .gpu_cupy import CupyJit, CupyKernelWrapper, LaunchGrid
 
 no_jit = NoJit()
@@ -33,6 +34,7 @@ __all__ = [
     "JitBase",
     "KernelWrapper",
     "LegacyCpuJit",
+    "CpuJit",
     "NoJit",
     "no_jit",
     "CupyJit",
diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py
new file mode 100644
index 000000000..69985c914
--- /dev/null
+++ b/src/pystencils/jit/cpu/__init__.py
@@ -0,0 +1,5 @@
+from .cpu_pybind11 import CpuJit
+
+__all__ = [
+    "CpuJit"
+]
diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py
new file mode 100644
index 000000000..128313a74
--- /dev/null
+++ b/src/pystencils/jit/cpu/cpu_pybind11.py
@@ -0,0 +1,281 @@
+from __future__ import annotations
+
+from typing import Sequence, cast
+from types import ModuleType
+from pathlib import Path
+from textwrap import indent
+import subprocess
+
+from ...types import PsPointerType, PsType
+from ...field import Field
+from ...sympyextensions import DynamicType
+from ..jit import KernelWrapper
+from ...codegen import Kernel, Parameter
+from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
+
+from ..jit import JitError, JitBase
+
+
+_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp"
+
+
+class CpuJit(JitBase):
+    """Just-in-time compiler for CPU kernels."""
+
+    def __init__(
+        self,
+        cxx: str = "g++",
+        cxxflags: Sequence[str] = ("-O3", "-fopenmp"),
+        objcache: str | Path | None = None,
+        strict_scalar_types: bool = False,
+    ):
+        self._cxx = cxx
+        self._cxxflags: list[str] = list(cxxflags)
+
+        self._strict_scalar_types = strict_scalar_types
+        self._restrict_qualifier = "__restrict__"
+
+        if objcache is None:
+            from appdirs import AppDirs
+
+            dirs = AppDirs(appname="pystencils")
+            self._objcache = Path(dirs.user_cache_dir) / "cpujit"
+        else:
+            self._objcache = Path(objcache)
+
+        #   Include Directories
+        import pybind11 as pb11
+
+        self._pybind11_include = pb11.get_include()
+
+        import sysconfig
+
+        self._py_include = sysconfig.get_path("include")
+
+        self._cxx_fixed_flags = [
+            "-shared",
+            "-fPIC",
+            f"-I{self._py_include}",
+            f"-I{self._pybind11_include}",
+        ]
+
+    @property
+    def objcache(self) -> Path:
+        return self._objcache
+
+    @property
+    def cxx(self) -> str:
+        return self._cxx
+
+    @cxx.setter
+    def cxx(self, path: str):
+        self._cxx = path
+
+    @property
+    def cxxflags(self) -> list[str]:
+        return self._cxxflags
+
+    @cxxflags.setter
+    def cxxflags(self, flags: Sequence[str]):
+        self._cxxflags = list(flags)
+
+    @property
+    def strict_scalar_types(self) -> bool:
+        """Enable or disable implicit type casts for scalar parameters.
+
+        If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
+        """
+        return self._strict_scalar_types
+
+    @strict_scalar_types.setter
+    def strict_scalar_types(self, v: bool):
+        self._strict_scalar_types = v
+
+    @property
+    def restrict_qualifier(self) -> str:
+        return self._restrict_qualifier
+    
+    @restrict_qualifier.setter
+    def restrict_qualifier(self, qual: str):
+        self._restrict_qualifier = qual
+
+    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
+        #   Get the Code
+        module_name = f"{kernel.function_name}_jit"
+        modbuilder = KernelModuleBuilder(self, module_name)
+        cpp_code = modbuilder(kernel)
+
+        #   Get compiler information
+        import sysconfig
+
+        so_abi = sysconfig.get_config_var("SOABI")
+        lib_suffix = f"{so_abi}.so"
+
+        #   Compute Code Hash
+        code_utf8 = cpp_code.encode("utf-8")
+        import hashlib
+
+        code_hash = hashlib.sha256(code_utf8)
+        module_stem = f"module_{code_hash.hexdigest()}"
+
+        module_dir = self._objcache
+
+        #   Lock module
+        import fasteners
+
+        lockfile = module_dir / f"{module_stem}.lock"
+        with fasteners.InterProcessLock(lockfile):
+            cpp_file = module_dir / f"{module_stem}.cpp"
+            if not cpp_file.exists():
+                cpp_file.write_bytes(code_utf8)
+
+            lib_file = module_dir / f"{module_stem}.{lib_suffix}"
+            if not lib_file.exists():
+                self._compile_extension_module(cpp_file, lib_file)
+
+            module = self._load_extension_module(module_name, lib_file)
+
+        return CpuJitKernelWrapper(kernel, module)
+
+    def _compile_extension_module(self, src_file: Path, libfile: Path):
+        args = (
+            [self._cxx]
+            + self._cxx_fixed_flags
+            + self._cxxflags
+            + ["-o", str(libfile), str(src_file)]
+        )
+
+        result = subprocess.run(args, capture_output=True)
+        if result.returncode != 0:
+            raise JitError(
+                "Compilation failed: C++ compiler terminated with an error.\n"
+                + result.stderr.decode()
+            )
+
+    def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
+        from importlib import util as iutil
+
+        spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
+        if spec is None:
+            raise JitError("Unable to load kernel extension module -- this is probably a bug.")
+        mod = iutil.module_from_spec(spec)
+        spec.loader.exec_module(mod)  # type: ignore
+        return mod
+
+
+class CpuJitKernelWrapper(KernelWrapper):
+    def __init__(self, kernel: Kernel, jit_module: ModuleType):
+        super().__init__(kernel)
+        self._module = jit_module
+        self._wrapper_func = getattr(jit_module, kernel.function_name)
+
+    def __call__(self, **kwargs) -> None:
+        return self._wrapper_func(**kwargs)
+
+
+class KernelModuleBuilder:
+    def __init__(self, jit: CpuJit, module_name: str):
+        self._jit = jit
+        self._module_name = module_name
+
+        self._actual_field_types: dict[Field, PsType] = dict()
+        self._param_binds: list[str] = []
+        self._public_params: list[str] = []
+        self._extraction_lines: list[str] = []
+
+    def __call__(self, kernel: Kernel) -> str:
+        self._handle_params(kernel.parameters)
+        
+        kernel_def = self._get_kernel_definition(kernel)
+        kernel_args = [param.name for param in kernel.parameters]
+        includes = [f"#include {h}" for h in kernel.required_headers]
+
+        from string import Template
+
+        templ = Template(_module_template.read_text())
+        code_str = templ.substitute(
+            includes="\n".join(includes),
+            restrict_qualifier=self._jit.restrict_qualifier,
+            module_name=self._module_name,
+            kernel_name=kernel.function_name,
+            param_binds=", ".join(self._param_binds),
+            public_params=", ".join(self._public_params),
+            extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
+            kernel_args=", ".join(kernel_args),
+            kernel_definition=kernel_def,
+        )
+        return code_str
+
+    def _get_kernel_definition(self, kernel: Kernel) -> str:
+        from ...backend.emission import CAstPrinter
+        printer = CAstPrinter(func_prefix="inline")
+
+        return printer(kernel)
+
+    def _add_field_param(self, ptr_param: Parameter):
+        field: Field = ptr_param.fields[0]
+
+        ptr_type = ptr_param.dtype
+        assert isinstance(ptr_type, PsPointerType)
+
+        if isinstance(field.dtype, DynamicType):
+            elem_type = ptr_type.base_type
+        else:
+            elem_type = field.dtype
+
+        self._actual_field_types[field] = elem_type
+
+        param_bind = f'py::arg("{field.name}").noconvert()'
+        self._param_binds.append(param_bind)
+
+        kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
+        self._public_params.append(kernel_param)
+
+    def _add_scalar_param(self, sc_param: Parameter):
+        param_bind = f'py::arg("{sc_param.name}")'
+        if self._jit.strict_scalar_types:
+            param_bind += ".noconvert()"
+        self._param_binds.append(param_bind)
+
+        kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
+        self._public_params.append(kernel_param)
+
+    def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
+        field_name = ptr_prop.field.name
+        assert isinstance(ptr_param.dtype, PsPointerType)
+        data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
+        extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
+        self._extraction_lines.append(extraction)
+
+    def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
+        field_name = shape_prop.field.name
+        coord = shape_prop.coordinate
+        extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
+        self._extraction_lines.append(extraction)
+
+    def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
+        field = stride_prop.field
+        field_name = field.name
+        coord = stride_prop.coordinate
+        field_type = self._actual_field_types[field]
+        assert field_type.itemsize is not None
+        extraction = (
+            f"{stride_param.dtype.c_string()} {stride_param.name} "
+            f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
+        )
+        self._extraction_lines.append(extraction)
+
+    def _handle_params(self, parameters: Sequence[Parameter]):
+        for param in parameters:
+            if param.get_properties(FieldBasePtr):
+                self._add_field_param(param)
+
+        for param in parameters:
+            if ptr_props := param.get_properties(FieldBasePtr):
+                self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
+            elif shape_props := param.get_properties(FieldShape):
+                self._extract_shape(param, cast(FieldShape, shape_props.pop()))
+            elif stride_props := param.get_properties(FieldStride):
+                self._extract_stride(param, cast(FieldStride, stride_props.pop()))
+            else:
+                self._add_scalar_param(param)
diff --git a/src/pystencils/jit/cpu/kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/kernel_module.tmpl.cpp
new file mode 100644
index 000000000..3ee5c6973
--- /dev/null
+++ b/src/pystencils/jit/cpu/kernel_module.tmpl.cpp
@@ -0,0 +1,23 @@
+#include "pybind11/pybind11.h"
+#include "pybind11/numpy.h"
+
+${includes}
+
+namespace py = pybind11;
+
+#define RESTRICT ${restrict_qualifier}
+
+namespace internal {
+
+${kernel_definition}
+
+}
+
+void callwrapper_${kernel_name} (${public_params}) {
+${extraction_lines}
+    internal::${kernel_name}(${kernel_args});
+}
+
+PYBIND11_MODULE(${module_name}, m) {
+    m.def("${kernel_name}", &callwrapper_${kernel_name}, py::kw_only(), ${param_binds});
+}
diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py
index 7431416c9..d01dbe57e 100644
--- a/src/pystencils/sympyextensions/__init__.py
+++ b/src/pystencils/sympyextensions/__init__.py
@@ -1,5 +1,5 @@
 from .astnodes import ConditionalFieldAccess
-from .typed_sympy import TypedSymbol, CastFunc
+from .typed_sympy import TypedSymbol, CastFunc, DynamicType
 from .pointers import mem_acc
 
 from .math import (
@@ -61,5 +61,6 @@ __all__ = [
     "count_operations_in_ast",
     "common_denominator",
     "get_symmetric_part",
-    "SymbolCreator"
+    "SymbolCreator",
+    "DynamicType"
 ]
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
new file mode 100644
index 000000000..0b09d66b0
--- /dev/null
+++ b/tests/jit/test_cpujit.py
@@ -0,0 +1,21 @@
+import sympy as sp
+import numpy as np
+from pystencils import create_kernel, Assignment, fields
+from pystencils.jit import CpuJit
+
+
+def test_basic_cpu_kernel(tmp_path):
+    jit = CpuJit("g++", ["-O3"], objcache=tmp_path)
+
+    f, g = fields("f, g: [2D]")
+    asm = Assignment(f.center(), 2.0 * g.center())
+    ker = create_kernel(asm)
+    kfunc = jit.compile(ker)
+
+    rng = np.random.default_rng()
+    f_arr = rng.random(size=(34, 26), dtype="float64")
+    g_arr = np.zeros_like(f_arr)
+
+    kfunc(f=f_arr, g=g_arr)
+
+    np.testing.assert_almost_equal(g_arr, 2.0 * f_arr)
-- 
GitLab


From 4370b2bcb24eb24a9b352fdecb8d9b415c76a200 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 27 Jan 2025 16:42:00 +0100
Subject: [PATCH 02/17] add compiler info classes for gcc and clang

---
 src/pystencils/jit/cpu/__init__.py      |  3 ++
 src/pystencils/jit/cpu/compiler_info.py | 72 +++++++++++++++++++++++++
 src/pystencils/jit/cpu/cpu_pybind11.py  | 36 +++++--------
 tests/jit/test_cpujit.py                |  2 +-
 4 files changed, 89 insertions(+), 24 deletions(-)
 create mode 100644 src/pystencils/jit/cpu/compiler_info.py

diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py
index 69985c914..335ec52d4 100644
--- a/src/pystencils/jit/cpu/__init__.py
+++ b/src/pystencils/jit/cpu/__init__.py
@@ -1,5 +1,8 @@
+from .compiler_info import GccInfo, ClangInfo
 from .cpu_pybind11 import CpuJit
 
 __all__ = [
+    "GccInfo",
+    "ClangInfo",
     "CpuJit"
 ]
diff --git a/src/pystencils/jit/cpu/compiler_info.py b/src/pystencils/jit/cpu/compiler_info.py
new file mode 100644
index 000000000..07a5eebbd
--- /dev/null
+++ b/src/pystencils/jit/cpu/compiler_info.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from ...codegen.target import Target
+
+
+@dataclass
+class CompilerInfo(ABC):
+    openmp: bool = True
+
+    optlevel: str | None = "fast"
+
+    cxx_standard: str = "c++11"
+
+    target: Target = Target.CurrentCPU
+
+    @abstractmethod
+    def cxx(self) -> str: ...
+
+    @abstractmethod
+    def cxxflags(self) -> list[str]: ...
+
+    @abstractmethod
+    def restrict_qualifier(self) -> str: ...
+
+
+class _GnuLikeCliCompiler(CompilerInfo):
+    def cxxflags(self) -> list[str]:
+        flags = ["-DNDEBUG", f"-std={self.cxx_standard}"]
+
+        if self.optlevel is not None:
+            flags.append(f"-O{self.optlevel}")
+
+        if self.openmp:
+            flags.append("-fopenmp")
+
+        match self.target:
+            case Target.CurrentCPU:
+                flags.append("-march=native")
+            case Target.X86_SSE:
+                flags += ["-march=x86-64-v2"]
+            case Target.X86_AVX:
+                flags += ["-march=x86-64-v3"]
+            case Target.X86_AVX512:
+                flags += ["-march=x86-64-v4"]
+            case Target.X86_AVX512_FP16:
+                flags += ["-march=x86-64-v4", "-mavx512fp16"]
+
+        return flags
+
+    def restrict_qualifier(self) -> str:
+        return "__restrict__"
+
+
+class GccInfo(_GnuLikeCliCompiler):
+    def cxx(self) -> str:
+        return "g++"
+
+
+@dataclass
+class ClangInfo(_GnuLikeCliCompiler):
+    llvm_version: int | None = None
+
+    def cxx(self) -> str:
+        if self.llvm_version is None:
+            return "clang"
+        else:
+            return f"clang-{self.llvm_version}"
+        
+    def cxxflags(self) -> list[str]:
+        return super().cxxflags() + ["-lstdc++"]
diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py
index 128313a74..d0368a390 100644
--- a/src/pystencils/jit/cpu/cpu_pybind11.py
+++ b/src/pystencils/jit/cpu/cpu_pybind11.py
@@ -5,6 +5,7 @@ from types import ModuleType
 from pathlib import Path
 from textwrap import indent
 import subprocess
+from copy import copy
 
 from ...types import PsPointerType, PsType
 from ...field import Field
@@ -14,6 +15,7 @@ from ...codegen import Kernel, Parameter
 from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
 
 from ..jit import JitError, JitBase
+from .compiler_info import CompilerInfo, GccInfo
 
 
 _module_template = Path(__file__).parent / "kernel_module.tmpl.cpp"
@@ -24,16 +26,12 @@ class CpuJit(JitBase):
 
     def __init__(
         self,
-        cxx: str = "g++",
-        cxxflags: Sequence[str] = ("-O3", "-fopenmp"),
+        compiler_info: CompilerInfo | None = None,
         objcache: str | Path | None = None,
         strict_scalar_types: bool = False,
     ):
-        self._cxx = cxx
-        self._cxxflags: list[str] = list(cxxflags)
-
+        self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo()
         self._strict_scalar_types = strict_scalar_types
-        self._restrict_qualifier = "__restrict__"
 
         if objcache is None:
             from appdirs import AppDirs
@@ -64,20 +62,12 @@ class CpuJit(JitBase):
         return self._objcache
 
     @property
-    def cxx(self) -> str:
-        return self._cxx
-
-    @cxx.setter
-    def cxx(self, path: str):
-        self._cxx = path
-
-    @property
-    def cxxflags(self) -> list[str]:
-        return self._cxxflags
-
-    @cxxflags.setter
-    def cxxflags(self, flags: Sequence[str]):
-        self._cxxflags = list(flags)
+    def compiler_info(self) -> CompilerInfo:
+        return self._compiler_info
+    
+    @compiler_info.setter
+    def compiler_info(self, info: CompilerInfo):
+        self._compiler_info = info
 
     @property
     def strict_scalar_types(self) -> bool:
@@ -139,9 +129,9 @@ class CpuJit(JitBase):
 
     def _compile_extension_module(self, src_file: Path, libfile: Path):
         args = (
-            [self._cxx]
+            [self._compiler_info.cxx()]
             + self._cxx_fixed_flags
-            + self._cxxflags
+            + self._compiler_info.cxxflags()
             + ["-o", str(libfile), str(src_file)]
         )
 
@@ -195,7 +185,7 @@ class KernelModuleBuilder:
         templ = Template(_module_template.read_text())
         code_str = templ.substitute(
             includes="\n".join(includes),
-            restrict_qualifier=self._jit.restrict_qualifier,
+            restrict_qualifier=self._jit.compiler_info.restrict_qualifier(),
             module_name=self._module_name,
             kernel_name=kernel.function_name,
             param_binds=", ".join(self._param_binds),
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index 0b09d66b0..aa4f50f3f 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit
 
 
 def test_basic_cpu_kernel(tmp_path):
-    jit = CpuJit("g++", ["-O3"], objcache=tmp_path)
+    jit = CpuJit(objcache=tmp_path)
 
     f, g = fields("f, g: [2D]")
     asm = Assignment(f.center(), 2.0 * g.center())
-- 
GitLab


From 937ed939ccecc722da7dc85e2d5bd3653e3b3710 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 27 Jan 2025 16:48:40 +0100
Subject: [PATCH 03/17] fix include paths

---
 src/pystencils/jit/cpu/cpu_pybind11.py | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py
index d0368a390..0229c5c4d 100644
--- a/src/pystencils/jit/cpu/cpu_pybind11.py
+++ b/src/pystencils/jit/cpu/cpu_pybind11.py
@@ -44,17 +44,22 @@ class CpuJit(JitBase):
         #   Include Directories
         import pybind11 as pb11
 
-        self._pybind11_include = pb11.get_include()
+        pybind11_include = pb11.get_include()
 
         import sysconfig
 
-        self._py_include = sysconfig.get_path("include")
+        python_include = sysconfig.get_path("include")
+
+        from ...include import get_pystencils_include_path
+
+        pystencils_include = get_pystencils_include_path()
 
         self._cxx_fixed_flags = [
             "-shared",
             "-fPIC",
-            f"-I{self._py_include}",
-            f"-I{self._pybind11_include}",
+            f"-I{python_include}",
+            f"-I{pybind11_include}",
+            f"-I{pystencils_include}"
         ]
 
     @property
-- 
GitLab


From f76aa82753edee6fe79ee8c167e46703d7b74962 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 27 Jan 2025 16:59:09 +0100
Subject: [PATCH 04/17] fix __global__ in gpu kernels

---
 src/pystencils/backend/emission/base_printer.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py
index adb9c232b..6808d4731 100644
--- a/src/pystencils/backend/emission/base_printer.py
+++ b/src/pystencils/backend/emission/base_printer.py
@@ -375,11 +375,16 @@ class BasePrinter(ABC):
                 )
 
     def print_signature(self, func: Kernel) -> str:
-        prefix = self._func_prefix
         params_str = ", ".join(
             f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters
         )
-        sig_parts = ([prefix] if prefix is not None else []) + ["void", func.name, f"({params_str})"]
+
+        from ...codegen import GpuKernel
+        
+        sig_parts = [self._func_prefix] if self._func_prefix is not None else []
+        if isinstance(func, GpuKernel):
+            sig_parts.append("__global__")
+        sig_parts += ["void", func.name, f"({params_str})"]
         signature = " ".join(sig_parts)
         return signature
 
-- 
GitLab


From e5eafb796f9bd81593bd772dad6b5eb25d492ad2 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 27 Jan 2025 17:04:49 +0100
Subject: [PATCH 05/17] fix __global__ again

---
 src/pystencils/backend/emission/base_printer.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py
index 6808d4731..cc4b50e21 100644
--- a/src/pystencils/backend/emission/base_printer.py
+++ b/src/pystencils/backend/emission/base_printer.py
@@ -57,6 +57,7 @@ from ..extensions.foreign_ast import PsForeignExpression
 from ..memory import PsSymbol
 from ..constants import PsConstant
 from ...types import PsType
+from ...codegen import Target
 
 if TYPE_CHECKING:
     from ...codegen import Kernel
@@ -382,7 +383,7 @@ class BasePrinter(ABC):
         from ...codegen import GpuKernel
         
         sig_parts = [self._func_prefix] if self._func_prefix is not None else []
-        if isinstance(func, GpuKernel):
+        if isinstance(func, GpuKernel) and func.target == Target.CUDA:
             sig_parts.append("__global__")
         sig_parts += ["void", func.name, f"({params_str})"]
         signature = " ".join(sig_parts)
-- 
GitLab


From fb3243dc21277f22dd02e146158ef3d1a0424d3a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 28 Jan 2025 10:05:50 +0100
Subject: [PATCH 06/17] modularize CPU jit

---
 src/pystencils/codegen/config.py              |   7 +-
 src/pystencils/jit/cpu/__init__.py            |   2 +-
 src/pystencils/jit/cpu/cpu_pybind11.py        | 276 ------------------
 src/pystencils/jit/cpu/cpujit.py              | 198 +++++++++++++
 src/pystencils/jit/cpu/cpujit_pybind11.py     | 143 +++++++++
 ...pl.cpp => pybind11_kernel_module.tmpl.cpp} |   0
 tests/jit/test_cpujit.py                      |   2 +-
 7 files changed, 348 insertions(+), 280 deletions(-)
 delete mode 100644 src/pystencils/jit/cpu/cpu_pybind11.py
 create mode 100644 src/pystencils/jit/cpu/cpujit.py
 create mode 100644 src/pystencils/jit/cpu/cpujit_pybind11.py
 rename src/pystencils/jit/cpu/{kernel_module.tmpl.cpp => pybind11_kernel_module.tmpl.cpp} (100%)

diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 3a7647907..209df1b0e 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, TypeGuard
 
 from warnings import warn
 from collections.abc import Collection
@@ -28,7 +28,10 @@ class PsOptionsError(Exception):
     """Indicates an option clash in the `CreateKernelConfig`."""
 
 
-class _AUTO_TYPE: ...  # noqa: E701
+class _AUTO_TYPE:
+    @staticmethod
+    def is_auto(val: Any) -> TypeGuard[_AUTO_TYPE]:
+        return val == AUTO
 
 
 AUTO = _AUTO_TYPE()
diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py
index 335ec52d4..5e4dfcbf7 100644
--- a/src/pystencils/jit/cpu/__init__.py
+++ b/src/pystencils/jit/cpu/__init__.py
@@ -1,5 +1,5 @@
 from .compiler_info import GccInfo, ClangInfo
-from .cpu_pybind11 import CpuJit
+from .cpujit import CpuJit
 
 __all__ = [
     "GccInfo",
diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py
deleted file mode 100644
index 0229c5c4d..000000000
--- a/src/pystencils/jit/cpu/cpu_pybind11.py
+++ /dev/null
@@ -1,276 +0,0 @@
-from __future__ import annotations
-
-from typing import Sequence, cast
-from types import ModuleType
-from pathlib import Path
-from textwrap import indent
-import subprocess
-from copy import copy
-
-from ...types import PsPointerType, PsType
-from ...field import Field
-from ...sympyextensions import DynamicType
-from ..jit import KernelWrapper
-from ...codegen import Kernel, Parameter
-from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
-
-from ..jit import JitError, JitBase
-from .compiler_info import CompilerInfo, GccInfo
-
-
-_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp"
-
-
-class CpuJit(JitBase):
-    """Just-in-time compiler for CPU kernels."""
-
-    def __init__(
-        self,
-        compiler_info: CompilerInfo | None = None,
-        objcache: str | Path | None = None,
-        strict_scalar_types: bool = False,
-    ):
-        self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo()
-        self._strict_scalar_types = strict_scalar_types
-
-        if objcache is None:
-            from appdirs import AppDirs
-
-            dirs = AppDirs(appname="pystencils")
-            self._objcache = Path(dirs.user_cache_dir) / "cpujit"
-        else:
-            self._objcache = Path(objcache)
-
-        #   Include Directories
-        import pybind11 as pb11
-
-        pybind11_include = pb11.get_include()
-
-        import sysconfig
-
-        python_include = sysconfig.get_path("include")
-
-        from ...include import get_pystencils_include_path
-
-        pystencils_include = get_pystencils_include_path()
-
-        self._cxx_fixed_flags = [
-            "-shared",
-            "-fPIC",
-            f"-I{python_include}",
-            f"-I{pybind11_include}",
-            f"-I{pystencils_include}"
-        ]
-
-    @property
-    def objcache(self) -> Path:
-        return self._objcache
-
-    @property
-    def compiler_info(self) -> CompilerInfo:
-        return self._compiler_info
-    
-    @compiler_info.setter
-    def compiler_info(self, info: CompilerInfo):
-        self._compiler_info = info
-
-    @property
-    def strict_scalar_types(self) -> bool:
-        """Enable or disable implicit type casts for scalar parameters.
-
-        If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
-        """
-        return self._strict_scalar_types
-
-    @strict_scalar_types.setter
-    def strict_scalar_types(self, v: bool):
-        self._strict_scalar_types = v
-
-    @property
-    def restrict_qualifier(self) -> str:
-        return self._restrict_qualifier
-    
-    @restrict_qualifier.setter
-    def restrict_qualifier(self, qual: str):
-        self._restrict_qualifier = qual
-
-    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
-        #   Get the Code
-        module_name = f"{kernel.function_name}_jit"
-        modbuilder = KernelModuleBuilder(self, module_name)
-        cpp_code = modbuilder(kernel)
-
-        #   Get compiler information
-        import sysconfig
-
-        so_abi = sysconfig.get_config_var("SOABI")
-        lib_suffix = f"{so_abi}.so"
-
-        #   Compute Code Hash
-        code_utf8 = cpp_code.encode("utf-8")
-        import hashlib
-
-        code_hash = hashlib.sha256(code_utf8)
-        module_stem = f"module_{code_hash.hexdigest()}"
-
-        module_dir = self._objcache
-
-        #   Lock module
-        import fasteners
-
-        lockfile = module_dir / f"{module_stem}.lock"
-        with fasteners.InterProcessLock(lockfile):
-            cpp_file = module_dir / f"{module_stem}.cpp"
-            if not cpp_file.exists():
-                cpp_file.write_bytes(code_utf8)
-
-            lib_file = module_dir / f"{module_stem}.{lib_suffix}"
-            if not lib_file.exists():
-                self._compile_extension_module(cpp_file, lib_file)
-
-            module = self._load_extension_module(module_name, lib_file)
-
-        return CpuJitKernelWrapper(kernel, module)
-
-    def _compile_extension_module(self, src_file: Path, libfile: Path):
-        args = (
-            [self._compiler_info.cxx()]
-            + self._cxx_fixed_flags
-            + self._compiler_info.cxxflags()
-            + ["-o", str(libfile), str(src_file)]
-        )
-
-        result = subprocess.run(args, capture_output=True)
-        if result.returncode != 0:
-            raise JitError(
-                "Compilation failed: C++ compiler terminated with an error.\n"
-                + result.stderr.decode()
-            )
-
-    def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
-        from importlib import util as iutil
-
-        spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
-        if spec is None:
-            raise JitError("Unable to load kernel extension module -- this is probably a bug.")
-        mod = iutil.module_from_spec(spec)
-        spec.loader.exec_module(mod)  # type: ignore
-        return mod
-
-
-class CpuJitKernelWrapper(KernelWrapper):
-    def __init__(self, kernel: Kernel, jit_module: ModuleType):
-        super().__init__(kernel)
-        self._module = jit_module
-        self._wrapper_func = getattr(jit_module, kernel.function_name)
-
-    def __call__(self, **kwargs) -> None:
-        return self._wrapper_func(**kwargs)
-
-
-class KernelModuleBuilder:
-    def __init__(self, jit: CpuJit, module_name: str):
-        self._jit = jit
-        self._module_name = module_name
-
-        self._actual_field_types: dict[Field, PsType] = dict()
-        self._param_binds: list[str] = []
-        self._public_params: list[str] = []
-        self._extraction_lines: list[str] = []
-
-    def __call__(self, kernel: Kernel) -> str:
-        self._handle_params(kernel.parameters)
-        
-        kernel_def = self._get_kernel_definition(kernel)
-        kernel_args = [param.name for param in kernel.parameters]
-        includes = [f"#include {h}" for h in kernel.required_headers]
-
-        from string import Template
-
-        templ = Template(_module_template.read_text())
-        code_str = templ.substitute(
-            includes="\n".join(includes),
-            restrict_qualifier=self._jit.compiler_info.restrict_qualifier(),
-            module_name=self._module_name,
-            kernel_name=kernel.function_name,
-            param_binds=", ".join(self._param_binds),
-            public_params=", ".join(self._public_params),
-            extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
-            kernel_args=", ".join(kernel_args),
-            kernel_definition=kernel_def,
-        )
-        return code_str
-
-    def _get_kernel_definition(self, kernel: Kernel) -> str:
-        from ...backend.emission import CAstPrinter
-        printer = CAstPrinter(func_prefix="inline")
-
-        return printer(kernel)
-
-    def _add_field_param(self, ptr_param: Parameter):
-        field: Field = ptr_param.fields[0]
-
-        ptr_type = ptr_param.dtype
-        assert isinstance(ptr_type, PsPointerType)
-
-        if isinstance(field.dtype, DynamicType):
-            elem_type = ptr_type.base_type
-        else:
-            elem_type = field.dtype
-
-        self._actual_field_types[field] = elem_type
-
-        param_bind = f'py::arg("{field.name}").noconvert()'
-        self._param_binds.append(param_bind)
-
-        kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
-        self._public_params.append(kernel_param)
-
-    def _add_scalar_param(self, sc_param: Parameter):
-        param_bind = f'py::arg("{sc_param.name}")'
-        if self._jit.strict_scalar_types:
-            param_bind += ".noconvert()"
-        self._param_binds.append(param_bind)
-
-        kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
-        self._public_params.append(kernel_param)
-
-    def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
-        field_name = ptr_prop.field.name
-        assert isinstance(ptr_param.dtype, PsPointerType)
-        data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
-        extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
-        self._extraction_lines.append(extraction)
-
-    def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
-        field_name = shape_prop.field.name
-        coord = shape_prop.coordinate
-        extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
-        self._extraction_lines.append(extraction)
-
-    def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
-        field = stride_prop.field
-        field_name = field.name
-        coord = stride_prop.coordinate
-        field_type = self._actual_field_types[field]
-        assert field_type.itemsize is not None
-        extraction = (
-            f"{stride_param.dtype.c_string()} {stride_param.name} "
-            f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
-        )
-        self._extraction_lines.append(extraction)
-
-    def _handle_params(self, parameters: Sequence[Parameter]):
-        for param in parameters:
-            if param.get_properties(FieldBasePtr):
-                self._add_field_param(param)
-
-        for param in parameters:
-            if ptr_props := param.get_properties(FieldBasePtr):
-                self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
-            elif shape_props := param.get_properties(FieldShape):
-                self._extract_shape(param, cast(FieldShape, shape_props.pop()))
-            elif stride_props := param.get_properties(FieldStride):
-                self._extract_stride(param, cast(FieldStride, stride_props.pop()))
-            else:
-                self._add_scalar_param(param)
diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
new file mode 100644
index 000000000..f9c39529a
--- /dev/null
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -0,0 +1,198 @@
+from __future__ import annotations
+
+from types import ModuleType
+from pathlib import Path
+import subprocess
+from copy import copy
+from abc import ABC, abstractmethod
+
+from ...codegen.config import _AUTO_TYPE, AUTO
+
+from ..jit import JitError, JitBase, KernelWrapper
+from ...codegen import Kernel
+from .compiler_info import CompilerInfo, GccInfo
+
+
+class CpuJit(JitBase):
+    """Just-in-time compiler for CPU kernels.
+
+    :Creation:
+    In most cases, objects of this class should be instantiated using the `create` factory method.
+    
+    :Implementation Details:
+
+    The `CpuJit` combines two separate components:
+    - The *extension module builder* produces the code of the dynamically built extension module containing
+      the kernel
+    - The *compiler info* describes the system compiler used to compile and link that extension module.
+
+    Both can be dynamically exchanged.
+    """
+
+    @staticmethod
+    def create(
+        compiler_info: CompilerInfo | None = None,
+        objcache: str | Path | _AUTO_TYPE | None = AUTO,
+    ):
+        if objcache is AUTO:
+            from appdirs import AppDirs
+
+            dirs = AppDirs(appname="pystencils")
+            objcache = Path(dirs.user_cache_dir) / "cpujit"
+        elif objcache is not None:
+            assert not isinstance(objcache, _AUTO_TYPE)
+            objcache = Path(objcache)
+
+        if compiler_info is None:
+            compiler_info = GccInfo()
+
+        from .cpujit_pybind11 import Pybind11KernelModuleBuilder
+
+        modbuilder = Pybind11KernelModuleBuilder(compiler_info)
+
+        return CpuJit(compiler_info, modbuilder, objcache)
+
+    def __init__(
+        self,
+        compiler_info: CompilerInfo,
+        ext_module_builder: ExtensionModuleBuilderBase,
+        objcache: Path | None,
+    ):
+        self._compiler_info = copy(compiler_info)
+        self._ext_module_builder = ext_module_builder
+
+        self._objcache = objcache
+
+        #   Include Directories
+
+        import sysconfig
+
+        python_include = sysconfig.get_path("include")
+
+        from ...include import get_pystencils_include_path
+
+        pystencils_include = get_pystencils_include_path()
+
+        self._cxx_fixed_flags = [
+            "-shared",
+            "-fPIC",
+            f"-I{python_include}",
+            f"-I{pystencils_include}",
+        ] + self._ext_module_builder.include_flags()
+
+    @property
+    def objcache(self) -> Path | None:
+        return self._objcache
+
+    @property
+    def compiler_info(self) -> CompilerInfo:
+        return self._compiler_info
+
+    @compiler_info.setter
+    def compiler_info(self, info: CompilerInfo):
+        self._compiler_info = info
+
+    @property
+    def strict_scalar_types(self) -> bool:
+        """Enable or disable implicit type casts for scalar parameters.
+
+        If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
+        """
+        return self._strict_scalar_types
+
+    @strict_scalar_types.setter
+    def strict_scalar_types(self, v: bool):
+        self._strict_scalar_types = v
+
+    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
+        #   Get the Code
+        module_name = f"{kernel.function_name}_jit"
+        cpp_code = self._ext_module_builder(kernel, module_name)
+
+        #   Get compiler information
+        import sysconfig
+
+        so_abi = sysconfig.get_config_var("SOABI")
+        lib_suffix = f"{so_abi}.so"
+
+        #   Compute Code Hash
+        code_utf8 = cpp_code.encode("utf-8")
+        import hashlib
+
+        code_hash = hashlib.sha256(code_utf8)
+        module_stem = f"module_{code_hash.hexdigest()}"
+
+        def compile_and_load(module_dir: Path):
+            cpp_file = module_dir / f"{module_stem}.cpp"
+            if not cpp_file.exists():
+                cpp_file.write_bytes(code_utf8)
+
+            lib_file = module_dir / f"{module_stem}.{lib_suffix}"
+            if not lib_file.exists():
+                self._compile_extension_module(cpp_file, lib_file)
+
+            module = self._load_extension_module(module_name, lib_file)
+            return module
+
+        if self._objcache is not None:
+            module_dir = self._objcache
+            #   Lock module
+            import fasteners
+
+            lockfile = module_dir / f"{module_stem}.lock"
+            with fasteners.InterProcessLock(lockfile):
+                module = compile_and_load(module_dir)
+        else:
+            from tempfile import TemporaryDirectory
+
+            with TemporaryDirectory() as tmpdir:
+                module_dir = Path(tmpdir)
+                module = compile_and_load(module_dir)
+
+        return CpuJitKernelWrapper(kernel, module)
+
+    def _compile_extension_module(self, src_file: Path, libfile: Path):
+        args = (
+            [self._compiler_info.cxx()]
+            + self._cxx_fixed_flags
+            + self._compiler_info.cxxflags()
+            + ["-o", str(libfile), str(src_file)]
+        )
+
+        result = subprocess.run(args, capture_output=True)
+        if result.returncode != 0:
+            raise JitError(
+                "Compilation failed: C++ compiler terminated with an error.\n"
+                + result.stderr.decode()
+            )
+
+    def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
+        from importlib import util as iutil
+
+        spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
+        if spec is None:
+            raise JitError(
+                "Unable to load kernel extension module -- this is probably a bug."
+            )
+        mod = iutil.module_from_spec(spec)
+        spec.loader.exec_module(mod)  # type: ignore
+        return mod
+
+
+class CpuJitKernelWrapper(KernelWrapper):
+    def __init__(self, kernel: Kernel, jit_module: ModuleType):
+        super().__init__(kernel)
+        self._module = jit_module
+        self._wrapper_func = getattr(jit_module, kernel.function_name)
+
+    def __call__(self, **kwargs) -> None:
+        return self._wrapper_func(**kwargs)
+
+
+class ExtensionModuleBuilderBase(ABC):
+    @staticmethod
+    @abstractmethod
+    def include_flags() -> list[str]: ...
+
+    @abstractmethod
+    def __call__(self, kernel: Kernel, module_name: str) -> str: ...
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
new file mode 100644
index 000000000..aee2e9f99
--- /dev/null
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -0,0 +1,143 @@
+from __future__ import annotations
+
+from typing import Sequence, cast
+from pathlib import Path
+from textwrap import indent
+
+from ...types import PsPointerType, PsType
+from ...field import Field
+from ...sympyextensions import DynamicType
+from ...codegen import Kernel, Parameter
+from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
+
+from .compiler_info import CompilerInfo
+from .cpujit import ExtensionModuleBuilderBase
+
+
+_module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp"
+
+
+class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
+    @staticmethod
+    def include_flags() -> list[str]:
+        import pybind11 as pb11
+
+        pybind11_include = pb11.get_include()
+        return [f"-I{pybind11_include}"]
+
+    def __init__(
+        self,
+        compiler_info: CompilerInfo,
+        strict_scalar_types: bool = False,
+    ):
+        self._compiler_info = compiler_info
+
+        self._strict_scalar_types = strict_scalar_types
+
+        self._actual_field_types: dict[Field, PsType]
+        self._param_binds: list[str]
+        self._public_params: list[str]
+        self._extraction_lines: list[str]
+
+    def __call__(self, kernel: Kernel, module_name: str) -> str:
+        self._actual_field_types = dict()
+        self._param_binds = []
+        self._public_params = []
+        self._extraction_lines = []
+
+        self._handle_params(kernel.parameters)
+
+        kernel_def = self._get_kernel_definition(kernel)
+        kernel_args = [param.name for param in kernel.parameters]
+        includes = [f"#include {h}" for h in kernel.required_headers]
+
+        from string import Template
+
+        templ = Template(_module_template.read_text())
+        code_str = templ.substitute(
+            includes="\n".join(includes),
+            restrict_qualifier=self._compiler_info.restrict_qualifier(),
+            module_name=module_name,
+            kernel_name=kernel.function_name,
+            param_binds=", ".join(self._param_binds),
+            public_params=", ".join(self._public_params),
+            extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
+            kernel_args=", ".join(kernel_args),
+            kernel_definition=kernel_def,
+        )
+        return code_str
+
+    def _get_kernel_definition(self, kernel: Kernel) -> str:
+        from ...backend.emission import CAstPrinter
+
+        printer = CAstPrinter(func_prefix="inline")
+
+        return printer(kernel)
+
+    def _add_field_param(self, ptr_param: Parameter):
+        field: Field = ptr_param.fields[0]
+
+        ptr_type = ptr_param.dtype
+        assert isinstance(ptr_type, PsPointerType)
+
+        if isinstance(field.dtype, DynamicType):
+            elem_type = ptr_type.base_type
+        else:
+            elem_type = field.dtype
+
+        self._actual_field_types[field] = elem_type
+
+        param_bind = f'py::arg("{field.name}").noconvert()'
+        self._param_binds.append(param_bind)
+
+        kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
+        self._public_params.append(kernel_param)
+
+    def _add_scalar_param(self, sc_param: Parameter):
+        param_bind = f'py::arg("{sc_param.name}")'
+        if self._strict_scalar_types:
+            param_bind += ".noconvert()"
+        self._param_binds.append(param_bind)
+
+        kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
+        self._public_params.append(kernel_param)
+
+    def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
+        field_name = ptr_prop.field.name
+        assert isinstance(ptr_param.dtype, PsPointerType)
+        data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
+        extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
+        self._extraction_lines.append(extraction)
+
+    def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
+        field_name = shape_prop.field.name
+        coord = shape_prop.coordinate
+        extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
+        self._extraction_lines.append(extraction)
+
+    def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
+        field = stride_prop.field
+        field_name = field.name
+        coord = stride_prop.coordinate
+        field_type = self._actual_field_types[field]
+        assert field_type.itemsize is not None
+        extraction = (
+            f"{stride_param.dtype.c_string()} {stride_param.name} "
+            f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
+        )
+        self._extraction_lines.append(extraction)
+
+    def _handle_params(self, parameters: Sequence[Parameter]):
+        for param in parameters:
+            if param.get_properties(FieldBasePtr):
+                self._add_field_param(param)
+
+        for param in parameters:
+            if ptr_props := param.get_properties(FieldBasePtr):
+                self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
+            elif shape_props := param.get_properties(FieldShape):
+                self._extract_shape(param, cast(FieldShape, shape_props.pop()))
+            elif stride_props := param.get_properties(FieldStride):
+                self._extract_stride(param, cast(FieldStride, stride_props.pop()))
+            else:
+                self._add_scalar_param(param)
diff --git a/src/pystencils/jit/cpu/kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
similarity index 100%
rename from src/pystencils/jit/cpu/kernel_module.tmpl.cpp
rename to src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index aa4f50f3f..0e17fa6a1 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit
 
 
 def test_basic_cpu_kernel(tmp_path):
-    jit = CpuJit(objcache=tmp_path)
+    jit = CpuJit.create(objcache=tmp_path)
 
     f, g = fields("f, g: [2D]")
     asm = Assignment(f.center(), 2.0 * g.center())
-- 
GitLab


From e2face7c9c9dbfd6981b99740c18526d38f86782 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 28 Jan 2025 13:13:33 +0100
Subject: [PATCH 07/17] add test for field types

---
 tests/jit/test_cpujit.py | 35 ++++++++++++++++++++++++++++++++---
 1 file changed, 32 insertions(+), 3 deletions(-)

diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index 0e17fa6a1..3d269d837 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -1,16 +1,21 @@
+import pytest
+
 import sympy as sp
 import numpy as np
 from pystencils import create_kernel, Assignment, fields
 from pystencils.jit import CpuJit
 
 
-def test_basic_cpu_kernel(tmp_path):
-    jit = CpuJit.create(objcache=tmp_path)
+@pytest.fixture
+def cpu_jit(tmp_path) -> CpuJit:
+    return CpuJit.create(objcache=tmp_path)
+
 
+def test_basic_cpu_kernel(cpu_jit):
     f, g = fields("f, g: [2D]")
     asm = Assignment(f.center(), 2.0 * g.center())
     ker = create_kernel(asm)
-    kfunc = jit.compile(ker)
+    kfunc = cpu_jit.compile(ker)
 
     rng = np.random.default_rng()
     f_arr = rng.random(size=(34, 26), dtype="float64")
@@ -19,3 +24,27 @@ def test_basic_cpu_kernel(tmp_path):
     kfunc(f=f_arr, g=g_arr)
 
     np.testing.assert_almost_equal(g_arr, 2.0 * f_arr)
+
+
+def test_argument_type_error(cpu_jit):
+    f, g = fields("f, g: [2D]")
+    c = sp.Symbol("c")
+    asm = Assignment(f.center(), c * g.center())
+    ker = create_kernel(asm)
+    kfunc = cpu_jit.compile(ker)
+
+    arr_fp16 = np.zeros((23, 12), dtype="float16")
+    arr_fp32 = np.zeros((23, 12), dtype="float32")
+    arr_fp64 = np.zeros((23, 12), dtype="float64")
+
+    with pytest.raises(TypeError):
+        kfunc(f=arr_fp32, g=arr_fp64, c=2.0)
+
+    with pytest.raises(TypeError):
+        kfunc(f=arr_fp64, g=arr_fp32, c=2.0)
+
+    with pytest.raises(TypeError):
+        kfunc(f=arr_fp16, g=arr_fp16, c=2.0)
+
+    #   Wrong scalar types are OK, though
+    kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0))
-- 
GitLab


From e8c8ea8ea71f41c65afd8c53448f35d82251b6db Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 28 Jan 2025 13:46:27 +0100
Subject: [PATCH 08/17] add checks for constant shape and strides

---
 src/pystencils/jit/cpu/cpujit.py              |  6 ++-
 src/pystencils/jit/cpu/cpujit_pybind11.py     | 15 ++++++
 .../jit/cpu/pybind11_kernel_module.tmpl.cpp   | 47 ++++++++++++++++++-
 tests/jit/test_cpujit.py                      | 24 +++++++++-
 4 files changed, 86 insertions(+), 6 deletions(-)

diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index f9c39529a..b3a9e48aa 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -183,10 +183,12 @@ class CpuJitKernelWrapper(KernelWrapper):
     def __init__(self, kernel: Kernel, jit_module: ModuleType):
         super().__init__(kernel)
         self._module = jit_module
-        self._wrapper_func = getattr(jit_module, kernel.function_name)
+        self._check_params = getattr(jit_module, "check_params")
+        self._invoke = getattr(jit_module, "invoke")
 
     def __call__(self, **kwargs) -> None:
-        return self._wrapper_func(**kwargs)
+        self._check_params(**kwargs)
+        return self._invoke(**kwargs)
 
 
 class ExtensionModuleBuilderBase(ABC):
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index aee2e9f99..b68ed9c29 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -37,12 +37,14 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
         self._actual_field_types: dict[Field, PsType]
         self._param_binds: list[str]
         self._public_params: list[str]
+        self._param_check_lines: list[str]
         self._extraction_lines: list[str]
 
     def __call__(self, kernel: Kernel, module_name: str) -> str:
         self._actual_field_types = dict()
         self._param_binds = []
         self._public_params = []
+        self._param_check_lines = []
         self._extraction_lines = []
 
         self._handle_params(kernel.parameters)
@@ -61,6 +63,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
             kernel_name=kernel.function_name,
             param_binds=", ".join(self._param_binds),
             public_params=", ".join(self._public_params),
+            param_check_lines=indent("\n".join(self._param_check_lines), prefix="    "),
             extraction_lines=indent("\n".join(self._extraction_lines), prefix="    "),
             kernel_args=", ".join(kernel_args),
             kernel_definition=kernel_def,
@@ -93,6 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
         kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
         self._public_params.append(kernel_param)
 
+        for coord, size in enumerate(field.shape):
+            if isinstance(size, int):
+                self._param_check_lines.append(
+                    f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});"
+                )
+
+        for coord, stride in enumerate(field.strides):
+            if isinstance(stride, int):
+                self._param_check_lines.append(
+                    f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});"
+                )
+
     def _add_scalar_param(self, sc_param: Parameter):
         param_bind = f'py::arg("{sc_param.name}")'
         if self._strict_scalar_types:
diff --git a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
index 3ee5c6973..acc9d05de 100644
--- a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
+++ b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
@@ -1,6 +1,10 @@
 #include "pybind11/pybind11.h"
 #include "pybind11/numpy.h"
 
+#include <array>
+#include <string>
+#include <sstream>
+
 ${includes}
 
 namespace py = pybind11;
@@ -13,11 +17,50 @@ ${kernel_definition}
 
 }
 
-void callwrapper_${kernel_name} (${public_params}) {
+template< typename T >
+void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
+    auto panic = [&](){
+        std::stringstream err;
+        err << "Invalid shape of argument " << fieldName;
+        throw py::value_error{ err.str() };
+    };
+    
+    if(arr.ndim() <= coord){
+        panic();
+    }
+
+    if(arr.shape(coord) != desired){
+        panic();
+    }
+}
+
+template< typename T >
+void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
+    auto panic = [&](){
+        std::stringstream err;
+        err << "Invalid strides of argument " << fieldName;
+        throw py::value_error{ err.str() };
+    };
+    
+    if(arr.ndim() <= coord){
+        panic();
+    }
+
+    if(arr.strides(coord) / sizeof(T) != desired){
+        panic();
+    }
+}
+
+void check_params_${kernel_name} (${public_params}) {
+${param_check_lines}
+}
+
+void run_${kernel_name} (${public_params}) {
 ${extraction_lines}
     internal::${kernel_name}(${kernel_args});
 }
 
 PYBIND11_MODULE(${module_name}, m) {
-    m.def("${kernel_name}", &callwrapper_${kernel_name}, py::kw_only(), ${param_binds});
+    m.def("check_params", &check_params_${kernel_name}, py::kw_only(), ${param_binds});
+    m.def("invoke", &run_${kernel_name}, py::kw_only(), ${param_binds});
 }
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index 3d269d837..c8e75f862 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -2,13 +2,13 @@ import pytest
 
 import sympy as sp
 import numpy as np
-from pystencils import create_kernel, Assignment, fields
+from pystencils import create_kernel, Assignment, fields, Field
 from pystencils.jit import CpuJit
 
 
 @pytest.fixture
 def cpu_jit(tmp_path) -> CpuJit:
-    return CpuJit.create(objcache=tmp_path)
+    return CpuJit.create(objcache=".jit")
 
 
 def test_basic_cpu_kernel(cpu_jit):
@@ -48,3 +48,23 @@ def test_argument_type_error(cpu_jit):
 
     #   Wrong scalar types are OK, though
     kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0))
+
+
+def test_fixed_shape(cpu_jit):
+    a = np.zeros((12, 23), dtype="float64")
+    b = np.zeros((13, 21), dtype="float64")
+    
+    f = Field.create_from_numpy_array("f", a)
+    g = Field.create_from_numpy_array("g", a)
+
+    asm = Assignment(f.center(), 2.0 * g.center())
+    ker = create_kernel(asm)
+    kfunc = cpu_jit.compile(ker)
+
+    kfunc(f=a, g=a)
+
+    with pytest.raises(ValueError):
+        kfunc(f=b, g=a)
+
+    with pytest.raises(ValueError):
+        kfunc(f=a, g=b)
-- 
GitLab


From 4dcb81acebad62c87a6a8a6f79ea87870d858f86 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 28 Jan 2025 14:04:27 +0100
Subject: [PATCH 09/17] improve error strings for shape and stride checks. Add
 more test cases.

---
 src/pystencils/jit/cpu/cpujit_pybind11.py     |  6 ++--
 .../jit/cpu/pybind11_kernel_module.tmpl.cpp   | 27 ++++++++++++++---
 tests/jit/test_cpujit.py                      | 29 ++++++++++++++++++-
 3 files changed, 55 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index b68ed9c29..eff3a061f 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -96,16 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
         kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
         self._public_params.append(kernel_param)
 
+        expect_shape = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.shape) + ")"
         for coord, size in enumerate(field.shape):
             if isinstance(size, int):
                 self._param_check_lines.append(
-                    f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});"
+                    f"checkFieldShape(\"{field.name}\", \"{expect_shape}\", {field.name}, {coord}, {size});"
                 )
 
+        expect_strides = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.strides) + ")"
         for coord, stride in enumerate(field.strides):
             if isinstance(stride, int):
                 self._param_check_lines.append(
-                    f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});"
+                    f"checkFieldStride(\"{field.name}\", \"{expect_strides}\", {field.name}, {coord}, {stride});"
                 )
 
     def _add_scalar_param(self, sc_param: Parameter):
diff --git a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
index acc9d05de..ef945586f 100644
--- a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
+++ b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp
@@ -17,11 +17,27 @@ ${kernel_definition}
 
 }
 
+std::string tuple_to_str(const ssize_t * data, const size_t N){
+    std::stringstream acc;
+    acc << "(";
+    for(size_t i = 0; i < N; ++i){
+        acc << data[i];
+        if(i + 1 < N){
+            acc << ", ";
+        }
+    }
+    acc << ")";
+    return acc.str();
+}
+
 template< typename T >
-void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
+void checkFieldShape(const std::string& fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
     auto panic = [&](){
         std::stringstream err;
-        err << "Invalid shape of argument " << fieldName;
+        err << "Invalid shape of argument " << fieldName
+            << ". Expected " << expected
+            << ", but got " << tuple_to_str(arr.shape(), arr.ndim())
+            << ".";
         throw py::value_error{ err.str() };
     };
     
@@ -35,10 +51,13 @@ void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr,
 }
 
 template< typename T >
-void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
+void checkFieldStride(const std::string fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
     auto panic = [&](){
         std::stringstream err;
-        err << "Invalid strides of argument " << fieldName;
+        err << "Invalid strides of argument " << fieldName 
+            << ". Expected " << expected
+            << ", but got " << tuple_to_str(arr.strides(), arr.ndim())
+            << ".";
         throw py::value_error{ err.str() };
     };
     
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index c8e75f862..bfa4c9897 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -8,7 +8,7 @@ from pystencils.jit import CpuJit
 
 @pytest.fixture
 def cpu_jit(tmp_path) -> CpuJit:
-    return CpuJit.create(objcache=".jit")
+    return CpuJit.create(objcache=tmp_path)
 
 
 def test_basic_cpu_kernel(cpu_jit):
@@ -68,3 +68,30 @@ def test_fixed_shape(cpu_jit):
 
     with pytest.raises(ValueError):
         kfunc(f=a, g=b)
+
+
+def test_fixed_index_shape(cpu_jit):
+    f, g = fields("f(3), g(2, 2): [2D]")
+
+    asm = Assignment(f.center(1), g.center(0, 0) + g.center(0, 1) + g.center(1, 0) + g.center(1, 1))
+    ker = create_kernel(asm)
+    kfunc = cpu_jit.compile(ker)
+
+    f_arr = np.zeros((12, 14, 3))
+    g_arr = np.zeros((12, 14, 2, 2))
+    kfunc(f=f_arr, g=g_arr)
+
+    with pytest.raises(ValueError):
+        f_arr = np.zeros((12, 14, 2))
+        g_arr = np.zeros((12, 14, 2, 2))
+        kfunc(f=f_arr, g=g_arr)
+
+    with pytest.raises(ValueError):
+        f_arr = np.zeros((12, 14, 3))
+        g_arr = np.zeros((12, 14, 4))
+        kfunc(f=f_arr, g=g_arr)
+
+    with pytest.raises(ValueError):
+        f_arr = np.zeros((12, 14, 3))
+        g_arr = np.zeros((12, 14, 1, 3))
+        kfunc(f=f_arr, g=g_arr)
-- 
GitLab


From fd31ce6471d87f0ed34e3e4c6196ae7a27626df4 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 28 Jan 2025 16:07:46 +0100
Subject: [PATCH 10/17] move kernel wrapper creation to the module builder

---
 src/pystencils/jit/cpu/cpujit.py          | 19 +++++--------------
 src/pystencils/jit/cpu/cpujit_pybind11.py | 18 ++++++++++++++++++
 2 files changed, 23 insertions(+), 14 deletions(-)

diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index b3a9e48aa..eb1b60abe 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -104,7 +104,7 @@ class CpuJit(JitBase):
     def strict_scalar_types(self, v: bool):
         self._strict_scalar_types = v
 
-    def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
+    def compile(self, kernel: Kernel) -> KernelWrapper:
         #   Get the Code
         module_name = f"{kernel.function_name}_jit"
         cpp_code = self._ext_module_builder(kernel, module_name)
@@ -149,7 +149,7 @@ class CpuJit(JitBase):
                 module_dir = Path(tmpdir)
                 module = compile_and_load(module_dir)
 
-        return CpuJitKernelWrapper(kernel, module)
+        return self._ext_module_builder.get_wrapper(kernel, module)
 
     def _compile_extension_module(self, src_file: Path, libfile: Path):
         args = (
@@ -179,18 +179,6 @@ class CpuJit(JitBase):
         return mod
 
 
-class CpuJitKernelWrapper(KernelWrapper):
-    def __init__(self, kernel: Kernel, jit_module: ModuleType):
-        super().__init__(kernel)
-        self._module = jit_module
-        self._check_params = getattr(jit_module, "check_params")
-        self._invoke = getattr(jit_module, "invoke")
-
-    def __call__(self, **kwargs) -> None:
-        self._check_params(**kwargs)
-        return self._invoke(**kwargs)
-
-
 class ExtensionModuleBuilderBase(ABC):
     @staticmethod
     @abstractmethod
@@ -198,3 +186,6 @@ class ExtensionModuleBuilderBase(ABC):
 
     @abstractmethod
     def __call__(self, kernel: Kernel, module_name: str) -> str: ...
+
+    @abstractmethod
+    def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ...
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index eff3a061f..ba9065c8a 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -1,9 +1,12 @@
 from __future__ import annotations
 
+from types import ModuleType
 from typing import Sequence, cast
 from pathlib import Path
 from textwrap import indent
 
+from pystencils.jit.jit import KernelWrapper
+
 from ...types import PsPointerType, PsType
 from ...field import Field
 from ...sympyextensions import DynamicType
@@ -69,6 +72,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
             kernel_definition=kernel_def,
         )
         return code_str
+    
+    def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper:
+        return Pybind11KernelWrapper(kernel, extension_module)
 
     def _get_kernel_definition(self, kernel: Kernel) -> str:
         from ...backend.emission import CAstPrinter
@@ -158,3 +164,15 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
                 self._extract_stride(param, cast(FieldStride, stride_props.pop()))
             else:
                 self._add_scalar_param(param)
+
+
+class Pybind11KernelWrapper(KernelWrapper):
+    def __init__(self, kernel: Kernel, jit_module: ModuleType):
+        super().__init__(kernel)
+        self._module = jit_module
+        self._check_params = getattr(jit_module, "check_params")
+        self._invoke = getattr(jit_module, "invoke")
+
+    def __call__(self, **kwargs) -> None:
+        self._check_params(**kwargs)
+        return self._invoke(**kwargs)
-- 
GitLab


From 8ed8252d3f859b588cbc666f66e864688fd145f0 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 13:54:07 +0100
Subject: [PATCH 11/17] make the compilerinfo provide include and linkage flags

---
 src/pystencils/codegen/config.py          |  8 ++----
 src/pystencils/jit/cpu/compiler_info.py   | 15 ++++++++++-
 src/pystencils/jit/cpu/cpujit.py          | 33 ++++++++++++-----------
 src/pystencils/jit/cpu/cpujit_pybind11.py |  4 +--
 4 files changed, 36 insertions(+), 24 deletions(-)

diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 07dd2691d..6587f16b7 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -1,11 +1,10 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING, Any, TypeGuard
 
 from warnings import warn
 from abc import ABC
 from collections.abc import Collection
 
-from typing import Sequence, Generic, TypeVar, Callable, Any, cast
+from typing import TYPE_CHECKING, Sequence, Generic, TypeVar, Callable, Any, cast, TypeGuard
 from dataclasses import dataclass, InitVar, fields
 
 from .target import Target
@@ -208,10 +207,7 @@ class Category(Generic[Category_T]):
         setattr(obj, self._lookup, cat.copy() if cat is not None else None)
 
 
-class _AUTO_TYPE:
-    @staticmethod
-    def is_auto(val: Any) -> TypeGuard[_AUTO_TYPE]:
-        return val == AUTO
+class _AUTO_TYPE: ...  # noqa: E701
 
 
 AUTO = _AUTO_TYPE()
diff --git a/src/pystencils/jit/cpu/compiler_info.py b/src/pystencils/jit/cpu/compiler_info.py
index 07a5eebbd..6ed568d72 100644
--- a/src/pystencils/jit/cpu/compiler_info.py
+++ b/src/pystencils/jit/cpu/compiler_info.py
@@ -1,4 +1,5 @@
 from __future__ import annotations
+from typing import Sequence
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 
@@ -21,13 +22,19 @@ class CompilerInfo(ABC):
     @abstractmethod
     def cxxflags(self) -> list[str]: ...
 
+    @abstractmethod
+    def linker_flags(self) -> list[str]: ...
+
+    @abstractmethod
+    def include_flags(self, include_dirs: Sequence[str]) -> list[str]: ...
+
     @abstractmethod
     def restrict_qualifier(self) -> str: ...
 
 
 class _GnuLikeCliCompiler(CompilerInfo):
     def cxxflags(self) -> list[str]:
-        flags = ["-DNDEBUG", f"-std={self.cxx_standard}"]
+        flags = ["-DNDEBUG", f"-std={self.cxx_standard}", "-fPIC"]
 
         if self.optlevel is not None:
             flags.append(f"-O{self.optlevel}")
@@ -48,6 +55,12 @@ class _GnuLikeCliCompiler(CompilerInfo):
                 flags += ["-march=x86-64-v4", "-mavx512fp16"]
 
         return flags
+    
+    def linker_flags(self) -> list[str]:
+        return ["-shared"]
+
+    def include_flags(self, include_dirs: Sequence[str]) -> list[str]:
+        return [f"-I{d}" for d in include_dirs]
 
     def restrict_qualifier(self) -> str:
         return "__restrict__"
diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index eb1b60abe..64045d549 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -18,7 +18,7 @@ class CpuJit(JitBase):
 
     :Creation:
     In most cases, objects of this class should be instantiated using the `create` factory method.
-    
+
     :Implementation Details:
 
     The `CpuJit` combines two separate components:
@@ -66,19 +66,21 @@ class CpuJit(JitBase):
         #   Include Directories
 
         import sysconfig
-
-        python_include = sysconfig.get_path("include")
-
         from ...include import get_pystencils_include_path
 
-        pystencils_include = get_pystencils_include_path()
+        include_dirs = [
+            sysconfig.get_path("include"),
+            get_pystencils_include_path(),
+        ] + self._ext_module_builder.include_dirs()
 
-        self._cxx_fixed_flags = [
-            "-shared",
-            "-fPIC",
-            f"-I{python_include}",
-            f"-I{pystencils_include}",
-        ] + self._ext_module_builder.include_flags()
+        #   Compiler Flags
+
+        self._cxx = self._compiler_info.cxx()
+        self._cxx_fixed_flags = (
+            self._compiler_info.cxxflags()
+            + self._compiler_info.include_flags(include_dirs)
+            + self._compiler_info.linker_flags()
+        )
 
     @property
     def objcache(self) -> Path | None:
@@ -153,9 +155,8 @@ class CpuJit(JitBase):
 
     def _compile_extension_module(self, src_file: Path, libfile: Path):
         args = (
-            [self._compiler_info.cxx()]
+            [self._cxx]
             + self._cxx_fixed_flags
-            + self._compiler_info.cxxflags()
             + ["-o", str(libfile), str(src_file)]
         )
 
@@ -182,10 +183,12 @@ class CpuJit(JitBase):
 class ExtensionModuleBuilderBase(ABC):
     @staticmethod
     @abstractmethod
-    def include_flags() -> list[str]: ...
+    def include_dirs() -> list[str]: ...
 
     @abstractmethod
     def __call__(self, kernel: Kernel, module_name: str) -> str: ...
 
     @abstractmethod
-    def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ...
+    def get_wrapper(
+        self, kernel: Kernel, extension_module: ModuleType
+    ) -> KernelWrapper: ...
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index ba9065c8a..2f43ef196 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -22,11 +22,11 @@ _module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp"
 
 class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
     @staticmethod
-    def include_flags() -> list[str]:
+    def include_dirs() -> list[str]:
         import pybind11 as pb11
 
         pybind11_include = pb11.get_include()
-        return [f"-I{pybind11_include}"]
+        return [pybind11_include]
 
     def __init__(
         self,
-- 
GitLab


From fc169886b25651349072f938ad8a8af807da61d1 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 15:06:05 +0100
Subject: [PATCH 12/17] add new CPU JIT classes to API reference pages

---
 docs/source/api/jit.md                    | 102 ++++++++++++++++++++++
 docs/source/api/jit.rst                   |  40 ---------
 src/pystencils/codegen/config.py          |   6 +-
 src/pystencils/jit/cpu/compiler_info.py   |  34 ++++++--
 src/pystencils/jit/cpu/cpujit.py          |  91 ++++++++++---------
 src/pystencils/jit/cpu/cpujit_pybind11.py |   2 +-
 6 files changed, 186 insertions(+), 89 deletions(-)
 create mode 100644 docs/source/api/jit.md
 delete mode 100644 docs/source/api/jit.rst

diff --git a/docs/source/api/jit.md b/docs/source/api/jit.md
new file mode 100644
index 000000000..06ea0cbaf
--- /dev/null
+++ b/docs/source/api/jit.md
@@ -0,0 +1,102 @@
+# JIT Compilation
+
+## Base Infrastructure
+
+```{eval-rst}
+.. module:: pystencils.jit
+
+.. autosummary::
+  :toctree: generated
+  :nosignatures:
+  :template: autosummary/entire_class.rst
+
+    KernelWrapper
+    JitBase
+    NoJit
+
+.. autodata:: no_jit
+```
+
+## Legacy CPU JIT
+
+The legacy CPU JIT Compiler is a leftover from pystencils 1.3
+which at the moment still drives most CPU JIT-compilation within the package,
+until the new JIT compiler is ready to take over.
+
+```{eval-rst}
+.. autosummary::
+  :toctree: generated
+  :nosignatures:
+  :template: autosummary/entire_class.rst
+
+  LegacyCpuJit
+```
+
+## CPU Just-In-Time Compiler
+
+:::{note}
+The new CPU JIT compiler is still considered experimental and not yet adopted by most of pystencils.
+While the APIs described here will (probably) become the default for pystencils 2.0
+and can (and should) already be used for testing,
+the current implementation is still *very slow*.
+For more information, see [issue !120](https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/120).
+:::
+
+To configure and create an instance of the CPU JIT compiler, use the `CpuJit.create` factory method:
+
+:::{card}
+```{eval-rst}
+.. autofunction:: pystencils.jit.CpuJit.create
+  :no-index:
+```
+:::
+
+### Compiler Infos
+
+The CPU JIT compiler invokes a host C++ compiler to compile and link a Python extension
+module containing the generated kernel.
+The properties of the host compiler are defined in a `CompilerInfo` object.
+To select a custom host compiler and customize its options, set up and pass
+a custom compiler info object to `CpuJit.create`.
+
+```{eval-rst}
+.. module:: pystencils.jit.cpu.compiler_info
+
+.. autosummary::
+  :toctree: generated
+  :nosignatures:
+  :template: autosummary/entire_class.rst
+
+  CompilerInfo
+  GccInfo
+  ClangInfo
+```
+
+### Implementation
+
+```{eval-rst}
+.. module:: pystencils.jit.cpu
+
+.. autosummary::
+  :toctree: generated
+  :nosignatures:
+  :template: autosummary/entire_class.rst
+
+  CpuJit
+  cpujit.ExtensionModuleBuilderBase
+```
+
+## CuPy-based GPU JIT
+
+```{eval-rst}
+.. module:: pystencils.jit.gpu_cupy
+
+.. autosummary::
+  :toctree: generated
+  :nosignatures:
+  :template: autosummary/entire_class.rst
+
+  CupyJit
+  CupyKernelWrapper
+  LaunchGrid
+```
diff --git a/docs/source/api/jit.rst b/docs/source/api/jit.rst
deleted file mode 100644
index f2e271db3..000000000
--- a/docs/source/api/jit.rst
+++ /dev/null
@@ -1,40 +0,0 @@
-JIT Compilation
-===============
-
-.. module:: pystencils.jit
-
-Base Infrastructure
--------------------
-
-.. autosummary::
-  :toctree: generated
-  :nosignatures:
-  :template: autosummary/entire_class.rst
-
-    KernelWrapper
-    JitBase
-    NoJit
-
-.. autodata:: no_jit
-
-Legacy CPU JIT
---------------
-
-.. autosummary::
-  :toctree: generated
-  :nosignatures:
-  :template: autosummary/entire_class.rst
-
-  LegacyCpuJit
-
-CuPy-based GPU JIT
-------------------
-
-.. autosummary::
-  :toctree: generated
-  :nosignatures:
-  :template: autosummary/entire_class.rst
-
-  CupyJit
-  CupyKernelWrapper
-  LaunchGrid
diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 6587f16b7..bce075731 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -4,7 +4,7 @@ from warnings import warn
 from abc import ABC
 from collections.abc import Collection
 
-from typing import TYPE_CHECKING, Sequence, Generic, TypeVar, Callable, Any, cast, TypeGuard
+from typing import TYPE_CHECKING, Sequence, Generic, TypeVar, Callable, Any, cast
 from dataclasses import dataclass, InitVar, fields
 
 from .target import Target
@@ -207,7 +207,9 @@ class Category(Generic[Category_T]):
         setattr(obj, self._lookup, cat.copy() if cat is not None else None)
 
 
-class _AUTO_TYPE: ...  # noqa: E701
+class _AUTO_TYPE:
+    def __repr__(self) -> str:
+        return "AUTO"  # for pretty-printing in the docs
 
 
 AUTO = _AUTO_TYPE()
diff --git a/src/pystencils/jit/cpu/compiler_info.py b/src/pystencils/jit/cpu/compiler_info.py
index 6ed568d72..061f37af5 100644
--- a/src/pystencils/jit/cpu/compiler_info.py
+++ b/src/pystencils/jit/cpu/compiler_info.py
@@ -8,28 +8,43 @@ from ...codegen.target import Target
 
 @dataclass
 class CompilerInfo(ABC):
+    """Base class for compiler infos."""
+
     openmp: bool = True
+    """Enable/disable OpenMP compilation"""
 
     optlevel: str | None = "fast"
+    """Compiler optimization level"""
 
     cxx_standard: str = "c++11"
+    """C++ language standard to be compiled with"""
 
     target: Target = Target.CurrentCPU
+    """Hardware target to compile for.
+    
+    Here, `Target.CurrentCPU` represents the current hardware,
+    which is reflected by ``-march=native`` in GNU-like compilers.
+    """
 
     @abstractmethod
-    def cxx(self) -> str: ...
+    def cxx(self) -> str:
+        """Path to the executable of this compiler"""
 
     @abstractmethod
-    def cxxflags(self) -> list[str]: ...
+    def cxxflags(self) -> list[str]:
+        """Compiler flags affecting C++ compilation"""
 
     @abstractmethod
-    def linker_flags(self) -> list[str]: ...
+    def linker_flags(self) -> list[str]:
+        """Flags affecting linkage of the extension module"""
 
     @abstractmethod
-    def include_flags(self, include_dirs: Sequence[str]) -> list[str]: ...
+    def include_flags(self, include_dirs: Sequence[str]) -> list[str]:
+        """Convert a list of include directories into corresponding compiler flags"""
 
     @abstractmethod
-    def restrict_qualifier(self) -> str: ...
+    def restrict_qualifier(self) -> str:
+        """*restrict* memory qualifier recognized by this compiler"""
 
 
 class _GnuLikeCliCompiler(CompilerInfo):
@@ -67,13 +82,18 @@ class _GnuLikeCliCompiler(CompilerInfo):
 
 
 class GccInfo(_GnuLikeCliCompiler):
+    """Compiler info for the GNU Compiler Collection C++ compiler (``g++``)."""
+
     def cxx(self) -> str:
         return "g++"
 
 
 @dataclass
 class ClangInfo(_GnuLikeCliCompiler):
+    """Compiler info for the LLVM C++ compiler (``clang``)."""
+    
     llvm_version: int | None = None
+    """Major version number of the LLVM installation providing the compiler."""
 
     def cxx(self) -> str:
         if self.llvm_version is None:
@@ -81,5 +101,5 @@ class ClangInfo(_GnuLikeCliCompiler):
         else:
             return f"clang-{self.llvm_version}"
         
-    def cxxflags(self) -> list[str]:
-        return super().cxxflags() + ["-lstdc++"]
+    def linker_flags(self) -> list[str]:
+        return super().linker_flags() + ["-lstdc++"]
diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index 64045d549..5fce101ee 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -16,24 +16,45 @@ from .compiler_info import CompilerInfo, GccInfo
 class CpuJit(JitBase):
     """Just-in-time compiler for CPU kernels.
 
-    :Creation:
-    In most cases, objects of this class should be instantiated using the `create` factory method.
-
-    :Implementation Details:
-
-    The `CpuJit` combines two separate components:
-    - The *extension module builder* produces the code of the dynamically built extension module containing
-      the kernel
-    - The *compiler info* describes the system compiler used to compile and link that extension module.
-
-    Both can be dynamically exchanged.
+    **Creation**
+    
+    To configure and create a CPU JIT compiler instance, use the `create` factory method.
+
+    **Implementation Details**
+
+    The `CpuJit` class acts as an orchestrator between two components:
+    
+    - The *extension module builder* produces the code of the dynamically built extension module
+      that contains the kernel and its invocation wrappers;
+    - The *compiler info* describes the host compiler used to compile and link that extension module.
+
+    Args:
+        compiler_info: The compiler info object defining the capabilities
+            and command-line interface of the host compiler
+        ext_module_builder: Extension module builder object used to generate the kernel extension module
+        objcache: Directory to cache the generated code files and compiled modules in.
+            If `None`, a temporary directory will be used, and compilation results will not be cached.
     """
 
     @staticmethod
     def create(
         compiler_info: CompilerInfo | None = None,
         objcache: str | Path | _AUTO_TYPE | None = AUTO,
-    ):
+    ) -> CpuJit:
+        """Configure and create a CPU JIT compiler object.
+        
+        Args:
+            compiler_info: Compiler info object defining capabilities and interface of the host compiler.
+                If `None`, a default compiler configuration will be determined from the current OS and runtime
+                environment.
+            objcache: Directory used for caching compilation results.
+                If set to `AUTO`, a persistent cache directory in the current user's home will be used.
+                If set to `None`, compilation results will not be cached--this may impact performance.
+
+        Returns:
+            The CPU just-in-time compiler.
+        """
+        
         if objcache is AUTO:
             from appdirs import AppDirs
 
@@ -82,34 +103,19 @@ class CpuJit(JitBase):
             + self._compiler_info.linker_flags()
         )
 
-    @property
-    def objcache(self) -> Path | None:
-        return self._objcache
-
-    @property
-    def compiler_info(self) -> CompilerInfo:
-        return self._compiler_info
-
-    @compiler_info.setter
-    def compiler_info(self, info: CompilerInfo):
-        self._compiler_info = info
-
-    @property
-    def strict_scalar_types(self) -> bool:
-        """Enable or disable implicit type casts for scalar parameters.
-
-        If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
+    def compile(self, kernel: Kernel) -> KernelWrapper:
+        """Compile the given kernel to an executable function.
+        
+        Args:
+            kernel: The kernel object to be compiled.
+        
+        Returns:
+            Wrapper object around the compiled function
         """
-        return self._strict_scalar_types
 
-    @strict_scalar_types.setter
-    def strict_scalar_types(self, v: bool):
-        self._strict_scalar_types = v
-
-    def compile(self, kernel: Kernel) -> KernelWrapper:
         #   Get the Code
         module_name = f"{kernel.function_name}_jit"
-        cpp_code = self._ext_module_builder(kernel, module_name)
+        cpp_code = self._ext_module_builder.render_module(kernel, module_name)
 
         #   Get compiler information
         import sysconfig
@@ -181,14 +187,21 @@ class CpuJit(JitBase):
 
 
 class ExtensionModuleBuilderBase(ABC):
+    """Base class for CPU extension module builders."""
+
     @staticmethod
     @abstractmethod
-    def include_dirs() -> list[str]: ...
+    def include_dirs() -> list[str]:
+        """List of directories that must be on the include path when compiling
+        generated extension modules."""
 
     @abstractmethod
-    def __call__(self, kernel: Kernel, module_name: str) -> str: ...
+    def render_module(self, kernel: Kernel, module_name: str) -> str:
+        """Produce the extension module code for the given kernel."""
 
     @abstractmethod
     def get_wrapper(
         self, kernel: Kernel, extension_module: ModuleType
-    ) -> KernelWrapper: ...
+    ) -> KernelWrapper:
+        """Produce the invocation wrapper for the given kernel
+        and its compiled extension module."""
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index 2f43ef196..1742bf301 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -43,7 +43,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
         self._param_check_lines: list[str]
         self._extraction_lines: list[str]
 
-    def __call__(self, kernel: Kernel, module_name: str) -> str:
+    def render_module(self, kernel: Kernel, module_name: str) -> str:
         self._actual_field_types = dict()
         self._param_binds = []
         self._public_params = []
-- 
GitLab


From 48be020aaf4730c9e8ed02f5a39eb135776a9372 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 16:22:50 +0100
Subject: [PATCH 13/17] add experimental-jit mode to test suite

---
 conftest.py                               | 23 +++++------------------
 src/pystencils/jit/cpu/cpujit.py          |  9 +++++----
 src/pystencils/jit/cpu/cpujit_pybind11.py |  7 +------
 tests/fixtures.py                         | 16 ++++++++++++++--
 4 files changed, 25 insertions(+), 30 deletions(-)

diff --git a/conftest.py b/conftest.py
index 4e8e2b73a..8296641ed 100644
--- a/conftest.py
+++ b/conftest.py
@@ -3,6 +3,7 @@ import runpy
 import sys
 import tempfile
 import warnings
+import pathlib
 
 import nbformat
 import pytest
@@ -185,24 +186,10 @@ class IPyNbFile(pytest.File):
         pass
 
 
-if pytest_version >= 70000:
-    #   Since pytest 7.0, usage of `py.path.local` is deprecated and `pathlib.Path` should be used instead
-    import pathlib
-
-    def pytest_collect_file(file_path: pathlib.Path, parent):
-        glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
-        if any(file_path.match(g) for g in glob_exprs):
-            return IPyNbFile.from_parent(path=file_path, parent=parent)
-
-else:
-
-    def pytest_collect_file(path, parent):
-        glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
-        if any(path.fnmatch(g) for g in glob_exprs):
-            if pytest_version >= 50403:
-                return IPyNbFile.from_parent(fspath=path, parent=parent)
-            else:
-                return IPyNbFile(path, parent)
+def pytest_collect_file(file_path: pathlib.Path, parent):
+    glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
+    if any(file_path.match(g) for g in glob_exprs):
+        return IPyNbFile.from_parent(path=file_path, parent=parent)
 
 
 #   Fixtures
diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py
index 5fce101ee..bddcc0bd9 100644
--- a/src/pystencils/jit/cpu/cpujit.py
+++ b/src/pystencils/jit/cpu/cpujit.py
@@ -114,7 +114,7 @@ class CpuJit(JitBase):
         """
 
         #   Get the Code
-        module_name = f"{kernel.function_name}_jit"
+        module_name = f"{kernel.name}_jit"
         cpp_code = self._ext_module_builder.render_module(kernel, module_name)
 
         #   Get compiler information
@@ -124,11 +124,12 @@ class CpuJit(JitBase):
         lib_suffix = f"{so_abi}.so"
 
         #   Compute Code Hash
-        code_utf8 = cpp_code.encode("utf-8")
+        code_utf8: bytes = cpp_code.encode("utf-8")
+        compiler_utf8: bytes = (" ".join([self._cxx] + self._cxx_fixed_flags)).encode("utf-8")
         import hashlib
 
-        code_hash = hashlib.sha256(code_utf8)
-        module_stem = f"module_{code_hash.hexdigest()}"
+        module_hash = hashlib.sha256(code_utf8 + compiler_utf8)
+        module_stem = f"module_{module_hash.hexdigest()}"
 
         def compile_and_load(module_dir: Path):
             cpp_file = module_dir / f"{module_stem}.cpp"
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index 1742bf301..a58e5bf03 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -31,12 +31,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
     def __init__(
         self,
         compiler_info: CompilerInfo,
-        strict_scalar_types: bool = False,
     ):
         self._compiler_info = compiler_info
 
-        self._strict_scalar_types = strict_scalar_types
-
         self._actual_field_types: dict[Field, PsType]
         self._param_binds: list[str]
         self._public_params: list[str]
@@ -63,7 +60,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
             includes="\n".join(includes),
             restrict_qualifier=self._compiler_info.restrict_qualifier(),
             module_name=module_name,
-            kernel_name=kernel.function_name,
+            kernel_name=kernel.name,
             param_binds=", ".join(self._param_binds),
             public_params=", ".join(self._public_params),
             param_check_lines=indent("\n".join(self._param_check_lines), prefix="    "),
@@ -118,8 +115,6 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
 
     def _add_scalar_param(self, sc_param: Parameter):
         param_bind = f'py::arg("{sc_param.name}")'
-        if self._strict_scalar_types:
-            param_bind += ".noconvert()"
         self._param_binds.append(param_bind)
 
         kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
diff --git a/tests/fixtures.py b/tests/fixtures.py
index 71e54bad8..ba2593f76 100644
--- a/tests/fixtures.py
+++ b/tests/fixtures.py
@@ -15,7 +15,6 @@ by your tests:
 import pytest
 
 from types import ModuleType
-from dataclasses import replace
 
 import pystencils as ps
 
@@ -32,6 +31,14 @@ AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets()
 TARGET_IDS = [t.name for t in AVAILABLE_TARGETS]
 
 
+def pytest_addoption(parser: pytest.Parser):
+    parser.addoption(
+        "--experimental-cpu-jit",
+        dest="experimental_cpu_jit",
+        action="store_true"
+    )
+
+
 @pytest.fixture(params=AVAILABLE_TARGETS, ids=TARGET_IDS)
 def target(request) -> ps.Target:
     """Provides all code generation targets available on the current hardware"""
@@ -39,7 +46,7 @@ def target(request) -> ps.Target:
 
 
 @pytest.fixture
-def gen_config(target: ps.Target):
+def gen_config(request: pytest.FixtureRequest, target: ps.Target):
     """Default codegen configuration for the current target.
 
     For GPU targets, set default indexing options.
@@ -52,6 +59,11 @@ def gen_config(target: ps.Target):
         gen_config.cpu.vectorize.enable = True
         gen_config.cpu.vectorize.assume_inner_stride_one = True
 
+    if target.is_cpu() and request.config.getoption("experimental_cpu_jit"):
+        from pystencils.jit.cpu import CpuJit, GccInfo
+
+        gen_config.jit = CpuJit.create(compiler_info=GccInfo(target=target))
+
     return gen_config
 
 
-- 
GitLab


From 2e8486bb8ac3069d5dd3906c07de4328738b4646 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 17:43:44 +0100
Subject: [PATCH 14/17] Upgrade noxfile to pass params through to pytest. Add
 testing section to contrib guide. Fix a weird error in the legacy jit.

---
 docs/source/contributing/dev-workflow.md   | 18 +++--
 docs/source/contributing/index.md          |  3 +-
 docs/source/contributing/testing.md        | 77 ++++++++++++++++++++++
 noxfile.py                                 |  9 ++-
 src/pystencils/jit/cpu/cpujit_pybind11.py  |  4 +-
 src/pystencils/jit/cpu_extension_module.py |  4 +-
 6 files changed, 100 insertions(+), 15 deletions(-)
 create mode 100644 docs/source/contributing/testing.md

diff --git a/docs/source/contributing/dev-workflow.md b/docs/source/contributing/dev-workflow.md
index fe8b70e77..8daac8cbd 100644
--- a/docs/source/contributing/dev-workflow.md
+++ b/docs/source/contributing/dev-workflow.md
@@ -127,18 +127,16 @@ If you think a new module is ready to be type-checked, add an exception clause t
 ## Running the Test Suite
 
 Pystencils comes with an extensive and steadily growing suite of unit tests.
-To run the testsuite, you may invoke a variant of the Nox `testsuite` session.
-There are multiple different versions of the `testsuite` session, depending on whether you are testing with our
-without CUDA, or which version of Python you wish to test with.
-You can list the available sessions using `nox -l`.
-Select one of the `testsuite` variants and run it via `nox -s "testsuite(<variant>)"`, e.g.
-```
-nox -s "testsuite(cpu)"
+To run the full testsuite, invoke the Nox `testsuite` session:
+
+```bash
+nox -s testsuite
 ```
-for the CPU-only suite.
 
-During the testsuite run, coverage information is collected and displayed using [coverage.py](https://coverage.readthedocs.io/en/7.6.10/).
-You can display a detailed overview of code coverage by opening the generated `htmlcov/index.html` page.
+:::{seealso}
+[](#testing_pystencils)
+:::
+
 
 ## Building the Documentation
 
diff --git a/docs/source/contributing/index.md b/docs/source/contributing/index.md
index 04ad821ce..56c97509c 100644
--- a/docs/source/contributing/index.md
+++ b/docs/source/contributing/index.md
@@ -7,6 +7,7 @@ Pystencils is an open-source package licensed under the [AGPL v3](https://www.gn
 As such, the act of contributing to pystencils by submitting a merge request is taken as agreement to the terms of the licence.
 
 :::{toctree}
-:maxdepth: 2    
+:maxdepth: 2
 dev-workflow
+testing
 :::
diff --git a/docs/source/contributing/testing.md b/docs/source/contributing/testing.md
new file mode 100644
index 000000000..da8ca2d16
--- /dev/null
+++ b/docs/source/contributing/testing.md
@@ -0,0 +1,77 @@
+(testing_pystencils)=
+# Testing pystencils
+
+The pystencils testsuite is located at the `tests` directory,
+constructed using [pytest](https://pytest.org),
+and automated through [Nox](https://nox.thea.codes).
+On this page, you will find instructions on how to execute and extend it.
+
+## Running the Testsuite
+
+The fastest way to execute the pystencils test suite is through the `testsuite` Nox session:
+
+```bash
+nox -s testsuite
+```
+
+There exist several configurations of the testsuite session, from which the above command will
+select and execute only those that are available on your machine.
+ - *Python Versions:* The testsuite session can be run against all major Python versions between 3.10 and 3.13 (inclusive).
+   To only use a specific Python version, add the `-p 3.XX` argument to your Nox invocation; e.g. `nox -s testsuite -p 3.11`.
+ - *CuPy:* There exist three variants of `testsuite`, including or excluding tests for the CUDA GPU target: `cpu`, `cupy12` and `cupy13`.
+   To select one, append `(<variant>)` to the session name; e.g. `nox -s "testsuite(cupy12)"`.
+
+You may also pass options through to pytest via positional arguments after a pair of dashes, e.g.:
+
+```bash
+nox -s testsuite -- -k "kernelcreation"
+```
+
+During the testsuite run, coverage information is collected using [coverage.py](https://coverage.readthedocs.io/en/7.6.10/),
+and the results are exported to HTML.
+You can display a detailed overview of code coverage by opening the generated `htmlcov/index.html` page.
+
+## Extending the Test Suite
+
+### Codegen Configurations via Fixtures
+
+In the pystencils test suite, it is often necessary to test code generation features against multiple targets.
+To simplify this process, we provide a number of [pytest fixtures](https://docs.pytest.org/en/stable/how-to/fixtures.html)
+you can and should use in your tests:
+
+ - `target`: Provides code generation targets for your test.
+   Using this fixture will make pytest create a copy of your test for each target
+   available on the current machine (see {any}`Target.available_targets`).
+ - `gen_config`: Provides default code generation configurations for your test.
+   This fixture depends on `target` and provides a {any}`CreateKernelConfig` instance
+   with target-specific optimization options (in particular vectorization) enabled.
+ - `xp`: The `xp` fixture gives you either the *NumPy* (`np`) or the *CuPy* (`cp`) module,
+   depending on whether `target` is a CPU or GPU target.
+
+These fixtures are defined in `tests/fixtures.py`.
+
+### Overriding Fixtures
+
+Pytest allows you to locally override fixtures, which can be especially practical when you wish
+to restrict the target selection of a test.
+For example, the following test overrides `target` using a parametrization mark,
+and uses this in combination with the `gen_config` fixture, which now
+receives the overridden `target` parameter as input:
+
+```Python
+@pytest.mark.parametrize("target", [Target.X86_SSE, Target.X86_AVX])
+def test_bogus(gen_config):
+    assert gen_config.target.is_vector_cpu()
+```
+
+## Testing with the Experimental CPU JIT
+
+Currently, the testsuite by default still uses the {any}`legacy CPU JIT compiler <LegacyCpuJit>`,
+since the new CPU JIT compiler is still in an experimental stage.
+To test your code against the new JIT compiler, pass the `--experimental-cpu-jit` option to pytest:
+
+```bash
+nox -s testsuite -- --experimental-cpu-jit
+```
+
+This will alter the `gen_config` fixture, activating the experimental CPU JIT for CPU targets.
diff --git a/noxfile.py b/noxfile.py
index 11b3731ec..f5b20da4f 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -86,9 +86,15 @@ def typecheck(session: nox.Session):
     session.run("mypy", "src/pystencils")
 
 
-@nox.session(python=["3.10", "3.12", "3.13"], tags=["test"])
+@nox.session(python=["3.10", "3.11", "3.12", "3.13"], tags=["test"])
 @nox.parametrize("cupy_version", [None, "12", "13"], ids=["cpu", "cupy12", "cupy13"])
 def testsuite(session: nox.Session, cupy_version: str | None):
+    """Run the pystencils test suite.
+    
+    **Positional Arguments:** Any positional arguments passed to nox after `--`
+    are propagated to pytest.
+    """
+
     if cupy_version is not None:
         install_cupy(session, cupy_version, skip_if_no_cuda=True)
 
@@ -108,6 +114,7 @@ def testsuite(session: nox.Session, cupy_version: str | None):
         "--html",
         "test-report/index.html",
         "--junitxml=report.xml",
+        *session.posargs
     )
     session.run("coverage", "html")
     session.run("coverage", "xml")
diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py
index a58e5bf03..90224b22b 100644
--- a/src/pystencils/jit/cpu/cpujit_pybind11.py
+++ b/src/pystencils/jit/cpu/cpujit_pybind11.py
@@ -51,7 +51,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
 
         kernel_def = self._get_kernel_definition(kernel)
         kernel_args = [param.name for param in kernel.parameters]
-        includes = [f"#include {h}" for h in kernel.required_headers]
+        includes = [f"#include {h}" for h in sorted(kernel.required_headers)]
 
         from string import Template
 
@@ -76,7 +76,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
     def _get_kernel_definition(self, kernel: Kernel) -> str:
         from ...backend.emission import CAstPrinter
 
-        printer = CAstPrinter(func_prefix="inline")
+        printer = CAstPrinter()
 
         return printer(kernel)
 
diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py
index 55f1961ca..fca043db9 100644
--- a/src/pystencils/jit/cpu_extension_module.py
+++ b/src/pystencils/jit/cpu_extension_module.py
@@ -91,12 +91,14 @@ class PsKernelExtensioNModule:
         code += "\n"
 
         #   Kernels and call wrappers
+        from ..backend.emission import CAstPrinter
+        printer = CAstPrinter(func_prefix="FUNC_PREFIX")
 
         for name, kernel in self._kernels.items():
             old_name = kernel.name
             kernel.name = f"kernel_{name}"
 
-            code += kernel.get_c_code()
+            code += printer(kernel)
             code += "\n"
             code += emit_call_wrapper(name, kernel)
             code += "\n"
-- 
GitLab


From 908c91de8949914b7b6cfd927c30eac1aeb40239 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 17:54:25 +0100
Subject: [PATCH 15/17] add CI task with experimental JIT

---
 .gitlab-ci.yml | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index cc73eb5aa..54fe457b2 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -313,6 +313,15 @@ typecheck:
     - docker
     - AVX
 
+"testsuite-experimental-jit-py3.10":
+  extends: .testsuite-base
+  image: i10git.cs.fau.de:5005/pycodegen/pycodegen/nox:alpine
+  script:
+    - nox -s "testsuite(cpu)" -p 3.10 -- --experimental-cpu-jit
+  tags:
+    - docker
+    - AVX
+
 # -------------------- Documentation ---------------------------------------------------------------------
 
 
-- 
GitLab


From 5da6ba2830d6a6c4179936591f32b1d5f51b7a22 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 30 Jan 2025 17:55:13 +0100
Subject: [PATCH 16/17] fix CI

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 54fe457b2..474afbb11 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -317,7 +317,7 @@ typecheck:
   extends: .testsuite-base
   image: i10git.cs.fau.de:5005/pycodegen/pycodegen/nox:alpine
   script:
-    - nox -s "testsuite(cpu)" -p 3.10 -- --experimental-cpu-jit
+    - nox -s "testsuite-3.10(cpu)" -- --experimental-cpu-jit
   tags:
     - docker
     - AVX
-- 
GitLab


From 56ab4abcec366edfd4a7243aa953b253651b9198 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Fri, 31 Jan 2025 10:15:22 +0100
Subject: [PATCH 17/17] Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Daniel Bauer <daniel.j.bauer@fau.de>
---
 docs/source/contributing/testing.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/docs/source/contributing/testing.md b/docs/source/contributing/testing.md
index da8ca2d16..b9c93b0d2 100644
--- a/docs/source/contributing/testing.md
+++ b/docs/source/contributing/testing.md
@@ -52,7 +52,7 @@ These fixtures are defined in `tests/fixtures.py`.
 
 ### Overriding Fixtures
 
-Pytest allows you to locally override fixtures, which can be especially practical when you wish
+Pytest allows you to locally override fixtures, which can be especially handy when you wish
 to restrict the target selection of a test.
 For example, the following test overrides `target` using a parametrization mark,
 and uses this in combination with the `gen_config` fixture, which now
-- 
GitLab