Skip to content
Snippets Groups Projects
Commit cb9fefd6 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add test cases for pragmas and omp. Remove pragma adding from public interface.

parent 43246f5a
No related branches found
No related tags found
1 merge request!383Pragmas and OpenMP Support
Pipeline #66664 failed
...@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): ...@@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsConstantExpr({repr(self._constant)})" return f"PsConstantExpr({repr(self._constant)})"
class PsLiteralExpr(PsLeafMixIn, PsExpression): class PsLiteralExpr(PsLeafMixIn, PsExpression):
__match_args__ = ("literal",) __match_args__ = ("literal",)
...@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): ...@@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
def clone(self) -> PsLiteralExpr: def clone(self) -> PsLiteralExpr:
return PsLiteralExpr(self._literal) return PsLiteralExpr(self._literal)
def structurally_equal(self, other: PsAstNode) -> bool: def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsLiteralExpr): if not isinstance(other, PsLiteralExpr):
return False return False
......
...@@ -250,7 +250,7 @@ class CAstPrinter: ...@@ -250,7 +250,7 @@ class CAstPrinter:
) )
return dtype.create_literal(constant.value) return dtype.create_literal(constant.value)
case PsLiteralExpr(lit): case PsLiteralExpr(lit):
return lit.text return lit.text
......
...@@ -5,7 +5,7 @@ from .context import KernelCreationContext ...@@ -5,7 +5,7 @@ from .context import KernelCreationContext
from ..platforms import GenericCpu from ..platforms import GenericCpu
from ..ast.structural import PsBlock from ..ast.structural import PsBlock
from ...config import CpuOptimConfig, OpenMpParams from ...config import CpuOptimConfig, OpenMpConfig
def optimize_cpu( def optimize_cpu(
...@@ -34,17 +34,12 @@ def optimize_cpu( ...@@ -34,17 +34,12 @@ def optimize_cpu(
if cfg.openmp is not False: if cfg.openmp is not False:
from ..transformations import AddOpenMP from ..transformations import AddOpenMP
params = cfg.openmp if isinstance(cfg.openmp, OpenMpParams) else OpenMpParams()
params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig()
add_omp = AddOpenMP(ctx, params) add_omp = AddOpenMP(ctx, params)
kernel_ast = cast(PsBlock, add_omp(kernel_ast)) kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if cfg.use_cacheline_zeroing: if cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet") raise NotImplementedError("CL-zeroing not implemented yet")
if cfg.insert_loop_pragmas:
from ..transformations import InsertPragmasAtLoops
insert_pragmas = InsertPragmasAtLoops(ctx, cfg.insert_loop_pragmas)
kernel_ast = cast(PsBlock, insert_pragmas(kernel_ast))
return kernel_ast return kernel_ast
...@@ -159,7 +159,7 @@ class TypeContext: ...@@ -159,7 +159,7 @@ class TypeContext:
f" Constant type: {c.dtype}\n" f" Constant type: {c.dtype}\n"
f" Target type: {self._target_type}" f" Target type: {self._target_type}"
) )
case PsLiteralExpr(lit): case PsLiteralExpr(lit):
if not self._compatible(lit.dtype): if not self._compatible(lit.dtype):
raise TypificationError( raise TypificationError(
......
...@@ -4,7 +4,7 @@ from ..types import PsType, constify ...@@ -4,7 +4,7 @@ from ..types import PsType, constify
class PsLiteral: class PsLiteral:
"""Representation of literal code. """Representation of literal code.
Instances of this class represent code literals inside the AST. Instances of this class represent code literals inside the AST.
These literals are not to be confused with C literals; the name `Literal` refers to the fact that These literals are not to be confused with C literals; the name `Literal` refers to the fact that
the code generator takes them "literally", printing them as they are. the code generator takes them "literally", printing them as they are.
...@@ -22,22 +22,22 @@ class PsLiteral: ...@@ -22,22 +22,22 @@ class PsLiteral:
@property @property
def text(self) -> str: def text(self) -> str:
return self._text return self._text
@property @property
def dtype(self) -> PsType: def dtype(self) -> PsType:
return self._dtype return self._dtype
def __str__(self) -> str: def __str__(self) -> str:
return f"{self._text}: {self._dtype}" return f"{self._text}: {self._dtype}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsLiteral): if not isinstance(other, PsLiteral):
return False return False
return self._text == other._text and self._dtype == other._dtype return self._text == other._text and self._dtype == other._dtype
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((PsLiteral, self._text, self._dtype)) return hash((PsLiteral, self._text, self._dtype))
...@@ -83,7 +83,7 @@ from .eliminate_constants import EliminateConstants ...@@ -83,7 +83,7 @@ from .eliminate_constants import EliminateConstants
from .eliminate_branches import EliminateBranches from .eliminate_branches import EliminateBranches
from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations
from .reshape_loops import ReshapeLoops from .reshape_loops import ReshapeLoops
from .add_pragmas import InsertPragmasAtLoops, AddOpenMP from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP
from .erase_anonymous_structs import EraseAnonymousStructTypes from .erase_anonymous_structs import EraseAnonymousStructTypes
from .select_functions import SelectFunctions from .select_functions import SelectFunctions
from .select_intrinsics import MaterializeVectorIntrinsics from .select_intrinsics import MaterializeVectorIntrinsics
...@@ -96,6 +96,7 @@ __all__ = [ ...@@ -96,6 +96,7 @@ __all__ = [
"HoistLoopInvariantDeclarations", "HoistLoopInvariantDeclarations",
"ReshapeLoops", "ReshapeLoops",
"InsertPragmasAtLoops", "InsertPragmasAtLoops",
"LoopPragma",
"AddOpenMP", "AddOpenMP",
"EraseAnonymousStructTypes", "EraseAnonymousStructTypes",
"SelectFunctions", "SelectFunctions",
......
...@@ -8,9 +8,24 @@ from ..ast import PsAstNode ...@@ -8,9 +8,24 @@ from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression from ..ast.expressions import PsExpression
from ...config import LoopPragma, OpenMpParams from ...config import OpenMpConfig
__all__ = ["InsertPragmasAtLoops", "AddOpenMP"] __all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"]
@dataclass
class LoopPragma:
"""A pragma that should be prepended to loops at a certain nesting depth."""
text: str
"""The pragma text, without the ``#pragma ``"""
loop_nesting_depth: int
"""Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops."""
def __post_init__(self):
if self.loop_nesting_depth < -1:
raise ValueError("Loop nesting depth must be nonnegative or -1.")
@dataclass @dataclass
...@@ -87,9 +102,10 @@ class AddOpenMP: ...@@ -87,9 +102,10 @@ class AddOpenMP:
`OpenMpParams` configuration. `OpenMpParams` configuration.
""" """
def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpParams) -> None: def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None:
pragma_text = "parallel " if not omp_params.omit_parallel_construct else "" pragma_text = "omp"
pragma_text += f"for schedule({omp_params.schedule})" pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if omp_params.collapse > 0: if omp_params.collapse > 0:
pragma_text += f" collapse({str(omp_params.collapse)})" pragma_text += f" collapse({str(omp_params.collapse)})"
......
...@@ -16,26 +16,11 @@ from .defaults import DEFAULTS ...@@ -16,26 +16,11 @@ from .defaults import DEFAULTS
@dataclass @dataclass
class LoopPragma: class OpenMpConfig:
"""A pragma that should be prepended to loops at a certain nesting depth."""
text: str
"""The pragma text, without the ``#pragma ``"""
loop_nesting_depth: int
"""Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops."""
def __post_init__(self):
if self.loop_nesting_depth < -1:
raise ValueError("Loop nesting depth must be nonnegative or -1.")
@dataclass
class OpenMpParams:
"""Parameters controlling kernel parallelization using OpenMP.""" """Parameters controlling kernel parallelization using OpenMP."""
nesting_depth: int = 0 nesting_depth: int = 0
"""Nesting depth of the loop that should be parallelized""" """Nesting depth of the loop that should be parallelized. Must be a nonnegative number."""
collapse: int = 0 collapse: int = 0
"""Argument to the OpenMP ``collapse`` clause""" """Argument to the OpenMP ``collapse`` clause"""
...@@ -53,12 +38,12 @@ class OpenMpParams: ...@@ -53,12 +38,12 @@ class OpenMpParams:
@dataclass @dataclass
class CpuOptimConfig: class CpuOptimConfig:
"""Configuration for the CPU optimizer. """Configuration for the CPU optimizer.
If any flag in this configuration is set to a value not supported by the CPU specified If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised. in `CreateKernelConfig.target`, an error will be raised.
""" """
openmp: bool | OpenMpParams = False openmp: bool | OpenMpConfig = False
"""Enable OpenMP parallelization. """Enable OpenMP parallelization.
If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`. If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`.
...@@ -89,22 +74,11 @@ class CpuOptimConfig: ...@@ -89,22 +74,11 @@ class CpuOptimConfig:
to produce cacheline zeroing instructions where possible. to produce cacheline zeroing instructions where possible.
""" """
insert_loop_pragmas: Sequence[LoopPragma] = ()
"""Insert pragmas before loops.
Each pragma is annotated with a nesting depth and will be prepended to all loops of that depth;
the value ``-1`` indicates the innermost loop.
The relative order of pragmas with the (exact) same nesting depth is preserved;
however, the relative order of pragmas inserted at ``-1`` and the actual depth of the deepest loop
is undefined.
"""
@dataclass @dataclass
class VectorizationConfig: class VectorizationConfig:
"""Configuration for the auto-vectorizer. """Configuration for the auto-vectorizer.
If any flag in this configuration is set to a value not supported by the CPU specified If any flag in this configuration is set to a value not supported by the CPU specified
in `CreateKernelConfig.target`, an error will be raised. in `CreateKernelConfig.target`, an error will be raised.
""" """
...@@ -228,19 +202,27 @@ class CreateKernelConfig: ...@@ -228,19 +202,27 @@ class CreateKernelConfig:
raise PsOptionsError( raise PsOptionsError(
"Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`"
) )
# Check optim # Check optim
if self.cpu_optim is not None: if self.cpu_optim is not None:
if not self.target.is_cpu(): if not self.target.is_cpu():
raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}") raise PsOptionsError(
f"`cpu_optim` cannot be set for non-CPU target {self.target}"
if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): )
raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}")
if (
self.cpu_optim.vectorize is not False
and not self.target.is_vector_cpu()
):
raise PsOptionsError(
f"Cannot enable auto-vectorization for non-vector CPU target {self.target}"
)
# Infer JIT # Infer JIT
if self.jit is None: if self.jit is None:
if self.target.is_cpu(): if self.target.is_cpu():
from .backend.jit import LegacyCpuJit from .backend.jit import LegacyCpuJit
self.jit = LegacyCpuJit() self.jit = LegacyCpuJit()
else: else:
raise NotImplementedError( raise NotImplementedError(
......
...@@ -104,6 +104,7 @@ def create_kernel( ...@@ -104,6 +104,7 @@ def create_kernel(
# Target-Specific optimizations # Target-Specific optimizations
if config.target.is_cpu(): if config.target.is_cpu():
from .backend.kernelcreation import optimize_cpu from .backend.kernelcreation import optimize_cpu
kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
erase_anons = EraseAnonymousStructTypes(ctx) erase_anons = EraseAnonymousStructTypes(ctx)
......
import pytest
from pystencils import (
fields,
Assignment,
create_kernel,
CreateKernelConfig,
CpuOptimConfig,
OpenMpConfig,
Target,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsLoop, PsPragma
@pytest.mark.parametrize("nesting_depth", range(3))
@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"])
@pytest.mark.parametrize("collapse", range(3))
@pytest.mark.parametrize("omit_parallel_construct", range(3))
def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct):
f, g = fields("f, g: [3D]")
asm = Assignment(f.center(0), g.center(0))
omp = OpenMpConfig(
nesting_depth=nesting_depth,
schedule=schedule,
collapse=collapse,
omit_parallel_construct=omit_parallel_construct,
)
gen_config = CreateKernelConfig(
target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp)
)
kernel = create_kernel(asm, gen_config)
ast = kernel.body
def find_omp_pragma(ast) -> PsPragma:
num_loops = 0
generator = dfs_preorder(ast)
for node in generator:
match node:
case PsLoop():
num_loops += 1
case PsPragma():
loop = next(generator)
assert isinstance(loop, PsLoop)
assert num_loops == nesting_depth
return node
pytest.fail("No OpenMP pragma found")
pragma = find_omp_pragma(ast)
tokens = set(pragma.text.split())
expected_tokens = {"omp", "for", f"schedule({omp.schedule})"}
if not omp.omit_parallel_construct:
expected_tokens.add("parallel")
if omp.collapse > 0:
expected_tokens.add(f"collapse({omp.collapse})")
assert tokens == expected_tokens
...@@ -56,7 +56,7 @@ def test_cloning(): ...@@ -56,7 +56,7 @@ def test_cloning():
PsComment("Loop body"), PsComment("Loop body"),
PsAssignment(x, y), PsAssignment(x, y),
PsAssignment(x, y), PsAssignment(x, y),
PsPragma("anyone has an idea for a nice pragma?"), PsPragma("#pragma clang loop vectorize(enable)"),
PsStatement( PsStatement(
PsDeref(PsCast(Ptr(Fp(32)), z)) PsDeref(PsCast(Ptr(Fp(32)), z))
+ PsSubscript(z, one + one + one) + PsSubscript(z, one + one + one)
......
import sympy as sp
from itertools import product
from pystencils import make_slice, fields, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
FullIterationSpace,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop
from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma
def test_insert_pragmas():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
f, g = fields("f, g: [3D]")
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[:, :, :], archetype_field=f
)
ctx.set_iteration_space(ispace)
stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1]))
loop_body = PsBlock([
factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil))))
])
loops = factory.loops_from_ispace(ispace, loop_body)
pragmas = (
LoopPragma("omp parallel for", 0),
LoopPragma("some nonsense pragma", 1),
LoopPragma("omp simd", -1),
)
add_pragmas = InsertPragmasAtLoops(ctx, pragmas)
ast = add_pragmas(loops)
assert isinstance(ast, PsBlock)
first_pragma = ast.statements[0]
assert isinstance(first_pragma, PsPragma)
assert first_pragma.text == pragmas[0].text
assert ast.statements[1] == loops
second_pragma = loops.body.statements[0]
assert isinstance(second_pragma, PsPragma)
assert second_pragma.text == pragmas[1].text
second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1]
assert isinstance(second_loop, PsLoop)
third_pragma = second_loop.body.statements[0]
assert isinstance(third_pragma, PsPragma)
assert third_pragma.text == pragmas[2].text
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment