diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py index 3a7647907b82a4ee1ddbb72c2c700e42c7547f69..209df1b0e81e3893d69e532dbcd978203b2907d1 100644 --- a/src/pystencils/codegen/config.py +++ b/src/pystencils/codegen/config.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, TypeGuard from warnings import warn from collections.abc import Collection @@ -28,7 +28,10 @@ class PsOptionsError(Exception): """Indicates an option clash in the `CreateKernelConfig`.""" -class _AUTO_TYPE: ... # noqa: E701 +class _AUTO_TYPE: + @staticmethod + def is_auto(val: Any) -> TypeGuard[_AUTO_TYPE]: + return val == AUTO AUTO = _AUTO_TYPE() diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py index 335ec52d4fae2b744eb0cb17cf4c9c42bcdda94e..5e4dfcbf70e89c543b6da6053709c75bfe421e25 100644 --- a/src/pystencils/jit/cpu/__init__.py +++ b/src/pystencils/jit/cpu/__init__.py @@ -1,5 +1,5 @@ from .compiler_info import GccInfo, ClangInfo -from .cpu_pybind11 import CpuJit +from .cpujit import CpuJit __all__ = [ "GccInfo", diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py deleted file mode 100644 index 0229c5c4d671a412fe3ba3c26711c703459a9d5d..0000000000000000000000000000000000000000 --- a/src/pystencils/jit/cpu/cpu_pybind11.py +++ /dev/null @@ -1,276 +0,0 @@ -from __future__ import annotations - -from typing import Sequence, cast -from types import ModuleType -from pathlib import Path -from textwrap import indent -import subprocess -from copy import copy - -from ...types import PsPointerType, PsType -from ...field import Field -from ...sympyextensions import DynamicType -from ..jit import KernelWrapper -from ...codegen import Kernel, Parameter -from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride - -from ..jit import JitError, JitBase -from .compiler_info import CompilerInfo, GccInfo - - -_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp" - - -class CpuJit(JitBase): - """Just-in-time compiler for CPU kernels.""" - - def __init__( - self, - compiler_info: CompilerInfo | None = None, - objcache: str | Path | None = None, - strict_scalar_types: bool = False, - ): - self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo() - self._strict_scalar_types = strict_scalar_types - - if objcache is None: - from appdirs import AppDirs - - dirs = AppDirs(appname="pystencils") - self._objcache = Path(dirs.user_cache_dir) / "cpujit" - else: - self._objcache = Path(objcache) - - # Include Directories - import pybind11 as pb11 - - pybind11_include = pb11.get_include() - - import sysconfig - - python_include = sysconfig.get_path("include") - - from ...include import get_pystencils_include_path - - pystencils_include = get_pystencils_include_path() - - self._cxx_fixed_flags = [ - "-shared", - "-fPIC", - f"-I{python_include}", - f"-I{pybind11_include}", - f"-I{pystencils_include}" - ] - - @property - def objcache(self) -> Path: - return self._objcache - - @property - def compiler_info(self) -> CompilerInfo: - return self._compiler_info - - @compiler_info.setter - def compiler_info(self, info: CompilerInfo): - self._compiler_info = info - - @property - def strict_scalar_types(self) -> bool: - """Enable or disable implicit type casts for scalar parameters. - - If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type. - """ - return self._strict_scalar_types - - @strict_scalar_types.setter - def strict_scalar_types(self, v: bool): - self._strict_scalar_types = v - - @property - def restrict_qualifier(self) -> str: - return self._restrict_qualifier - - @restrict_qualifier.setter - def restrict_qualifier(self, qual: str): - self._restrict_qualifier = qual - - def compile(self, kernel: Kernel) -> CpuJitKernelWrapper: - # Get the Code - module_name = f"{kernel.function_name}_jit" - modbuilder = KernelModuleBuilder(self, module_name) - cpp_code = modbuilder(kernel) - - # Get compiler information - import sysconfig - - so_abi = sysconfig.get_config_var("SOABI") - lib_suffix = f"{so_abi}.so" - - # Compute Code Hash - code_utf8 = cpp_code.encode("utf-8") - import hashlib - - code_hash = hashlib.sha256(code_utf8) - module_stem = f"module_{code_hash.hexdigest()}" - - module_dir = self._objcache - - # Lock module - import fasteners - - lockfile = module_dir / f"{module_stem}.lock" - with fasteners.InterProcessLock(lockfile): - cpp_file = module_dir / f"{module_stem}.cpp" - if not cpp_file.exists(): - cpp_file.write_bytes(code_utf8) - - lib_file = module_dir / f"{module_stem}.{lib_suffix}" - if not lib_file.exists(): - self._compile_extension_module(cpp_file, lib_file) - - module = self._load_extension_module(module_name, lib_file) - - return CpuJitKernelWrapper(kernel, module) - - def _compile_extension_module(self, src_file: Path, libfile: Path): - args = ( - [self._compiler_info.cxx()] - + self._cxx_fixed_flags - + self._compiler_info.cxxflags() - + ["-o", str(libfile), str(src_file)] - ) - - result = subprocess.run(args, capture_output=True) - if result.returncode != 0: - raise JitError( - "Compilation failed: C++ compiler terminated with an error.\n" - + result.stderr.decode() - ) - - def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType: - from importlib import util as iutil - - spec = iutil.spec_from_file_location(name=module_name, location=module_loc) - if spec is None: - raise JitError("Unable to load kernel extension module -- this is probably a bug.") - mod = iutil.module_from_spec(spec) - spec.loader.exec_module(mod) # type: ignore - return mod - - -class CpuJitKernelWrapper(KernelWrapper): - def __init__(self, kernel: Kernel, jit_module: ModuleType): - super().__init__(kernel) - self._module = jit_module - self._wrapper_func = getattr(jit_module, kernel.function_name) - - def __call__(self, **kwargs) -> None: - return self._wrapper_func(**kwargs) - - -class KernelModuleBuilder: - def __init__(self, jit: CpuJit, module_name: str): - self._jit = jit - self._module_name = module_name - - self._actual_field_types: dict[Field, PsType] = dict() - self._param_binds: list[str] = [] - self._public_params: list[str] = [] - self._extraction_lines: list[str] = [] - - def __call__(self, kernel: Kernel) -> str: - self._handle_params(kernel.parameters) - - kernel_def = self._get_kernel_definition(kernel) - kernel_args = [param.name for param in kernel.parameters] - includes = [f"#include {h}" for h in kernel.required_headers] - - from string import Template - - templ = Template(_module_template.read_text()) - code_str = templ.substitute( - includes="\n".join(includes), - restrict_qualifier=self._jit.compiler_info.restrict_qualifier(), - module_name=self._module_name, - kernel_name=kernel.function_name, - param_binds=", ".join(self._param_binds), - public_params=", ".join(self._public_params), - extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "), - kernel_args=", ".join(kernel_args), - kernel_definition=kernel_def, - ) - return code_str - - def _get_kernel_definition(self, kernel: Kernel) -> str: - from ...backend.emission import CAstPrinter - printer = CAstPrinter(func_prefix="inline") - - return printer(kernel) - - def _add_field_param(self, ptr_param: Parameter): - field: Field = ptr_param.fields[0] - - ptr_type = ptr_param.dtype - assert isinstance(ptr_type, PsPointerType) - - if isinstance(field.dtype, DynamicType): - elem_type = ptr_type.base_type - else: - elem_type = field.dtype - - self._actual_field_types[field] = elem_type - - param_bind = f'py::arg("{field.name}").noconvert()' - self._param_binds.append(param_bind) - - kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" - self._public_params.append(kernel_param) - - def _add_scalar_param(self, sc_param: Parameter): - param_bind = f'py::arg("{sc_param.name}")' - if self._jit.strict_scalar_types: - param_bind += ".noconvert()" - self._param_binds.append(param_bind) - - kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}" - self._public_params.append(kernel_param) - - def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr): - field_name = ptr_prop.field.name - assert isinstance(ptr_param.dtype, PsPointerType) - data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()" - extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};" - self._extraction_lines.append(extraction) - - def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape): - field_name = shape_prop.field.name - coord = shape_prop.coordinate - extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};" - self._extraction_lines.append(extraction) - - def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride): - field = stride_prop.field - field_name = field.name - coord = stride_prop.coordinate - field_type = self._actual_field_types[field] - assert field_type.itemsize is not None - extraction = ( - f"{stride_param.dtype.c_string()} {stride_param.name} " - f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};" - ) - self._extraction_lines.append(extraction) - - def _handle_params(self, parameters: Sequence[Parameter]): - for param in parameters: - if param.get_properties(FieldBasePtr): - self._add_field_param(param) - - for param in parameters: - if ptr_props := param.get_properties(FieldBasePtr): - self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop())) - elif shape_props := param.get_properties(FieldShape): - self._extract_shape(param, cast(FieldShape, shape_props.pop())) - elif stride_props := param.get_properties(FieldStride): - self._extract_stride(param, cast(FieldStride, stride_props.pop())) - else: - self._add_scalar_param(param) diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c39529ae2519f2fbab45fc60d4c522b7b09ec2 --- /dev/null +++ b/src/pystencils/jit/cpu/cpujit.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from types import ModuleType +from pathlib import Path +import subprocess +from copy import copy +from abc import ABC, abstractmethod + +from ...codegen.config import _AUTO_TYPE, AUTO + +from ..jit import JitError, JitBase, KernelWrapper +from ...codegen import Kernel +from .compiler_info import CompilerInfo, GccInfo + + +class CpuJit(JitBase): + """Just-in-time compiler for CPU kernels. + + :Creation: + In most cases, objects of this class should be instantiated using the `create` factory method. + + :Implementation Details: + + The `CpuJit` combines two separate components: + - The *extension module builder* produces the code of the dynamically built extension module containing + the kernel + - The *compiler info* describes the system compiler used to compile and link that extension module. + + Both can be dynamically exchanged. + """ + + @staticmethod + def create( + compiler_info: CompilerInfo | None = None, + objcache: str | Path | _AUTO_TYPE | None = AUTO, + ): + if objcache is AUTO: + from appdirs import AppDirs + + dirs = AppDirs(appname="pystencils") + objcache = Path(dirs.user_cache_dir) / "cpujit" + elif objcache is not None: + assert not isinstance(objcache, _AUTO_TYPE) + objcache = Path(objcache) + + if compiler_info is None: + compiler_info = GccInfo() + + from .cpujit_pybind11 import Pybind11KernelModuleBuilder + + modbuilder = Pybind11KernelModuleBuilder(compiler_info) + + return CpuJit(compiler_info, modbuilder, objcache) + + def __init__( + self, + compiler_info: CompilerInfo, + ext_module_builder: ExtensionModuleBuilderBase, + objcache: Path | None, + ): + self._compiler_info = copy(compiler_info) + self._ext_module_builder = ext_module_builder + + self._objcache = objcache + + # Include Directories + + import sysconfig + + python_include = sysconfig.get_path("include") + + from ...include import get_pystencils_include_path + + pystencils_include = get_pystencils_include_path() + + self._cxx_fixed_flags = [ + "-shared", + "-fPIC", + f"-I{python_include}", + f"-I{pystencils_include}", + ] + self._ext_module_builder.include_flags() + + @property + def objcache(self) -> Path | None: + return self._objcache + + @property + def compiler_info(self) -> CompilerInfo: + return self._compiler_info + + @compiler_info.setter + def compiler_info(self, info: CompilerInfo): + self._compiler_info = info + + @property + def strict_scalar_types(self) -> bool: + """Enable or disable implicit type casts for scalar parameters. + + If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type. + """ + return self._strict_scalar_types + + @strict_scalar_types.setter + def strict_scalar_types(self, v: bool): + self._strict_scalar_types = v + + def compile(self, kernel: Kernel) -> CpuJitKernelWrapper: + # Get the Code + module_name = f"{kernel.function_name}_jit" + cpp_code = self._ext_module_builder(kernel, module_name) + + # Get compiler information + import sysconfig + + so_abi = sysconfig.get_config_var("SOABI") + lib_suffix = f"{so_abi}.so" + + # Compute Code Hash + code_utf8 = cpp_code.encode("utf-8") + import hashlib + + code_hash = hashlib.sha256(code_utf8) + module_stem = f"module_{code_hash.hexdigest()}" + + def compile_and_load(module_dir: Path): + cpp_file = module_dir / f"{module_stem}.cpp" + if not cpp_file.exists(): + cpp_file.write_bytes(code_utf8) + + lib_file = module_dir / f"{module_stem}.{lib_suffix}" + if not lib_file.exists(): + self._compile_extension_module(cpp_file, lib_file) + + module = self._load_extension_module(module_name, lib_file) + return module + + if self._objcache is not None: + module_dir = self._objcache + # Lock module + import fasteners + + lockfile = module_dir / f"{module_stem}.lock" + with fasteners.InterProcessLock(lockfile): + module = compile_and_load(module_dir) + else: + from tempfile import TemporaryDirectory + + with TemporaryDirectory() as tmpdir: + module_dir = Path(tmpdir) + module = compile_and_load(module_dir) + + return CpuJitKernelWrapper(kernel, module) + + def _compile_extension_module(self, src_file: Path, libfile: Path): + args = ( + [self._compiler_info.cxx()] + + self._cxx_fixed_flags + + self._compiler_info.cxxflags() + + ["-o", str(libfile), str(src_file)] + ) + + result = subprocess.run(args, capture_output=True) + if result.returncode != 0: + raise JitError( + "Compilation failed: C++ compiler terminated with an error.\n" + + result.stderr.decode() + ) + + def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType: + from importlib import util as iutil + + spec = iutil.spec_from_file_location(name=module_name, location=module_loc) + if spec is None: + raise JitError( + "Unable to load kernel extension module -- this is probably a bug." + ) + mod = iutil.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore + return mod + + +class CpuJitKernelWrapper(KernelWrapper): + def __init__(self, kernel: Kernel, jit_module: ModuleType): + super().__init__(kernel) + self._module = jit_module + self._wrapper_func = getattr(jit_module, kernel.function_name) + + def __call__(self, **kwargs) -> None: + return self._wrapper_func(**kwargs) + + +class ExtensionModuleBuilderBase(ABC): + @staticmethod + @abstractmethod + def include_flags() -> list[str]: ... + + @abstractmethod + def __call__(self, kernel: Kernel, module_name: str) -> str: ... diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py new file mode 100644 index 0000000000000000000000000000000000000000..aee2e9f996d3147fd06435aaac60bbcb728083a2 --- /dev/null +++ b/src/pystencils/jit/cpu/cpujit_pybind11.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from typing import Sequence, cast +from pathlib import Path +from textwrap import indent + +from ...types import PsPointerType, PsType +from ...field import Field +from ...sympyextensions import DynamicType +from ...codegen import Kernel, Parameter +from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride + +from .compiler_info import CompilerInfo +from .cpujit import ExtensionModuleBuilderBase + + +_module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp" + + +class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): + @staticmethod + def include_flags() -> list[str]: + import pybind11 as pb11 + + pybind11_include = pb11.get_include() + return [f"-I{pybind11_include}"] + + def __init__( + self, + compiler_info: CompilerInfo, + strict_scalar_types: bool = False, + ): + self._compiler_info = compiler_info + + self._strict_scalar_types = strict_scalar_types + + self._actual_field_types: dict[Field, PsType] + self._param_binds: list[str] + self._public_params: list[str] + self._extraction_lines: list[str] + + def __call__(self, kernel: Kernel, module_name: str) -> str: + self._actual_field_types = dict() + self._param_binds = [] + self._public_params = [] + self._extraction_lines = [] + + self._handle_params(kernel.parameters) + + kernel_def = self._get_kernel_definition(kernel) + kernel_args = [param.name for param in kernel.parameters] + includes = [f"#include {h}" for h in kernel.required_headers] + + from string import Template + + templ = Template(_module_template.read_text()) + code_str = templ.substitute( + includes="\n".join(includes), + restrict_qualifier=self._compiler_info.restrict_qualifier(), + module_name=module_name, + kernel_name=kernel.function_name, + param_binds=", ".join(self._param_binds), + public_params=", ".join(self._public_params), + extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "), + kernel_args=", ".join(kernel_args), + kernel_definition=kernel_def, + ) + return code_str + + def _get_kernel_definition(self, kernel: Kernel) -> str: + from ...backend.emission import CAstPrinter + + printer = CAstPrinter(func_prefix="inline") + + return printer(kernel) + + def _add_field_param(self, ptr_param: Parameter): + field: Field = ptr_param.fields[0] + + ptr_type = ptr_param.dtype + assert isinstance(ptr_type, PsPointerType) + + if isinstance(field.dtype, DynamicType): + elem_type = ptr_type.base_type + else: + elem_type = field.dtype + + self._actual_field_types[field] = elem_type + + param_bind = f'py::arg("{field.name}").noconvert()' + self._param_binds.append(param_bind) + + kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" + self._public_params.append(kernel_param) + + def _add_scalar_param(self, sc_param: Parameter): + param_bind = f'py::arg("{sc_param.name}")' + if self._strict_scalar_types: + param_bind += ".noconvert()" + self._param_binds.append(param_bind) + + kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}" + self._public_params.append(kernel_param) + + def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr): + field_name = ptr_prop.field.name + assert isinstance(ptr_param.dtype, PsPointerType) + data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()" + extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};" + self._extraction_lines.append(extraction) + + def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape): + field_name = shape_prop.field.name + coord = shape_prop.coordinate + extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};" + self._extraction_lines.append(extraction) + + def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride): + field = stride_prop.field + field_name = field.name + coord = stride_prop.coordinate + field_type = self._actual_field_types[field] + assert field_type.itemsize is not None + extraction = ( + f"{stride_param.dtype.c_string()} {stride_param.name} " + f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};" + ) + self._extraction_lines.append(extraction) + + def _handle_params(self, parameters: Sequence[Parameter]): + for param in parameters: + if param.get_properties(FieldBasePtr): + self._add_field_param(param) + + for param in parameters: + if ptr_props := param.get_properties(FieldBasePtr): + self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop())) + elif shape_props := param.get_properties(FieldShape): + self._extract_shape(param, cast(FieldShape, shape_props.pop())) + elif stride_props := param.get_properties(FieldStride): + self._extract_stride(param, cast(FieldStride, stride_props.pop())) + else: + self._add_scalar_param(param) diff --git a/src/pystencils/jit/cpu/kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp similarity index 100% rename from src/pystencils/jit/cpu/kernel_module.tmpl.cpp rename to src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py index aa4f50f3fca04e0d7177f8d623a44f1eb7f50329..0e17fa6a14c66d104eea1e3a40366e8664040f59 100644 --- a/tests/jit/test_cpujit.py +++ b/tests/jit/test_cpujit.py @@ -5,7 +5,7 @@ from pystencils.jit import CpuJit def test_basic_cpu_kernel(tmp_path): - jit = CpuJit(objcache=tmp_path) + jit = CpuJit.create(objcache=tmp_path) f, g = fields("f, g: [2D]") asm = Assignment(f.center(), 2.0 * g.center())