Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (6)
Showing
with 504 additions and 62 deletions
......@@ -16,3 +16,6 @@ ignore_missing_imports=true
[mypy-appdirs.*]
ignore_missing_imports=true
[mypy-islpy.*]
ignore_missing_imports=true
......@@ -6,8 +6,14 @@ from . import fd
from . import stencil as stencil
from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields
from .types import create_type
from .cache import clear_cache
from .config import CreateKernelConfig, CpuOptimConfig, VectorizationConfig
from .config import (
CreateKernelConfig,
CpuOptimConfig,
VectorizationConfig,
OpenMpConfig,
)
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
from .backend.kernelfunction import KernelFunction
......@@ -34,10 +40,12 @@ __all__ = [
"fields",
"DEFAULTS",
"TypedSymbol",
"create_type",
"make_slice",
"CreateKernelConfig",
"CpuOptimConfig",
"VectorizationConfig",
"OpenMpConfig",
"create_kernel",
"KernelFunction",
"Target",
......
......@@ -5,7 +5,7 @@ from .structural import (
PsAssignment,
PsAstNode,
PsBlock,
PsComment,
PsEmptyLeafMixIn,
PsConditional,
PsDeclaration,
PsExpression,
......@@ -63,7 +63,7 @@ class UndefinedSymbolsCollector:
undefined_vars |= self(branch_false)
return undefined_vars
case PsComment():
case PsEmptyLeafMixIn():
return set()
case unknown:
......@@ -92,11 +92,11 @@ class UndefinedSymbolsCollector:
case (
PsAssignment()
| PsBlock()
| PsComment()
| PsConditional()
| PsExpression()
| PsLoop()
| PsStatement()
| PsEmptyLeafMixIn()
):
return set()
......
......@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",)
......@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
def clone(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr):
return False
......
......@@ -4,32 +4,32 @@ from .structural import PsAstNode
def dfs_preorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Pre-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True
filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
"""
if yield_pred(node):
if filter_pred(node):
yield node
for c in node.children:
yield from dfs_preorder(c, yield_pred)
yield from dfs_preorder(c, filter_pred)
def dfs_postorder(
node: PsAstNode, yield_pred: Callable[[PsAstNode], bool] = lambda _: True
node: PsAstNode, filter_pred: Callable[[PsAstNode], bool] = lambda _: True
) -> Generator[PsAstNode, None, None]:
"""Post-Order depth-first traversal of an abstract syntax tree.
Args:
node: The tree's root node
yield_pred: Filter predicate; a node is only yielded to the caller if `yield_pred(node)` returns True
filter_pred: Filter predicate; a node is only returned to the caller if `yield_pred(node)` returns True
"""
for c in node.children:
yield from dfs_postorder(c, yield_pred)
yield from dfs_postorder(c, filter_pred)
if yield_pred(node):
if filter_pred(node):
yield node
......@@ -307,7 +307,42 @@ class PsConditional(PsAstNode):
assert False, "unreachable code"
class PsComment(PsLeafMixIn, PsAstNode):
class PsEmptyLeafMixIn:
"""Mix-in marking AST leaves that can be treated as empty by the code generator,
such as comments and preprocessor directives."""
pass
class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
"""A C/C++ preprocessor pragma.
Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
Args:
text: The pragma's text, without the ``#pragma ``.
"""
__match_args__ = ("text",)
def __init__(self, text: str) -> None:
self._text = text
@property
def text(self) -> str:
return self._text
def clone(self) -> PsPragma:
return PsPragma(self.text)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsPragma):
return False
return self._text == other._text
class PsComment(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
__match_args__ = ("lines",)
def __init__(self, text: str) -> None:
......
......@@ -10,6 +10,7 @@ from .ast.structural import (
PsLoop,
PsConditional,
PsComment,
PsPragma,
)
from .ast.expressions import (
......@@ -235,6 +236,9 @@ class CAstPrinter:
lines_list[-1] = lines_list[-1] + " */"
return pc.indent("\n".join(lines_list))
case PsPragma(text):
return pc.indent("#pragma " + text)
case PsSymbolExpr(symbol):
return symbol.name
......@@ -246,7 +250,7 @@ class CAstPrinter:
)
return dtype.create_literal(constant.value)
case PsLiteralExpr(lit):
return lit.text
......
......@@ -3,10 +3,9 @@ from typing import cast
from .context import KernelCreationContext
from ..platforms import GenericCpu
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from ..ast.structural import PsBlock
from ...config import CpuOptimConfig
from ...config import CpuOptimConfig, OpenMpConfig
def optimize_cpu(
......@@ -16,6 +15,7 @@ def optimize_cpu(
cfg: CpuOptimConfig | None,
) -> PsBlock:
"""Carry out CPU-specific optimizations according to the given configuration."""
from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
......@@ -32,8 +32,12 @@ def optimize_cpu(
if cfg.vectorize is not False:
raise NotImplementedError("Vectorization not implemented yet")
if cfg.openmp:
raise NotImplementedError("OpenMP not implemented yet")
if cfg.openmp is not False:
from ..transformations import AddOpenMP
params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig()
add_omp = AddOpenMP(ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
......
......@@ -22,7 +22,7 @@ from ..ast.structural import (
PsExpression,
PsAssignment,
PsDeclaration,
PsComment,
PsEmptyLeafMixIn,
)
from ..ast.expressions import (
PsArrayAccess,
......@@ -159,7 +159,7 @@ class TypeContext:
f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}"
)
case PsLiteralExpr(lit):
if not self._compatible(lit.dtype):
raise TypificationError(
......@@ -336,7 +336,7 @@ class Typifier:
self.visit(body)
case PsComment():
case PsEmptyLeafMixIn():
pass
case _:
......
......@@ -4,7 +4,7 @@ from ..types import PsType, constify
class PsLiteral:
"""Representation of literal code.
Instances of this class represent code literals inside the AST.
These literals are not to be confused with C literals; the name `Literal` refers to the fact that
the code generator takes them "literally", printing them as they are.
......@@ -22,22 +22,22 @@ class PsLiteral:
@property
def text(self) -> str:
return self._text
@property
def dtype(self) -> PsType:
return self._dtype
def __str__(self) -> str:
return f"{self._text}: {self._dtype}"
def __repr__(self) -> str:
return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLiteral):
return False
return self._text == other._text and self._dtype == other._dtype
def __hash__(self) -> int:
return hash((PsLiteral, self._text, self._dtype))
......@@ -60,6 +60,11 @@ Loop Reshaping Transformations
.. autoclass:: ReshapeLoops
:members:
.. autoclass:: InsertPragmasAtLoops
:members:
.. autoclass:: AddOpenMP
:members:
Code Lowering and Materialization
---------------------------------
......@@ -78,6 +83,7 @@ from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops
from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP
from .erase_anonymous_structs import EraseAnonymousStructTypes
from .select_functions import SelectFunctions
from .select_intrinsics import MaterializeVectorIntrinsics
......@@ -89,6 +95,9 @@ __all__ = [
"EliminateBranches",
"HoistLoopInvariantDeclarations",
"ReshapeLoops",
"InsertPragmasAtLoops",
"LoopPragma",
"AddOpenMP",
"EraseAnonymousStructTypes",
"SelectFunctions",
"MaterializeVectorIntrinsics",
......
from dataclasses import dataclass
from typing import Sequence
from collections import defaultdict
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression
from ...config import OpenMpConfig
__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"]
@dataclass
class LoopPragma:
"""A pragma that should be prepended to loops at a certain nesting depth."""
text: str
"""The pragma text, without the ``#pragma ``"""
loop_nesting_depth: int
"""Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops."""
def __post_init__(self):
if self.loop_nesting_depth < -1:
raise ValueError("Loop nesting depth must be nonnegative or -1.")
@dataclass
class Nesting:
depth: int
has_inner_loops: bool = False
class InsertPragmasAtLoops:
"""Insert pragmas before loops in a loop nest.
This transformation augments the AST with pragma directives which are prepended to loops.
The directives are annotated with the nesting depth of the loops they should be added to,
where ``-1`` indicates the innermost loop.
The relative order of pragmas with the (exact) same nesting depth is preserved;
however, no guarantees are given about the relative order of pragmas inserted at ``-1``
and at the actual depth of the innermost loop.
"""
def __init__(
self, ctx: KernelCreationContext, insertions: Sequence[LoopPragma]
) -> None:
self._ctx = ctx
self._insertions: dict[int, list[LoopPragma]] = defaultdict(list)
for ins in insertions:
self._insertions[ins.loop_nesting_depth].append(ins)
def __call__(self, node: PsAstNode) -> PsAstNode:
is_loop = isinstance(node, PsLoop)
if is_loop:
node = PsBlock([node])
self.visit(node, Nesting(0))
if is_loop and len(node.children) == 1:
node = node.children[0]
return node
def visit(self, node: PsAstNode, nest: Nesting) -> None:
match node:
case PsExpression():
return
case PsBlock(children):
new_children: list[PsAstNode] = []
for c in children:
if isinstance(c, PsLoop):
nest.has_inner_loops = True
inner_nest = Nesting(nest.depth + 1)
self.visit(c.body, inner_nest)
if not inner_nest.has_inner_loops:
# c is the innermost loop
for pragma in self._insertions[-1]:
new_children.append(PsPragma(pragma.text))
for pragma in self._insertions[nest.depth]:
new_children.append(PsPragma(pragma.text))
new_children.append(c)
node.children = new_children
case other:
for c in other.children:
self.visit(c, nest)
class AddOpenMP:
"""Apply OpenMP directives to loop nests.
This transformation augments the AST with OpenMP pragmas according to the given
`OpenMpParams` configuration.
"""
def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None:
pragma_text = "omp"
pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if omp_params.collapse > 0:
pragma_text += f" collapse({str(omp_params.collapse)})"
self._insert_pragmas = InsertPragmasAtLoops(
ctx, [LoopPragma(pragma_text, omp_params.nesting_depth)]
)
def __call__(self, node: PsAstNode) -> PsAstNode:
return self._insert_pragmas(node)
......@@ -12,6 +12,7 @@ from ..ast.structural import (
PsDeclaration,
PsAssignment,
PsComment,
PsPragma,
PsStatement,
)
from ..ast.expressions import PsExpression, PsSymbolExpr
......@@ -73,7 +74,7 @@ class CanonicalClone:
),
)
case PsComment():
case PsComment() | PsPragma():
return cast(Node_T, node.clone())
case PsDeclaration(lhs, rhs):
......
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.analysis import collect_undefined_symbols
from ..ast.structural import PsLoop, PsBlock, PsConditional
from ..ast.expressions import PsConstantExpr
from ..ast.expressions import (
PsAnd,
PsCast,
PsConstant,
PsConstantExpr,
PsDiv,
PsEq,
PsExpression,
PsGe,
PsGt,
PsIntDiv,
PsLe,
PsLt,
PsMul,
PsNe,
PsNeg,
PsNot,
PsOr,
PsSub,
PsSymbolExpr,
PsAdd,
)
from .eliminate_constants import EliminateConstants
from ...types import PsBoolType, PsIntegerType
__all__ = ["EliminateBranches"]
class IslAnalysisError(Exception):
"""Indicates a fatal error during integer set analysis (based on islpy)"""
class BranchElimContext:
def __init__(self) -> None:
self.enclosing_loops: list[PsLoop] = []
self.enclosing_conditions: list[PsExpression] = []
class EliminateBranches:
......@@ -20,12 +48,16 @@ class EliminateBranches:
This pass will attempt to evaluate branch conditions within their context in the AST, and replace
conditionals by either their then- or their else-block if the branch is unequivocal.
TODO: If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops into its analysis.
If islpy is installed, this pass will incorporate information about the iteration regions
of enclosing loops and enclosing conditionals into its analysis.
Args:
use_isl (bool, optional): enable islpy based analysis (default: True)
"""
def __init__(self, ctx: KernelCreationContext) -> None:
def __init__(self, ctx: KernelCreationContext, use_isl: bool = True) -> None:
self._ctx = ctx
self._use_isl = use_isl
self._elim_constants = EliminateConstants(ctx, extract_constant_exprs=False)
def __call__(self, node: PsAstNode) -> PsAstNode:
......@@ -41,20 +73,30 @@ class EliminateBranches:
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
if isinstance(stmt, PsConditional):
result = self.handle_conditional(stmt, ec)
if result is not None:
statements_new.append(result)
else:
statements_new.append(self.visit(stmt, ec))
statements_new.append(self.visit(stmt, ec))
node.statements = statements_new
case PsConditional():
result = self.handle_conditional(node, ec)
if result is None:
return PsBlock([])
else:
return result
match result:
case PsConditional(_, branch_true, branch_false):
ec.enclosing_conditions.append(result.condition)
self.visit(branch_true, ec)
ec.enclosing_conditions.pop()
if branch_false is not None:
ec.enclosing_conditions.append(PsNot(result.condition))
self.visit(branch_false, ec)
ec.enclosing_conditions.pop()
case PsBlock():
self.visit(result, ec)
case None:
result = PsBlock([])
case _:
assert False, "unreachable code"
return result
return node
......@@ -62,12 +104,124 @@ class EliminateBranches:
self, conditional: PsConditional, ec: BranchElimContext
) -> PsConditional | PsBlock | None:
condition_simplified = self._elim_constants(conditional.condition)
if self._use_isl:
condition_simplified = self._isl_simplify_condition(
condition_simplified, ec
)
match condition_simplified:
case PsConstantExpr(c) if c.value:
return conditional.branch_true
case PsConstantExpr(c) if not c.value:
return conditional.branch_false
# TODO: Analyze condition against counters of enclosing loops using ISL
return conditional
def _isl_simplify_condition(
self, condition: PsExpression, ec: BranchElimContext
) -> PsExpression:
"""If installed, use ISL to simplify the passed condition to true or
false based on enclosing loops and conditionals. If no simplification
can be made or ISL is not installed, the original condition is returned.
"""
try:
import islpy as isl
except ImportError:
return condition
def printer(expr: PsExpression):
match expr:
case PsSymbolExpr(symbol):
return symbol.name
case PsConstantExpr(constant):
dtype = constant.get_dtype()
if not isinstance(dtype, (PsIntegerType, PsBoolType)):
raise IslAnalysisError(
"Only scalar integer and bool constant may appear in isl expressions."
)
return str(constant.value)
case PsAdd(op1, op2):
return f"({printer(op1)} + {printer(op2)})"
case PsSub(op1, op2):
return f"({printer(op1)} - {printer(op2)})"
case PsMul(op1, op2):
return f"({printer(op1)} * {printer(op2)})"
case PsDiv(op1, op2) | PsIntDiv(op1, op2):
return f"({printer(op1)} / {printer(op2)})"
case PsAnd(op1, op2):
return f"({printer(op1)} and {printer(op2)})"
case PsOr(op1, op2):
return f"({printer(op1)} or {printer(op2)})"
case PsEq(op1, op2):
return f"({printer(op1)} = {printer(op2)})"
case PsNe(op1, op2):
return f"({printer(op1)} != {printer(op2)})"
case PsGt(op1, op2):
return f"({printer(op1)} > {printer(op2)})"
case PsGe(op1, op2):
return f"({printer(op1)} >= {printer(op2)})"
case PsLt(op1, op2):
return f"({printer(op1)} < {printer(op2)})"
case PsLe(op1, op2):
return f"({printer(op1)} <= {printer(op2)})"
case PsNeg(operand):
return f"(-{printer(operand)})"
case PsNot(operand):
return f"(not {printer(operand)})"
case PsCast(_, operand):
return printer(operand)
case _:
raise IslAnalysisError(
f"Not supported by isl or don't know how to print {expr}"
)
dofs = collect_undefined_symbols(condition)
outer_conditions = []
for loop in ec.enclosing_loops:
if not (
isinstance(loop.step, PsConstantExpr)
and loop.step.constant.value == 1
):
raise IslAnalysisError(
"Loops with strides != 1 are not yet supported."
)
dofs.add(loop.counter.symbol)
dofs.update(collect_undefined_symbols(loop.start))
dofs.update(collect_undefined_symbols(loop.stop))
loop_start_str = printer(loop.start)
loop_stop_str = printer(loop.stop)
ctr_name = loop.counter.symbol.name
outer_conditions.append(
f"{ctr_name} >= {loop_start_str} and {ctr_name} < {loop_stop_str}"
)
for cond in ec.enclosing_conditions:
dofs.update(collect_undefined_symbols(cond))
outer_conditions.append(printer(cond))
dofs_str = ",".join(dof.name for dof in dofs)
outer_conditions_str = " and ".join(outer_conditions)
condition_str = printer(condition)
outer_set = isl.BasicSet(f"{{ [{dofs_str}] : {outer_conditions_str} }}")
inner_set = isl.BasicSet(f"{{ [{dofs_str}] : {condition_str} }}")
if inner_set.is_empty():
return PsExpression.make(PsConstant(False))
intersection = outer_set.intersect(inner_set)
if intersection.is_empty():
return PsExpression.make(PsConstant(False))
elif intersection == outer_set:
return PsExpression.make(PsConstant(True))
else:
return condition
......@@ -15,19 +15,39 @@ from .types import PsIntegerType, PsNumericType, PsIeeeFloatType
from .defaults import DEFAULTS
@dataclass
class OpenMpConfig:
"""Parameters controlling kernel parallelization using OpenMP."""
nesting_depth: int = 0
"""Nesting depth of the loop that should be parallelized. Must be a nonnegative number."""
collapse: int = 0
"""Argument to the OpenMP ``collapse`` clause"""
schedule: str = "static"
"""Argument to the OpenMP ``schedule`` clause"""
omit_parallel_construct: bool = False
"""If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``.
Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region.
"""
@dataclass
class CpuOptimConfig:
"""Configuration for the CPU optimizer.
If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised.
"""
openmp: bool = False
openmp: bool | OpenMpConfig = False
"""Enable OpenMP parallelization.
If set to `True`, the kernel will be parallelized using OpenMP according to the OpenMP settings
given in this configuration.
If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`.
To customize OpenMP parallelization, pass an instance of `OpenMpParams` instead.
"""
vectorize: bool | VectorizationConfig = False
......@@ -58,7 +78,7 @@ class CpuOptimConfig:
@dataclass
class VectorizationConfig:
"""Configuration for the auto-vectorizer.
If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised.
"""
......@@ -182,19 +202,27 @@ class CreateKernelConfig:
raise PsOptionsError(
"Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
)
# Check optim
if self.cpu_optim is not None:
if not self.target.is_cpu():
raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}")
if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu():
raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}")
raise PsOptionsError(
f"`cpu_optim` cannot be set for non-CPU target {self.target}"
)
if (
self.cpu_optim.vectorize is not False
and not self.target.is_vector_cpu()
):
raise PsOptionsError(
f"Cannot enable auto-vectorization for non-vector CPU target {self.target}"
)
# Infer JIT
if self.jit is None:
if self.target.is_cpu():
from .backend.jit import LegacyCpuJit
self.jit = LegacyCpuJit()
else:
raise NotImplementedError(
......
......@@ -104,6 +104,7 @@ def create_kernel(
# Target-Specific optimizations
if config.target.is_cpu():
from .backend.kernelcreation import optimize_cpu
kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
erase_anons = EraseAnonymousStructTypes(ctx)
......
......@@ -483,7 +483,20 @@ class PsIntegerType(PsScalarType, ABC):
if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
unsigned_suffix = "" if self.signed else "u"
return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
match self.width:
case w if w < 32:
# Plain integer literals get at least type `int`, which is 32 bit in all relevant cases
# So we need to explicitly cast to smaller types
return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
case 32:
# No suffix here - becomes `int`, which is 32 bit
return f"{value}{unsigned_suffix}"
case 64:
# LL suffix: `long long` is the only type guaranteed to be 64 bit wide
return f"{value}{unsigned_suffix}LL"
case _:
assert False, "unreachable code"
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
......
import pytest
from pystencils import (
fields,
Assignment,
create_kernel,
CreateKernelConfig,
CpuOptimConfig,
OpenMpConfig,
Target,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsLoop, PsPragma
@pytest.mark.parametrize("nesting_depth", range(3))
@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"])
@pytest.mark.parametrize("collapse", range(3))
@pytest.mark.parametrize("omit_parallel_construct", range(3))
def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct):
f, g = fields("f, g: [3D]")
asm = Assignment(f.center(0), g.center(0))
omp = OpenMpConfig(
nesting_depth=nesting_depth,
schedule=schedule,
collapse=collapse,
omit_parallel_construct=omit_parallel_construct,
)
gen_config = CreateKernelConfig(
target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp)
)
kernel = create_kernel(asm, gen_config)
ast = kernel.body
def find_omp_pragma(ast) -> PsPragma:
num_loops = 0
generator = dfs_preorder(ast)
for node in generator:
match node:
case PsLoop():
num_loops += 1
case PsPragma():
loop = next(generator)
assert isinstance(loop, PsLoop)
assert num_loops == nesting_depth
return node
pytest.fail("No OpenMP pragma found")
pragma = find_omp_pragma(ast)
tokens = set(pragma.text.split())
expected_tokens = {"omp", "for", f"schedule({omp.schedule})"}
if not omp.omit_parallel_construct:
expected_tokens.add("parallel")
if omp.collapse > 0:
expected_tokens.add(f"collapse({omp.collapse})")
assert tokens == expected_tokens
......@@ -12,6 +12,7 @@ from pystencils.backend.ast.structural import (
PsBlock,
PsConditional,
PsComment,
PsPragma,
PsLoop,
)
from pystencils.types.quick import Fp, Ptr
......@@ -44,6 +45,7 @@ def test_cloning():
PsConditional(
y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
),
PsPragma("omp parallel for"),
PsLoop(
x,
y,
......@@ -54,6 +56,7 @@ def test_cloning():
PsComment("Loop body"),
PsAssignment(x, y),
PsAssignment(x, y),
PsPragma("#pragma clang loop vectorize(enable)"),
PsStatement(
PsDeref(PsCast(Ptr(Fp(32)), z))
+ PsSubscript(z, one + one + one)
......
......@@ -54,6 +54,6 @@ def test_literals():
print(code)
assert "const double x = C;" in code
assert "CELLS[((int64_t) 0)]" in code
assert "CELLS[((int64_t) 1)]" in code
assert "CELLS[((int64_t) 2)]" in code
assert "CELLS[0LL]" in code
assert "CELLS[1LL]" in code
assert "CELLS[2LL]" in code