Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (6)
Showing
with 504 additions and 62 deletions
...@@ -16,3 +16,6 @@ ignore_missing_imports=true ...@@ -16,3 +16,6 @@ ignore_missing_imports=true
[mypy-appdirs.*] [mypy-appdirs.*]
ignore_missing_imports=true ignore_missing_imports=true
[mypy-islpy.*]
ignore_missing_imports=true
...@@ -6,8 +6,14 @@ from . import fd ...@@ -6,8 +6,14 @@ from . import fd
from . import stencil as stencil from . import stencil as stencil
from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields from .field import Field, FieldType, fields
from .types import create_type
from .cache import clear_cache from .cache import clear_cache
from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig from .config import (
CreateKernelConfig,
CpuOptimConfig,
VectorizationConfig,
OpenMpConfig,
)
from .kernel_decorator import kernel, kernel_config from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction from .backend.kernelfunction import KernelFunction
...@@ -34,10 +40,12 @@ __all__ = [ ...@@ -34,10 +40,12 @@ __all__ = [
"fields", "fields",
"DEFAULTS", "DEFAULTS",
"TypedSymbol", "TypedSymbol",
"create_type",
"make_slice", "make_slice",
"CreateKernelConfig", "CreateKernelConfig",
"CpuOptimConfig", "CpuOptimConfig",
"VectorizationConfig", "VectorizationConfig",
"OpenMpConfig",
"create_kernel", "create_kernel",
"KernelFunction", "KernelFunction",
"Target", "Target",
......
...@@ -5,7 +5,7 @@ from .structural import ( ...@@ -5,7 +5,7 @@ from .structural import (
PsAssignment, PsAssignment,
PsAstNode, PsAstNode,
PsBlock, PsBlock,
PsComment, PsEmptyLeafMixIn,
PsConditional, PsConditional,
PsDeclaration, PsDeclaration,
PsExpression, PsExpression,
...@@ -63,7 +63,7 @@ class UndefinedSymbolsCollector: ...@@ -63,7 +63,7 @@ class UndefinedSymbolsCollector:
undefined_vars |= self(branch_false) undefined_vars |= self(branch_false)
return undefined_vars return undefined_vars
case PsComment(): case PsEmptyLeafMixIn():
return set() return set()
case unknown: case unknown:
...@@ -92,11 +92,11 @@ class UndefinedSymbolsCollector: ...@@ -92,11 +92,11 @@ class UndefinedSymbolsCollector:
case ( case (
PsAssignment() PsAssignment()
| PsBlock() | PsBlock()
| PsComment()
| PsConditional() | PsConditional()
| PsExpression() | PsExpression()
| PsLoop() | PsLoop()
| PsStatement() | PsStatement()
| PsEmptyLeafMixIn()
): ):
return set() return set()
......
...@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): ...@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})" return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression): class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",) __match_args__ = ("literal",)
...@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): ...@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
def clone(self) -> PsLiteralExpr: def clone(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal) return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool: def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr): if not isinstance(other, PsLiteralExpr):
return False return False
......
...@@ -4,32 +4,32 @@ from .structural import PsAstNode ...@@ -4,32 +4,32 @@ from .structural import PsAstNode
def dfs_preorder( def dfs_preorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]: ) -> Generator[PsAstNode, None, None]:
"""Pre-Order depth-first traversal of an abstract syntax tree. """Pre-Order depth-first traversal of an abstract syntax tree.
Args: Args:
node: The tree's root node node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
""" """
if yield_pred(node): if filter_pred(node):
yield node yield node
for c in node.children: for c in node.children:
yield from dfs_preorder(c, yield_pred) yield from dfs_preorder(c, filter_pred)
def dfs_postorder( def dfs_postorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]: ) -> Generator[PsAstNode, None, None]:
"""Post-Order depth-first traversal of an abstract syntax tree. """Post-Order depth-first traversal of an abstract syntax tree.
Args: Args:
node: The tree's root node node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
""" """
for c in node.children: for c in node.children:
yield from dfs_postorder(c, yield_pred) yield from dfs_postorder(c, filter_pred)
if yield_pred(node): if filter_pred(node):
yield node yield node
...@@ -307,7 +307,42 @@ class PsConditional(PsAstNode): ...@@ -307,7 +307,42 @@ class PsConditional(PsAstNode):
assert False, "unreachable code" assert False, "unreachable code"
class PsComment(PsLeafMixIn, PsAstNode): class PsEmptyLeafMixIn:
"""Mix-in marking AST leaves that can be treated as empty by the code generator,
such as comments and preprocessor directives."""
pass
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
"""A C/C++ preprocessor pragma.
Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
Args:
text: The pragma's text, without the ``#pragma ``.
"""
__match_args__ = ("text",)
def __init__(self, text: str) -> None:
self._text = text
@property
def text(self) -> str:
return self._text
def clone(self) -> PsPragma:
return PsPragma(self.text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsPragma):
return False
return self._text == other._text
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
__match_args__ = ("lines",) __match_args__ = ("lines",)
def __init__(self, text: str) -> None: def __init__(self, text: str) -> None:
......
...@@ -10,6 +10,7 @@ from .ast.structural import ( ...@@ -10,6 +10,7 @@ from .ast.structural import (
PsLoop, PsLoop,
PsConditional, PsConditional,
PsComment, PsComment,
PsPragma,
) )
from .ast.expressions import ( from .ast.expressions import (
...@@ -235,6 +236,9 @@ class CAstPrinter: ...@@ -235,6 +236,9 @@ class CAstPrinter:
lines_list[-1] = lines_list[-1] + " */" lines_list[-1] = lines_list[-1] + " */"
return pc.indent("\n".join(lines_list)) return pc.indent("\n".join(lines_list))
case PsPragma(text):
return pc.indent("#pragma " + text)
case PsSymbolExpr(symbol): case PsSymbolExpr(symbol):
return symbol.name return symbol.name
...@@ -246,7 +250,7 @@ class CAstPrinter: ...@@ -246,7 +250,7 @@ class CAstPrinter:
) )
return dtype.create_literal(constant.value) return dtype.create_literal(constant.value)
case PsLiteralExpr(lit): case PsLiteralExpr(lit):
return lit.text return lit.text
......
...@@ -3,10 +3,9 @@ from typing import cast ...@@ -3,10 +3,9 @@ from typing import cast
from .context import KernelCreationContext from .context import KernelCreationContext
from ..platforms import GenericCpu from ..platforms import GenericCpu
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from ..ast.structural import PsBlock from ..ast.structural import PsBlock
from ...config import CpuOptimConfig from ...config import CpuOptimConfig, OpenMpConfig
def optimize_cpu( def optimize_cpu(
...@@ -16,6 +15,7 @@ def optimize_cpu( ...@@ -16,6 +15,7 @@ def optimize_cpu(
cfg: CpuOptimConfig | None, cfg: CpuOptimConfig | None,
) -> PsBlock: ) -> PsBlock:
"""Carry out CPU-specific optimizations according to the given configuration.""" """Carry out CPU-specific optimizations according to the given configuration."""
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
canonicalize = CanonicalizeSymbols(ctx, True) canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
...@@ -32,8 +32,12 @@ def optimize_cpu( ...@@ -32,8 +32,12 @@ def optimize_cpu(
if cfg.vectorize is not False: if cfg.vectorize is not False:
raise NotImplementedError("Vectorization not implemented yet") raise NotImplementedError("Vectorization not implemented yet")
if cfg.openmp: if cfg.openmp is not False:
raise NotImplementedError("OpenMP not implemented yet") from ..transformations import AddOpenMP
params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig()
add_omp = AddOpenMP(ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cfg.use_cacheline_zeroing: if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet") raise NotImplementedError("CL-zeroing not implemented yet")
......
...@@ -22,7 +22,7 @@ from ..ast.structural import ( ...@@ -22,7 +22,7 @@ from ..ast.structural import (
PsExpression, PsExpression,
PsAssignment, PsAssignment,
PsDeclaration, PsDeclaration,
PsComment, PsEmptyLeafMixIn,
) )
from ..ast.expressions import ( from ..ast.expressions import (
PsArrayAccess, PsArrayAccess,
...@@ -159,7 +159,7 @@ class TypeContext: ...@@ -159,7 +159,7 @@ class TypeContext:
f" Constant type: {c.dtype}\n" f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}" f" Target type: {self._target_type}"
) )
case PsLiteralExpr(lit): case PsLiteralExpr(lit):
if not self._compatible(lit.dtype): if not self._compatible(lit.dtype):
raise TypificationError( raise TypificationError(
...@@ -336,7 +336,7 @@ class Typifier: ...@@ -336,7 +336,7 @@ class Typifier:
self.visit(body) self.visit(body)
case PsComment(): case PsEmptyLeafMixIn():
pass pass
case _: case _:
......
...@@ -4,7 +4,7 @@ from ..types import PsType, constify ...@@ -4,7 +4,7 @@ from ..types import PsType, constify
class PsLiteral: class PsLiteral:
"""Representation of literal code. """Representation of literal code.
Instances of this class represent code literals inside the AST. Instances of this class represent code literals inside the AST.
These literals are not to be confused with C literals; the name `Literal` refers to the fact that These literals are not to be confused with C literals; the name `Literal` refers to the fact that
the code generator takes them "literally", printing them as they are. the code generator takes them "literally", printing them as they are.
...@@ -22,22 +22,22 @@ class PsLiteral: ...@@ -22,22 +22,22 @@ class PsLiteral:
@property @property
def text(self) -> str: def text(self) -> str:
return self._text return self._text
@property @property
def dtype(self) -> PsType: def dtype(self) -> PsType:
return self._dtype return self._dtype
def __str__(self) -> str: def __str__(self) -> str:
return f"{self._text}: {self._dtype}" return f"{self._text}: {self._dtype}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLiteral): if not isinstance(other, PsLiteral):
return False return False
return self._text == other._text and self._dtype == other._dtype return self._text == other._text and self._dtype == other._dtype
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((PsLiteral, self._text, self._dtype)) return hash((PsLiteral, self._text, self._dtype))
...@@ -60,6 +60,11 @@ Loop Reshaping Transformations ...@@ -60,6 +60,11 @@ Loop Reshaping Transformations
.. autoclass:: ReshapeLoops .. autoclass:: ReshapeLoops
:members: :members:
.. autoclass:: InsertPragmasAtLoops
:members:
.. autoclass:: AddOpenMP
:members:
Code Lowering and Materialization Code Lowering and Materialization
--------------------------------- ---------------------------------
...@@ -78,6 +83,7 @@ from .eliminate_constants import EliminateConstants ...@@ -78,6 +83,7 @@ from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches from .eliminate_branches import EliminateBranches
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops from .reshape_loops import ReshapeLoops
from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP
from .erase_anonymous_structs import EraseAnonymousStructTypes from .erase_anonymous_structs import EraseAnonymousStructTypes
from .select_functions import SelectFunctions from .select_functions import SelectFunctions
from .select_intrinsics import MaterializeVectorIntrinsics from .select_intrinsics import MaterializeVectorIntrinsics
...@@ -89,6 +95,9 @@ __all__ = [ ...@@ -89,6 +95,9 @@ __all__ = [
"EliminateBranches", "EliminateBranches",
"HoistLoopInvariantDeclarations", "HoistLoopInvariantDeclarations",
"ReshapeLoops", "ReshapeLoops",
"InsertPragmasAtLoops",
"LoopPragma",
"AddOpenMP",
"EraseAnonymousStructTypes", "EraseAnonymousStructTypes",
"SelectFunctions", "SelectFunctions",
"MaterializeVectorIntrinsics", "MaterializeVectorIntrinsics",
......
from dataclasses import dataclass
from typing import Sequence
from collections import defaultdict
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression
from ...config import OpenMpConfig
__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"]
@dataclass
class LoopPragma:
"""A pragma that should be prepended to loops at a certain nesting depth."""
text: str
"""The pragma text, without the ``#pragma ``"""
loop_nesting_depth: int
"""Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops."""
def __post_init__(self):
if self.loop_nesting_depth < -1:
raise ValueError("Loop nesting depth must be nonnegative or -1.")
@dataclass
class Nesting:
depth: int
has_inner_loops: bool = False
class InsertPragmasAtLoops:
"""Insert pragmas before loops in a loop nest.
This transformation augments the AST with pragma directives which are prepended to loops.
The directives are annotated with the nesting depth of the loops they should be added to,
where ``-1`` indicates the innermost loop.
The relative order of pragmas with the (exact) same nesting depth is preserved;
however, no guarantees are given about the relative order of pragmas inserted at ``-1``
and at the actual depth of the innermost loop.
"""
def __init__(
self, ctx: KernelCreationContext, insertions: Sequence[LoopPragma]
) -> None:
self._ctx = ctx
self._insertions: dict[int, list[LoopPragma]] = defaultdict(list)
for ins in insertions:
self._insertions[ins.loop_nesting_depth].append(ins)
def __call__(self, node: PsAstNode) -> PsAstNode:
is_loop = isinstance(node, PsLoop)
if is_loop:
node = PsBlock([node])
self.visit(node, Nesting(0))
if is_loop and len(node.children) == 1:
node = node.children[0]
return node
def visit(self, node: PsAstNode, nest: Nesting) -> None:
match node:
case PsExpression():
return
case PsBlock(children):
new_children: list[PsAstNode] = []
for c in children:
if isinstance(c, PsLoop):
nest.has_inner_loops = True
inner_nest = Nesting(nest.depth + 1)
self.visit(c.body, inner_nest)
if not inner_nest.has_inner_loops:
# c is the innermost loop
for pragma in self._insertions[-1]:
new_children.append(PsPragma(pragma.text))
for pragma in self._insertions[nest.depth]:
new_children.append(PsPragma(pragma.text))
new_children.append(c)
node.children = new_children
case other:
for c in other.children:
self.visit(c, nest)
class AddOpenMP:
"""Apply OpenMP directives to loop nests.
This transformation augments the AST with OpenMP pragmas according to the given
`OpenMpParams` configuration.
"""
def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None:
pragma_text = "omp"
pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if omp_params.collapse > 0:
pragma_text += f" collapse({str(omp_params.collapse)})"
self._insert_pragmas = InsertPragmasAtLoops(
ctx, [LoopPragma(pragma_text, omp_params.nesting_depth)]
)
def __call__(self, node: PsAstNode) -> PsAstNode:
return self._insert_pragmas(node)
...@@ -12,6 +12,7 @@ from ..ast.structural import ( ...@@ -12,6 +12,7 @@ from ..ast.structural import (
PsDeclaration, PsDeclaration,
PsAssignment, PsAssignment,
PsComment, PsComment,
PsPragma,
PsStatement, PsStatement,
) )
from ..ast.expressions import PsExpression, PsSymbolExpr from ..ast.expressions import PsExpression, PsSymbolExpr
...@@ -73,7 +74,7 @@ class CanonicalClone: ...@@ -73,7 +74,7 @@ class CanonicalClone:
), ),
) )
case PsComment(): case PsComment() | PsPragma():
return cast(Node_T, node.clone()) return cast(Node_T, node.clone())
case PsDeclaration(lhs, rhs): case PsDeclaration(lhs, rhs):
......
from ..kernelcreation import KernelCreationContext from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.analysis import collect_undefined_symbols
from ..ast.structural import PsLoop, PsBlock, PsConditional from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import PsConstantExpr from ..ast.expressions import (
PsAnd,
PsCast,
PsConstant,
PsConstantExpr,
PsDiv,
PsEq,
PsExpression,
PsGe,
PsGt,
PsIntDiv,
PsLe,
PsLt,
PsMul,
PsNe,
PsNeg,
PsNot,
PsOr,
PsSub,
PsSymbolExpr,
PsAdd,
)
from .eliminate_constants import EliminateConstants from .eliminate_constants import EliminateConstants
from ...types import PsBoolType, PsIntegerType
__all__ = ["EliminateBranches"] __all__ = ["EliminateBranches"]
class IslAnalysisError(Exception):
"""Indicates a fatal error during integer set analysis (based on islpy)"""
class BranchElimContext: class BranchElimContext:
def __init__(self) -> None: def __init__(self) -> None:
self.enclosing_loops: list[PsLoop] = [] self.enclosing_loops: list[PsLoop] = []
self.enclosing_conditions: list[PsExpression] = []
class EliminateBranches: class EliminateBranches:
...@@ -20,12 +48,16 @@ class EliminateBranches: ...@@ -20,12 +48,16 @@ class EliminateBranches:
This pass will attempt to evaluate branch conditions within their context in the AST, and replace This pass will attempt to evaluate branch conditions within their context in the AST, and replace
conditionals by either their then- or their else-block if the branch is unequivocal. conditionals by either their then- or their else-block if the branch is unequivocal.
TODO: If islpy is installed, this pass will incorporate information about the iteration regions If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops into its analysis. of enclosing loops and enclosing conditionals into its analysis.
Args:
use_isl (bool, optional): enable islpy based analysis (default: True)
""" """
def __init__(self, ctx: KernelCreationContext) -> None: def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None:
self._ctx = ctx self._ctx = ctx
self._use_isl = use_isl
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False) self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
...@@ -41,20 +73,30 @@ class EliminateBranches: ...@@ -41,20 +73,30 @@ class EliminateBranches:
case PsBlock(statements): case PsBlock(statements):
statements_new: list[PsAstNode] = [] statements_new: list[PsAstNode] = []
for stmt in statements: for stmt in statements:
if isinstance(stmt, PsConditional): statements_new.append(self.visit(stmt, ec))
result = self.handle_conditional(stmt, ec)
if result is not None:
statements_new.append(result)
else:
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new node.statements = statements_new
case PsConditional(): case PsConditional():
result = self.handle_conditional(node, ec) result = self.handle_conditional(node, ec)
if result is None:
return PsBlock([]) match result:
else: case PsConditional(_, branch_true, branch_false):
return result ec.enclosing_conditions.append(result.condition)
self.visit(branch_true, ec)
ec.enclosing_conditions.pop()
if branch_false is not None:
ec.enclosing_conditions.append(PsNot(result.condition))
self.visit(branch_false, ec)
ec.enclosing_conditions.pop()
case PsBlock():
self.visit(result, ec)
case None:
result = PsBlock([])
case _:
assert False, "unreachable code"
return result
return node return node
...@@ -62,12 +104,124 @@ class EliminateBranches: ...@@ -62,12 +104,124 @@ class EliminateBranches:
self, conditional: PsConditional, ec: BranchElimContext self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None: ) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition) condition_simplified = self._elim_constants(conditional.condition)
if self._use_isl:
condition_simplified = self._isl_simplify_condition(
condition_simplified, ec
)
match condition_simplified: match condition_simplified:
case PsConstantExpr(c) if c.value: case PsConstantExpr(c) if c.value:
return conditional.branch_true return conditional.branch_true
case PsConstantExpr(c) if not c.value: case PsConstantExpr(c) if not c.value:
return conditional.branch_false return conditional.branch_false
# TODO: Analyze condition against counters of enclosing loops using ISL
return conditional return conditional
def _isl_simplify_condition(
self, condition: PsExpression, ec: BranchElimContext
) -> PsExpression:
"""If installed, use ISL to simplify the passed condition to true or
false based on enclosing loops and conditionals. If no simplification
can be made or ISL is not installed, the original condition is returned.
"""
try:
import islpy as isl
except ImportError:
return condition
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
"Only scalar integer and bool constant may appear in isl expressions."
)
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
case _:
raise IslAnalysisError(
f"Not supported by isl or don't know how to print {expr}"
)
dofs = collect_undefined_symbols(condition)
outer_conditions = []
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
)
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition
...@@ -15,19 +15,39 @@ from .types import PsIntegerType, PsNumericType, PsIeeeFloatType ...@@ -15,19 +15,39 @@ from .types import PsIntegerType, PsNumericType, PsIeeeFloatType
from .defaults import DEFAULTS from .defaults import DEFAULTS
@dataclass
class OpenMpConfig:
"""Parameters controlling kernel parallelization using OpenMP."""
nesting_depth: int = 0
"""Nesting depth of the loop that should be parallelized. Must be a nonnegative number."""
collapse: int = 0
"""Argument to the OpenMP ``collapse`` clause"""
schedule: str = "static"
"""Argument to the OpenMP ``schedule`` clause"""
omit_parallel_construct: bool = False
"""If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``.
Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region.
"""
@dataclass @dataclass
class CpuOptimConfig: class CpuOptimConfig:
"""Configuration for the CPU optimizer. """Configuration for the CPU optimizer.
If any flag in this configuration is set to a value not supported by the CPU specified If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised. in `CreateKernelConfig.target`, an error will be raised.
""" """
openmp: bool = False openmp: bool | OpenMpConfig = False
"""Enable OpenMP parallelization. """Enable OpenMP parallelization.
If set to `True`, the kernel will be parallelized using OpenMP according to the OpenMP settings If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`.
given in this configuration. To customize OpenMP parallelization, pass an instance of `OpenMpParams` instead.
""" """
vectorize: bool | VectorizationConfig = False vectorize: bool | VectorizationConfig = False
...@@ -58,7 +78,7 @@ class CpuOptimConfig: ...@@ -58,7 +78,7 @@ class CpuOptimConfig:
@dataclass @dataclass
class VectorizationConfig: class VectorizationConfig:
"""Configuration for the auto-vectorizer. """Configuration for the auto-vectorizer.
If any flag in this configuration is set to a value not supported by the CPU specified If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised. in `CreateKernelConfig.target`, an error will be raised.
""" """
...@@ -182,19 +202,27 @@ class CreateKernelConfig: ...@@ -182,19 +202,27 @@ class CreateKernelConfig:
raise PsOptionsError( raise PsOptionsError(
"Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
) )
# Check optim # Check optim
if self.cpu_optim is not None: if self.cpu_optim is not None:
if not self.target.is_cpu(): if not self.target.is_cpu():
raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}") raise PsOptionsError(
f"`cpu_optim` cannot be set for non-CPU target {self.target}"
if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): )
raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}")
if (
self.cpu_optim.vectorize is not False
and not self.target.is_vector_cpu()
):
raise PsOptionsError(
f"Cannot enable auto-vectorization for non-vector CPU target {self.target}"
)
# Infer JIT # Infer JIT
if self.jit is None: if self.jit is None:
if self.target.is_cpu(): if self.target.is_cpu():
from .backend.jit import LegacyCpuJit from .backend.jit import LegacyCpuJit
self.jit = LegacyCpuJit() self.jit = LegacyCpuJit()
else: else:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -104,6 +104,7 @@ def create_kernel( ...@@ -104,6 +104,7 @@ def create_kernel(
# Target-Specific optimizations # Target-Specific optimizations
if config.target.is_cpu(): if config.target.is_cpu():
from .backend.kernelcreation import optimize_cpu from .backend.kernelcreation import optimize_cpu
kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
erase_anons = EraseAnonymousStructTypes(ctx) erase_anons = EraseAnonymousStructTypes(ctx)
......
...@@ -483,7 +483,20 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -483,7 +483,20 @@ class PsIntegerType(PsScalarType, ABC):
if not isinstance(value, np_dtype): if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
unsigned_suffix = "" if self.signed else "u" unsigned_suffix = "" if self.signed else "u"
return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
match self.width:
case w if w < 32:
# Plain integer literals get at least type `int`, which is 32 bit in all relevant cases
# So we need to explicitly cast to smaller types
return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
case 32:
# No suffix here - becomes `int`, which is 32 bit
return f"{value}{unsigned_suffix}"
case 64:
# LL suffix: `long long` is the only type guaranteed to be 64 bit wide
return f"{value}{unsigned_suffix}LL"
case _:
assert False, "unreachable code"
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width] np_type = self.NUMPY_TYPES[self._width]
......
import pytest
from pystencils import (
fields,
Assignment,
create_kernel,
CreateKernelConfig,
CpuOptimConfig,
OpenMpConfig,
Target,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsLoop, PsPragma
@pytest.mark.parametrize("nesting_depth", range(3))
@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"])
@pytest.mark.parametrize("collapse", range(3))
@pytest.mark.parametrize("omit_parallel_construct", range(3))
def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct):
f, g = fields("f, g: [3D]")
asm = Assignment(f.center(0), g.center(0))
omp = OpenMpConfig(
nesting_depth=nesting_depth,
schedule=schedule,
collapse=collapse,
omit_parallel_construct=omit_parallel_construct,
)
gen_config = CreateKernelConfig(
target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp)
)
kernel = create_kernel(asm, gen_config)
ast = kernel.body
def find_omp_pragma(ast) -> PsPragma:
num_loops = 0
generator = dfs_preorder(ast)
for node in generator:
match node:
case PsLoop():
num_loops += 1
case PsPragma():
loop = next(generator)
assert isinstance(loop, PsLoop)
assert num_loops == nesting_depth
return node
pytest.fail("No OpenMP pragma found")
pragma = find_omp_pragma(ast)
tokens = set(pragma.text.split())
expected_tokens = {"omp", "for", f"schedule({omp.schedule})"}
if not omp.omit_parallel_construct:
expected_tokens.add("parallel")
if omp.collapse > 0:
expected_tokens.add(f"collapse({omp.collapse})")
assert tokens == expected_tokens
...@@ -12,6 +12,7 @@ from pystencils.backend.ast.structural import ( ...@@ -12,6 +12,7 @@ from pystencils.backend.ast.structural import (
PsBlock, PsBlock,
PsConditional, PsConditional,
PsComment, PsComment,
PsPragma,
PsLoop, PsLoop,
) )
from pystencils.types.quick import Fp, Ptr from pystencils.types.quick import Fp, Ptr
...@@ -44,6 +45,7 @@ def test_cloning(): ...@@ -44,6 +45,7 @@ def test_cloning():
PsConditional( PsConditional(
y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")]) y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
), ),
PsPragma("omp parallel for"),
PsLoop( PsLoop(
x, x,
y, y,
...@@ -54,6 +56,7 @@ def test_cloning(): ...@@ -54,6 +56,7 @@ def test_cloning():
PsComment("Loop body"), PsComment("Loop body"),
PsAssignment(x, y), PsAssignment(x, y),
PsAssignment(x, y), PsAssignment(x, y),
PsPragma("#pragma clang loop vectorize(enable)"),
PsStatement( PsStatement(
PsDeref(PsCast(Ptr(Fp(32)), z)) PsDeref(PsCast(Ptr(Fp(32)), z))
+ PsSubscript(z, one + one + one) + PsSubscript(z, one + one + one)
......
...@@ -54,6 +54,6 @@ def test_literals(): ...@@ -54,6 +54,6 @@ def test_literals():
print(code) print(code)
assert "const double x = C;" in code assert "const double x = C;" in code
assert "CELLS[((int64_t) 0)]" in code assert "CELLS[0LL]" in code
assert "CELLS[((int64_t) 1)]" in code assert "CELLS[1LL]" in code
assert "CELLS[((int64_t) 2)]" in code assert "CELLS[2LL]" in code