diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 1d716fa9676db4a2dda8f5241990461cab48d835..31d8ea192269a9a9947457814ff5e58d63f61c14 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -1,5 +1,7 @@ from __future__ import annotations -from typing import Sequence, cast + +from abc import ABC, abstractmethod +from typing import Iterable, Sequence, cast from types import NoneType from .astnode import PsAstNode, PsLeafMixIn @@ -9,10 +11,35 @@ from ..memory import PsSymbol from .util import failing_cast -class PsBlock(PsAstNode): +class PsStructuralNode(PsAstNode, ABC): + """Base class for structural nodes in the pystencils AST. + + This class acts as a trait that structural AST nodes like blocks, conditionals, etc. can inherit from. + """ + + def clone(self): + """Clone this structure node. + + .. note:: + Subclasses of `PsStructuralNode` should not override this method, + but implement `_clone_structural` instead. + That implementation shall call `clone` on any of its children. + """ + return self._clone_structural() + + @abstractmethod + def _clone_structural(self) -> PsStructuralNode: + """Implementation of structural node cloning. + + :meta public: + """ + pass + + +class PsBlock(PsStructuralNode): __match_args__ = ("statements",) - def __init__(self, cs: Sequence[PsAstNode]): + def __init__(self, cs: Iterable[PsStructuralNode]): self._statements = list(cs) @property @@ -21,23 +48,23 @@ class PsBlock(PsAstNode): @children.setter def children(self, cs: Sequence[PsAstNode]): - self._statements = list(cs) + self._statements = list([failing_cast(PsStructuralNode, c) for c in cs]) def get_children(self) -> tuple[PsAstNode, ...]: return tuple(self._statements) def set_child(self, idx: int, c: PsAstNode): - self._statements[idx] = c + self._statements[idx] = failing_cast(PsStructuralNode, c) - def clone(self) -> PsBlock: - return PsBlock([stmt.clone() for stmt in self._statements]) + def _clone_structural(self) -> PsBlock: + return PsBlock([stmt._clone_structural() for stmt in self._statements]) @property - def statements(self) -> list[PsAstNode]: + def statements(self) -> list[PsStructuralNode]: return self._statements @statements.setter - def statements(self, stm: Sequence[PsAstNode]): + def statements(self, stm: Sequence[PsStructuralNode]): self._statements = list(stm) def __repr__(self) -> str: @@ -45,7 +72,7 @@ class PsBlock(PsAstNode): return f"PsBlock( {contents} )" -class PsStatement(PsAstNode): +class PsStatement(PsStructuralNode): __match_args__ = ("expression",) def __init__(self, expr: PsExpression): @@ -59,7 +86,7 @@ class PsStatement(PsAstNode): def expression(self, expr: PsExpression): self._expression = expr - def clone(self) -> PsStatement: + def _clone_structural(self) -> PsStatement: return PsStatement(self._expression.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -71,7 +98,7 @@ class PsStatement(PsAstNode): self._expression = failing_cast(PsExpression, c) -class PsAssignment(PsAstNode): +class PsAssignment(PsStructuralNode): __match_args__ = ( "lhs", "rhs", @@ -101,7 +128,7 @@ class PsAssignment(PsAstNode): def rhs(self, expr: PsExpression): self._rhs = expr - def clone(self) -> PsAssignment: + def _clone_structural(self) -> PsAssignment: return PsAssignment(self._lhs.clone(), self._rhs.clone()) def get_children(self) -> tuple[PsAstNode, ...]: @@ -141,7 +168,7 @@ class PsDeclaration(PsAssignment): def declared_symbol(self) -> PsSymbol: return cast(PsSymbolExpr, self._lhs).symbol - def clone(self) -> PsDeclaration: + def _clone_structural(self) -> PsDeclaration: return PsDeclaration(cast(PsSymbolExpr, self._lhs.clone()), self.rhs.clone()) def set_child(self, idx: int, c: PsAstNode): @@ -157,7 +184,7 @@ class PsDeclaration(PsAssignment): return f"PsDeclaration({repr(self._lhs)}, {repr(self._rhs)})" -class PsLoop(PsAstNode): +class PsLoop(PsStructuralNode): __match_args__ = ("counter", "start", "stop", "step", "body") def __init__( @@ -214,13 +241,13 @@ class PsLoop(PsAstNode): def body(self, block: PsBlock): self._body = block - def clone(self) -> PsLoop: + def _clone_structural(self) -> PsLoop: return PsLoop( self._ctr.clone(), self._start.clone(), self._stop.clone(), self._step.clone(), - self._body.clone(), + self._body._clone_structural(), ) def get_children(self) -> tuple[PsAstNode, ...]: @@ -243,7 +270,7 @@ class PsLoop(PsAstNode): assert False, "unreachable code" -class PsConditional(PsAstNode): +class PsConditional(PsStructuralNode): """Conditional branch""" __match_args__ = ("condition", "branch_true", "branch_false") @@ -282,11 +309,11 @@ class PsConditional(PsAstNode): def branch_false(self, block: PsBlock | None): self._branch_false = block - def clone(self) -> PsConditional: + def _clone_structural(self) -> PsConditional: return PsConditional( self._condition.clone(), - self._branch_true.clone(), - self._branch_false.clone() if self._branch_false is not None else None, + self._branch_true._clone_structural(), + self._branch_false._clone_structural() if self._branch_false is not None else None, ) def get_children(self) -> tuple[PsAstNode, ...]: @@ -317,7 +344,7 @@ class PsEmptyLeafMixIn: pass -class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): +class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): """A C/C++ preprocessor pragma. Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. @@ -335,7 +362,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): def text(self) -> str: return self._text - def clone(self) -> PsPragma: + def _clone_structural(self) -> PsPragma: return PsPragma(self.text) def structurally_equal(self, other: PsAstNode) -> bool: @@ -345,7 +372,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): return self._text == other._text -class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): +class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsStructuralNode): __match_args__ = ("lines",) def __init__(self, text: str) -> None: @@ -360,7 +387,7 @@ class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): def lines(self) -> tuple[str, ...]: return self._lines - def clone(self) -> PsComment: + def _clone_structural(self) -> PsComment: return PsComment(self._text) def structurally_equal(self, other: PsAstNode) -> bool: diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ce65cd85ded0c04fd7fb7858b88610989eb6a9d0..df6bfbd1f160ceec819d8d3f1923c43c145e88b3 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -28,6 +28,7 @@ from ..ast.structural import ( PsDeclaration, PsExpression, PsSymbolExpr, + PsStructuralNode, ) from ..ast.expressions import ( PsBufferAcc, @@ -109,7 +110,7 @@ class FreezeExpressions: def __call__(self, obj: AssignmentCollection | sp.Basic) -> PsAstNode: if isinstance(obj, AssignmentCollection): - return PsBlock([self.visit(asm) for asm in obj.all_assignments]) + return PsBlock([cast(PsStructuralNode, self.visit(asm)) for asm in obj.all_assignments]) elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) elif isinstance(obj, _ExprLike): diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index a9ec9d8d6fd6292eaccc58ce30bc24a64ac547ae..7aac0d412dbf5068c5e36a5740825dbd6e1eb6e5 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -197,7 +197,7 @@ class CudaPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return {'"gpu_defines.h"'} + return {'"pystencils_runtime/hip.h"'} # TODO: move to HipPlatform once it is introduced def materialize_iteration_space( self, body: PsBlock, ispace: IterationSpace diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index eae2b7598bfa43cf5379fe8782233be11d0dfef2..b613f3756d8708bcb844ee91c29346f388497010 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -25,12 +25,13 @@ from ..extensions.cpp import CppMethodCall from ..kernelcreation import KernelCreationContext, AstFactory from ..constants import PsConstant -from .generic_gpu import GenericGpu from ..exceptions import MaterializationError from ...types import PsCustomType, PsIeeeFloatType, constify, PsIntegerType +from .platform import Platform -class SyclPlatform(GenericGpu): + +class SyclPlatform(Platform): def __init__( self, diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 10e6b39d844335eef3c0c02f2e5a9c5a7be4b94f..c9e8b3994cba3b1dd3c52e85acf405c00b4817f4 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -6,7 +6,7 @@ from collections import defaultdict from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsPragma +from ..ast.structural import PsBlock, PsLoop, PsPragma, PsStructuralNode from ..ast.expressions import PsExpression from ...types import PsScalarType @@ -56,13 +56,12 @@ class InsertPragmasAtLoops: self._insertions[ins.loop_nesting_depth].append(ins) def __call__(self, node: PsAstNode) -> PsAstNode: - is_loop = isinstance(node, PsLoop) - if is_loop: + if isinstance(node, PsLoop): node = PsBlock([node]) self.visit(node, Nesting(0)) - if is_loop and len(node.children) == 1: + if isinstance(node, PsLoop) and len(node.children) == 1: node = node.children[0] return node @@ -73,7 +72,7 @@ class InsertPragmasAtLoops: return case PsBlock(children): - new_children: list[PsAstNode] = [] + new_children: list[PsStructuralNode] = [] for c in children: if isinstance(c, PsLoop): nest.has_inner_loops = True @@ -92,8 +91,8 @@ class InsertPragmasAtLoops: node.children = new_children case other: - for c in other.children: - self.visit(c, nest) + for child in other.children: + self.visit(child, nest) class AddOpenMP: diff --git a/src/pystencils/backend/transformations/ast_vectorizer.py b/src/pystencils/backend/transformations/ast_vectorizer.py index ab4401f9ca0142d9cfeec258eeb34fb2a7f6e8eb..c793c424d2417cbbdcc0cf3782e696c7c9226bb6 100644 --- a/src/pystencils/backend/transformations/ast_vectorizer.py +++ b/src/pystencils/backend/transformations/ast_vectorizer.py @@ -18,6 +18,7 @@ from ..ast.structural import ( PsAssignment, PsLoop, PsEmptyLeafMixIn, + PsStructuralNode, ) from ..ast.expressions import ( PsExpression, @@ -268,6 +269,18 @@ class AstVectorizer: """ return self.visit(node, vc) + @overload + def visit(self, node: PsStructuralNode, vc: VectorizationContext) -> PsStructuralNode: + pass + + @overload + def visit(self, node: PsExpression, vc: VectorizationContext) -> PsExpression: + pass + + @overload + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: + pass + def visit(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: """Vectorize a subtree.""" diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index f098d82df1ce6a748097756aa1616a72e57487b5..69dd1dd11d726e597c15ece772846ba8cd84acba 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -1,7 +1,9 @@ +from typing import cast + from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode from ..ast.analysis import collect_undefined_symbols -from ..ast.structural import PsLoop, PsBlock, PsConditional +from ..ast.structural import PsLoop, PsBlock, PsConditional, PsStructuralNode from ..ast.expressions import ( PsAnd, PsCast, @@ -71,9 +73,9 @@ class EliminateBranches: ec.enclosing_loops.pop() case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: - statements_new.append(self.visit(stmt, ec)) + statements_new.append(cast(PsStructuralNode, self.visit(stmt, ec))) node.statements = statements_new case PsConditional(): diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index ab1cabc557a88b03766e6a9fb2ab44a84a5711da..3a07cb56fcb8f1c60107b5b1883c679191429e7e 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -6,7 +6,7 @@ import numpy as np from ..kernelcreation import KernelCreationContext, Typifier from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsDeclaration +from ..ast.structural import PsBlock, PsDeclaration, PsStructuralNode from ..ast.expressions import ( PsExpression, PsConstantExpr, @@ -36,6 +36,7 @@ from ..ast.expressions import ( ) from ..ast.vector import PsVecBroadcast from ..ast.util import AstEqWrapper +from ..exceptions import PsInternalCompilerError from ..constants import PsConstant from ..memory import PsSymbol @@ -138,6 +139,11 @@ class EliminateConstants: node = self.visit(node, ecc) if ecc.extractions: + if not isinstance(node, PsStructuralNode): + raise PsInternalCompilerError( + f"Cannot extract constant expressions from outermost node {node}" + ) + prepend_decls = [ PsDeclaration(PsExpression.make(symb), expr) for symb, expr in ecc.extractions diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index f0e4cc9f19f1a046125bb3e8aab5302a9df2790c..f7fe81ad736981bee6f38427fbd4face73f0c455 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -2,7 +2,7 @@ from typing import cast from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode -from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment +from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment, PsStructuralNode from ..ast.expressions import ( PsExpression, PsSymbolExpr, @@ -99,7 +99,7 @@ class HoistLoopInvariantDeclarations: return temp_block case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -153,7 +153,7 @@ class HoistLoopInvariantDeclarations: return case PsBlock(statements): - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for stmt in statements: if isinstance(stmt, PsLoop): loop = stmt @@ -178,7 +178,7 @@ class HoistLoopInvariantDeclarations: This method processes only statements of the given block, and any blocks directly nested inside it. It does not descend into control structures like conditionals and nested loops. """ - statements_new: list[PsAstNode] = [] + statements_new: list[PsStructuralNode] = [] for node in block.statements: if isinstance(node, PsDeclaration): diff --git a/src/pystencils/backend/transformations/rewrite.py b/src/pystencils/backend/transformations/rewrite.py index 59241c295f42eeaf60f4cd03a5138214fdbd6c50..8dff9e45ec283fc6c3712c2e77ff56a9b2aaeae5 100644 --- a/src/pystencils/backend/transformations/rewrite.py +++ b/src/pystencils/backend/transformations/rewrite.py @@ -2,7 +2,7 @@ from typing import overload from ..memory import PsSymbol from ..ast import PsAstNode -from ..ast.structural import PsBlock +from ..ast.structural import PsStructuralNode, PsBlock from ..ast.expressions import PsExpression, PsSymbolExpr @@ -18,6 +18,13 @@ def substitute_symbols( pass +@overload +def substitute_symbols( + node: PsStructuralNode, subs: dict[PsSymbol, PsExpression] +) -> PsStructuralNode: + pass + + @overload def substitute_symbols( node: PsAstNode, subs: dict[PsSymbol, PsExpression] diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index cc3411249d0fd95a68ca51f8761c3096bc09f2d2..1e88b29721b2e12740f48ab8db2b0ca81f4c634a 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import cast, Sequence, Iterable, Callable, TYPE_CHECKING +from typing import cast, Sequence, Callable, TYPE_CHECKING from dataclasses import dataclass, replace from .target import Target @@ -15,8 +15,7 @@ from .config import ( from .kernel import Kernel, GpuKernel from .properties import PsSymbolProperty, FieldBasePtr from .parameters import Parameter -from ..backend.functions import PsReductionFunction, ReductionFunctions -from ..backend.ast.expressions import PsSymbolExpr, PsCall, PsMemAcc, PsConstantExpr +from .functions import Lambda from .gpu_indexing import GpuIndexing, GpuLaunchConfiguration from ..field import Field @@ -24,6 +23,8 @@ from ..types import PsIntegerType, PsScalarType from ..backend.memory import PsSymbol from ..backend.ast import PsAstNode +from ..backend.functions import PsReductionFunction, ReductionFunctions +from ..backend.ast.expressions import PsExpression, PsSymbolExpr, PsCall, PsMemAcc, PsConstantExpr from ..backend.ast.structural import PsBlock, PsLoop, PsDeclaration, PsAssignment from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers from ..backend.kernelcreation import ( @@ -220,20 +221,20 @@ class DefaultKernelCreationDriver: canonicalize = CanonicalizeSymbols(self._ctx, True) kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) - if self._target.is_cpu(): - return create_cpu_kernel_function( - self._ctx, + kernel_factory = KernelFactory(self._ctx) + + if self._target.is_cpu() or self._target == Target.SYCL: + return kernel_factory.create_generic_kernel( self._platform, kernel_ast, self._cfg.get_option("function_name"), self._target, self._cfg.get_jit(), ) - else: + elif self._target.is_gpu(): assert self._gpu_indexing is not None - return create_gpu_kernel_function( - self._ctx, + return kernel_factory.create_gpu_kernel( self._platform, kernel_ast, self._cfg.get_option("function_name"), @@ -241,6 +242,8 @@ class DefaultKernelCreationDriver: self._cfg.get_jit(), self._gpu_indexing.get_launch_config_factory(), ) + else: + assert False, "unexpected target" def parse_kernel_body( self, @@ -451,23 +454,11 @@ class DefaultKernelCreationDriver: f"No platform is currently available for CPU target {self._target}" ) - elif Target._GPU in self._target: + elif self._target.is_gpu(): gpu_opts = self._cfg.gpu omit_range_check: bool = gpu_opts.get_option("omit_range_check") match self._target: - case Target.SYCL: - from ..backend.platforms import SyclPlatform - - auto_block_size: bool = self._cfg.sycl.get_option( - "automatic_block_size" - ) - - return SyclPlatform( - self._ctx, - omit_range_check=omit_range_check, - automatic_block_size=auto_block_size, - ) case Target.CUDA: from ..backend.platforms import CudaPlatform @@ -482,89 +473,102 @@ class DefaultKernelCreationDriver: omit_range_check=omit_range_check, thread_mapping=thread_mapping, ) + elif self._target == Target.SYCL: + from ..backend.platforms import SyclPlatform + + auto_block_size: bool = self._cfg.sycl.get_option("automatic_block_size") + omit_range_check = self._cfg.gpu.get_option("omit_range_check") + + return SyclPlatform( + self._ctx, + omit_range_check=omit_range_check, + automatic_block_size=auto_block_size, + ) raise NotImplementedError( f"Code generation for target {self._target} not implemented" ) -def create_cpu_kernel_function( - ctx: KernelCreationContext, - platform: Platform, - body: PsBlock, - function_name: str, - target_spec: Target, - jit: JitBase, -) -> Kernel: - undef_symbols = collect_undefined_symbols(body) - - params = _get_function_params(ctx, undef_symbols) - req_headers = _get_headers(ctx, platform, body) - - kfunc = Kernel(body, target_spec, function_name, params, req_headers, jit) - kfunc.metadata.update(ctx.metadata) - return kfunc - - -def create_gpu_kernel_function( - ctx: KernelCreationContext, - platform: Platform, - body: PsBlock, - function_name: str, - target_spec: Target, - jit: JitBase, - launch_config_factory: Callable[[], GpuLaunchConfiguration], -) -> GpuKernel: - undef_symbols = collect_undefined_symbols(body) - - params = _get_function_params(ctx, undef_symbols) - req_headers = _get_headers(ctx, platform, body) - - kfunc = GpuKernel( - body, - target_spec, - function_name, - params, - req_headers, - jit, - launch_config_factory, - ) - kfunc.metadata.update(ctx.metadata) - return kfunc - - -def _symbol_to_param(ctx: KernelCreationContext, symbol: PsSymbol): - from pystencils.backend.memory import BufferBasePtr, BackendPrivateProperty - - props: set[PsSymbolProperty] = set() - for prop in symbol.properties: - match prop: - case BufferBasePtr(buf): - field = ctx.find_field(buf.name) - props.add(FieldBasePtr(field)) - case BackendPrivateProperty(): - pass - case _: - props.add(prop) - - return Parameter(symbol.name, symbol.get_dtype(), props) - - -def _get_function_params( - ctx: KernelCreationContext, symbols: Iterable[PsSymbol] -) -> list[Parameter]: - params: list[Parameter] = [_symbol_to_param(ctx, s) for s in symbols] - params.sort(key=lambda p: p.name) - return params - - -def _get_headers( - ctx: KernelCreationContext, platform: Platform, body: PsBlock -) -> set[str]: - req_headers = collect_required_headers(body) - req_headers |= platform.required_headers - req_headers |= ctx.required_headers - return req_headers +class KernelFactory: + """Factory for wrapping up backend and IR objects into exportable kernels and function objects.""" + + def __init__(self, ctx: KernelCreationContext): + self._ctx = ctx + + def create_lambda(self, expr: PsExpression) -> Lambda: + """Create a Lambda from an expression.""" + params = self._get_function_params(expr) + return Lambda(expr, params) + + def create_generic_kernel( + self, + platform: Platform, + body: PsBlock, + function_name: str, + target_spec: Target, + jit: JitBase, + ) -> Kernel: + """Create a kernel for a generic target""" + params = self._get_function_params(body) + req_headers = self._get_headers(platform, body) + + kfunc = Kernel(body, target_spec, function_name, params, req_headers, jit) + kfunc.metadata.update(self._ctx.metadata) + return kfunc + + def create_gpu_kernel( + self, + platform: Platform, + body: PsBlock, + function_name: str, + target_spec: Target, + jit: JitBase, + launch_config_factory: Callable[[], GpuLaunchConfiguration], + ) -> GpuKernel: + """Create a kernel for a GPU target""" + params = self._get_function_params(body) + req_headers = self._get_headers(platform, body) + + kfunc = GpuKernel( + body, + target_spec, + function_name, + params, + req_headers, + jit, + launch_config_factory, + ) + kfunc.metadata.update(self._ctx.metadata) + return kfunc + + def _symbol_to_param(self, symbol: PsSymbol): + from pystencils.backend.memory import BufferBasePtr, BackendPrivateProperty + + props: set[PsSymbolProperty] = set() + for prop in symbol.properties: + match prop: + case BufferBasePtr(buf): + field = self._ctx.find_field(buf.name) + props.add(FieldBasePtr(field)) + case BackendPrivateProperty(): + pass + case _: + props.add(prop) + + return Parameter(symbol.name, symbol.get_dtype(), props) + + def _get_function_params(self, ast: PsAstNode) -> list[Parameter]: + symbols = collect_undefined_symbols(ast) + params: list[Parameter] = [self._symbol_to_param(s) for s in symbols] + params.sort(key=lambda p: p.name) + return params + + def _get_headers(self, platform: Platform, body: PsBlock) -> set[str]: + req_headers = collect_required_headers(body) + req_headers |= platform.required_headers + req_headers |= self._ctx.required_headers + return req_headers @dataclass diff --git a/src/pystencils/codegen/functions.py b/src/pystencils/codegen/functions.py index f6be3b1f3446c6b9a25a0013f0e06d099edf5bed..c24dbaffb9947d68c854f83532e87386914c6677 100644 --- a/src/pystencils/codegen/functions.py +++ b/src/pystencils/codegen/functions.py @@ -4,21 +4,12 @@ from typing import Sequence, Any from .parameters import Parameter from ..types import PsType -from ..backend.kernelcreation import KernelCreationContext from ..backend.ast.expressions import PsExpression class Lambda: """A one-line function emitted by the code generator as an auxiliary object.""" - @staticmethod - def from_expression(ctx: KernelCreationContext, expr: PsExpression): - from ..backend.ast.analysis import collect_undefined_symbols - from .driver import _get_function_params - - params = _get_function_params(ctx, collect_undefined_symbols(expr)) - return Lambda(expr, params) - def __init__(self, expr: PsExpression, params: Sequence[Parameter]): self._expr = expr self._params = tuple(params) diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py index 2d22ec624856d9cf8a0b825845fee04caaa4ee74..27d6fc817d5a9193c3faa4b170d907987fe6022e 100644 --- a/src/pystencils/codegen/gpu_indexing.py +++ b/src/pystencils/codegen/gpu_indexing.py @@ -228,8 +228,10 @@ class GpuIndexing: self._manual_launch_grid = manual_launch_grid from ..backend.kernelcreation import AstFactory + from .driver import KernelFactory - self._factory = AstFactory(self._ctx) + self._ast_factory = AstFactory(self._ctx) + self._kernel_factory = KernelFactory(self._ctx) def get_thread_mapping(self) -> ThreadMapping: """Retrieve a thread mapping object for use by the backend""" @@ -265,9 +267,14 @@ class GpuIndexing: f" for a {rank}-dimensional kernel." ) + work_items_expr += tuple( + self._ast_factory.parse_index(1) + for _ in range(3 - rank) + ) + num_work_items = cast( _Dim3Lambda, - tuple(Lambda.from_expression(self._ctx, wit) for wit in work_items_expr), + tuple(self._kernel_factory.create_lambda(wit) for wit in work_items_expr), ) def factory(): @@ -305,15 +312,15 @@ class GpuIndexing: raise ValueError(f"Iteration space rank is too large: {rank}") block_size = ( - Lambda.from_expression(self._ctx, work_items[0]), - Lambda.from_expression(self._ctx, self._factory.parse_index(1)), - Lambda.from_expression(self._ctx, self._factory.parse_index(1)), + self._kernel_factory.create_lambda(work_items[0]), + self._kernel_factory.create_lambda(self._ast_factory.parse_index(1)), + self._kernel_factory.create_lambda(self._ast_factory.parse_index(1)), ) grid_size = tuple( - Lambda.from_expression(self._ctx, wit) for wit in work_items[1:] + self._kernel_factory.create_lambda(wit) for wit in work_items[1:] ) + tuple( - Lambda.from_expression(self._ctx, self._factory.parse_index(1)) + self._kernel_factory.create_lambda(self._ast_factory.parse_index(1)) for _ in range(4 - rank) ) @@ -350,7 +357,7 @@ class GpuIndexing: return tuple(ispace.actual_iterations(dim) for dim in dimensions) case SparseIterationSpace(): - return (self._factory.parse_index(ispace.index_list.shape[0]),) + return (self._ast_factory.parse_index(ispace.index_list.shape[0]),) case _: assert False, "unexpected iteration space" diff --git a/src/pystencils/codegen/target.py b/src/pystencils/codegen/target.py index b847a8139a8725c9c926b7c12c9556aba3ec6e87..0d724b87730f0ec327772bccbb55a8bfff7c8ddd 100644 --- a/src/pystencils/codegen/target.py +++ b/src/pystencils/codegen/target.py @@ -89,10 +89,13 @@ class Target(Flag): GPU = CUDA """Alias for `Target.CUDA`, for backward compatibility.""" - SYCL = _GPU | _SYCL + SYCL = _SYCL """SYCL kernel target. Generate a function to be called within a SYCL parallel command. + + .. note:: + The SYCL target is experimental and not thoroughly tested yet. """ def is_automatic(self) -> bool: diff --git a/src/pystencils/include/PyStencilsField.h b/src/pystencils/include/PyStencilsField.h deleted file mode 100644 index 3055cae2365279e28fdcaab4353779b97ca27d35..0000000000000000000000000000000000000000 --- a/src/pystencils/include/PyStencilsField.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -extern "C++" { -#ifdef __CUDA_ARCH__ -template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { - DTYPE_T *data; - DTYPE_T shape[DIMENSION]; - DTYPE_T stride[DIMENSION]; -}; -#else -#include <array> - -template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { - DTYPE_T *data; - std::array<DTYPE_T, DIMENSION> shape; - std::array<DTYPE_T, DIMENSION> stride; -}; -#endif -} diff --git a/src/pystencils/include/half_precision.h b/src/pystencils/include/pystencils_runtime/half.h similarity index 100% rename from src/pystencils/include/half_precision.h rename to src/pystencils/include/pystencils_runtime/half.h diff --git a/src/pystencils/include/gpu_defines.h b/src/pystencils/include/pystencils_runtime/hip.h similarity index 95% rename from src/pystencils/include/gpu_defines.h rename to src/pystencils/include/pystencils_runtime/hip.h index 34cff79dea2f14399622a0026e362a4832bd739c..b0b4d967911688bc302537110e54ea0901e661b8 100644 --- a/src/pystencils/include/gpu_defines.h +++ b/src/pystencils/include/pystencils_runtime/hip.h @@ -1,11 +1,5 @@ #pragma once -#define POS_INFINITY __int_as_float(0x7f800000) -#define NEG_INFINITY __int_as_float(0xff800000) -#ifndef INFINITY -#define INFINITY POS_INFINITY -#endif - #ifdef __HIPCC_RTC__ typedef __hip_uint8_t uint8_t; typedef __hip_int8_t int8_t; diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index 4ea991e28cbadc68c81129bcff1a7c02d689bf07..4c3c8945e34b55cd557c180397d72086f766bf7e 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -252,8 +252,8 @@ class CupyJit(JitBase): headers = self._runtime_headers headers |= kfunc.required_headers - if '"half_precision.h"' in headers: - headers.remove('"half_precision.h"') + if '"pystencils_runtime/half.h"' in headers: + headers.remove('"pystencils_runtime/half.h"') if cp.cuda.runtime.is_hip: headers.add("<hip/hip_fp16.h>") else: diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 825ac1d5d35fde0f26c5a9ebadb55ec43004c9ae..8dea97ca43539b966e260ac7a206b0e26b3b2110 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -661,7 +661,7 @@ class PsIeeeFloatType(PsScalarType): @property def required_headers(self) -> set[str]: if self._width == 16: - return {'"half_precision.h"'} + return {'"pystencils_runtime/half.h"'} else: return set() @@ -672,7 +672,7 @@ class PsIeeeFloatType(PsScalarType): match self.width: case 16: - return f"((half) {value})" # see include/half_precision.h + return f"((half) {value})" # see include/pystencils_runtime/half.h case 32: return f"{value}f" case 64: diff --git a/tests/kernelcreation/test_gpu.py b/tests/kernelcreation/test_gpu.py index 10b37e610cebd23c9fc961f14118aee5f24582c4..f1905b1fcb7c7406f43cfb94af2928b6f35bc3f8 100644 --- a/tests/kernelcreation/test_gpu.py +++ b/tests/kernelcreation/test_gpu.py @@ -31,7 +31,7 @@ except ImportError: @pytest.mark.parametrize("indexing_scheme", ["linear3d", "blockwise4d"]) @pytest.mark.parametrize("omit_range_check", [False, True]) @pytest.mark.parametrize("manual_grid", [False, True]) -def test_indexing_options( +def test_indexing_options_3d( indexing_scheme: str, omit_range_check: bool, manual_grid: bool ): src, dst = fields("src, dst: [3D]") @@ -76,6 +76,52 @@ def test_indexing_options( cp.testing.assert_allclose(dst_arr, expected) +@pytest.mark.parametrize("indexing_scheme", ["linear3d", "blockwise4d"]) +@pytest.mark.parametrize("omit_range_check", [False, True]) +@pytest.mark.parametrize("manual_grid", [False, True]) +def test_indexing_options_2d( + indexing_scheme: str, omit_range_check: bool, manual_grid: bool +): + src, dst = fields("src, dst: [2D]") + asm = Assignment( + dst.center(), + src[-1, 0] + + src[1, 0] + + src[0, -1] + + src[0, 1] + ) + + cfg = CreateKernelConfig(target=Target.CUDA) + cfg.gpu.indexing_scheme = indexing_scheme + cfg.gpu.omit_range_check = omit_range_check + cfg.gpu.manual_launch_grid = manual_grid + + ast = create_kernel(asm, cfg) + kernel = ast.compile() + + src_arr = cp.ones((18, 42)) + dst_arr = cp.zeros_like(src_arr) + + if manual_grid: + match indexing_scheme: + case "linear3d": + kernel.launch_config.block_size = (10, 8, 1) + kernel.launch_config.grid_size = (4, 2, 1) + case "blockwise4d": + kernel.launch_config.block_size = (40, 1, 1) + kernel.launch_config.grid_size = (16, 1, 1) + + elif indexing_scheme == "linear3d": + kernel.launch_config.block_size = (10, 8, 1) + + kernel(src=src_arr, dst=dst_arr) + + expected = cp.zeros_like(src_arr) + expected[1:-1, 1:-1].fill(4.0) + + cp.testing.assert_allclose(dst_arr, expected) + + def test_invalid_indexing_schemes(): src, dst = fields("src, dst: [4D]") asm = Assignment(src.center(0), dst.center(0)) diff --git a/tests/kernelcreation/test_sycl_codegen.py b/tests/kernelcreation/test_sycl_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..b6907c9965b4b80a5b97865063170dfdc3654615 --- /dev/null +++ b/tests/kernelcreation/test_sycl_codegen.py @@ -0,0 +1,45 @@ +""" +Since we don't have a JIT compiler for SYCL, these tests can only +perform dry-dock testing. +If the SYCL target should ever become non-experimental, we need to +find a way to properly test SYCL kernels in execution. + +These tests primarily check that the code generation driver runs +successfully for the SYCL target. +""" + +import sympy as sp +from pystencils import ( + create_kernel, + Target, + fields, + Assignment, + CreateKernelConfig, +) + + +def test_sycl_kernel_static(): + src, dst = fields("src, dst: [2D]") + asm = Assignment(dst.center(), sp.sin(src.center()) + sp.cos(src.center())) + + cfg = CreateKernelConfig(target=Target.SYCL) + kernel = create_kernel(asm, cfg) + + code_string = kernel.get_c_code() + + assert "sycl::id< 2 >" in code_string + assert "sycl::sin(" in code_string + assert "sycl::cos(" in code_string + + +def test_sycl_kernel_manual_block_size(): + src, dst = fields("src, dst: [2D]") + asm = Assignment(dst.center(), sp.sin(src.center()) + sp.cos(src.center())) + + cfg = CreateKernelConfig(target=Target.SYCL) + cfg.sycl.automatic_block_size = False + kernel = create_kernel(asm, cfg) + + code_string = kernel.get_c_code() + + assert "sycl::nd_item< 2 >" in code_string diff --git a/tests/nbackend/test_vectorization.py b/tests/nbackend/test_vectorization.py index b60dc24774566d67eaa271c6ab775374746d89cf..fecade65d97afcaae4382bcc2ced119b2a957bed 100644 --- a/tests/nbackend/test_vectorization.py +++ b/tests/nbackend/test_vectorization.py @@ -20,7 +20,7 @@ from pystencils.backend.transformations import ( LowerToC, ) from pystencils.backend.constants import PsConstant -from pystencils.codegen.driver import create_cpu_kernel_function +from pystencils.codegen.driver import KernelFactory from pystencils.jit import LegacyCpuJit from pystencils import Target, fields, Assignment, Field from pystencils.field import create_numpy_array_with_layout @@ -135,8 +135,8 @@ def create_vector_kernel( lower = LowerToC(ctx) loop_nest = lower(loop_nest) - func = create_cpu_kernel_function( - ctx, + kfactory = KernelFactory(ctx) + func = kfactory.create_generic_kernel( platform, PsBlock([loop_nest]), "vector_kernel",