Skip to content
Snippets Groups Projects
Commit 1fadc0dd authored by Frederik Hennig's avatar Frederik Hennig
Browse files

basic JIT integration

parent 45a51fe6
No related branches found
No related tags found
No related merge requests found
Pipeline #61702 failed
...@@ -16,4 +16,4 @@ all code generation functionality currently implemented in *pystencils* version ...@@ -16,4 +16,4 @@ all code generation functionality currently implemented in *pystencils* version
arrays arrays
ast ast
kernelcreation kernelcreation
jit
************************
Just-In-Time Compilation
************************
.. automodule:: pystencils.nbackend.jit
:members:
...@@ -69,9 +69,6 @@ from pystencils.kernel_wrapper import KernelWrapper ...@@ -69,9 +69,6 @@ from pystencils.kernel_wrapper import KernelWrapper
from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess
from pystencils.utils import atomic_file_write, recursive_dict_update from pystencils.utils import atomic_file_write, recursive_dict_update
from ..nbackend.ast import PsKernelFunction
from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule
def make_python_function(kernel_function_node, custom_backend=None): def make_python_function(kernel_function_node, custom_backend=None):
""" """
...@@ -622,7 +619,9 @@ def compile_and_load(ast, custom_backend=None): ...@@ -622,7 +619,9 @@ def compile_and_load(ast, custom_backend=None):
compiler_config = get_compiler_config() compiler_config = get_compiler_config()
function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else '' function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''
from ..nbackend.ast import PsKernelFunction
if isinstance(ast, PsKernelFunction): if isinstance(ast, PsKernelFunction):
from ..nbackend.jit.cpu_extension_module import PsKernelExtensioNModule
code = PsKernelExtensioNModule() code = PsKernelExtensioNModule()
else: else:
code = ExtensionModuleCode(custom_backend=custom_backend) code = ExtensionModuleCode(custom_backend=custom_backend)
......
from __future__ import annotations from __future__ import annotations
from typing import Callable
from dataclasses import dataclass from dataclasses import dataclass
from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.dependency import DependencyMapper
from .nodes import PsAstNode, PsBlock, failing_cast from .nodes import PsAstNode, PsBlock, failing_cast
from ..constraints import PsKernelConstraint from ..constraints import PsKernelConstraint
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable
from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
from ..jit import JitBase, no_jit
from ..exceptions import PsInternalCompilerError from ..exceptions import PsInternalCompilerError
from ...enums import Target from ...enums import Target
...@@ -64,10 +68,11 @@ class PsKernelFunction(PsAstNode): ...@@ -64,10 +68,11 @@ class PsKernelFunction(PsAstNode):
__match_args__ = ("body",) __match_args__ = ("body",)
def __init__(self, body: PsBlock, target: Target, name: str = "kernel"): def __init__(self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit):
self._body: PsBlock = body self._body: PsBlock = body
self._target = target self._target = target
self._name = name self._name = name
self._jit = jit
self._constraints: list[PsKernelConstraint] = [] self._constraints: list[PsKernelConstraint] = []
...@@ -133,3 +138,6 @@ class PsKernelFunction(PsAstNode): ...@@ -133,3 +138,6 @@ class PsKernelFunction(PsAstNode):
# To Do: Headers from target/instruction set/... # To Do: Headers from target/instruction set/...
from .collectors import collect_required_headers from .collectors import collect_required_headers
return collect_required_headers(self) return collect_required_headers(self)
def compile(self) -> Callable[..., None]:
return self._jit.compile(self)
"""
JIT compilation in the ``nbackend`` is managed by subclasses of `JitBase`.
A JIT compiler may freely be created and configured by the user.
It can then be passed to `create_kernel` using the ``jit`` argument of
`CreateKernelConfig`, in which case it is hooked into the `PsKernelFunction.compile` method
of the generated kernel function::
my_jit = MyJit()
kernel = create_kernel(ast, CreateKernelConfig(jit=my_jit))
func = kernel.compile()
Otherwise, a JIT compiler may also be created free-standing, with the same effect::
my_jit = MyJit()
kernel = create_kernel(ast)
func = my_jit.compile(kernel)
Currently, only wrappers around the legacy JIT compilers are available.
Legacy Just-In-Time Compilation
-------------------------------
Historically, pystencils provides two main pathways for just-in-time compilation:
The ``cpu.cpujit`` module for CPU kernels, and the ``gpu.gpujit`` module for device kernels.
Both are available here through `LegacyCpuJit` and `LegacyGpuJit`.
"""
from .jit import JitBase, NoJit, LegacyCpuJit, LegacyGpuJit
no_jit = NoJit()
legacy_cpu = LegacyCpuJit()
legacy_gpu = LegacyGpuJit()
__all__ = [
"JitBase",
"LegacyCpuJit",
"legacy_cpu",
"NoJit",
"no_jit",
"LegacyGpuJit",
"legacy_gpu",
]
from __future__ import annotations
from typing import Callable, TYPE_CHECKING
from abc import ABC, abstractmethod
if TYPE_CHECKING:
from ..ast import PsKernelFunction
class JitError(Exception):
"""Indicates an error during just-in-time compilation"""
class JitBase(ABC):
"""Base class for just-in-time compilation interfaces implemented in pystencils."""
@abstractmethod
def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
"""Compile a kernel function and return a callable object which invokes the kernel."""
class NoJit(JitBase):
"""Not a JIT compiler: Used to explicitly disable JIT compilation on an AST."""
def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
raise JitError(
"Just-in-time compilation of this kernel was explicitly disabled."
)
class LegacyCpuJit(JitBase):
"""Wrapper around ``pystencils.cpu.cpujit``"""
def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
from ...cpu.cpujit import compile_and_load
return compile_and_load(kernel)
class LegacyGpuJit(JitBase):
"""Wrapper around ``pystencils.gpu.gpujit``"""
def compile(self, kernel: PsKernelFunction) -> Callable[..., None]:
from ...gpu.gpujit import make_python_function
return make_python_function(kernel)
...@@ -4,6 +4,7 @@ from dataclasses import dataclass ...@@ -4,6 +4,7 @@ from dataclasses import dataclass
from ...enums import Target from ...enums import Target
from ...field import Field, FieldType from ...field import Field, FieldType
from ..jit import JitBase
from ..exceptions import PsOptionsError from ..exceptions import PsOptionsError
from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType from ..types import PsIntegerType, PsNumericType, PsIeeeFloatType
...@@ -20,6 +21,13 @@ class CreateKernelConfig: ...@@ -20,6 +21,13 @@ class CreateKernelConfig:
TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ... TODO: Enhance `Target` from enum to a larger target spec, e.g. including vectorization architecture, ...
""" """
jit: JitBase | None = None
"""Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment.
If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter.
To explicitly disable JIT compilation, pass `nbackend.jit.no_jit`.
"""
function_name: str = "kernel" function_name: str = "kernel"
"""Name of the generated function""" """Name of the generated function"""
...@@ -63,6 +71,7 @@ class CreateKernelConfig: ...@@ -63,6 +71,7 @@ class CreateKernelConfig:
""" """
def __post_init__(self): def __post_init__(self):
# Check iteration space argument consistency
if ( if (
int(self.iteration_slice is not None) int(self.iteration_slice is not None)
+ int(self.ghost_layers is not None) + int(self.ghost_layers is not None)
...@@ -74,6 +83,7 @@ class CreateKernelConfig: ...@@ -74,6 +83,7 @@ class CreateKernelConfig:
"at most one of them may be set." "at most one of them may be set."
) )
# Check index field
if ( if (
self.index_field is not None self.index_field is not None
and self.index_field.field_type != FieldType.INDEXED and self.index_field.field_type != FieldType.INDEXED
...@@ -81,3 +91,12 @@ class CreateKernelConfig: ...@@ -81,3 +91,12 @@ class CreateKernelConfig:
raise PsOptionsError( raise PsOptionsError(
"Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
) )
# Infer JIT
if self.jit is None:
match self.target:
case Target.CPU:
from ..jit import legacy_cpu
self.jit = legacy_cpu
case _:
raise NotImplementedError(f"No default JIT compiler implemented yet for target {self.target}")
...@@ -19,6 +19,7 @@ def create_kernel( ...@@ -19,6 +19,7 @@ def create_kernel(
assignments: AssignmentCollection, assignments: AssignmentCollection,
config: CreateKernelConfig = CreateKernelConfig(), config: CreateKernelConfig = CreateKernelConfig(),
): ):
"""Create a kernel AST from an assignment collection."""
ctx = KernelCreationContext(config) ctx = KernelCreationContext(config)
analysis = KernelAnalysis(ctx) analysis = KernelAnalysis(ctx)
...@@ -57,7 +58,8 @@ def create_kernel( ...@@ -57,7 +58,8 @@ def create_kernel(
# - Loop Splitting, Tiling, Blocking # - Loop Splitting, Tiling, Blocking
kernel_ast = platform.optimize(kernel_ast) kernel_ast = platform.optimize(kernel_ast)
function = PsKernelFunction(kernel_ast, config.target, name=config.function_name) assert config.jit is not None
function = PsKernelFunction(kernel_ast, config.target, name=config.function_name, jit=config.jit)
function.add_constraints(*ctx.constraints) function.add_constraints(*ctx.constraints)
return function return function
...@@ -7,7 +7,6 @@ from pystencils import fields, Field, AssignmentCollection ...@@ -7,7 +7,6 @@ from pystencils import fields, Field, AssignmentCollection
from pystencils.assignment import assignment_from_stencil from pystencils.assignment import assignment_from_stencil
from pystencils.nbackend.kernelcreation import create_kernel from pystencils.nbackend.kernelcreation import create_kernel
from pystencils.cpu.cpujit import compile_and_load
def test_filter_kernel(): def test_filter_kernel():
weight = sp.Symbol("weight") weight = sp.Symbol("weight")
...@@ -22,8 +21,7 @@ def test_filter_kernel(): ...@@ -22,8 +21,7 @@ def test_filter_kernel():
asms = AssignmentCollection([asm]) asms = AssignmentCollection([asm])
ast = create_kernel(asms) ast = create_kernel(asms)
kernel = ast.compile()
kernel = compile_and_load(ast)
src_arr = np.ones((42, 42)) src_arr = np.ones((42, 42))
dst_arr = np.zeros_like(src_arr) dst_arr = np.zeros_like(src_arr)
...@@ -54,7 +52,7 @@ def test_filter_kernel_fixedsize(): ...@@ -54,7 +52,7 @@ def test_filter_kernel_fixedsize():
asms = AssignmentCollection([asm]) asms = AssignmentCollection([asm])
ast = create_kernel(asms) ast = create_kernel(asms)
kernel = compile_and_load(ast) kernel = ast.compile()
kernel(src=src_arr, dst=dst_arr, weight=2.0) kernel(src=src_arr, dst=dst_arr, weight=2.0)
......
...@@ -5,7 +5,6 @@ import numpy as np ...@@ -5,7 +5,6 @@ import numpy as np
from pystencils import Assignment, Field, FieldType, AssignmentCollection from pystencils import Assignment, Field, FieldType, AssignmentCollection
from pystencils.nbackend.kernelcreation import create_kernel, CreateKernelConfig from pystencils.nbackend.kernelcreation import create_kernel, CreateKernelConfig
from pystencils.cpu.cpujit import compile_and_load
def test_indexed_kernel(): def test_indexed_kernel():
arr = np.zeros((3, 4)) arr = np.zeros((3, 4))
...@@ -23,7 +22,7 @@ def test_indexed_kernel(): ...@@ -23,7 +22,7 @@ def test_indexed_kernel():
options = CreateKernelConfig(index_field=index_field) options = CreateKernelConfig(index_field=index_field)
ast = create_kernel(update_rule, options) ast = create_kernel(update_rule, options)
kernel = compile_and_load(ast) kernel = ast.compile()
kernel(f=arr, index=index_arr) kernel(f=arr, index=index_arr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment