Skip to content
Snippets Groups Projects
Commit fd31ce64 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

move kernel wrapper creation to the module builder

parent 4dcb81ac
No related branches found
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Pipeline #72944 passed
......@@ -104,7 +104,7 @@ class CpuJit(JitBase):
def strict_scalar_types(self, v: bool):
self._strict_scalar_types = v
def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
def compile(self, kernel: Kernel) -> KernelWrapper:
# Get the Code
module_name = f"{kernel.function_name}_jit"
cpp_code = self._ext_module_builder(kernel, module_name)
......@@ -149,7 +149,7 @@ class CpuJit(JitBase):
module_dir = Path(tmpdir)
module = compile_and_load(module_dir)
return CpuJitKernelWrapper(kernel, module)
return self._ext_module_builder.get_wrapper(kernel, module)
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
......@@ -179,18 +179,6 @@ class CpuJit(JitBase):
return mod
class CpuJitKernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._check_params = getattr(jit_module, "check_params")
self._invoke = getattr(jit_module, "invoke")
def __call__(self, **kwargs) -> None:
self._check_params(**kwargs)
return self._invoke(**kwargs)
class ExtensionModuleBuilderBase(ABC):
@staticmethod
@abstractmethod
......@@ -198,3 +186,6 @@ class ExtensionModuleBuilderBase(ABC):
@abstractmethod
def __call__(self, kernel: Kernel, module_name: str) -> str: ...
@abstractmethod
def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ...
from __future__ import annotations
from types import ModuleType
from typing import Sequence, cast
from pathlib import Path
from textwrap import indent
from pystencils.jit.jit import KernelWrapper
from ...types import PsPointerType, PsType
from ...field import Field
from ...sympyextensions import DynamicType
......@@ -69,6 +72,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
kernel_definition=kernel_def,
)
return code_str
def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper:
return Pybind11KernelWrapper(kernel, extension_module)
def _get_kernel_definition(self, kernel: Kernel) -> str:
from ...backend.emission import CAstPrinter
......@@ -158,3 +164,15 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
self._extract_stride(param, cast(FieldStride, stride_props.pop()))
else:
self._add_scalar_param(param)
class Pybind11KernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._check_params = getattr(jit_module, "check_params")
self._invoke = getattr(jit_module, "invoke")
def __call__(self, **kwargs) -> None:
self._check_params(**kwargs)
return self._invoke(**kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment