Skip to content
Snippets Groups Projects
Commit fb3243dc authored by Frederik Hennig's avatar Frederik Hennig
Browse files

modularize CPU jit

parent e5eafb79
No related branches found
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Pipeline #72901 passed
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, TypeGuard
from warnings import warn from warnings import warn
from collections.abc import Collection from collections.abc import Collection
...@@ -28,7 +28,10 @@ class PsOptionsError(Exception): ...@@ -28,7 +28,10 @@ class PsOptionsError(Exception):
"""Indicates an option clash in the `CreateKernelConfig`.""" """Indicates an option clash in the `CreateKernelConfig`."""
class _AUTO_TYPE: ... # noqa: E701 class _AUTO_TYPE:
@staticmethod
def is_auto(val: Any) -> TypeGuard[_AUTO_TYPE]:
return val == AUTO
AUTO = _AUTO_TYPE() AUTO = _AUTO_TYPE()
......
from .compiler_info import GccInfo, ClangInfo from .compiler_info import GccInfo, ClangInfo
from .cpu_pybind11 import CpuJit from .cpujit import CpuJit
__all__ = [ __all__ = [
"GccInfo", "GccInfo",
......
from __future__ import annotations
from types import ModuleType
from pathlib import Path
import subprocess
from copy import copy
from abc import ABC, abstractmethod
from ...codegen.config import _AUTO_TYPE, AUTO
from ..jit import JitError, JitBase, KernelWrapper
from ...codegen import Kernel
from .compiler_info import CompilerInfo, GccInfo
class CpuJit(JitBase):
"""Just-in-time compiler for CPU kernels.
:Creation:
In most cases, objects of this class should be instantiated using the `create` factory method.
:Implementation Details:
The `CpuJit` combines two separate components:
- The *extension module builder* produces the code of the dynamically built extension module containing
the kernel
- The *compiler info* describes the system compiler used to compile and link that extension module.
Both can be dynamically exchanged.
"""
@staticmethod
def create(
compiler_info: CompilerInfo | None = None,
objcache: str | Path | _AUTO_TYPE | None = AUTO,
):
if objcache is AUTO:
from appdirs import AppDirs
dirs = AppDirs(appname="pystencils")
objcache = Path(dirs.user_cache_dir) / "cpujit"
elif objcache is not None:
assert not isinstance(objcache, _AUTO_TYPE)
objcache = Path(objcache)
if compiler_info is None:
compiler_info = GccInfo()
from .cpujit_pybind11 import Pybind11KernelModuleBuilder
modbuilder = Pybind11KernelModuleBuilder(compiler_info)
return CpuJit(compiler_info, modbuilder, objcache)
def __init__(
self,
compiler_info: CompilerInfo,
ext_module_builder: ExtensionModuleBuilderBase,
objcache: Path | None,
):
self._compiler_info = copy(compiler_info)
self._ext_module_builder = ext_module_builder
self._objcache = objcache
# Include Directories
import sysconfig
python_include = sysconfig.get_path("include")
from ...include import get_pystencils_include_path
pystencils_include = get_pystencils_include_path()
self._cxx_fixed_flags = [
"-shared",
"-fPIC",
f"-I{python_include}",
f"-I{pystencils_include}",
] + self._ext_module_builder.include_flags()
@property
def objcache(self) -> Path | None:
return self._objcache
@property
def compiler_info(self) -> CompilerInfo:
return self._compiler_info
@compiler_info.setter
def compiler_info(self, info: CompilerInfo):
self._compiler_info = info
@property
def strict_scalar_types(self) -> bool:
"""Enable or disable implicit type casts for scalar parameters.
If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
"""
return self._strict_scalar_types
@strict_scalar_types.setter
def strict_scalar_types(self, v: bool):
self._strict_scalar_types = v
def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
# Get the Code
module_name = f"{kernel.function_name}_jit"
cpp_code = self._ext_module_builder(kernel, module_name)
# Get compiler information
import sysconfig
so_abi = sysconfig.get_config_var("SOABI")
lib_suffix = f"{so_abi}.so"
# Compute Code Hash
code_utf8 = cpp_code.encode("utf-8")
import hashlib
code_hash = hashlib.sha256(code_utf8)
module_stem = f"module_{code_hash.hexdigest()}"
def compile_and_load(module_dir: Path):
cpp_file = module_dir / f"{module_stem}.cpp"
if not cpp_file.exists():
cpp_file.write_bytes(code_utf8)
lib_file = module_dir / f"{module_stem}.{lib_suffix}"
if not lib_file.exists():
self._compile_extension_module(cpp_file, lib_file)
module = self._load_extension_module(module_name, lib_file)
return module
if self._objcache is not None:
module_dir = self._objcache
# Lock module
import fasteners
lockfile = module_dir / f"{module_stem}.lock"
with fasteners.InterProcessLock(lockfile):
module = compile_and_load(module_dir)
else:
from tempfile import TemporaryDirectory
with TemporaryDirectory() as tmpdir:
module_dir = Path(tmpdir)
module = compile_and_load(module_dir)
return CpuJitKernelWrapper(kernel, module)
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
[self._compiler_info.cxx()]
+ self._cxx_fixed_flags
+ self._compiler_info.cxxflags()
+ ["-o", str(libfile), str(src_file)]
)
result = subprocess.run(args, capture_output=True)
if result.returncode != 0:
raise JitError(
"Compilation failed: C++ compiler terminated with an error.\n"
+ result.stderr.decode()
)
def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
from importlib import util as iutil
spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
if spec is None:
raise JitError(
"Unable to load kernel extension module -- this is probably a bug."
)
mod = iutil.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
class CpuJitKernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._wrapper_func = getattr(jit_module, kernel.function_name)
def __call__(self, **kwargs) -> None:
return self._wrapper_func(**kwargs)
class ExtensionModuleBuilderBase(ABC):
@staticmethod
@abstractmethod
def include_flags() -> list[str]: ...
@abstractmethod
def __call__(self, kernel: Kernel, module_name: str) -> str: ...
from __future__ import annotations from __future__ import annotations
from typing import Sequence, cast from typing import Sequence, cast
from types import ModuleType
from pathlib import Path from pathlib import Path
from textwrap import indent from textwrap import indent
import subprocess
from copy import copy
from ...types import PsPointerType, PsType from ...types import PsPointerType, PsType
from ...field import Field from ...field import Field
from ...sympyextensions import DynamicType from ...sympyextensions import DynamicType
from ..jit import KernelWrapper
from ...codegen import Kernel, Parameter from ...codegen import Kernel, Parameter
from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
from ..jit import JitError, JitBase from .compiler_info import CompilerInfo
from .compiler_info import CompilerInfo, GccInfo from .cpujit import ExtensionModuleBuilderBase
_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp" _module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp"
class CpuJit(JitBase): class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
"""Just-in-time compiler for CPU kernels.""" @staticmethod
def include_flags() -> list[str]:
import pybind11 as pb11
pybind11_include = pb11.get_include()
return [f"-I{pybind11_include}"]
def __init__( def __init__(
self, self,
compiler_info: CompilerInfo | None = None, compiler_info: CompilerInfo,
objcache: str | Path | None = None,
strict_scalar_types: bool = False, strict_scalar_types: bool = False,
): ):
self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo() self._compiler_info = compiler_info
self._strict_scalar_types = strict_scalar_types
if objcache is None:
from appdirs import AppDirs
dirs = AppDirs(appname="pystencils")
self._objcache = Path(dirs.user_cache_dir) / "cpujit"
else:
self._objcache = Path(objcache)
# Include Directories
import pybind11 as pb11
pybind11_include = pb11.get_include()
import sysconfig
python_include = sysconfig.get_path("include")
from ...include import get_pystencils_include_path
pystencils_include = get_pystencils_include_path()
self._cxx_fixed_flags = [
"-shared",
"-fPIC",
f"-I{python_include}",
f"-I{pybind11_include}",
f"-I{pystencils_include}"
]
@property
def objcache(self) -> Path:
return self._objcache
@property
def compiler_info(self) -> CompilerInfo:
return self._compiler_info
@compiler_info.setter
def compiler_info(self, info: CompilerInfo):
self._compiler_info = info
@property
def strict_scalar_types(self) -> bool:
"""Enable or disable implicit type casts for scalar parameters.
If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type.
"""
return self._strict_scalar_types
@strict_scalar_types.setter
def strict_scalar_types(self, v: bool):
self._strict_scalar_types = v
@property
def restrict_qualifier(self) -> str:
return self._restrict_qualifier
@restrict_qualifier.setter
def restrict_qualifier(self, qual: str):
self._restrict_qualifier = qual
def compile(self, kernel: Kernel) -> CpuJitKernelWrapper:
# Get the Code
module_name = f"{kernel.function_name}_jit"
modbuilder = KernelModuleBuilder(self, module_name)
cpp_code = modbuilder(kernel)
# Get compiler information
import sysconfig
so_abi = sysconfig.get_config_var("SOABI")
lib_suffix = f"{so_abi}.so"
# Compute Code Hash
code_utf8 = cpp_code.encode("utf-8")
import hashlib
code_hash = hashlib.sha256(code_utf8)
module_stem = f"module_{code_hash.hexdigest()}"
module_dir = self._objcache
# Lock module
import fasteners
lockfile = module_dir / f"{module_stem}.lock"
with fasteners.InterProcessLock(lockfile):
cpp_file = module_dir / f"{module_stem}.cpp"
if not cpp_file.exists():
cpp_file.write_bytes(code_utf8)
lib_file = module_dir / f"{module_stem}.{lib_suffix}"
if not lib_file.exists():
self._compile_extension_module(cpp_file, lib_file)
module = self._load_extension_module(module_name, lib_file)
return CpuJitKernelWrapper(kernel, module)
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
[self._compiler_info.cxx()]
+ self._cxx_fixed_flags
+ self._compiler_info.cxxflags()
+ ["-o", str(libfile), str(src_file)]
)
result = subprocess.run(args, capture_output=True)
if result.returncode != 0:
raise JitError(
"Compilation failed: C++ compiler terminated with an error.\n"
+ result.stderr.decode()
)
def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
from importlib import util as iutil
spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
if spec is None:
raise JitError("Unable to load kernel extension module -- this is probably a bug.")
mod = iutil.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
class CpuJitKernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._wrapper_func = getattr(jit_module, kernel.function_name)
def __call__(self, **kwargs) -> None:
return self._wrapper_func(**kwargs)
self._strict_scalar_types = strict_scalar_types
class KernelModuleBuilder: self._actual_field_types: dict[Field, PsType]
def __init__(self, jit: CpuJit, module_name: str): self._param_binds: list[str]
self._jit = jit self._public_params: list[str]
self._module_name = module_name self._extraction_lines: list[str]
self._actual_field_types: dict[Field, PsType] = dict() def __call__(self, kernel: Kernel, module_name: str) -> str:
self._param_binds: list[str] = [] self._actual_field_types = dict()
self._public_params: list[str] = [] self._param_binds = []
self._extraction_lines: list[str] = [] self._public_params = []
self._extraction_lines = []
def __call__(self, kernel: Kernel) -> str:
self._handle_params(kernel.parameters) self._handle_params(kernel.parameters)
kernel_def = self._get_kernel_definition(kernel) kernel_def = self._get_kernel_definition(kernel)
kernel_args = [param.name for param in kernel.parameters] kernel_args = [param.name for param in kernel.parameters]
includes = [f"#include {h}" for h in kernel.required_headers] includes = [f"#include {h}" for h in kernel.required_headers]
...@@ -190,8 +56,8 @@ class KernelModuleBuilder: ...@@ -190,8 +56,8 @@ class KernelModuleBuilder:
templ = Template(_module_template.read_text()) templ = Template(_module_template.read_text())
code_str = templ.substitute( code_str = templ.substitute(
includes="\n".join(includes), includes="\n".join(includes),
restrict_qualifier=self._jit.compiler_info.restrict_qualifier(), restrict_qualifier=self._compiler_info.restrict_qualifier(),
module_name=self._module_name, module_name=module_name,
kernel_name=kernel.function_name, kernel_name=kernel.function_name,
param_binds=", ".join(self._param_binds), param_binds=", ".join(self._param_binds),
public_params=", ".join(self._public_params), public_params=", ".join(self._public_params),
...@@ -203,6 +69,7 @@ class KernelModuleBuilder: ...@@ -203,6 +69,7 @@ class KernelModuleBuilder:
def _get_kernel_definition(self, kernel: Kernel) -> str: def _get_kernel_definition(self, kernel: Kernel) -> str:
from ...backend.emission import CAstPrinter from ...backend.emission import CAstPrinter
printer = CAstPrinter(func_prefix="inline") printer = CAstPrinter(func_prefix="inline")
return printer(kernel) return printer(kernel)
...@@ -228,7 +95,7 @@ class KernelModuleBuilder: ...@@ -228,7 +95,7 @@ class KernelModuleBuilder:
def _add_scalar_param(self, sc_param: Parameter): def _add_scalar_param(self, sc_param: Parameter):
param_bind = f'py::arg("{sc_param.name}")' param_bind = f'py::arg("{sc_param.name}")'
if self._jit.strict_scalar_types: if self._strict_scalar_types:
param_bind += ".noconvert()" param_bind += ".noconvert()"
self._param_binds.append(param_bind) self._param_binds.append(param_bind)
......
...@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit ...@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit
def test_basic_cpu_kernel(tmp_path): def test_basic_cpu_kernel(tmp_path):
jit = CpuJit(objcache=tmp_path) jit = CpuJit.create(objcache=tmp_path)
f, g = fields("f, g: [2D]") f, g = fields("f, g: [2D]")
asm = Assignment(f.center(), 2.0 * g.center()) asm = Assignment(f.center(), 2.0 * g.center())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment