diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py index b3a9e48aaf678189dbcbb514b5e61e3c115fbaca..eb1b60abec079b55a52d11fed1ed85ac1e85f017 100644 --- a/src/pystencils/jit/cpu/cpujit.py +++ b/src/pystencils/jit/cpu/cpujit.py @@ -104,7 +104,7 @@ class CpuJit(JitBase): def strict_scalar_types(self, v: bool): self._strict_scalar_types = v - def compile(self, kernel: Kernel) -> CpuJitKernelWrapper: + def compile(self, kernel: Kernel) -> KernelWrapper: # Get the Code module_name = f"{kernel.function_name}_jit" cpp_code = self._ext_module_builder(kernel, module_name) @@ -149,7 +149,7 @@ class CpuJit(JitBase): module_dir = Path(tmpdir) module = compile_and_load(module_dir) - return CpuJitKernelWrapper(kernel, module) + return self._ext_module_builder.get_wrapper(kernel, module) def _compile_extension_module(self, src_file: Path, libfile: Path): args = ( @@ -179,18 +179,6 @@ class CpuJit(JitBase): return mod -class CpuJitKernelWrapper(KernelWrapper): - def __init__(self, kernel: Kernel, jit_module: ModuleType): - super().__init__(kernel) - self._module = jit_module - self._check_params = getattr(jit_module, "check_params") - self._invoke = getattr(jit_module, "invoke") - - def __call__(self, **kwargs) -> None: - self._check_params(**kwargs) - return self._invoke(**kwargs) - - class ExtensionModuleBuilderBase(ABC): @staticmethod @abstractmethod @@ -198,3 +186,6 @@ class ExtensionModuleBuilderBase(ABC): @abstractmethod def __call__(self, kernel: Kernel, module_name: str) -> str: ... + + @abstractmethod + def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ... diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py index eff3a061f2f4fef4d03a56183553ebedd33d0c6d..ba9065c8a552c4ad70ce6179ce5339daef637253 100644 --- a/src/pystencils/jit/cpu/cpujit_pybind11.py +++ b/src/pystencils/jit/cpu/cpujit_pybind11.py @@ -1,9 +1,12 @@ from __future__ import annotations +from types import ModuleType from typing import Sequence, cast from pathlib import Path from textwrap import indent +from pystencils.jit.jit import KernelWrapper + from ...types import PsPointerType, PsType from ...field import Field from ...sympyextensions import DynamicType @@ -69,6 +72,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): kernel_definition=kernel_def, ) return code_str + + def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: + return Pybind11KernelWrapper(kernel, extension_module) def _get_kernel_definition(self, kernel: Kernel) -> str: from ...backend.emission import CAstPrinter @@ -158,3 +164,15 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): self._extract_stride(param, cast(FieldStride, stride_props.pop())) else: self._add_scalar_param(param) + + +class Pybind11KernelWrapper(KernelWrapper): + def __init__(self, kernel: Kernel, jit_module: ModuleType): + super().__init__(kernel) + self._module = jit_module + self._check_params = getattr(jit_module, "check_params") + self._invoke = getattr(jit_module, "invoke") + + def __call__(self, **kwargs) -> None: + self._check_params(**kwargs) + return self._invoke(**kwargs)