diff --git a/mypy.ini b/mypy.ini index 8e9fe08334a08d8ab0a272f114d4a719d81398ed..07228fe24009da6ea4f21cb6cdf15a0516041149 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,3 +16,6 @@ ignore_missing_imports=true [mypy-appdirs.*] ignore_missing_imports=true + +[mypy-islpy.*] +ignore_missing_imports=true diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 61016e14f11a536444c798441ba8be516d97a167..3d3b7846a84bb9c477b38e90839d7f67fe12933c 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -6,8 +6,14 @@ from . import fd from . import stencil as stencil from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields +from .types import create_type from .cache import clear_cache -from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig +from .config import ( + CreateKernelConfig, + CpuOptimConfig, + VectorizationConfig, + OpenMpConfig, +) from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel from .backend.kernelfunction import KernelFunction @@ -34,10 +40,12 @@ __all__ = [ "fields", "DEFAULTS", "TypedSymbol", + "create_type", "make_slice", "CreateKernelConfig", "CpuOptimConfig", "VectorizationConfig", + "OpenMpConfig", "create_kernel", "KernelFunction", "Target", diff --git a/src/pystencils/backend/ast/analysis.py b/src/pystencils/backend/ast/analysis.py index 040c6167827dd48c97bc289c0959d937bbbefb38..15ee0680edb0b5b8197aec2545d182f90eb6c71a 100644 --- a/src/pystencils/backend/ast/analysis.py +++ b/src/pystencils/backend/ast/analysis.py @@ -5,7 +5,7 @@ from .structural import ( PsAssignment, PsAstNode, PsBlock, - PsComment, + PsEmptyLeafMixIn, PsConditional, PsDeclaration, PsExpression, @@ -63,7 +63,7 @@ class UndefinedSymbolsCollector: undefined_vars |= self(branch_false) return undefined_vars - case PsComment(): + case PsEmptyLeafMixIn(): return set() case unknown: @@ -92,11 +92,11 @@ class UndefinedSymbolsCollector: case ( PsAssignment() | PsBlock() - | PsComment() | PsConditional() | PsExpression() | PsLoop() | PsStatement() + | PsEmptyLeafMixIn() ): return set() diff --git a/src/pystencils/backend/ast/iteration.py b/src/pystencils/backend/ast/iteration.py index 6c1c406ed4602ed98416816886686fe29324975f..cc666c72257c5a434c0b312c9e8d8e5ea9dc028f 100644 --- a/src/pystencils/backend/ast/iteration.py +++ b/src/pystencils/backend/ast/iteration.py @@ -4,32 +4,32 @@ from .structural import PsAstNode def dfs_preorder( - node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True + node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True ) -> Generator[PsAstNode, None, None]: """Pre-Order depth-first traversal of an abstract syntax tree. Args: node: The tree's root node - yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True + filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True """ - if yield_pred(node): + if filter_pred(node): yield node for c in node.children: - yield from dfs_preorder(c, yield_pred) + yield from dfs_preorder(c, filter_pred) def dfs_postorder( - node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True + node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True ) -> Generator[PsAstNode, None, None]: """Post-Order depth-first traversal of an abstract syntax tree. Args: node: The tree's root node - yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True + filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True """ for c in node.children: - yield from dfs_postorder(c, yield_pred) + yield from dfs_postorder(c, filter_pred) - if yield_pred(node): + if filter_pred(node): yield node diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py index 47342cfedcb7c748f8c460ae99512ec64800ce61..cd3aae30d35061ab6c15c338a735aaecca83a141 100644 --- a/src/pystencils/backend/ast/structural.py +++ b/src/pystencils/backend/ast/structural.py @@ -307,7 +307,42 @@ class PsConditional(PsAstNode): assert False, "unreachable code" -class PsComment(PsLeafMixIn, PsAstNode): +class PsEmptyLeafMixIn: + """Mix-in marking AST leaves that can be treated as empty by the code generator, + such as comments and preprocessor directives.""" + + pass + + +class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): + """A C/C++ preprocessor pragma. + + Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``. + + Args: + text: The pragma's text, without the ``#pragma ``. + """ + + __match_args__ = ("text",) + + def __init__(self, text: str) -> None: + self._text = text + + @property + def text(self) -> str: + return self._text + + def clone(self) -> PsPragma: + return PsPragma(self.text) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsPragma): + return False + + return self._text == other._text + + +class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode): __match_args__ = ("lines",) def __init__(self, text: str) -> None: diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index b5f8a43549d989c568201ef43873eeca850bd966..e8fc2a662b2f49c43fe51d021915b9e44cc59ad3 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -10,6 +10,7 @@ from .ast.structural import ( PsLoop, PsConditional, PsComment, + PsPragma, ) from .ast.expressions import ( @@ -235,6 +236,9 @@ class CAstPrinter: lines_list[-1] = lines_list[-1] + " */" return pc.indent("\n".join(lines_list)) + case PsPragma(text): + return pc.indent("#pragma " + text) + case PsSymbolExpr(symbol): return symbol.name diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py index b0156c7e8ce0b9cac2c6f3be9f60bbeef41e1c51..29b133ff164e856783f14eb83357c8382db9ba5d 100644 --- a/src/pystencils/backend/kernelcreation/cpu_optimization.py +++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py @@ -3,10 +3,9 @@ from typing import cast from .context import KernelCreationContext from ..platforms import GenericCpu -from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations from ..ast.structural import PsBlock -from ...config import CpuOptimConfig +from ...config import CpuOptimConfig, OpenMpConfig def optimize_cpu( @@ -16,6 +15,7 @@ def optimize_cpu( cfg: CpuOptimConfig | None, ) -> PsBlock: """Carry out CPU-specific optimizations according to the given configuration.""" + from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations canonicalize = CanonicalizeSymbols(ctx, True) kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) @@ -32,8 +32,12 @@ def optimize_cpu( if cfg.vectorize is not False: raise NotImplementedError("Vectorization not implemented yet") - if cfg.openmp: - raise NotImplementedError("OpenMP not implemented yet") + if cfg.openmp is not False: + from ..transformations import AddOpenMP + + params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig() + add_omp = AddOpenMP(ctx, params) + kernel_ast = cast(PsBlock, add_omp(kernel_ast)) if cfg.use_cacheline_zeroing: raise NotImplementedError("CL-zeroing not implemented yet") diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 7e628631dba164ab8f1c756961f1f1112437f36f..06e34d4e36219366a25713d28ae758b61fb3d0d6 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -22,7 +22,7 @@ from ..ast.structural import ( PsExpression, PsAssignment, PsDeclaration, - PsComment, + PsEmptyLeafMixIn, ) from ..ast.expressions import ( PsArrayAccess, @@ -336,7 +336,7 @@ class Typifier: self.visit(body) - case PsComment(): + case PsEmptyLeafMixIn(): pass case _: diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index 518c402d27e8828cc129c06a20bd10e2ba3d3168..88ad9348f09685258d2aecb5fca66fcfe609173b 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -60,6 +60,11 @@ Loop Reshaping Transformations .. autoclass:: ReshapeLoops :members: +.. autoclass:: InsertPragmasAtLoops + :members: + +.. autoclass:: AddOpenMP + :members: Code Lowering and Materialization --------------------------------- @@ -78,6 +83,7 @@ from .eliminate_constants import EliminateConstants from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops +from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics @@ -89,6 +95,9 @@ __all__ = [ "EliminateBranches", "HoistLoopInvariantDeclarations", "ReshapeLoops", + "InsertPragmasAtLoops", + "LoopPragma", + "AddOpenMP", "EraseAnonymousStructTypes", "SelectFunctions", "MaterializeVectorIntrinsics", diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py new file mode 100644 index 0000000000000000000000000000000000000000..c7015ccb62042e6beaf131e68485f7c8186b2a2a --- /dev/null +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -0,0 +1,118 @@ +from dataclasses import dataclass + +from typing import Sequence +from collections import defaultdict + +from ..kernelcreation import KernelCreationContext +from ..ast import PsAstNode +from ..ast.structural import PsBlock, PsLoop, PsPragma +from ..ast.expressions import PsExpression + +from ...config import OpenMpConfig + +__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"] + + +@dataclass +class LoopPragma: + """A pragma that should be prepended to loops at a certain nesting depth.""" + + text: str + """The pragma text, without the ``#pragma ``""" + + loop_nesting_depth: int + """Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops.""" + + def __post_init__(self): + if self.loop_nesting_depth < -1: + raise ValueError("Loop nesting depth must be nonnegative or -1.") + + +@dataclass +class Nesting: + depth: int + has_inner_loops: bool = False + + +class InsertPragmasAtLoops: + """Insert pragmas before loops in a loop nest. + + This transformation augments the AST with pragma directives which are prepended to loops. + The directives are annotated with the nesting depth of the loops they should be added to, + where ``-1`` indicates the innermost loop. + + The relative order of pragmas with the (exact) same nesting depth is preserved; + however, no guarantees are given about the relative order of pragmas inserted at ``-1`` + and at the actual depth of the innermost loop. + """ + + def __init__( + self, ctx: KernelCreationContext, insertions: Sequence[LoopPragma] + ) -> None: + self._ctx = ctx + self._insertions: dict[int, list[LoopPragma]] = defaultdict(list) + for ins in insertions: + self._insertions[ins.loop_nesting_depth].append(ins) + + def __call__(self, node: PsAstNode) -> PsAstNode: + is_loop = isinstance(node, PsLoop) + if is_loop: + node = PsBlock([node]) + + self.visit(node, Nesting(0)) + + if is_loop and len(node.children) == 1: + node = node.children[0] + + return node + + def visit(self, node: PsAstNode, nest: Nesting) -> None: + match node: + case PsExpression(): + return + + case PsBlock(children): + new_children: list[PsAstNode] = [] + for c in children: + if isinstance(c, PsLoop): + nest.has_inner_loops = True + inner_nest = Nesting(nest.depth + 1) + self.visit(c.body, inner_nest) + + if not inner_nest.has_inner_loops: + # c is the innermost loop + for pragma in self._insertions[-1]: + new_children.append(PsPragma(pragma.text)) + + for pragma in self._insertions[nest.depth]: + new_children.append(PsPragma(pragma.text)) + + new_children.append(c) + node.children = new_children + + case other: + for c in other.children: + self.visit(c, nest) + + +class AddOpenMP: + """Apply OpenMP directives to loop nests. + + This transformation augments the AST with OpenMP pragmas according to the given + `OpenMpParams` configuration. + """ + + def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None: + pragma_text = "omp" + pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" + pragma_text += f" for schedule({omp_params.schedule})" + + if omp_params.collapse > 0: + pragma_text += f" collapse({str(omp_params.collapse)})" + + self._insert_pragmas = InsertPragmasAtLoops( + ctx, [LoopPragma(pragma_text, omp_params.nesting_depth)] + ) + + def __call__(self, node: PsAstNode) -> PsAstNode: + return self._insert_pragmas(node) diff --git a/src/pystencils/backend/transformations/canonical_clone.py b/src/pystencils/backend/transformations/canonical_clone.py index 7c040d30471c9dbe413d287142199050c7c24a37..b21fd115f98645ff4c8dfb2dd3f72c252282fcf2 100644 --- a/src/pystencils/backend/transformations/canonical_clone.py +++ b/src/pystencils/backend/transformations/canonical_clone.py @@ -12,6 +12,7 @@ from ..ast.structural import ( PsDeclaration, PsAssignment, PsComment, + PsPragma, PsStatement, ) from ..ast.expressions import PsExpression, PsSymbolExpr @@ -73,7 +74,7 @@ class CanonicalClone: ), ) - case PsComment(): + case PsComment() | PsPragma(): return cast(Node_T, node.clone()) case PsDeclaration(lhs, rhs): diff --git a/src/pystencils/backend/transformations/eliminate_branches.py b/src/pystencils/backend/transformations/eliminate_branches.py index eab3d3722c30756ab39af072e75e9d6d89874447..f098d82df1ce6a748097756aa1616a72e57487b5 100644 --- a/src/pystencils/backend/transformations/eliminate_branches.py +++ b/src/pystencils/backend/transformations/eliminate_branches.py @@ -1,16 +1,44 @@ from ..kernelcreation import KernelCreationContext from ..ast import PsAstNode +from ..ast.analysis import collect_undefined_symbols from ..ast.structural import PsLoop, PsBlock, PsConditional -from ..ast.expressions import PsConstantExpr +from ..ast.expressions import ( + PsAnd, + PsCast, + PsConstant, + PsConstantExpr, + PsDiv, + PsEq, + PsExpression, + PsGe, + PsGt, + PsIntDiv, + PsLe, + PsLt, + PsMul, + PsNe, + PsNeg, + PsNot, + PsOr, + PsSub, + PsSymbolExpr, + PsAdd, +) from .eliminate_constants import EliminateConstants +from ...types import PsBoolType, PsIntegerType __all__ = ["EliminateBranches"] +class IslAnalysisError(Exception): + """Indicates a fatal error during integer set analysis (based on islpy)""" + + class BranchElimContext: def __init__(self) -> None: self.enclosing_loops: list[PsLoop] = [] + self.enclosing_conditions: list[PsExpression] = [] class EliminateBranches: @@ -20,12 +48,16 @@ class EliminateBranches: This pass will attempt to evaluate branch conditions within their context in the AST, and replace conditionals by either their then- or their else-block if the branch is unequivocal. - TODO: If islpy is installed, this pass will incorporate information about the iteration regions - of enclosing loops into its analysis. + If islpy is installed, this pass will incorporate information about the iteration regions + of enclosing loops and enclosing conditionals into its analysis. + + Args: + use_isl (bool, optional): enable islpy based analysis (default: True) """ - def __init__(self, ctx: KernelCreationContext) -> None: + def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None: self._ctx = ctx + self._use_isl = use_isl self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) def __call__(self, node: PsAstNode) -> PsAstNode: @@ -41,20 +73,30 @@ class EliminateBranches: case PsBlock(statements): statements_new: list[PsAstNode] = [] for stmt in statements: - if isinstance(stmt, PsConditional): - result = self.handle_conditional(stmt, ec) - if result is not None: - statements_new.append(result) - else: - statements_new.append(self.visit(stmt, ec)) + statements_new.append(self.visit(stmt, ec)) node.statements = statements_new case PsConditional(): result = self.handle_conditional(node, ec) - if result is None: - return PsBlock([]) - else: - return result + + match result: + case PsConditional(_, branch_true, branch_false): + ec.enclosing_conditions.append(result.condition) + self.visit(branch_true, ec) + ec.enclosing_conditions.pop() + + if branch_false is not None: + ec.enclosing_conditions.append(PsNot(result.condition)) + self.visit(branch_false, ec) + ec.enclosing_conditions.pop() + case PsBlock(): + self.visit(result, ec) + case None: + result = PsBlock([]) + case _: + assert False, "unreachable code" + + return result return node @@ -62,12 +104,124 @@ class EliminateBranches: self, conditional: PsConditional, ec: BranchElimContext ) -> PsConditional | PsBlock | None: condition_simplified = self._elim_constants(conditional.condition) + if self._use_isl: + condition_simplified = self._isl_simplify_condition( + condition_simplified, ec + ) + match condition_simplified: case PsConstantExpr(c) if c.value: return conditional.branch_true case PsConstantExpr(c) if not c.value: return conditional.branch_false - # TODO: Analyze condition against counters of enclosing loops using ISL - return conditional + + def _isl_simplify_condition( + self, condition: PsExpression, ec: BranchElimContext + ) -> PsExpression: + """If installed, use ISL to simplify the passed condition to true or + false based on enclosing loops and conditionals. If no simplification + can be made or ISL is not installed, the original condition is returned. + """ + + try: + import islpy as isl + except ImportError: + return condition + + def printer(expr: PsExpression): + match expr: + case PsSymbolExpr(symbol): + return symbol.name + + case PsConstantExpr(constant): + dtype = constant.get_dtype() + if not isinstance(dtype, (PsIntegerType, PsBoolType)): + raise IslAnalysisError( + "Only scalar integer and bool constant may appear in isl expressions." + ) + return str(constant.value) + + case PsAdd(op1, op2): + return f"({printer(op1)} + {printer(op2)})" + case PsSub(op1, op2): + return f"({printer(op1)} - {printer(op2)})" + case PsMul(op1, op2): + return f"({printer(op1)} * {printer(op2)})" + case PsDiv(op1, op2) | PsIntDiv(op1, op2): + return f"({printer(op1)} / {printer(op2)})" + case PsAnd(op1, op2): + return f"({printer(op1)} and {printer(op2)})" + case PsOr(op1, op2): + return f"({printer(op1)} or {printer(op2)})" + case PsEq(op1, op2): + return f"({printer(op1)} = {printer(op2)})" + case PsNe(op1, op2): + return f"({printer(op1)} != {printer(op2)})" + case PsGt(op1, op2): + return f"({printer(op1)} > {printer(op2)})" + case PsGe(op1, op2): + return f"({printer(op1)} >= {printer(op2)})" + case PsLt(op1, op2): + return f"({printer(op1)} < {printer(op2)})" + case PsLe(op1, op2): + return f"({printer(op1)} <= {printer(op2)})" + + case PsNeg(operand): + return f"(-{printer(operand)})" + case PsNot(operand): + return f"(not {printer(operand)})" + + case PsCast(_, operand): + return printer(operand) + + case _: + raise IslAnalysisError( + f"Not supported by isl or don't know how to print {expr}" + ) + + dofs = collect_undefined_symbols(condition) + outer_conditions = [] + + for loop in ec.enclosing_loops: + if not ( + isinstance(loop.step, PsConstantExpr) + and loop.step.constant.value == 1 + ): + raise IslAnalysisError( + "Loops with strides != 1 are not yet supported." + ) + + dofs.add(loop.counter.symbol) + dofs.update(collect_undefined_symbols(loop.start)) + dofs.update(collect_undefined_symbols(loop.stop)) + + loop_start_str = printer(loop.start) + loop_stop_str = printer(loop.stop) + ctr_name = loop.counter.symbol.name + outer_conditions.append( + f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}" + ) + + for cond in ec.enclosing_conditions: + dofs.update(collect_undefined_symbols(cond)) + outer_conditions.append(printer(cond)) + + dofs_str = ",".join(dof.name for dof in dofs) + outer_conditions_str = " and ".join(outer_conditions) + condition_str = printer(condition) + + outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}") + inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}") + + if inner_set.is_empty(): + return PsExpression.make(PsConstant(False)) + + intersection = outer_set.intersect(inner_set) + if intersection.is_empty(): + return PsExpression.make(PsConstant(False)) + elif intersection == outer_set: + return PsExpression.make(PsConstant(True)) + else: + return condition diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 7b0ec590d33ad83d3618e110eae9921b2e52830c..8745dbfa73055142dbf7b145685ca89d774bd299 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -15,19 +15,39 @@ from .types import PsIntegerType, PsNumericType, PsIeeeFloatType from .defaults import DEFAULTS +@dataclass +class OpenMpConfig: + """Parameters controlling kernel parallelization using OpenMP.""" + + nesting_depth: int = 0 + """Nesting depth of the loop that should be parallelized. Must be a nonnegative number.""" + + collapse: int = 0 + """Argument to the OpenMP ``collapse`` clause""" + + schedule: str = "static" + """Argument to the OpenMP ``schedule`` clause""" + + omit_parallel_construct: bool = False + """If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``. + + Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region. + """ + + @dataclass class CpuOptimConfig: """Configuration for the CPU optimizer. - + If any flag in this configuration is set to a value not supported by the CPU specified in `CreateKernelConfig.target`, an error will be raised. """ - - openmp: bool = False + + openmp: bool | OpenMpConfig = False """Enable OpenMP parallelization. - If set to `True`, the kernel will be parallelized using OpenMP according to the OpenMP settings - given in this configuration. + If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`. + To customize OpenMP parallelization, pass an instance of `OpenMpParams` instead. """ vectorize: bool | VectorizationConfig = False @@ -58,7 +78,7 @@ class CpuOptimConfig: @dataclass class VectorizationConfig: """Configuration for the auto-vectorizer. - + If any flag in this configuration is set to a value not supported by the CPU specified in `CreateKernelConfig.target`, an error will be raised. """ @@ -211,14 +231,21 @@ class CreateKernelConfig: raise PsOptionsError( "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" ) - + # Check optim if self.cpu_optim is not None: if not self.target.is_cpu(): - raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}") - - if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): - raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}") + raise PsOptionsError( + f"`cpu_optim` cannot be set for non-CPU target {self.target}" + ) + + if ( + self.cpu_optim.vectorize is not False + and not self.target.is_vector_cpu() + ): + raise PsOptionsError( + f"Cannot enable auto-vectorization for non-vector CPU target {self.target}" + ) if self.gpu_indexing is not None: if self.target != Target.SYCL: @@ -228,6 +255,7 @@ class CreateKernelConfig: if self.jit is None: if self.target.is_cpu(): from .backend.jit import LegacyCpuJit + self.jit = LegacyCpuJit() elif self.target == Target.SYCL: from .backend.jit import no_jit diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index ff0f09512bb6f40ac23d0d0cdbdc23ea32e0012f..90fcf73d3f996837b730b3cea9552800b8743a62 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -107,6 +107,7 @@ def create_kernel( # Target-Specific optimizations if config.target.is_cpu(): from .backend.kernelcreation import optimize_cpu + kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) erase_anons = EraseAnonymousStructTypes(ctx) diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index ae0a8829df43135be99acec9abb3500f80a9abf5..2f0f2ff46ec1498eda724f23964bf28aa2ad2a9b 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -483,7 +483,20 @@ class PsIntegerType(PsScalarType, ABC): if not isinstance(value, np_dtype): raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") unsigned_suffix = "" if self.signed else "u" - return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})" + + match self.width: + case w if w < 32: + # Plain integer literals get at least type `int`, which is 32 bit in all relevant cases + # So we need to explicitly cast to smaller types + return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})" + case 32: + # No suffix here - becomes `int`, which is 32 bit + return f"{value}{unsigned_suffix}" + case 64: + # LL suffix: `long long` is the only type guaranteed to be 64 bit wide + return f"{value}{unsigned_suffix}LL" + case _: + assert False, "unreachable code" def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] diff --git a/tests/nbackend/kernelcreation/test_openmp.py b/tests/nbackend/kernelcreation/test_openmp.py new file mode 100644 index 0000000000000000000000000000000000000000..d7be8eb98cd29bea370bc6279013ef973e621370 --- /dev/null +++ b/tests/nbackend/kernelcreation/test_openmp.py @@ -0,0 +1,61 @@ +import pytest +from pystencils import ( + fields, + Assignment, + create_kernel, + CreateKernelConfig, + CpuOptimConfig, + OpenMpConfig, + Target, +) + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsLoop, PsPragma + + +@pytest.mark.parametrize("nesting_depth", range(3)) +@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"]) +@pytest.mark.parametrize("collapse", range(3)) +@pytest.mark.parametrize("omit_parallel_construct", range(3)) +def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct): + f, g = fields("f, g: [3D]") + asm = Assignment(f.center(0), g.center(0)) + + omp = OpenMpConfig( + nesting_depth=nesting_depth, + schedule=schedule, + collapse=collapse, + omit_parallel_construct=omit_parallel_construct, + ) + gen_config = CreateKernelConfig( + target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp) + ) + + kernel = create_kernel(asm, gen_config) + ast = kernel.body + + def find_omp_pragma(ast) -> PsPragma: + num_loops = 0 + generator = dfs_preorder(ast) + for node in generator: + match node: + case PsLoop(): + num_loops += 1 + case PsPragma(): + loop = next(generator) + assert isinstance(loop, PsLoop) + assert num_loops == nesting_depth + return node + + pytest.fail("No OpenMP pragma found") + + pragma = find_omp_pragma(ast) + tokens = set(pragma.text.split()) + + expected_tokens = {"omp", "for", f"schedule({omp.schedule})"} + if not omp.omit_parallel_construct: + expected_tokens.add("parallel") + if omp.collapse > 0: + expected_tokens.add(f"collapse({omp.collapse})") + + assert tokens == expected_tokens diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index fb2dd0e04081050138be1c8efc0210f81c999707..09c63a5572f7873648712dbd3279d16f3c342458 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -12,6 +12,7 @@ from pystencils.backend.ast.structural import ( PsBlock, PsConditional, PsComment, + PsPragma, PsLoop, ) from pystencils.types.quick import Fp, Ptr @@ -44,6 +45,7 @@ def test_cloning(): PsConditional( y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) ), + PsPragma("omp parallel for"), PsLoop( x, y, @@ -54,6 +56,7 @@ def test_cloning(): PsComment("Loop body"), PsAssignment(x, y), PsAssignment(x, y), + PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( PsDeref(PsCast(Ptr(Fp(32)), z)) + PsSubscript(z, one + one + one) diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py index 8d600ef766d55c51e93c23ae5bd720c03130b555..16e610a552b426cc4245e2e5c4ee36663f6c2bfa 100644 --- a/tests/nbackend/test_extensions.py +++ b/tests/nbackend/test_extensions.py @@ -54,6 +54,6 @@ def test_literals(): print(code) assert "const double x = C;" in code - assert "CELLS[((int64_t) 0)]" in code - assert "CELLS[((int64_t) 1)]" in code - assert "CELLS[((int64_t) 2)]" in code + assert "CELLS[0LL]" in code + assert "CELLS[1LL]" in code + assert "CELLS[2LL]" in code diff --git a/tests/nbackend/transformations/test_add_pragmas.py b/tests/nbackend/transformations/test_add_pragmas.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8dd1ded148697dbe7acde3648a292dc5a7fcac --- /dev/null +++ b/tests/nbackend/transformations/test_add_pragmas.py @@ -0,0 +1,54 @@ +import sympy as sp +from itertools import product + +from pystencils import make_slice, fields, Assignment +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop +from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma + +def test_insert_pragmas(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + f, g = fields("f, g: [3D]") + ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[:, :, :], archetype_field=f + ) + ctx.set_iteration_space(ispace) + + stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1])) + loop_body = PsBlock([ + factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil)))) + ]) + loops = factory.loops_from_ispace(ispace, loop_body) + + pragmas = ( + LoopPragma("omp parallel for", 0), + LoopPragma("some nonsense pragma", 1), + LoopPragma("omp simd", -1), + ) + add_pragmas = InsertPragmasAtLoops(ctx, pragmas) + ast = add_pragmas(loops) + + assert isinstance(ast, PsBlock) + + first_pragma = ast.statements[0] + assert isinstance(first_pragma, PsPragma) + assert first_pragma.text == pragmas[0].text + + assert ast.statements[1] == loops + second_pragma = loops.body.statements[0] + assert isinstance(second_pragma, PsPragma) + assert second_pragma.text == pragmas[1].text + + second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1] + assert isinstance(second_loop, PsLoop) + third_pragma = second_loop.body.statements[0] + assert isinstance(third_pragma, PsPragma) + assert third_pragma.text == pragmas[2].text diff --git a/tests/nbackend/transformations/test_branch_elimination.py b/tests/nbackend/transformations/test_branch_elimination.py index 0fb3526d0b53fd40972c4dfeb06cf3a614bc6c10..fae8f158aaa472e02efadbd93365ce042dff0ab1 100644 --- a/tests/nbackend/transformations/test_branch_elimination.py +++ b/tests/nbackend/transformations/test_branch_elimination.py @@ -4,12 +4,18 @@ from pystencils.backend.kernelcreation import ( Typifier, AstFactory, ) -from pystencils.backend.ast.expressions import PsExpression +from pystencils.backend.ast.expressions import ( + PsExpression, + PsEq, + PsGe, + PsGt, + PsLe, + PsLt, +) from pystencils.backend.ast.structural import PsConditional, PsBlock, PsComment from pystencils.backend.constants import PsConstant from pystencils.backend.transformations import EliminateBranches from pystencils.types.quick import Int -from pystencils.backend.ast.expressions import PsGt i0 = PsExpression.make(PsConstant(0, Int(32))) @@ -53,3 +59,39 @@ def test_eliminate_nested_conditional(): result = elim(ast) assert result.body.statements[0].body.statements[0] == b1 + + +def test_isl(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + typify = Typifier(ctx) + elim = EliminateBranches(ctx) + + i = PsExpression.make(ctx.get_symbol("i", ctx.index_dtype)) + j = PsExpression.make(ctx.get_symbol("j", ctx.index_dtype)) + + const_2 = PsExpression.make(PsConstant(2, ctx.index_dtype)) + const_4 = PsExpression.make(PsConstant(4, ctx.index_dtype)) + + a_true = PsBlock([PsComment("a true")]) + a_false = PsBlock([PsComment("a false")]) + b_true = PsBlock([PsComment("b true")]) + b_false = PsBlock([PsComment("b false")]) + c_true = PsBlock([PsComment("c true")]) + c_false = PsBlock([PsComment("c false")]) + + a = PsConditional(PsLt(i + j, const_2 * const_4), a_true, a_false) + b = PsConditional(PsGe(j, const_4), b_true, b_false) + c = PsConditional(PsEq(i, const_4), c_true, c_false) + + outer_loop = factory.loop(j.symbol.name, slice(0, 3), PsBlock([a, b, c])) + outer_cond = typify( + PsConditional(PsLe(i, const_4), PsBlock([outer_loop]), PsBlock([])) + ) + ast = outer_cond + + result = elim(ast) + + assert result.branch_true.statements[0].body.statements[0] == a_true + assert result.branch_true.statements[0].body.statements[1] == b_false + assert result.branch_true.statements[0].body.statements[2] == c