diff --git a/mypy.ini b/mypy.ini index cc23a503a2da6c9849d3a41e82fe8ceb8de13b43..a0533a60adfd21abdd8fc57126af5511ba83d381 100644 --- a/mypy.ini +++ b/mypy.ini @@ -31,3 +31,6 @@ ignore_missing_imports=true [mypy-cpuinfo.*] ignore_missing_imports=true + +[mypy-fasteners.*] +ignore_missing_imports=true diff --git a/pyproject.toml b/pyproject.toml index 59e71b8db2d6156c25aebfcdeb88af5652dacc8e..b3c6b1c0238654bedd1ba90b09c8d98541e1c73f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ ] license = { file = "COPYING.txt" } requires-python = ">=3.10" -dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"] +dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml", "pybind11", "fasteners"] classifiers = [ "Development Status :: 4 - Beta", "Framework :: Jupyter", @@ -90,6 +90,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools.package-data] pystencils = [ "include/*.h", + "jit/cpu/*.tmpl.cpp", "boundaries/createindexlistcython.pyx" ] diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py index a4358bbf328b65aaf5e45eff5a2083ef067285a6..adb9c232b9c927a4094f54a77fa99669b99f141d 100644 --- a/src/pystencils/backend/emission/base_printer.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -3,8 +3,6 @@ from enum import Enum from abc import ABC, abstractmethod from typing import TYPE_CHECKING -from ...codegen import Target - from ..ast.structural import ( PsAstNode, PsBlock, @@ -171,8 +169,9 @@ class BasePrinter(ABC): and in `IRAstPrinter` for debug-printing the entire IR. """ - def __init__(self, indent_width=3): + def __init__(self, indent_width=3, func_prefix: str | None = None): self._indent_width = indent_width + self._func_prefix = func_prefix def __call__(self, obj: PsAstNode | Kernel) -> str: from ...codegen import Kernel @@ -376,21 +375,14 @@ class BasePrinter(ABC): ) def print_signature(self, func: Kernel) -> str: - prefix = self._func_prefix(func) + prefix = self._func_prefix params_str = ", ".join( f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters ) - signature = " ".join([prefix, "void", func.name, f"({params_str})"]) + sig_parts = ([prefix] if prefix is not None else []) + ["void", func.name, f"({params_str})"] + signature = " ".join(sig_parts) return signature - def _func_prefix(self, func: Kernel): - from ...codegen import GpuKernel - - if isinstance(func, GpuKernel) and func.target == Target.CUDA: - return "__global__" - else: - return "FUNC_PREFIX" - @abstractmethod def _symbol_decl(self, symb: PsSymbol) -> str: pass diff --git a/src/pystencils/backend/emission/c_printer.py b/src/pystencils/backend/emission/c_printer.py index 90a7e54e22b3eb14866c9260c85247baf8b4f340..40cd692836117d48a0ab6f955681085c90fa0b86 100644 --- a/src/pystencils/backend/emission/c_printer.py +++ b/src/pystencils/backend/emission/c_printer.py @@ -21,10 +21,6 @@ def emit_code(ast: PsAstNode | Kernel): class CAstPrinter(BasePrinter): - - def __init__(self, indent_width=3): - super().__init__(indent_width) - def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: match node: case PsVecMemAcc(): diff --git a/src/pystencils/jit/__init__.py b/src/pystencils/jit/__init__.py index 1ef8378d3000e95b12bb6a3a17062fb6488e1729..3ae63fa721a4a70340bf7dd88f5a203fb6c2da66 100644 --- a/src/pystencils/jit/__init__.py +++ b/src/pystencils/jit/__init__.py @@ -24,6 +24,7 @@ It is due to be replaced in the near future. from .jit import JitBase, NoJit, KernelWrapper from .legacy_cpu import LegacyCpuJit +from .cpu import CpuJit from .gpu_cupy import CupyJit, CupyKernelWrapper, LaunchGrid no_jit = NoJit() @@ -33,6 +34,7 @@ __all__ = [ "JitBase", "KernelWrapper", "LegacyCpuJit", + "CpuJit", "NoJit", "no_jit", "CupyJit", diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..69985c91420859c4071dc34bfa601eebcb472a30 --- /dev/null +++ b/src/pystencils/jit/cpu/__init__.py @@ -0,0 +1,5 @@ +from .cpu_pybind11 import CpuJit + +__all__ = [ + "CpuJit" +] diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py new file mode 100644 index 0000000000000000000000000000000000000000..128313a74d63cb7ea81e425da23185a617e83be7 --- /dev/null +++ b/src/pystencils/jit/cpu/cpu_pybind11.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from typing import Sequence, cast +from types import ModuleType +from pathlib import Path +from textwrap import indent +import subprocess + +from ...types import PsPointerType, PsType +from ...field import Field +from ...sympyextensions import DynamicType +from ..jit import KernelWrapper +from ...codegen import Kernel, Parameter +from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride + +from ..jit import JitError, JitBase + + +_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp" + + +class CpuJit(JitBase): + """Just-in-time compiler for CPU kernels.""" + + def __init__( + self, + cxx: str = "g++", + cxxflags: Sequence[str] = ("-O3", "-fopenmp"), + objcache: str | Path | None = None, + strict_scalar_types: bool = False, + ): + self._cxx = cxx + self._cxxflags: list[str] = list(cxxflags) + + self._strict_scalar_types = strict_scalar_types + self._restrict_qualifier = "__restrict__" + + if objcache is None: + from appdirs import AppDirs + + dirs = AppDirs(appname="pystencils") + self._objcache = Path(dirs.user_cache_dir) / "cpujit" + else: + self._objcache = Path(objcache) + + # Include Directories + import pybind11 as pb11 + + self._pybind11_include = pb11.get_include() + + import sysconfig + + self._py_include = sysconfig.get_path("include") + + self._cxx_fixed_flags = [ + "-shared", + "-fPIC", + f"-I{self._py_include}", + f"-I{self._pybind11_include}", + ] + + @property + def objcache(self) -> Path: + return self._objcache + + @property + def cxx(self) -> str: + return self._cxx + + @cxx.setter + def cxx(self, path: str): + self._cxx = path + + @property + def cxxflags(self) -> list[str]: + return self._cxxflags + + @cxxflags.setter + def cxxflags(self, flags: Sequence[str]): + self._cxxflags = list(flags) + + @property + def strict_scalar_types(self) -> bool: + """Enable or disable implicit type casts for scalar parameters. + + If `True`, values for scalar kernel parameters must always be provided with the correct NumPy type. + """ + return self._strict_scalar_types + + @strict_scalar_types.setter + def strict_scalar_types(self, v: bool): + self._strict_scalar_types = v + + @property + def restrict_qualifier(self) -> str: + return self._restrict_qualifier + + @restrict_qualifier.setter + def restrict_qualifier(self, qual: str): + self._restrict_qualifier = qual + + def compile(self, kernel: Kernel) -> CpuJitKernelWrapper: + # Get the Code + module_name = f"{kernel.function_name}_jit" + modbuilder = KernelModuleBuilder(self, module_name) + cpp_code = modbuilder(kernel) + + # Get compiler information + import sysconfig + + so_abi = sysconfig.get_config_var("SOABI") + lib_suffix = f"{so_abi}.so" + + # Compute Code Hash + code_utf8 = cpp_code.encode("utf-8") + import hashlib + + code_hash = hashlib.sha256(code_utf8) + module_stem = f"module_{code_hash.hexdigest()}" + + module_dir = self._objcache + + # Lock module + import fasteners + + lockfile = module_dir / f"{module_stem}.lock" + with fasteners.InterProcessLock(lockfile): + cpp_file = module_dir / f"{module_stem}.cpp" + if not cpp_file.exists(): + cpp_file.write_bytes(code_utf8) + + lib_file = module_dir / f"{module_stem}.{lib_suffix}" + if not lib_file.exists(): + self._compile_extension_module(cpp_file, lib_file) + + module = self._load_extension_module(module_name, lib_file) + + return CpuJitKernelWrapper(kernel, module) + + def _compile_extension_module(self, src_file: Path, libfile: Path): + args = ( + [self._cxx] + + self._cxx_fixed_flags + + self._cxxflags + + ["-o", str(libfile), str(src_file)] + ) + + result = subprocess.run(args, capture_output=True) + if result.returncode != 0: + raise JitError( + "Compilation failed: C++ compiler terminated with an error.\n" + + result.stderr.decode() + ) + + def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType: + from importlib import util as iutil + + spec = iutil.spec_from_file_location(name=module_name, location=module_loc) + if spec is None: + raise JitError("Unable to load kernel extension module -- this is probably a bug.") + mod = iutil.module_from_spec(spec) + spec.loader.exec_module(mod) # type: ignore + return mod + + +class CpuJitKernelWrapper(KernelWrapper): + def __init__(self, kernel: Kernel, jit_module: ModuleType): + super().__init__(kernel) + self._module = jit_module + self._wrapper_func = getattr(jit_module, kernel.function_name) + + def __call__(self, **kwargs) -> None: + return self._wrapper_func(**kwargs) + + +class KernelModuleBuilder: + def __init__(self, jit: CpuJit, module_name: str): + self._jit = jit + self._module_name = module_name + + self._actual_field_types: dict[Field, PsType] = dict() + self._param_binds: list[str] = [] + self._public_params: list[str] = [] + self._extraction_lines: list[str] = [] + + def __call__(self, kernel: Kernel) -> str: + self._handle_params(kernel.parameters) + + kernel_def = self._get_kernel_definition(kernel) + kernel_args = [param.name for param in kernel.parameters] + includes = [f"#include {h}" for h in kernel.required_headers] + + from string import Template + + templ = Template(_module_template.read_text()) + code_str = templ.substitute( + includes="\n".join(includes), + restrict_qualifier=self._jit.restrict_qualifier, + module_name=self._module_name, + kernel_name=kernel.function_name, + param_binds=", ".join(self._param_binds), + public_params=", ".join(self._public_params), + extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "), + kernel_args=", ".join(kernel_args), + kernel_definition=kernel_def, + ) + return code_str + + def _get_kernel_definition(self, kernel: Kernel) -> str: + from ...backend.emission import CAstPrinter + printer = CAstPrinter(func_prefix="inline") + + return printer(kernel) + + def _add_field_param(self, ptr_param: Parameter): + field: Field = ptr_param.fields[0] + + ptr_type = ptr_param.dtype + assert isinstance(ptr_type, PsPointerType) + + if isinstance(field.dtype, DynamicType): + elem_type = ptr_type.base_type + else: + elem_type = field.dtype + + self._actual_field_types[field] = elem_type + + param_bind = f'py::arg("{field.name}").noconvert()' + self._param_binds.append(param_bind) + + kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" + self._public_params.append(kernel_param) + + def _add_scalar_param(self, sc_param: Parameter): + param_bind = f'py::arg("{sc_param.name}")' + if self._jit.strict_scalar_types: + param_bind += ".noconvert()" + self._param_binds.append(param_bind) + + kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}" + self._public_params.append(kernel_param) + + def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr): + field_name = ptr_prop.field.name + assert isinstance(ptr_param.dtype, PsPointerType) + data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()" + extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};" + self._extraction_lines.append(extraction) + + def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape): + field_name = shape_prop.field.name + coord = shape_prop.coordinate + extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};" + self._extraction_lines.append(extraction) + + def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride): + field = stride_prop.field + field_name = field.name + coord = stride_prop.coordinate + field_type = self._actual_field_types[field] + assert field_type.itemsize is not None + extraction = ( + f"{stride_param.dtype.c_string()} {stride_param.name} " + f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};" + ) + self._extraction_lines.append(extraction) + + def _handle_params(self, parameters: Sequence[Parameter]): + for param in parameters: + if param.get_properties(FieldBasePtr): + self._add_field_param(param) + + for param in parameters: + if ptr_props := param.get_properties(FieldBasePtr): + self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop())) + elif shape_props := param.get_properties(FieldShape): + self._extract_shape(param, cast(FieldShape, shape_props.pop())) + elif stride_props := param.get_properties(FieldStride): + self._extract_stride(param, cast(FieldStride, stride_props.pop())) + else: + self._add_scalar_param(param) diff --git a/src/pystencils/jit/cpu/kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/kernel_module.tmpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ee5c6973a761a6acbabf7864df947de38e4289c --- /dev/null +++ b/src/pystencils/jit/cpu/kernel_module.tmpl.cpp @@ -0,0 +1,23 @@ +#include "pybind11/pybind11.h" +#include "pybind11/numpy.h" + +${includes} + +namespace py = pybind11; + +#define RESTRICT ${restrict_qualifier} + +namespace internal { + +${kernel_definition} + +} + +void callwrapper_${kernel_name} (${public_params}) { +${extraction_lines} + internal::${kernel_name}(${kernel_args}); +} + +PYBIND11_MODULE(${module_name}, m) { + m.def("${kernel_name}", &callwrapper_${kernel_name}, py::kw_only(), ${param_binds}); +} diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 7431416c9eb9bcd4433dab76c32fb1b755501105..d01dbe57e1e463042681db82a465208efb6a7512 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,5 +1,5 @@ from .astnodes import ConditionalFieldAccess -from .typed_sympy import TypedSymbol, CastFunc +from .typed_sympy import TypedSymbol, CastFunc, DynamicType from .pointers import mem_acc from .math import ( @@ -61,5 +61,6 @@ __all__ = [ "count_operations_in_ast", "common_denominator", "get_symmetric_part", - "SymbolCreator" + "SymbolCreator", + "DynamicType" ] diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py new file mode 100644 index 0000000000000000000000000000000000000000..0b09d66b01a1b935665713e3507dbb8396ad777c --- /dev/null +++ b/tests/jit/test_cpujit.py @@ -0,0 +1,21 @@ +import sympy as sp +import numpy as np +from pystencils import create_kernel, Assignment, fields +from pystencils.jit import CpuJit + + +def test_basic_cpu_kernel(tmp_path): + jit = CpuJit("g++", ["-O3"], objcache=tmp_path) + + f, g = fields("f, g: [2D]") + asm = Assignment(f.center(), 2.0 * g.center()) + ker = create_kernel(asm) + kfunc = jit.compile(ker) + + rng = np.random.default_rng() + f_arr = rng.random(size=(34, 26), dtype="float64") + g_arr = np.zeros_like(f_arr) + + kfunc(f=f_arr, g=g_arr) + + np.testing.assert_almost_equal(g_arr, 2.0 * f_arr)