From 4b11cdd047c51d3d342c7873976eca8b42f66998 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 4 Dec 2024 15:16:36 +0100 Subject: [PATCH] code movements complete, all tests run --- src/pystencils/__init__.py | 2 +- src/pystencils/backend/constraints.py | 22 --- .../backend/emission/base_printer.py | 5 +- src/pystencils/backend/emission/c_printer.py | 11 +- src/pystencils/backend/emission/ir_printer.py | 8 +- .../backend/kernelcreation/context.py | 11 -- .../backend/kernelcreation/iteration_space.py | 2 +- src/pystencils/backend/kernelfunction.py | 113 ---------------- src/pystencils/backend/platforms/__init__.py | 3 +- src/pystencils/backend/platforms/cuda.py | 12 +- .../backend/platforms/generic_gpu.py | 68 +++------- src/pystencils/backend/platforms/sycl.py | 12 +- .../backend/transformations/add_pragmas.py | 1 + .../transformations/canonicalize_symbols.py | 1 - src/pystencils/codegen/__init__.py | 9 +- src/pystencils/codegen/config.py | 13 +- src/pystencils/codegen/driver.py | 125 +++++++++++++++++- src/pystencils/codegen/gpu.py | 28 ---- src/pystencils/codegen/kernel.py | 93 +++++++++++-- src/pystencils/codegen/parameters.py | 17 +-- src/pystencils/display_utils.py | 30 ++--- src/pystencils/inspection.py | 18 +-- src/pystencils/jit/__init__.py | 2 +- src/pystencils/jit/cpu_extension_module.py | 71 +++++----- src/pystencils/jit/gpu_cupy.py | 41 +++--- src/pystencils/jit/jit.py | 15 +-- src/pystencils/jit/legacy_cpu.py | 8 +- src/pystencils/kernel_wrapper.py | 2 +- src/pystencils/kernelcreation.py | 6 +- tests/kernelcreation/test_domain_kernels.py | 6 +- tests/kernelcreation/test_index_kernels.py | 2 +- tests/kernelcreation/test_iteration_slices.py | 2 +- tests/nbackend/test_code_printing.py | 9 +- tests/nbackend/test_cpujit.py | 7 +- tests/nbackend/test_vectorization.py | 4 +- 35 files changed, 376 insertions(+), 403 deletions(-) delete mode 100644 src/pystencils/backend/constraints.py delete mode 100644 src/pystencils/backend/kernelfunction.py delete mode 100644 src/pystencils/codegen/gpu.py diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 4374ccda4..028e4b885 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -19,7 +19,7 @@ from .cache import clear_cache from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel, create_staggered_kernel from .codegen import Kernel -from .backend.jit import no_jit +from .jit import no_jit from .backend.exceptions import KernelConstraintsError from .slicing import make_slice from .spatial_coordinates import ( diff --git a/src/pystencils/backend/constraints.py b/src/pystencils/backend/constraints.py deleted file mode 100644 index 229f6718c..000000000 --- a/src/pystencils/backend/constraints.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import Any, TYPE_CHECKING -from dataclasses import dataclass - -if TYPE_CHECKING: - from .kernelfunction import KernelParameter - - -@dataclass -class KernelParamsConstraint: - condition: Any # FIXME Implement conditions - message: str = "" - - def to_code(self): - raise NotImplementedError() - - def get_parameters(self) -> set[KernelParameter]: - raise NotImplementedError() - - def __str__(self) -> str: - return f"{self.message} [{self.condition}]" diff --git a/src/pystencils/backend/emission/base_printer.py b/src/pystencils/backend/emission/base_printer.py index d721b9f89..a4358bbf3 100644 --- a/src/pystencils/backend/emission/base_printer.py +++ b/src/pystencils/backend/emission/base_printer.py @@ -61,7 +61,7 @@ from ..constants import PsConstant from ...types import PsType if TYPE_CHECKING: - from ...codegen import Kernel, GpuKernel + from ...codegen import Kernel class EmissionError(Exception): @@ -175,6 +175,7 @@ class BasePrinter(ABC): self._indent_width = indent_width def __call__(self, obj: PsAstNode | Kernel) -> str: + from ...codegen import Kernel if isinstance(obj, Kernel): sig = self.print_signature(obj) body_code = self.visit(obj.body, PrinterCtx()) @@ -383,6 +384,8 @@ class BasePrinter(ABC): return signature def _func_prefix(self, func: Kernel): + from ...codegen import GpuKernel + if isinstance(func, GpuKernel) and func.target == Target.CUDA: return "__global__" else: diff --git a/src/pystencils/backend/emission/c_printer.py b/src/pystencils/backend/emission/c_printer.py index 95e27bd66..90a7e54e2 100644 --- a/src/pystencils/backend/emission/c_printer.py +++ b/src/pystencils/backend/emission/c_printer.py @@ -1,18 +1,23 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + from pystencils.backend.ast.astnode import PsAstNode from pystencils.backend.constants import PsConstant from pystencils.backend.emission.base_printer import PrinterCtx, EmissionError from pystencils.backend.memory import PsSymbol from .base_printer import BasePrinter -from ..kernelfunction import KernelFunction from ...types import PsType, PsArrayType, PsScalarType, PsTypeError from ..ast.expressions import PsBufferAcc from ..ast.vector import PsVecMemAcc +if TYPE_CHECKING: + from ...codegen import Kernel + -def emit_code(kernel: KernelFunction): +def emit_code(ast: PsAstNode | Kernel): printer = CAstPrinter() - return printer(kernel) + return printer(ast) class CAstPrinter(BasePrinter): diff --git a/src/pystencils/backend/emission/ir_printer.py b/src/pystencils/backend/emission/ir_printer.py index 124ce200d..ffb65181c 100644 --- a/src/pystencils/backend/emission/ir_printer.py +++ b/src/pystencils/backend/emission/ir_printer.py @@ -1,3 +1,6 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + from pystencils.backend.constants import PsConstant from pystencils.backend.emission.base_printer import PrinterCtx from pystencils.backend.memory import PsSymbol @@ -9,8 +12,11 @@ from ..ast import PsAstNode from ..ast.expressions import PsBufferAcc from ..ast.vector import PsVecMemAcc, PsVecBroadcast +if TYPE_CHECKING: + from ...codegen import Kernel + -def emit_ir(ir: PsAstNode): +def emit_ir(ir: PsAstNode | Kernel): """Emit the IR as C-like pseudo-code for inspection.""" ir_printer = IRAstPrinter() return ir_printer(ir) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index bb7bd708d..39fb8ef6d 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -18,7 +18,6 @@ from ...types import ( PsPointerType, deconstify, ) -from ..constraints import KernelParamsConstraint from ..exceptions import PsInternalCompilerError, KernelConstraintsError from .iteration_space import IterationSpace, FullIterationSpace, SparseIterationSpace @@ -81,7 +80,6 @@ class KernelCreationContext: self._ispace: IterationSpace | None = None - self._constraints: list[KernelParamsConstraint] = [] self._req_headers: set[str] = set() self._metadata: dict[str, Any] = dict() @@ -96,15 +94,6 @@ class KernelCreationContext: """Data type used by default for index expressions""" return self._index_dtype - # Constraints - - def add_constraints(self, *constraints: KernelParamsConstraint): - self._constraints += constraints - - @property - def constraints(self) -> tuple[KernelParamsConstraint, ...]: - return tuple(self._constraints) - @property def metadata(self) -> dict[str, Any]: return self._metadata diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index a7802c931..031a0d843 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -457,7 +457,7 @@ def create_full_iteration_space( # Otherwise, if an iteration slice was specified, use that # Otherwise, use the inferred ghost layers - from ...codegen.config import AUTO + from ...codegen.config import AUTO, _AUTO_TYPE if ghost_layers is AUTO: if len(domain_field_accesses) > 0: diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py deleted file mode 100644 index 3c7e103b3..000000000 --- a/src/pystencils/backend/kernelfunction.py +++ /dev/null @@ -1,113 +0,0 @@ -from __future__ import annotations - -from warnings import warn -from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING -from itertools import chain - -from .._deprecation import _deprecated - -from .ast.structural import PsBlock -from .ast.analysis import collect_required_headers, collect_undefined_symbols -from .memory import PsSymbol -from ..codegen.properties import ( - PsSymbolProperty, - _FieldProperty, - FieldShape, - FieldStride, - FieldBasePtr, -) -from .kernelcreation.context import KernelCreationContext -from .platforms import Platform, GpuThreadsRange - -from .constraints import KernelParamsConstraint -from ..types import PsType - -from ..codegen.target import Target -from ..field import Field -from ..sympyextensions import TypedSymbol - -if TYPE_CHECKING: - from .jit import JitBase - - - - -def create_cpu_kernel_function( - ctx: KernelCreationContext, - platform: Platform, - body: PsBlock, - function_name: str, - target_spec: Target, - jit: JitBase, -): - undef_symbols = collect_undefined_symbols(body) - - params = _get_function_params(ctx, undef_symbols) - req_headers = _get_headers(ctx, platform, body) - - kfunc = KernelFunction( - body, target_spec, function_name, params, req_headers, ctx.constraints, jit - ) - kfunc.metadata.update(ctx.metadata) - return kfunc - - - - -def create_gpu_kernel_function( - ctx: KernelCreationContext, - platform: Platform, - body: PsBlock, - threads_range: GpuThreadsRange | None, - function_name: str, - target_spec: Target, - jit: JitBase, -): - undef_symbols = collect_undefined_symbols(body) - - if threads_range is not None: - for threads in threads_range.num_work_items: - undef_symbols |= collect_undefined_symbols(threads) - - params = _get_function_params(ctx, undef_symbols) - req_headers = _get_headers(ctx, platform, body) - - kfunc = GpuKernelFunction( - body, - threads_range, - target_spec, - function_name, - params, - req_headers, - ctx.constraints, - jit, - ) - kfunc.metadata.update(ctx.metadata) - return kfunc - - -def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): - params: list[KernelParameter] = [] - - from pystencils.backend.memory import BufferBasePtr - - for symb in symbols: - props: set[PsSymbolProperty] = set() - for prop in symb.properties: - match prop: - case FieldShape() | FieldStride(): - props.add(prop) - case BufferBasePtr(buf): - field = ctx.find_field(buf.name) - props.add(FieldBasePtr(field)) - params.append(KernelParameter(symb.name, symb.get_dtype(), props)) - - params.sort(key=lambda p: p.name) - return params - - -def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock): - req_headers = collect_required_headers(body) - req_headers |= platform.required_headers - req_headers |= ctx.required_headers - return req_headers diff --git a/src/pystencils/backend/platforms/__init__.py b/src/pystencils/backend/platforms/__init__.py index 9332453c6..589841db8 100644 --- a/src/pystencils/backend/platforms/__init__.py +++ b/src/pystencils/backend/platforms/__init__.py @@ -1,6 +1,6 @@ from .platform import Platform from .generic_cpu import GenericCpu, GenericVectorCpu -from .generic_gpu import GenericGpu, GpuThreadsRange +from .generic_gpu import GenericGpu from .cuda import CudaPlatform from .x86 import X86VectorCpu, X86VectorArch from .sycl import SyclPlatform @@ -12,7 +12,6 @@ __all__ = [ "X86VectorCpu", "X86VectorArch", "GenericGpu", - "GpuThreadsRange", "CudaPlatform", "SyclPlatform", ] diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 048bcb0d5..b8f5356ae 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -1,9 +1,10 @@ +from __future__ import annotations from warnings import warn from typing import TYPE_CHECKING from ...types import constify from ..exceptions import MaterializationError -from .generic_gpu import GenericGpu, GpuThreadsRange +from .generic_gpu import GenericGpu from ..kernelcreation import ( Typifier, @@ -29,7 +30,7 @@ from ..literals import PsLiteral from ..functions import PsMathFunction, MathFunctions, CFunction if TYPE_CHECKING: - from ...codegen.config import GpuIndexingConfig + from ...codegen import GpuIndexingConfig, GpuThreadsRange int32 = PsSignedIntegerType(width=32, const=False) @@ -54,6 +55,9 @@ class CudaPlatform(GenericGpu): self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None ) -> None: super().__init__(ctx) + + from ...codegen.config import GpuIndexingConfig + self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig() self._typify = Typifier(ctx) @@ -134,7 +138,7 @@ class CudaPlatform(GenericGpu): if not self._cfg.manual_launch_grid: try: - threads_range = GpuThreadsRange.from_ispace(ispace) + threads_range = self.threads_from_ispace(ispace) except MaterializationError as e: warn( str(e.args[0]) @@ -212,7 +216,7 @@ class CudaPlatform(GenericGpu): body.statements = [sparse_idx_decl] + body.statements ast = body - return ast, GpuThreadsRange.from_ispace(ispace) + return ast, self.threads_from_ispace(ispace) def _linear_thread_idx(self, coord: int): block_size = BLOCK_DIM[coord] diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 0512351cd..da8fa64f9 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING from abc import abstractmethod from ..ast.expressions import PsExpression @@ -12,55 +12,33 @@ from ..kernelcreation.iteration_space import ( from .platform import Platform from ..exceptions import MaterializationError +if TYPE_CHECKING: + from ...codegen.kernel import GpuThreadsRange -class GpuThreadsRange: - """Number of threads required by a GPU kernel, in order (x, y, z).""" - @staticmethod - def from_ispace(ispace: IterationSpace) -> GpuThreadsRange: +class GenericGpu(Platform): + @abstractmethod + def materialize_iteration_space( + self, block: PsBlock, ispace: IterationSpace + ) -> tuple[PsBlock, GpuThreadsRange | None]: + pass + + @classmethod + def threads_from_ispace(cls, ispace: IterationSpace) -> GpuThreadsRange: + from ...codegen.kernel import GpuThreadsRange + if isinstance(ispace, FullIterationSpace): - return GpuThreadsRange._from_full_ispace(ispace) + return cls._threads_from_full_ispace(ispace) elif isinstance(ispace, SparseIterationSpace): work_items = (PsExpression.make(ispace.index_list.shape[0]),) return GpuThreadsRange(work_items) else: assert False - def __init__( - self, - num_work_items: Sequence[PsExpression], - ): - self._dim = len(num_work_items) - self._num_work_items = tuple(num_work_items) - - # @property - # def grid_size(self) -> tuple[PsExpression, ...]: - # return self._grid_size - - # @property - # def block_size(self) -> tuple[PsExpression, ...]: - # return self._block_size - - @property - def num_work_items(self) -> tuple[PsExpression, ...]: - """Number of work items in (x, y, z)-order.""" - return self._num_work_items - - @property - def dim(self) -> int: - return self._dim - - def __str__(self) -> str: - rep = "GpuThreadsRange { " - rep += "; ".join(f"{x}: {w}" for x, w in zip("xyz", self._num_work_items)) - rep += " }" - return rep - - def _repr_html_(self) -> str: - return str(self) - - @staticmethod - def _from_full_ispace(ispace: FullIterationSpace) -> GpuThreadsRange: + @classmethod + def _threads_from_full_ispace(cls, ispace: FullIterationSpace) -> GpuThreadsRange: + from ...codegen.kernel import GpuThreadsRange + dimensions = ispace.dimensions_in_loop_order()[::-1] if len(dimensions) > 3: raise NotImplementedError( @@ -81,11 +59,3 @@ class GpuThreadsRange: work_items = [ispace.actual_iterations(dim) for dim in dimensions] return GpuThreadsRange(work_items) - - -class GenericGpu(Platform): - @abstractmethod - def materialize_iteration_space( - self, block: PsBlock, ispace: IterationSpace - ) -> tuple[PsBlock, GpuThreadsRange | None]: - pass diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 56615af24..9c04d6074 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import TYPE_CHECKING from ..functions import CFunction, PsMathFunction, MathFunctions @@ -24,12 +25,12 @@ from ..extensions.cpp import CppMethodCall from ..kernelcreation import KernelCreationContext, AstFactory from ..constants import PsConstant -from .generic_gpu import GenericGpu, GpuThreadsRange +from .generic_gpu import GenericGpu from ..exceptions import MaterializationError from ...types import PsCustomType, PsIeeeFloatType, constify, PsIntegerType if TYPE_CHECKING: - from ...codegen.config import GpuIndexingConfig + from ...codegen import GpuIndexingConfig, GpuThreadsRange class SyclPlatform(GenericGpu): @@ -38,6 +39,9 @@ class SyclPlatform(GenericGpu): self, ctx: KernelCreationContext, indexing_cfg: GpuIndexingConfig | None = None ): super().__init__(ctx) + + from ...codegen.config import GpuIndexingConfig + self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig() @property @@ -113,7 +117,7 @@ class SyclPlatform(GenericGpu): id_decl = self._id_declaration(rank, id_symbol) dimensions = ispace.dimensions_in_loop_order() - launch_config = GpuThreadsRange.from_ispace(ispace) + launch_config = self.threads_from_ispace(ispace) indexing_decls = [id_decl] conds = [] @@ -188,7 +192,7 @@ class SyclPlatform(GenericGpu): body.statements = [sparse_idx_decl] + body.statements ast = body - return ast, GpuThreadsRange.from_ispace(ispace) + return ast, self.threads_from_ispace(ispace) def _item_type(self, rank: int): if not self._cfg.sycl_automatic_block_size: diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 47c008819..78e721f38 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -1,3 +1,4 @@ +from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING diff --git a/src/pystencils/backend/transformations/canonicalize_symbols.py b/src/pystencils/backend/transformations/canonicalize_symbols.py index f5b356432..c0406c25d 100644 --- a/src/pystencils/backend/transformations/canonicalize_symbols.py +++ b/src/pystencils/backend/transformations/canonicalize_symbols.py @@ -72,7 +72,6 @@ class CanonicalizeSymbols: symb.dtype = constify(symb.dtype) # Any symbols still alive now are function params or globals - # Might use that to populate KernelFunction self._last_result = cc return node diff --git a/src/pystencils/codegen/__init__.py b/src/pystencils/codegen/__init__.py index be9fd9510..86f7f2940 100644 --- a/src/pystencils/codegen/__init__.py +++ b/src/pystencils/codegen/__init__.py @@ -6,8 +6,8 @@ from .config import ( OpenMpConfig, GpuIndexingConfig, ) - -from .kernel import Kernel +from .parameters import Parameter +from .kernel import Kernel, GpuKernel, GpuThreadsRange from .driver import create_kernel, get_driver __all__ = [ @@ -17,7 +17,10 @@ __all__ = [ "VectorizationConfig", "OpenMpConfig", "GpuIndexingConfig", + "Parameter", "Kernel", + "GpuKernel", + "GpuThreadsRange", "create_kernel", "get_driver", -] \ No newline at end of file +] diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py index 05e3ec3de..b516245fa 100644 --- a/src/pystencils/codegen/config.py +++ b/src/pystencils/codegen/config.py @@ -21,15 +21,14 @@ from ..types import ( from ..defaults import DEFAULTS if TYPE_CHECKING: - from ..backend.jit import JitBase + from ..jit import JitBase class PsOptionsError(Exception): """Indicates an option clash in the `CreateKernelConfig`.""" -class _AUTO_TYPE: - ... +class _AUTO_TYPE: ... # noqa: E701 AUTO = _AUTO_TYPE() @@ -336,12 +335,12 @@ class CreateKernelConfig: """Returns either the user-specified JIT compiler, or infers one from the target if none is given.""" if self.jit is None: if self.target.is_cpu(): - from ..backend.jit import LegacyCpuJit + from ..jit import LegacyCpuJit return LegacyCpuJit() elif self.target == Target.CUDA: try: - from ..backend.jit.gpu_cupy import CupyJit + from ..jit.gpu_cupy import CupyJit if ( self.gpu_indexing is not None @@ -352,12 +351,12 @@ class CreateKernelConfig: return CupyJit() except ImportError: - from ..backend.jit import no_jit + from ..jit import no_jit return no_jit elif self.target == Target.SYCL: - from ..backend.jit import no_jit + from ..jit import no_jit return no_jit else: diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index bc690a598..0fd49b248 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -1,15 +1,19 @@ from __future__ import annotations - -from typing import cast, Sequence +from typing import cast, Sequence, Iterable, TYPE_CHECKING from dataclasses import dataclass, replace from .target import Target from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO -from .kernel import Kernel +from .kernel import Kernel, GpuKernel, GpuThreadsRange +from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr +from .parameters import Parameter from ..types import create_numeric_type, PsIntegerType, PsScalarType + +from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode from ..backend.ast.structural import PsBlock, PsLoop +from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers from ..backend.kernelcreation import ( KernelCreationContext, KernelAnalysis, @@ -22,7 +26,12 @@ from ..backend.kernelcreation.iteration_space import ( create_full_iteration_space, FullIterationSpace, ) -from ..backend.platforms import Platform, GenericCpu, GenericVectorCpu, GenericGpu +from ..backend.platforms import ( + Platform, + GenericCpu, + GenericVectorCpu, + GenericGpu, +) from ..backend.exceptions import VectorizationError from ..backend.transformations import ( @@ -36,6 +45,9 @@ from ..backend.transformations import ( from ..simp import AssignmentCollection from sympy.codegen.ast import AssignmentBase +if TYPE_CHECKING: + from ..jit import JitBase + __all__ = ["create_kernel"] @@ -238,7 +250,7 @@ class DefaultKernelCreationDriver: kernel_ast = self._vectorize(kernel_ast) if cpu_cfg.openmp is not False: - from .backend.transformations import AddOpenMP + from ..backend.transformations import AddOpenMP params = ( cpu_cfg.openmp @@ -262,7 +274,7 @@ class DefaultKernelCreationDriver: if vec_config is None: return kernel_ast - from .backend.transformations import LoopVectorizer, SelectIntrinsics + from ..backend.transformations import LoopVectorizer, SelectIntrinsics assert isinstance(self._platform, GenericVectorCpu) @@ -359,6 +371,107 @@ class DefaultKernelCreationDriver: f"Code generation for target {self._target} not implemented" ) + def _get_function_params(self, symbols: Iterable[PsSymbol]): + params: list[Parameter] = [] + + from pystencils.backend.memory import BufferBasePtr + + for symb in symbols: + props: set[PsSymbolProperty] = set() + for prop in symb.properties: + match prop: + case FieldShape() | FieldStride(): + props.add(prop) + case BufferBasePtr(buf): + field = self._ctx.find_field(buf.name) + props.add(FieldBasePtr(field)) + params.append(Parameter(symb.name, symb.get_dtype(), props)) + + params.sort(key=lambda p: p.name) + return params + + def _get_headers(self, body: PsBlock): + req_headers = collect_required_headers(body) + req_headers |= self._platform.required_headers + req_headers |= self._ctx.required_headers + return req_headers + + +def create_cpu_kernel_function( + ctx: KernelCreationContext, + platform: Platform, + body: PsBlock, + function_name: str, + target_spec: Target, + jit: JitBase, +): + undef_symbols = collect_undefined_symbols(body) + + params = _get_function_params(ctx, undef_symbols) + req_headers = _get_headers(ctx, platform, body) + + kfunc = Kernel( + body, target_spec, function_name, params, req_headers, jit + ) + kfunc.metadata.update(ctx.metadata) + return kfunc + + +def create_gpu_kernel_function( + ctx: KernelCreationContext, + platform: Platform, + body: PsBlock, + threads_range: GpuThreadsRange, + function_name: str, + target_spec: Target, + jit: JitBase, +): + undef_symbols = collect_undefined_symbols(body) + for threads in threads_range.num_work_items: + undef_symbols |= collect_undefined_symbols(threads) + + params = _get_function_params(ctx, undef_symbols) + req_headers = _get_headers(ctx, platform, body) + + kfunc = GpuKernel( + body, + threads_range, + target_spec, + function_name, + params, + req_headers, + jit, + ) + kfunc.metadata.update(ctx.metadata) + return kfunc + + +def _get_function_params(ctx: KernelCreationContext, symbols: Iterable[PsSymbol]): + params: list[Parameter] = [] + + from pystencils.backend.memory import BufferBasePtr + + for symb in symbols: + props: set[PsSymbolProperty] = set() + for prop in symb.properties: + match prop: + case FieldShape() | FieldStride(): + props.add(prop) + case BufferBasePtr(buf): + field = ctx.find_field(buf.name) + props.add(FieldBasePtr(field)) + params.append(Parameter(symb.name, symb.get_dtype(), props)) + + params.sort(key=lambda p: p.name) + return params + + +def _get_headers(ctx: KernelCreationContext, platform: Platform, body: PsBlock): + req_headers = collect_required_headers(body) + req_headers |= platform.required_headers + req_headers |= ctx.required_headers + return req_headers + @dataclass class StageResult: diff --git a/src/pystencils/codegen/gpu.py b/src/pystencils/codegen/gpu.py deleted file mode 100644 index 9cce9b55b..000000000 --- a/src/pystencils/codegen/gpu.py +++ /dev/null @@ -1,28 +0,0 @@ - - -from .kernel import Kernel - - -class GpuKernel(Kernel): - """Internal representation of a kernel function targeted at CUDA GPUs.""" - - def __init__( - self, - body: PsBlock, - threads_range: GpuThreadsRange | None, - target: Target, - name: str, - parameters: Sequence[KernelParameter], - required_headers: set[str], - constraints: Sequence[KernelParamsConstraint], - jit: JitBase, - ): - super().__init__( - body, target, name, parameters, required_headers, constraints, jit - ) - self._threads_range = threads_range - - @property - def threads_range(self) -> GpuThreadsRange | None: - """Object exposing the total size of the launch grid this kernel expects to be executed with.""" - return self._threads_range diff --git a/src/pystencils/codegen/kernel.py b/src/pystencils/codegen/kernel.py index 6a0a6d576..c4ad860b6 100644 --- a/src/pystencils/codegen/kernel.py +++ b/src/pystencils/codegen/kernel.py @@ -1,21 +1,19 @@ from __future__ import annotations from warnings import warn -from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING +from typing import Callable, Sequence, Any, TYPE_CHECKING from itertools import chain -from .._deprecation import _deprecated - -from ..backend.ast.structural import PsBlock -from ..backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from ..backend.memory import PsSymbol - -from ..types import PsType - from .target import Target from .parameters import Parameter +from ..backend.ast.structural import PsBlock +from ..backend.ast.expressions import PsExpression from ..field import Field -from ..sympyextensions import TypedSymbol + +from .._deprecation import _deprecated + +if TYPE_CHECKING: + from ..jit import JitBase class Kernel: @@ -23,7 +21,7 @@ class Kernel: The kernel object is the final result of the translation process. It is immutable, and its AST should not be altered any more, either, as this - might invalidate information about the kernel already stored in the `KernelFunction` object. + might invalidate information about the kernel already stored in the kernel object. """ def __init__( @@ -78,7 +76,7 @@ class Kernel: return self._params def get_parameters(self) -> tuple[Parameter, ...]: - _deprecated("KernelFunction.get_parameters", "KernelFunction.parameters") + _deprecated("Kernel.get_parameters", "Kernel.parameters") return self.parameters def get_fields(self) -> set[Field]: @@ -97,6 +95,77 @@ class Kernel: def required_headers(self) -> set[str]: return self._required_headers + def get_c_code(self) -> str: + from ..backend.emission import CAstPrinter + + printer = CAstPrinter() + return printer(self) + + def get_ir_code(self) -> str: + from ..backend.emission import IRAstPrinter + + printer = IRAstPrinter() + return printer(self) + def compile(self) -> Callable[..., None]: """Invoke the underlying just-in-time compiler to obtain the kernel as an executable Python function.""" return self._jit.compile(self) + + +class GpuKernel(Kernel): + """Internal representation of a kernel function targeted at CUDA GPUs.""" + + def __init__( + self, + body: PsBlock, + threads_range: GpuThreadsRange | None, + target: Target, + name: str, + parameters: Sequence[Parameter], + required_headers: set[str], + jit: JitBase, + ): + super().__init__(body, target, name, parameters, required_headers, jit) + self._threads_range = threads_range + + @property + def threads_range(self) -> GpuThreadsRange | None: + """Object exposing the total size of the launch grid this kernel expects to be executed with.""" + return self._threads_range + + +class GpuThreadsRange: + """Number of threads required by a GPU kernel, in order (x, y, z).""" + + def __init__( + self, + num_work_items: Sequence[PsExpression], + ): + self._dim = len(num_work_items) + self._num_work_items = tuple(num_work_items) + + # @property + # def grid_size(self) -> tuple[PsExpression, ...]: + # return self._grid_size + + # @property + # def block_size(self) -> tuple[PsExpression, ...]: + # return self._block_size + + @property + def num_work_items(self) -> tuple[PsExpression, ...]: + """Number of work items in (x, y, z)-order.""" + return self._num_work_items + + @property + def dim(self) -> int: + return self._dim + + def __str__(self) -> str: + rep = "GpuThreadsRange { " + rep += "; ".join(f"{x}: {w}" for x, w in zip("xyz", self._num_work_items)) + rep += " }" + return rep + + def _repr_html_(self) -> str: + return str(self) diff --git a/src/pystencils/codegen/parameters.py b/src/pystencils/codegen/parameters.py index 1e01e07aa..d40eae220 100644 --- a/src/pystencils/codegen/parameters.py +++ b/src/pystencils/codegen/parameters.py @@ -1,14 +1,8 @@ from __future__ import annotations from warnings import warn -from typing import Callable, Sequence, Iterable, Any, TYPE_CHECKING -from itertools import chain +from typing import Sequence, Iterable -from .._deprecation import _deprecated - -from ..backend.ast.structural import PsBlock -from ..backend.ast.analysis import collect_required_headers, collect_undefined_symbols -from ..backend.memory import PsSymbol from .properties import ( PsSymbolProperty, _FieldProperty, @@ -16,16 +10,13 @@ from .properties import ( FieldStride, FieldBasePtr, ) - from ..types import PsType - -from .target import Target from ..field import Field from ..sympyextensions import TypedSymbol class Parameter: - """Parameter to a `KernelFunction`.""" + """Parameter to an output object of the code generator.""" __match_args__ = ("name", "dtype", "properties") @@ -45,7 +36,7 @@ class Parameter: lambda p: isinstance(p, _FieldProperty), self._properties ) ), - key=lambda f: f.name + key=lambda f: f.name, ) ) @@ -139,4 +130,4 @@ class Parameter: "Use `param.fields[0].name` instead.", DeprecationWarning, ) - return self._fields[0].name \ No newline at end of file + return self._fields[0].name diff --git a/src/pystencils/display_utils.py b/src/pystencils/display_utils.py index 7f110c9c0..919dea4a8 100644 --- a/src/pystencils/display_utils.py +++ b/src/pystencils/display_utils.py @@ -2,9 +2,8 @@ from typing import Any, Dict, Optional import sympy as sp -from pystencils.backend import KernelFunction -from pystencils.kernel_wrapper import KernelWrapper as OldKernelWrapper -from .backend.jit import KernelWrapper +from .codegen import Kernel +from .jit import KernelWrapper def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): @@ -43,32 +42,27 @@ def highlight_cpp(code: str): return HTML(highlight(code, CppLexer(), HtmlFormatter())) -def get_code_obj(ast: KernelWrapper | KernelFunction, custom_backend=None): +def get_code_obj(ast: KernelWrapper | Kernel, custom_backend=None): """Returns an object to display generated code (C/C++ or CUDA) Can either be displayed as HTML in Jupyter notebooks or printed as normal string. """ - from pystencils.backend.emission import emit_code - - if isinstance(ast, OldKernelWrapper): - ast = ast.ast - elif isinstance(ast, KernelWrapper): - ast = ast.kernel_function + if isinstance(ast, KernelWrapper): + func = ast.kernel_function + else: + func = ast class CodeDisplay: - def __init__(self, ast_input): - self.ast = ast_input - def _repr_html_(self): - return highlight_cpp(emit_code(self.ast)).__html__() + return highlight_cpp(func.get_c_code()).__html__() def __str__(self): - return emit_code(self.ast) + return func.get_c_code() def __repr__(self): - return emit_code(self.ast) + return func.get_c_code() - return CodeDisplay(ast) + return CodeDisplay() def get_code_str(ast, custom_backend=None): @@ -88,7 +82,7 @@ def _isnotebook(): return False -def show_code(ast: KernelWrapper | KernelFunction, custom_backend=None): +def show_code(ast: KernelWrapper | Kernel, custom_backend=None): code = get_code_obj(ast, custom_backend) if _isnotebook(): diff --git a/src/pystencils/inspection.py b/src/pystencils/inspection.py index 7fa3047c6..7f050c745 100644 --- a/src/pystencils/inspection.py +++ b/src/pystencils/inspection.py @@ -2,8 +2,8 @@ from typing import overload from .backend.ast import PsAstNode from .backend.emission import CAstPrinter, IRAstPrinter, EmissionError -from .backend.kernelfunction import KernelFunction -from .kernelcreation import StageResult, CodegenIntermediates +from .codegen import Kernel +from .codegen.driver import StageResult, CodegenIntermediates from abc import ABC, abstractmethod _UNABLE_TO_DISPLAY_CPP = """ @@ -37,7 +37,7 @@ class CodeInspectionBase(ABC): self._ir_printer = IRAstPrinter(annotate_constants=False) self._c_printer = CAstPrinter() - def _ir_tab(self, ir_obj: PsAstNode | KernelFunction): + def _ir_tab(self, ir_obj: PsAstNode | Kernel): import ipywidgets as widgets ir = self._ir_printer(ir_obj) @@ -45,7 +45,7 @@ class CodeInspectionBase(ABC): self._apply_tab_layout(ir_tab) return ir_tab - def _cpp_tab(self, ir_obj: PsAstNode | KernelFunction): + def _cpp_tab(self, ir_obj: PsAstNode | Kernel): import ipywidgets as widgets try: @@ -64,7 +64,7 @@ class CodeInspectionBase(ABC): self._apply_tab_layout(cpp_tab) return cpp_tab - def _graphviz_tab(self, ir_obj: PsAstNode | KernelFunction): + def _graphviz_tab(self, ir_obj: PsAstNode | Kernel): import ipywidgets as widgets graphviz_tab = widgets.HTML(_GRAPHVIZ_NOT_IMPLEMENTED) @@ -124,7 +124,7 @@ class AstInspection(CodeInspectionBase): class KernelInspection(CodeInspectionBase): - def __init__(self, kernel: KernelFunction) -> None: + def __init__(self, kernel: Kernel) -> None: super().__init__() self._kernel = kernel @@ -190,7 +190,7 @@ def inspect(obj: PsAstNode): ... @overload -def inspect(obj: KernelFunction): ... +def inspect(obj: Kernel): ... @overload @@ -207,7 +207,7 @@ def inspect(obj): When run inside a Jupyter notebook, this function displays an inspection widget for the following types of objects: - `PsAstNode` - - `KernelFunction` + - `Kernel` - `StageResult` - `CodegenIntermediates` """ @@ -217,7 +217,7 @@ def inspect(obj): match obj: case PsAstNode(): preview = AstInspection(obj) - case KernelFunction(): + case Kernel(): preview = KernelInspection(obj) case StageResult(ast, _): preview = AstInspection(ast) diff --git a/src/pystencils/jit/__init__.py b/src/pystencils/jit/__init__.py index f45cb9bff..a47dc4aa6 100644 --- a/src/pystencils/jit/__init__.py +++ b/src/pystencils/jit/__init__.py @@ -2,7 +2,7 @@ JIT compilation is realized by subclasses of `JitBase`. A JIT compiler may freely be created and configured by the user. It can then be passed to `create_kernel` using the ``jit`` argument of -`CreateKernelConfig`, in which case it is hooked into the `KernelFunction.compile` method +`CreateKernelConfig`, in which case it is hooked into the `Kernel.compile` method of the generated kernel function:: my_jit = MyJit() diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index 444167f9d..befb033e6 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -9,22 +9,19 @@ from textwrap import indent import numpy as np -from ..exceptions import PsInternalCompilerError -from ..kernelfunction import ( - KernelFunction, - KernelParameter, +from ..codegen import ( + Kernel, + Parameter, ) -from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride -from ..constraints import KernelParamsConstraint -from ...types import ( +from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride +from ..types import ( PsType, PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType, ) -from ...types.quick import Fp, SInt, UInt -from ...field import Field -from ..emission import emit_code +from ..types.quick import Fp, SInt, UInt +from ..field import Field class PsKernelExtensioNModule: @@ -38,11 +35,11 @@ class PsKernelExtensioNModule: self._module_name = module_name if custom_backend is not None: - raise PsInternalCompilerError( + raise Exception( "The `custom_backend` parameter exists only for interface compatibility and cannot be set." ) - self._kernels: dict[str, KernelFunction] = dict() + self._kernels: dict[str, Kernel] = dict() self._code_string: str | None = None self._code_hash: str | None = None @@ -50,7 +47,7 @@ class PsKernelExtensioNModule: def module_name(self) -> str: return self._module_name - def add_function(self, kernel_function: KernelFunction, name: str | None = None): + def add_function(self, kernel_function: Kernel, name: str | None = None): if name is None: name = kernel_function.name @@ -98,7 +95,7 @@ class PsKernelExtensioNModule: old_name = kernel.name kernel.name = f"kernel_{name}" - code += emit_code(kernel) + code += kernel.get_c_code() code += "\n" code += emit_call_wrapper(name, kernel) code += "\n" @@ -122,14 +119,14 @@ class PsKernelExtensioNModule: print(self._code_string, file=file) -def emit_call_wrapper(function_name: str, kernel: KernelFunction) -> str: +def emit_call_wrapper(function_name: str, kernel: Kernel) -> str: builder = CallWrapperBuilder() for p in kernel.parameters: builder.extract_parameter(p) - for c in kernel.constraints: - builder.check_constraint(c) + # for c in kernel.constraints: + # builder.check_constraint(c) builder.call(kernel, kernel.parameters) @@ -206,8 +203,8 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._array_extractions: dict[Field, str] = dict() self._array_frees: dict[Field, str] = dict() - self._array_assoc_var_extractions: dict[KernelParameter, str] = dict() - self._scalar_extractions: dict[KernelParameter, str] = dict() + self._array_assoc_var_extractions: dict[Parameter, str] = dict() + self._scalar_extractions: dict[Parameter, str] = dict() self._constraint_checks: list[str] = [] @@ -223,7 +220,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return "PyLong_AsUnsignedLong" case _: - raise PsInternalCompilerError( + raise ValueError( f"Don't know how to cast Python objects to {dtype}" ) @@ -267,7 +264,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return self._array_buffers[field] - def extract_scalar(self, param: KernelParameter) -> str: + def extract_scalar(self, param: Parameter) -> str: if param not in self._scalar_extractions: extract_func = self._scalar_extractor(param.dtype) code = self.TMPL_EXTRACT_SCALAR.format( @@ -279,7 +276,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_array_assoc_var(self, param: KernelParameter) -> str: + def extract_array_assoc_var(self, param: Parameter) -> str: if param not in self._array_assoc_var_extractions: field = param.fields[0] buffer = self.extract_field(field) @@ -305,31 +302,31 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_parameter(self, param: KernelParameter): + def extract_parameter(self, param: Parameter): if param.is_field_parameter: self.extract_array_assoc_var(param) else: self.extract_scalar(param) - def check_constraint(self, constraint: KernelParamsConstraint): - variables = constraint.get_parameters() +# def check_constraint(self, constraint: KernelParamsConstraint): +# variables = constraint.get_parameters() - for var in variables: - self.extract_parameter(var) +# for var in variables: +# self.extract_parameter(var) - cond = constraint.to_code() +# cond = constraint.to_code() - code = f""" -if(!({cond})) -{{ - PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); - return NULL; -}} -""" +# code = f""" +# if(!({cond})) +# {{ +# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); +# return NULL; +# }} +# """ - self._constraint_checks.append(code) +# self._constraint_checks.append(code) - def call(self, kernel: KernelFunction, params: tuple[KernelParameter, ...]): + def call(self, kernel: Kernel, params: tuple[Parameter, ...]): param_list = ", ".join(p.name for p in params) self._call = f"{kernel.name} ({param_list});" diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index 2f5753e05..c208ac219 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -8,21 +8,20 @@ try: except ImportError: HAVE_CUPY = False -from ...codegen import Target -from ...field import FieldType +from ..codegen import Target +from ..field import FieldType -from ...types import PsType +from ..types import PsType from .jit import JitBase, JitError, KernelWrapper -from ..kernelfunction import ( - KernelFunction, - GpuKernelFunction, - KernelParameter, +from ..codegen import ( + Kernel, + GpuKernel, + Parameter, ) -from ...codegen.properties import FieldShape, FieldStride, FieldBasePtr -from ..emission import emit_code -from ...types import PsStructType +from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr +from ..types import PsStructType -from ...include import get_pystencils_include_path +from ..include import get_pystencils_include_path @dataclass @@ -34,18 +33,18 @@ class LaunchGrid: class CupyKernelWrapper(KernelWrapper): def __init__( self, - kfunc: GpuKernelFunction, + kfunc: GpuKernel, raw_kernel: Any, block_size: tuple[int, int, int], ): - self._kfunc: GpuKernelFunction = kfunc + self._kfunc: GpuKernel = kfunc self._raw_kernel = raw_kernel self._block_size = block_size self._num_blocks: tuple[int, int, int] | None = None self._args_cache: dict[Any, tuple] = dict() @property - def kernel_function(self) -> GpuKernelFunction: + def kernel_function(self) -> GpuKernel: return self._kfunc @property @@ -105,7 +104,7 @@ class CupyKernelWrapper(KernelWrapper): field_shapes = set() index_shapes = set() - def check_shape(field_ptr: KernelParameter, arr: cp.ndarray): + def check_shape(field_ptr: Parameter, arr: cp.ndarray): field = field_ptr.fields[0] if field.has_fixed_shape: @@ -190,7 +189,7 @@ class CupyKernelWrapper(KernelWrapper): add_arg(kparam.name, val, kparam.dtype) # Determine launch grid - from ..ast.expressions import evaluate_expression + from ..backend.ast.expressions import evaluate_expression symbolic_threads_range = self._kfunc.threads_range @@ -243,13 +242,13 @@ class CupyJit(JitBase): tuple(default_block_size) + (1,) * (3 - len(default_block_size)), ) - def compile(self, kfunc: KernelFunction) -> KernelWrapper: + def compile(self, kfunc: Kernel) -> KernelWrapper: if not HAVE_CUPY: raise JitError( "`cupy` is not installed: just-in-time-compilation of CUDA kernels is unavailable." ) - if not isinstance(kfunc, GpuKernelFunction) or kfunc.target != Target.CUDA: + if not isinstance(kfunc, GpuKernel) or kfunc.target != Target.CUDA: raise ValueError( "The CupyJit just-in-time compiler only accepts kernels generated for CUDA or HIP" ) @@ -269,7 +268,7 @@ class CupyJit(JitBase): options.append("-I" + get_pystencils_include_path()) return tuple(options) - def _prelude(self, kfunc: GpuKernelFunction) -> str: + def _prelude(self, kfunc: GpuKernel) -> str: headers = self._runtime_headers headers |= kfunc.required_headers @@ -286,6 +285,6 @@ class CupyJit(JitBase): return code - def _kernel_code(self, kfunc: GpuKernelFunction) -> str: - kernel_code = emit_code(kfunc) + def _kernel_code(self, kfunc: GpuKernel) -> str: + kernel_code = kfunc.get_c_code() return f'extern "C" {kernel_code}' diff --git a/src/pystencils/jit/jit.py b/src/pystencils/jit/jit.py index 250bba240..4998c14ad 100644 --- a/src/pystencils/jit/jit.py +++ b/src/pystencils/jit/jit.py @@ -3,8 +3,7 @@ from typing import Sequence, TYPE_CHECKING from abc import ABC, abstractmethod if TYPE_CHECKING: - from ..kernelfunction import KernelFunction, KernelParameter - from ...codegen.target import Target + from ..codegen import Kernel, Parameter, Target class JitError(Exception): @@ -14,7 +13,7 @@ class JitError(Exception): class KernelWrapper(ABC): """Wrapper around a compiled and executable pystencils kernel.""" - def __init__(self, kfunc: KernelFunction) -> None: + def __init__(self, kfunc: Kernel) -> None: self._kfunc = kfunc @abstractmethod @@ -22,11 +21,11 @@ class KernelWrapper(ABC): pass @property - def kernel_function(self) -> KernelFunction: + def kernel_function(self) -> Kernel: return self._kfunc @property - def ast(self) -> KernelFunction: + def ast(self) -> Kernel: return self._kfunc @property @@ -34,7 +33,7 @@ class KernelWrapper(ABC): return self._kfunc.target @property - def parameters(self) -> Sequence[KernelParameter]: + def parameters(self) -> Sequence[Parameter]: return self._kfunc.parameters @property @@ -48,14 +47,14 @@ class JitBase(ABC): """Base class for just-in-time compilation interfaces implemented in pystencils.""" @abstractmethod - def compile(self, kernel: KernelFunction) -> KernelWrapper: + def compile(self, kernel: Kernel) -> KernelWrapper: """Compile a kernel function and return a callable object which invokes the kernel.""" class NoJit(JitBase): """Not a JIT compiler: Used to explicitly disable JIT compilation on an AST.""" - def compile(self, kernel: KernelFunction) -> KernelWrapper: + def compile(self, kernel: Kernel) -> KernelWrapper: raise JitError( "Just-in-time compilation of this kernel was explicitly disabled." ) diff --git a/src/pystencils/jit/legacy_cpu.py b/src/pystencils/jit/legacy_cpu.py index 1acd1b22a..514e9b60e 100644 --- a/src/pystencils/jit/legacy_cpu.py +++ b/src/pystencils/jit/legacy_cpu.py @@ -61,7 +61,7 @@ import time import warnings -from ..kernelfunction import KernelFunction +from ..codegen import Kernel from .jit import JitBase, KernelWrapper from .cpu_extension_module import PsKernelExtensioNModule @@ -71,7 +71,7 @@ from pystencils.utils import atomic_file_write, recursive_dict_update class CpuKernelWrapper(KernelWrapper): - def __init__(self, kfunc: KernelFunction, compiled_kernel: Callable[..., None]) -> None: + def __init__(self, kfunc: Kernel, compiled_kernel: Callable[..., None]) -> None: super().__init__(kfunc) self._compiled_kernel = compiled_kernel @@ -86,7 +86,7 @@ class CpuKernelWrapper(KernelWrapper): class LegacyCpuJit(JitBase): """Wrapper around ``pystencils.cpu.cpujit``""" - def compile(self, kernel: KernelFunction) -> KernelWrapper: + def compile(self, kernel: Kernel) -> KernelWrapper: return compile_and_load(kernel) @@ -436,7 +436,7 @@ def compile_module(code, code_hash, base_dir, compile_flags=None): return lib_file -def compile_and_load(kernel: KernelFunction, custom_backend=None): +def compile_and_load(kernel: Kernel, custom_backend=None): cache_config = get_cache_config() compiler_config = get_compiler_config() diff --git a/src/pystencils/kernel_wrapper.py b/src/pystencils/kernel_wrapper.py index afce06d77..5095332c1 100644 --- a/src/pystencils/kernel_wrapper.py +++ b/src/pystencils/kernel_wrapper.py @@ -1,3 +1,3 @@ -from .backend.jit import KernelWrapper as _KernelWrapper +from .jit import KernelWrapper as _KernelWrapper KernelWrapper = _KernelWrapper diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 9bf3eaf67..97965f709 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -4,9 +4,10 @@ from .codegen import create_kernel as _create_kernel from warnings import warn warn( - "Importing anything from `pystencils.kernelcreation` is deprecated and the module will be removed in pystencils 2.1. " + "Importing anything from `pystencils.kernelcreation` is deprecated " + "and the module will be removed in pystencils 2.1. " "Import from `pystencils` instead.", - FutureWarning + FutureWarning, ) @@ -19,4 +20,3 @@ def create_staggered_kernel( raise NotImplementedError( "Staggered kernels are not yet implemented for pystencils 2.0" ) - diff --git a/tests/kernelcreation/test_domain_kernels.py b/tests/kernelcreation/test_domain_kernels.py index d02bfd8e4..da261faec 100644 --- a/tests/kernelcreation/test_domain_kernels.py +++ b/tests/kernelcreation/test_domain_kernels.py @@ -10,16 +10,14 @@ from pystencils import ( AssignmentCollection, Target, CreateKernelConfig, - CpuOptimConfig, - VectorizationConfig, ) from pystencils.assignment import assignment_from_stencil -from pystencils.kernelcreation import create_kernel, KernelFunction +from pystencils import create_kernel, Kernel from pystencils.backend.emission import emit_code -def inspect_dp_kernel(kernel: KernelFunction, gen_config: CreateKernelConfig): +def inspect_dp_kernel(kernel: Kernel, gen_config: CreateKernelConfig): code = emit_code(kernel) match gen_config.target: diff --git a/tests/kernelcreation/test_index_kernels.py b/tests/kernelcreation/test_index_kernels.py index 5093c43ff..569c0ab6a 100644 --- a/tests/kernelcreation/test_index_kernels.py +++ b/tests/kernelcreation/test_index_kernels.py @@ -2,7 +2,7 @@ import numpy as np import pytest from pystencils import Assignment, Field, FieldType, AssignmentCollection, Target -from pystencils.kernelcreation import create_kernel, CreateKernelConfig +from pystencils import create_kernel, CreateKernelConfig @pytest.mark.parametrize("target", [Target.CPU, Target.GPU]) diff --git a/tests/kernelcreation/test_iteration_slices.py b/tests/kernelcreation/test_iteration_slices.py index 94ed02954..47f3b5fac 100644 --- a/tests/kernelcreation/test_iteration_slices.py +++ b/tests/kernelcreation/test_iteration_slices.py @@ -19,7 +19,7 @@ from pystencils import ( from pystencils.sympyextensions.integer_functions import int_rem from pystencils.simp import sympy_cse_on_assignment_list from pystencils.slicing import normalize_slice -from pystencils.backend.jit.gpu_cupy import CupyKernelWrapper +from pystencils.jit.gpu_cupy import CupyKernelWrapper def test_sliced_iteration(): diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index ef4806314..109cfdc19 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -1,11 +1,6 @@ -from pystencils import Target - from pystencils.backend.ast.expressions import PsExpression -from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock -from pystencils.backend.kernelfunction import KernelFunction -from pystencils.backend.memory import PsSymbol, PsBuffer +from pystencils.backend.memory import PsSymbol from pystencils.backend.constants import PsConstant -from pystencils.backend.literals import PsLiteral from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter @@ -129,7 +124,7 @@ def test_relations_precedence(): def test_ternary(): from pystencils.backend.ast.expressions import PsTernary - from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr + from pystencils.backend.ast.expressions import PsAnd, PsOr p, q = [PsExpression.make(PsSymbol(x, Bool())) for x in "pq"] x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"] diff --git a/tests/nbackend/test_cpujit.py b/tests/nbackend/test_cpujit.py index 648112ef9..c053df9a9 100644 --- a/tests/nbackend/test_cpujit.py +++ b/tests/nbackend/test_cpujit.py @@ -1,6 +1,6 @@ import pytest -from pystencils import Target +from pystencils import Target, Kernel # from pystencils.backend.constraints import PsKernelParamsConstraint from pystencils.backend.memory import PsSymbol, PsBuffer @@ -8,10 +8,9 @@ from pystencils.backend.constants import PsConstant from pystencils.backend.ast.expressions import PsBufferAcc, PsExpression from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop -from pystencils.backend.kernelfunction import KernelFunction from pystencils.types.quick import SInt, Fp -from pystencils.backend.jit import LegacyCpuJit +from pystencils.jit import LegacyCpuJit import numpy as np @@ -45,7 +44,7 @@ def test_pairwise_addition(): PsBlock([update]) ) - func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set()) + func = Kernel(PsBlock([loop]), Target.CPU, "kernel", set()) # sizes_constraint = PsKernelParamsConstraint( # u.shape[0].eq(2 * v.shape[0]), diff --git a/tests/nbackend/test_vectorization.py b/tests/nbackend/test_vectorization.py index 55330c9ee..a4825669c 100644 --- a/tests/nbackend/test_vectorization.py +++ b/tests/nbackend/test_vectorization.py @@ -19,8 +19,8 @@ from pystencils.backend.transformations import ( LowerToC, ) from pystencils.backend.constants import PsConstant -from pystencils.backend.kernelfunction import create_cpu_kernel_function -from pystencils.backend.jit import LegacyCpuJit +from pystencils.codegen.driver import create_cpu_kernel_function +from pystencils.jit import LegacyCpuJit from pystencils import Target, fields, Assignment, Field from pystencils.field import create_numpy_array_with_layout -- GitLab