Skip to content
Snippets Groups Projects
Commit fa2683ce authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Object-Oriented CPU JIT API and Prototype Implementation

parent 929cef16
Branches
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Showing
with 813 additions and 90 deletions
......@@ -313,6 +313,15 @@ typecheck:
- docker
- AVX
"testsuite-experimental-jit-py3.10":
extends: .testsuite-base
image: i10git.cs.fau.de:5005/pycodegen/pycodegen/nox:alpine
script:
- nox -s "testsuite-3.10(cpu)" -- --experimental-cpu-jit
tags:
- docker
- AVX
# -------------------- Documentation ---------------------------------------------------------------------
......
......@@ -3,6 +3,7 @@ import runpy
import sys
import tempfile
import warnings
import pathlib
import nbformat
import pytest
......@@ -185,24 +186,10 @@ class IPyNbFile(pytest.File):
pass
if pytest_version >= 70000:
# Since pytest 7.0, usage of `py.path.local` is deprecated and `pathlib.Path` should be used instead
import pathlib
def pytest_collect_file(file_path: pathlib.Path, parent):
glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
if any(file_path.match(g) for g in glob_exprs):
return IPyNbFile.from_parent(path=file_path, parent=parent)
else:
def pytest_collect_file(path, parent):
glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
if any(path.fnmatch(g) for g in glob_exprs):
if pytest_version >= 50403:
return IPyNbFile.from_parent(fspath=path, parent=parent)
else:
return IPyNbFile(path, parent)
def pytest_collect_file(file_path: pathlib.Path, parent):
glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
if any(file_path.match(g) for g in glob_exprs):
return IPyNbFile.from_parent(path=file_path, parent=parent)
# Fixtures
......
# JIT Compilation
## Base Infrastructure
```{eval-rst}
.. module:: pystencils.jit
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
KernelWrapper
JitBase
NoJit
.. autodata:: no_jit
```
## Legacy CPU JIT
The legacy CPU JIT Compiler is a leftover from pystencils 1.3
which at the moment still drives most CPU JIT-compilation within the package,
until the new JIT compiler is ready to take over.
```{eval-rst}
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
LegacyCpuJit
```
## CPU Just-In-Time Compiler
:::{note}
The new CPU JIT compiler is still considered experimental and not yet adopted by most of pystencils.
While the APIs described here will (probably) become the default for pystencils 2.0
and can (and should) already be used for testing,
the current implementation is still *very slow*.
For more information, see [issue !120](https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/120).
:::
To configure and create an instance of the CPU JIT compiler, use the `CpuJit.create` factory method:
:::{card}
```{eval-rst}
.. autofunction:: pystencils.jit.CpuJit.create
:no-index:
```
:::
### Compiler Infos
The CPU JIT compiler invokes a host C++ compiler to compile and link a Python extension
module containing the generated kernel.
The properties of the host compiler are defined in a `CompilerInfo` object.
To select a custom host compiler and customize its options, set up and pass
a custom compiler info object to `CpuJit.create`.
```{eval-rst}
.. module:: pystencils.jit.cpu.compiler_info
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
CompilerInfo
GccInfo
ClangInfo
```
### Implementation
```{eval-rst}
.. module:: pystencils.jit.cpu
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
CpuJit
cpujit.ExtensionModuleBuilderBase
```
## CuPy-based GPU JIT
```{eval-rst}
.. module:: pystencils.jit.gpu_cupy
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
CupyJit
CupyKernelWrapper
LaunchGrid
```
JIT Compilation
===============
.. module:: pystencils.jit
Base Infrastructure
-------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
KernelWrapper
JitBase
NoJit
.. autodata:: no_jit
Legacy CPU JIT
--------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
LegacyCpuJit
CuPy-based GPU JIT
------------------
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/entire_class.rst
CupyJit
CupyKernelWrapper
LaunchGrid
......@@ -127,18 +127,16 @@ If you think a new module is ready to be type-checked, add an exception clause t
## Running the Test Suite
Pystencils comes with an extensive and steadily growing suite of unit tests.
To run the testsuite, you may invoke a variant of the Nox `testsuite` session.
There are multiple different versions of the `testsuite` session, depending on whether you are testing with our
without CUDA, or which version of Python you wish to test with.
You can list the available sessions using `nox -l`.
Select one of the `testsuite` variants and run it via `nox -s "testsuite(<variant>)"`, e.g.
```
nox -s "testsuite(cpu)"
To run the full testsuite, invoke the Nox `testsuite` session:
```bash
nox -s testsuite
```
for the CPU-only suite.
During the testsuite run, coverage information is collected and displayed using [coverage.py](https://coverage.readthedocs.io/en/7.6.10/).
You can display a detailed overview of code coverage by opening the generated `htmlcov/index.html` page.
:::{seealso}
[](#testing_pystencils)
:::
## Building the Documentation
......
......@@ -7,6 +7,7 @@ Pystencils is an open-source package licensed under the [AGPL v3](https://www.gn
As such, the act of contributing to pystencils by submitting a merge request is taken as agreement to the terms of the licence.
:::{toctree}
:maxdepth: 2
:maxdepth: 2
dev-workflow
testing
:::
(testing_pystencils)=
# Testing pystencils
The pystencils testsuite is located at the `tests` directory,
constructed using [pytest](https://pytest.org),
and automated through [Nox](https://nox.thea.codes).
On this page, you will find instructions on how to execute and extend it.
## Running the Testsuite
The fastest way to execute the pystencils test suite is through the `testsuite` Nox session:
```bash
nox -s testsuite
```
There exist several configurations of the testsuite session, from which the above command will
select and execute only those that are available on your machine.
- *Python Versions:* The testsuite session can be run against all major Python versions between 3.10 and 3.13 (inclusive).
To only use a specific Python version, add the `-p 3.XX` argument to your Nox invocation; e.g. `nox -s testsuite -p 3.11`.
- *CuPy:* There exist three variants of `testsuite`, including or excluding tests for the CUDA GPU target: `cpu`, `cupy12` and `cupy13`.
To select one, append `(<variant>)` to the session name; e.g. `nox -s "testsuite(cupy12)"`.
You may also pass options through to pytest via positional arguments after a pair of dashes, e.g.:
```bash
nox -s testsuite -- -k "kernelcreation"
```
During the testsuite run, coverage information is collected using [coverage.py](https://coverage.readthedocs.io/en/7.6.10/),
and the results are exported to HTML.
You can display a detailed overview of code coverage by opening the generated `htmlcov/index.html` page.
## Extending the Test Suite
### Codegen Configurations via Fixtures
In the pystencils test suite, it is often necessary to test code generation features against multiple targets.
To simplify this process, we provide a number of [pytest fixtures](https://docs.pytest.org/en/stable/how-to/fixtures.html)
you can and should use in your tests:
- `target`: Provides code generation targets for your test.
Using this fixture will make pytest create a copy of your test for each target
available on the current machine (see {any}`Target.available_targets`).
- `gen_config`: Provides default code generation configurations for your test.
This fixture depends on `target` and provides a {any}`CreateKernelConfig` instance
with target-specific optimization options (in particular vectorization) enabled.
- `xp`: The `xp` fixture gives you either the *NumPy* (`np`) or the *CuPy* (`cp`) module,
depending on whether `target` is a CPU or GPU target.
These fixtures are defined in `tests/fixtures.py`.
### Overriding Fixtures
Pytest allows you to locally override fixtures, which can be especially handy when you wish
to restrict the target selection of a test.
For example, the following test overrides `target` using a parametrization mark,
and uses this in combination with the `gen_config` fixture, which now
receives the overridden `target` parameter as input:
```Python
@pytest.mark.parametrize("target", [Target.X86_SSE, Target.X86_AVX])
def test_bogus(gen_config):
assert gen_config.target.is_vector_cpu()
```
## Testing with the Experimental CPU JIT
Currently, the testsuite by default still uses the {any}`legacy CPU JIT compiler <LegacyCpuJit>`,
since the new CPU JIT compiler is still in an experimental stage.
To test your code against the new JIT compiler, pass the `--experimental-cpu-jit` option to pytest:
```bash
nox -s testsuite -- --experimental-cpu-jit
```
This will alter the `gen_config` fixture, activating the experimental CPU JIT for CPU targets.
......@@ -37,3 +37,6 @@ ignore_missing_imports=true
[mypy-cpuinfo.*]
ignore_missing_imports=true
[mypy-fasteners.*]
ignore_missing_imports=true
......@@ -86,9 +86,15 @@ def typecheck(session: nox.Session):
session.run("mypy", "src/pystencils")
@nox.session(python=["3.10", "3.12", "3.13"], tags=["test"])
@nox.session(python=["3.10", "3.11", "3.12", "3.13"], tags=["test"])
@nox.parametrize("cupy_version", [None, "12", "13"], ids=["cpu", "cupy12", "cupy13"])
def testsuite(session: nox.Session, cupy_version: str | None):
"""Run the pystencils test suite.
**Positional Arguments:** Any positional arguments passed to nox after `--`
are propagated to pytest.
"""
if cupy_version is not None:
install_cupy(session, cupy_version, skip_if_no_cuda=True)
......@@ -108,6 +114,7 @@ def testsuite(session: nox.Session, cupy_version: str | None):
"--html",
"test-report/index.html",
"--junitxml=report.xml",
*session.posargs
)
session.run("coverage", "html")
session.run("coverage", "xml")
......
......@@ -12,7 +12,7 @@ authors = [
]
license = { file = "COPYING.txt" }
requires-python = ">=3.10"
dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml", "pybind11", "fasteners"]
classifiers = [
"Development Status :: 4 - Beta",
"Framework :: Jupyter",
......@@ -90,6 +90,7 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.package-data]
pystencils = [
"include/*.h",
"jit/cpu/*.tmpl.cpp",
"boundaries/createindexlistcython.pyx"
]
......
......@@ -3,8 +3,6 @@ from enum import Enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from ...codegen import Target
from ..ast.structural import (
PsAstNode,
PsBlock,
......@@ -59,6 +57,7 @@ from ..extensions.foreign_ast import PsForeignExpression
from ..memory import PsSymbol
from ..constants import PsConstant
from ...types import PsType
from ...codegen import Target
if TYPE_CHECKING:
from ...codegen import Kernel
......@@ -171,8 +170,9 @@ class BasePrinter(ABC):
and in `IRAstPrinter` for debug-printing the entire IR.
"""
def __init__(self, indent_width=3):
def __init__(self, indent_width=3, func_prefix: str | None = None):
self._indent_width = indent_width
self._func_prefix = func_prefix
def __call__(self, obj: PsAstNode | Kernel) -> str:
from ...codegen import Kernel
......@@ -376,20 +376,18 @@ class BasePrinter(ABC):
)
def print_signature(self, func: Kernel) -> str:
prefix = self._func_prefix(func)
params_str = ", ".join(
f"{self._type_str(p.dtype)} {p.name}" for p in func.parameters
)
signature = " ".join([prefix, "void", func.name, f"({params_str})"])
return signature
def _func_prefix(self, func: Kernel):
from ...codegen import GpuKernel
sig_parts = [self._func_prefix] if self._func_prefix is not None else []
if isinstance(func, GpuKernel) and func.target == Target.CUDA:
return "__global__"
else:
return "FUNC_PREFIX"
sig_parts.append("__global__")
sig_parts += ["void", func.name, f"({params_str})"]
signature = " ".join(sig_parts)
return signature
@abstractmethod
def _symbol_decl(self, symb: PsSymbol) -> str:
......
......@@ -21,10 +21,6 @@ def emit_code(ast: PsAstNode | Kernel):
class CAstPrinter(BasePrinter):
def __init__(self, indent_width=3):
super().__init__(indent_width)
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node:
case PsVecMemAcc():
......
from __future__ import annotations
from typing import TYPE_CHECKING
from warnings import warn
from abc import ABC
from collections.abc import Collection
from typing import Sequence, Generic, TypeVar, Callable, Any, cast
from typing import TYPE_CHECKING, Sequence, Generic, TypeVar, Callable, Any, cast
from dataclasses import dataclass, InitVar, fields
from .target import Target
......@@ -208,7 +207,9 @@ class Category(Generic[Category_T]):
setattr(obj, self._lookup, cat.copy() if cat is not None else None)
class _AUTO_TYPE: ... # noqa: E701
class _AUTO_TYPE:
def __repr__(self) -> str:
return "AUTO" # for pretty-printing in the docs
AUTO = _AUTO_TYPE()
......
......@@ -24,6 +24,7 @@ It is due to be replaced in the near future.
from .jit import JitBase, NoJit, KernelWrapper
from .legacy_cpu import LegacyCpuJit
from .cpu import CpuJit
from .gpu_cupy import CupyJit, CupyKernelWrapper, LaunchGrid
no_jit = NoJit()
......@@ -33,6 +34,7 @@ __all__ = [
"JitBase",
"KernelWrapper",
"LegacyCpuJit",
"CpuJit",
"NoJit",
"no_jit",
"CupyJit",
......
from .compiler_info import GccInfo, ClangInfo
from .cpujit import CpuJit
__all__ = [
"GccInfo",
"ClangInfo",
"CpuJit"
]
from __future__ import annotations
from typing import Sequence
from abc import ABC, abstractmethod
from dataclasses import dataclass
from ...codegen.target import Target
@dataclass
class CompilerInfo(ABC):
"""Base class for compiler infos."""
openmp: bool = True
"""Enable/disable OpenMP compilation"""
optlevel: str | None = "fast"
"""Compiler optimization level"""
cxx_standard: str = "c++11"
"""C++ language standard to be compiled with"""
target: Target = Target.CurrentCPU
"""Hardware target to compile for.
Here, `Target.CurrentCPU` represents the current hardware,
which is reflected by ``-march=native`` in GNU-like compilers.
"""
@abstractmethod
def cxx(self) -> str:
"""Path to the executable of this compiler"""
@abstractmethod
def cxxflags(self) -> list[str]:
"""Compiler flags affecting C++ compilation"""
@abstractmethod
def linker_flags(self) -> list[str]:
"""Flags affecting linkage of the extension module"""
@abstractmethod
def include_flags(self, include_dirs: Sequence[str]) -> list[str]:
"""Convert a list of include directories into corresponding compiler flags"""
@abstractmethod
def restrict_qualifier(self) -> str:
"""*restrict* memory qualifier recognized by this compiler"""
class _GnuLikeCliCompiler(CompilerInfo):
def cxxflags(self) -> list[str]:
flags = ["-DNDEBUG", f"-std={self.cxx_standard}", "-fPIC"]
if self.optlevel is not None:
flags.append(f"-O{self.optlevel}")
if self.openmp:
flags.append("-fopenmp")
match self.target:
case Target.CurrentCPU:
flags.append("-march=native")
case Target.X86_SSE:
flags += ["-march=x86-64-v2"]
case Target.X86_AVX:
flags += ["-march=x86-64-v3"]
case Target.X86_AVX512:
flags += ["-march=x86-64-v4"]
case Target.X86_AVX512_FP16:
flags += ["-march=x86-64-v4", "-mavx512fp16"]
return flags
def linker_flags(self) -> list[str]:
return ["-shared"]
def include_flags(self, include_dirs: Sequence[str]) -> list[str]:
return [f"-I{d}" for d in include_dirs]
def restrict_qualifier(self) -> str:
return "__restrict__"
class GccInfo(_GnuLikeCliCompiler):
"""Compiler info for the GNU Compiler Collection C++ compiler (``g++``)."""
def cxx(self) -> str:
return "g++"
@dataclass
class ClangInfo(_GnuLikeCliCompiler):
"""Compiler info for the LLVM C++ compiler (``clang``)."""
llvm_version: int | None = None
"""Major version number of the LLVM installation providing the compiler."""
def cxx(self) -> str:
if self.llvm_version is None:
return "clang"
else:
return f"clang-{self.llvm_version}"
def linker_flags(self) -> list[str]:
return super().linker_flags() + ["-lstdc++"]
from __future__ import annotations
from types import ModuleType
from pathlib import Path
import subprocess
from copy import copy
from abc import ABC, abstractmethod
from ...codegen.config import _AUTO_TYPE, AUTO
from ..jit import JitError, JitBase, KernelWrapper
from ...codegen import Kernel
from .compiler_info import CompilerInfo, GccInfo
class CpuJit(JitBase):
"""Just-in-time compiler for CPU kernels.
**Creation**
To configure and create a CPU JIT compiler instance, use the `create` factory method.
**Implementation Details**
The `CpuJit` class acts as an orchestrator between two components:
- The *extension module builder* produces the code of the dynamically built extension module
that contains the kernel and its invocation wrappers;
- The *compiler info* describes the host compiler used to compile and link that extension module.
Args:
compiler_info: The compiler info object defining the capabilities
and command-line interface of the host compiler
ext_module_builder: Extension module builder object used to generate the kernel extension module
objcache: Directory to cache the generated code files and compiled modules in.
If `None`, a temporary directory will be used, and compilation results will not be cached.
"""
@staticmethod
def create(
compiler_info: CompilerInfo | None = None,
objcache: str | Path | _AUTO_TYPE | None = AUTO,
) -> CpuJit:
"""Configure and create a CPU JIT compiler object.
Args:
compiler_info: Compiler info object defining capabilities and interface of the host compiler.
If `None`, a default compiler configuration will be determined from the current OS and runtime
environment.
objcache: Directory used for caching compilation results.
If set to `AUTO`, a persistent cache directory in the current user's home will be used.
If set to `None`, compilation results will not be cached--this may impact performance.
Returns:
The CPU just-in-time compiler.
"""
if objcache is AUTO:
from appdirs import AppDirs
dirs = AppDirs(appname="pystencils")
objcache = Path(dirs.user_cache_dir) / "cpujit"
elif objcache is not None:
assert not isinstance(objcache, _AUTO_TYPE)
objcache = Path(objcache)
if compiler_info is None:
compiler_info = GccInfo()
from .cpujit_pybind11 import Pybind11KernelModuleBuilder
modbuilder = Pybind11KernelModuleBuilder(compiler_info)
return CpuJit(compiler_info, modbuilder, objcache)
def __init__(
self,
compiler_info: CompilerInfo,
ext_module_builder: ExtensionModuleBuilderBase,
objcache: Path | None,
):
self._compiler_info = copy(compiler_info)
self._ext_module_builder = ext_module_builder
self._objcache = objcache
# Include Directories
import sysconfig
from ...include import get_pystencils_include_path
include_dirs = [
sysconfig.get_path("include"),
get_pystencils_include_path(),
] + self._ext_module_builder.include_dirs()
# Compiler Flags
self._cxx = self._compiler_info.cxx()
self._cxx_fixed_flags = (
self._compiler_info.cxxflags()
+ self._compiler_info.include_flags(include_dirs)
+ self._compiler_info.linker_flags()
)
def compile(self, kernel: Kernel) -> KernelWrapper:
"""Compile the given kernel to an executable function.
Args:
kernel: The kernel object to be compiled.
Returns:
Wrapper object around the compiled function
"""
# Get the Code
module_name = f"{kernel.name}_jit"
cpp_code = self._ext_module_builder.render_module(kernel, module_name)
# Get compiler information
import sysconfig
so_abi = sysconfig.get_config_var("SOABI")
lib_suffix = f"{so_abi}.so"
# Compute Code Hash
code_utf8: bytes = cpp_code.encode("utf-8")
compiler_utf8: bytes = (" ".join([self._cxx] + self._cxx_fixed_flags)).encode("utf-8")
import hashlib
module_hash = hashlib.sha256(code_utf8 + compiler_utf8)
module_stem = f"module_{module_hash.hexdigest()}"
def compile_and_load(module_dir: Path):
cpp_file = module_dir / f"{module_stem}.cpp"
if not cpp_file.exists():
cpp_file.write_bytes(code_utf8)
lib_file = module_dir / f"{module_stem}.{lib_suffix}"
if not lib_file.exists():
self._compile_extension_module(cpp_file, lib_file)
module = self._load_extension_module(module_name, lib_file)
return module
if self._objcache is not None:
module_dir = self._objcache
# Lock module
import fasteners
lockfile = module_dir / f"{module_stem}.lock"
with fasteners.InterProcessLock(lockfile):
module = compile_and_load(module_dir)
else:
from tempfile import TemporaryDirectory
with TemporaryDirectory() as tmpdir:
module_dir = Path(tmpdir)
module = compile_and_load(module_dir)
return self._ext_module_builder.get_wrapper(kernel, module)
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
[self._cxx]
+ self._cxx_fixed_flags
+ ["-o", str(libfile), str(src_file)]
)
result = subprocess.run(args, capture_output=True)
if result.returncode != 0:
raise JitError(
"Compilation failed: C++ compiler terminated with an error.\n"
+ result.stderr.decode()
)
def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
from importlib import util as iutil
spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
if spec is None:
raise JitError(
"Unable to load kernel extension module -- this is probably a bug."
)
mod = iutil.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
class ExtensionModuleBuilderBase(ABC):
"""Base class for CPU extension module builders."""
@staticmethod
@abstractmethod
def include_dirs() -> list[str]:
"""List of directories that must be on the include path when compiling
generated extension modules."""
@abstractmethod
def render_module(self, kernel: Kernel, module_name: str) -> str:
"""Produce the extension module code for the given kernel."""
@abstractmethod
def get_wrapper(
self, kernel: Kernel, extension_module: ModuleType
) -> KernelWrapper:
"""Produce the invocation wrapper for the given kernel
and its compiled extension module."""
from __future__ import annotations
from types import ModuleType
from typing import Sequence, cast
from pathlib import Path
from textwrap import indent
from pystencils.jit.jit import KernelWrapper
from ...types import PsPointerType, PsType
from ...field import Field
from ...sympyextensions import DynamicType
from ...codegen import Kernel, Parameter
from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
from .compiler_info import CompilerInfo
from .cpujit import ExtensionModuleBuilderBase
_module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp"
class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
@staticmethod
def include_dirs() -> list[str]:
import pybind11 as pb11
pybind11_include = pb11.get_include()
return [pybind11_include]
def __init__(
self,
compiler_info: CompilerInfo,
):
self._compiler_info = compiler_info
self._actual_field_types: dict[Field, PsType]
self._param_binds: list[str]
self._public_params: list[str]
self._param_check_lines: list[str]
self._extraction_lines: list[str]
def render_module(self, kernel: Kernel, module_name: str) -> str:
self._actual_field_types = dict()
self._param_binds = []
self._public_params = []
self._param_check_lines = []
self._extraction_lines = []
self._handle_params(kernel.parameters)
kernel_def = self._get_kernel_definition(kernel)
kernel_args = [param.name for param in kernel.parameters]
includes = [f"#include {h}" for h in sorted(kernel.required_headers)]
from string import Template
templ = Template(_module_template.read_text())
code_str = templ.substitute(
includes="\n".join(includes),
restrict_qualifier=self._compiler_info.restrict_qualifier(),
module_name=module_name,
kernel_name=kernel.name,
param_binds=", ".join(self._param_binds),
public_params=", ".join(self._public_params),
param_check_lines=indent("\n".join(self._param_check_lines), prefix=" "),
extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "),
kernel_args=", ".join(kernel_args),
kernel_definition=kernel_def,
)
return code_str
def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper:
return Pybind11KernelWrapper(kernel, extension_module)
def _get_kernel_definition(self, kernel: Kernel) -> str:
from ...backend.emission import CAstPrinter
printer = CAstPrinter()
return printer(kernel)
def _add_field_param(self, ptr_param: Parameter):
field: Field = ptr_param.fields[0]
ptr_type = ptr_param.dtype
assert isinstance(ptr_type, PsPointerType)
if isinstance(field.dtype, DynamicType):
elem_type = ptr_type.base_type
else:
elem_type = field.dtype
self._actual_field_types[field] = elem_type
param_bind = f'py::arg("{field.name}").noconvert()'
self._param_binds.append(param_bind)
kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
self._public_params.append(kernel_param)
expect_shape = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.shape) + ")"
for coord, size in enumerate(field.shape):
if isinstance(size, int):
self._param_check_lines.append(
f"checkFieldShape(\"{field.name}\", \"{expect_shape}\", {field.name}, {coord}, {size});"
)
expect_strides = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.strides) + ")"
for coord, stride in enumerate(field.strides):
if isinstance(stride, int):
self._param_check_lines.append(
f"checkFieldStride(\"{field.name}\", \"{expect_strides}\", {field.name}, {coord}, {stride});"
)
def _add_scalar_param(self, sc_param: Parameter):
param_bind = f'py::arg("{sc_param.name}")'
self._param_binds.append(param_bind)
kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
self._public_params.append(kernel_param)
def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
field_name = ptr_prop.field.name
assert isinstance(ptr_param.dtype, PsPointerType)
data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
self._extraction_lines.append(extraction)
def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
field_name = shape_prop.field.name
coord = shape_prop.coordinate
extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
self._extraction_lines.append(extraction)
def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
field = stride_prop.field
field_name = field.name
coord = stride_prop.coordinate
field_type = self._actual_field_types[field]
assert field_type.itemsize is not None
extraction = (
f"{stride_param.dtype.c_string()} {stride_param.name} "
f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
)
self._extraction_lines.append(extraction)
def _handle_params(self, parameters: Sequence[Parameter]):
for param in parameters:
if param.get_properties(FieldBasePtr):
self._add_field_param(param)
for param in parameters:
if ptr_props := param.get_properties(FieldBasePtr):
self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
elif shape_props := param.get_properties(FieldShape):
self._extract_shape(param, cast(FieldShape, shape_props.pop()))
elif stride_props := param.get_properties(FieldStride):
self._extract_stride(param, cast(FieldStride, stride_props.pop()))
else:
self._add_scalar_param(param)
class Pybind11KernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._check_params = getattr(jit_module, "check_params")
self._invoke = getattr(jit_module, "invoke")
def __call__(self, **kwargs) -> None:
self._check_params(**kwargs)
return self._invoke(**kwargs)
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include <array>
#include <string>
#include <sstream>
${includes}
namespace py = pybind11;
#define RESTRICT ${restrict_qualifier}
namespace internal {
${kernel_definition}
}
std::string tuple_to_str(const ssize_t * data, const size_t N){
std::stringstream acc;
acc << "(";
for(size_t i = 0; i < N; ++i){
acc << data[i];
if(i + 1 < N){
acc << ", ";
}
}
acc << ")";
return acc.str();
}
template< typename T >
void checkFieldShape(const std::string& fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){
std::stringstream err;
err << "Invalid shape of argument " << fieldName
<< ". Expected " << expected
<< ", but got " << tuple_to_str(arr.shape(), arr.ndim())
<< ".";
throw py::value_error{ err.str() };
};
if(arr.ndim() <= coord){
panic();
}
if(arr.shape(coord) != desired){
panic();
}
}
template< typename T >
void checkFieldStride(const std::string fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){
std::stringstream err;
err << "Invalid strides of argument " << fieldName
<< ". Expected " << expected
<< ", but got " << tuple_to_str(arr.strides(), arr.ndim())
<< ".";
throw py::value_error{ err.str() };
};
if(arr.ndim() <= coord){
panic();
}
if(arr.strides(coord) / sizeof(T) != desired){
panic();
}
}
void check_params_${kernel_name} (${public_params}) {
${param_check_lines}
}
void run_${kernel_name} (${public_params}) {
${extraction_lines}
internal::${kernel_name}(${kernel_args});
}
PYBIND11_MODULE(${module_name}, m) {
m.def("check_params", &check_params_${kernel_name}, py::kw_only(), ${param_binds});
m.def("invoke", &run_${kernel_name}, py::kw_only(), ${param_binds});
}
......@@ -91,12 +91,14 @@ class PsKernelExtensioNModule:
code += "\n"
# Kernels and call wrappers
from ..backend.emission import CAstPrinter
printer = CAstPrinter(func_prefix="FUNC_PREFIX")
for name, kernel in self._kernels.items():
old_name = kernel.name
kernel.name = f"kernel_{name}"
code += kernel.get_c_code()
code += printer(kernel)
code += "\n"
code += emit_call_wrapper(name, kernel)
code += "\n"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment