Skip to content
Snippets Groups Projects

Object-Oriented CPU JIT API and Prototype Implementation

Merged Frederik Hennig requested to merge fhennig/pybind11-jit into v2.0-dev
Files
4
@@ -104,7 +104,7 @@ class CpuJit(JitBase):
def strict_scalar_types(self, v: bool):
self._strict_scalar_types = v
def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
def compile(self, kernel: Kernel) -> KernelWrapper:
# Get the Code
module_name = f"{kernel.function_name}_jit"
cpp_code = self._ext_module_builder(kernel, module_name)
@@ -149,7 +149,7 @@ class CpuJit(JitBase):
module_dir = Path(tmpdir)
module = compile_and_load(module_dir)
return CpuJitKernelWrapper(kernel, module)
return self._ext_module_builder.get_wrapper(kernel, module)
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
@@ -179,18 +179,6 @@ class CpuJit(JitBase):
return mod
class CpuJitKernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._check_params = getattr(jit_module, "check_params")
self._invoke = getattr(jit_module, "invoke")
def __call__(self, **kwargs) -> None:
self._check_params(**kwargs)
return self._invoke(**kwargs)
class ExtensionModuleBuilderBase(ABC):
@staticmethod
@abstractmethod
@@ -198,3 +186,6 @@ class ExtensionModuleBuilderBase(ABC):
@abstractmethod
def __call__(self, kernel: Kernel, module_name: str) -> str: ...
@abstractmethod
def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ...