diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7bcf62b973d8ace8e9ad9847ae165c398f1cbb0e..4063f7b539ab387d1f950a75e735f3c6201b5ef2 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -158,7 +158,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def __repr__(self) -> str: return f"PsConstantExpr({repr(self._constant)})" - + class PsLiteralExpr(PsLeafMixIn, PsExpression): __match_args__ = ("literal",) @@ -177,7 +177,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression): def clone(self) -> PsLiteralExpr: return PsLiteralExpr(self._literal) - + def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsLiteralExpr): return False diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 08b864d74dc00f42ed1ef6ca4bf23b6cb1454d31..e8fc2a662b2f49c43fe51d021915b9e44cc59ad3 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -250,7 +250,7 @@ class CAstPrinter: ) return dtype.create_literal(constant.value) - + case PsLiteralExpr(lit): return lit.text diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py index d8d761aec67f77f35484f779687e13b042cc98e7..29b133ff164e856783f14eb83357c8382db9ba5d 100644 --- a/src/pystencils/backend/kernelcreation/cpu_optimization.py +++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py @@ -5,7 +5,7 @@ from .context import KernelCreationContext from ..platforms import GenericCpu from ..ast.structural import PsBlock -from ...config import CpuOptimConfig, OpenMpParams +from ...config import CpuOptimConfig, OpenMpConfig def optimize_cpu( @@ -34,17 +34,12 @@ def optimize_cpu( if cfg.openmp is not False: from ..transformations import AddOpenMP - params = cfg.openmp if isinstance(cfg.openmp, OpenMpParams) else OpenMpParams() + + params = cfg.openmp if isinstance(cfg.openmp, OpenMpConfig) else OpenMpConfig() add_omp = AddOpenMP(ctx, params) kernel_ast = cast(PsBlock, add_omp(kernel_ast)) if cfg.use_cacheline_zeroing: raise NotImplementedError("CL-zeroing not implemented yet") - if cfg.insert_loop_pragmas: - from ..transformations import InsertPragmasAtLoops - - insert_pragmas = InsertPragmasAtLoops(ctx, cfg.insert_loop_pragmas) - kernel_ast = cast(PsBlock, insert_pragmas(kernel_ast)) - return kernel_ast diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index b018157546c5bbbb1b905284b5eb77d7754d582c..06e34d4e36219366a25713d28ae758b61fb3d0d6 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -159,7 +159,7 @@ class TypeContext: f" Constant type: {c.dtype}\n" f" Target type: {self._target_type}" ) - + case PsLiteralExpr(lit): if not self._compatible(lit.dtype): raise TypificationError( diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py index dc7504f520f8950b46df76b0359aaad371244b19..dc254da0e340d518929d6eecb483defcdffbe185 100644 --- a/src/pystencils/backend/literals.py +++ b/src/pystencils/backend/literals.py @@ -4,7 +4,7 @@ from ..types import PsType, constify class PsLiteral: """Representation of literal code. - + Instances of this class represent code literals inside the AST. These literals are not to be confused with C literals; the name `Literal` refers to the fact that the code generator takes them "literally", printing them as they are. @@ -22,22 +22,22 @@ class PsLiteral: @property def text(self) -> str: return self._text - + @property def dtype(self) -> PsType: return self._dtype - + def __str__(self) -> str: return f"{self._text}: {self._dtype}" - + def __repr__(self) -> str: return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" - + def __eq__(self, other: object) -> bool: if not isinstance(other, PsLiteral): return False - + return self._text == other._text and self._dtype == other._dtype - + def __hash__(self) -> int: return hash((PsLiteral, self._text, self._dtype)) diff --git a/src/pystencils/backend/transformations/__init__.py b/src/pystencils/backend/transformations/__init__.py index f47f1f82bcf65e887816d74564082c7ed5880d06..88ad9348f09685258d2aecb5fca66fcfe609173b 100644 --- a/src/pystencils/backend/transformations/__init__.py +++ b/src/pystencils/backend/transformations/__init__.py @@ -83,7 +83,7 @@ from .eliminate_constants import EliminateConstants from .eliminate_branches import EliminateBranches from .hoist_loop_invariant_decls import HoistLoopInvariantDeclarations from .reshape_loops import ReshapeLoops -from .add_pragmas import InsertPragmasAtLoops, AddOpenMP +from .add_pragmas import InsertPragmasAtLoops, LoopPragma, AddOpenMP from .erase_anonymous_structs import EraseAnonymousStructTypes from .select_functions import SelectFunctions from .select_intrinsics import MaterializeVectorIntrinsics @@ -96,6 +96,7 @@ __all__ = [ "HoistLoopInvariantDeclarations", "ReshapeLoops", "InsertPragmasAtLoops", + "LoopPragma", "AddOpenMP", "EraseAnonymousStructTypes", "SelectFunctions", diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 9ef1ee4580c072eddeb0461686c6dd65c33c2f5c..c7015ccb62042e6beaf131e68485f7c8186b2a2a 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -8,9 +8,24 @@ from ..ast import PsAstNode from ..ast.structural import PsBlock, PsLoop, PsPragma from ..ast.expressions import PsExpression -from ...config import LoopPragma, OpenMpParams +from ...config import OpenMpConfig -__all__ = ["InsertPragmasAtLoops", "AddOpenMP"] +__all__ = ["InsertPragmasAtLoops", "LoopPragma", "AddOpenMP"] + + +@dataclass +class LoopPragma: + """A pragma that should be prepended to loops at a certain nesting depth.""" + + text: str + """The pragma text, without the ``#pragma ``""" + + loop_nesting_depth: int + """Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops.""" + + def __post_init__(self): + if self.loop_nesting_depth < -1: + raise ValueError("Loop nesting depth must be nonnegative or -1.") @dataclass @@ -87,9 +102,10 @@ class AddOpenMP: `OpenMpParams` configuration. """ - def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpParams) -> None: - pragma_text = "parallel " if not omp_params.omit_parallel_construct else "" - pragma_text += f"for schedule({omp_params.schedule})" + def __init__(self, ctx: KernelCreationContext, omp_params: OpenMpConfig) -> None: + pragma_text = "omp" + pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" + pragma_text += f" for schedule({omp_params.schedule})" if omp_params.collapse > 0: pragma_text += f" collapse({str(omp_params.collapse)})" diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 0fc71d6662f3f87235a465a0ef802d5b4e90f285..7c49c4c37c0257adde151c3c32680faa50e9a36a 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -16,26 +16,11 @@ from .defaults import DEFAULTS @dataclass -class LoopPragma: - """A pragma that should be prepended to loops at a certain nesting depth.""" - - text: str - """The pragma text, without the ``#pragma ``""" - - loop_nesting_depth: int - """Nesting depth of the loops the pragma should be added to. ``-1`` indicates the innermost loops.""" - - def __post_init__(self): - if self.loop_nesting_depth < -1: - raise ValueError("Loop nesting depth must be nonnegative or -1.") - - -@dataclass -class OpenMpParams: +class OpenMpConfig: """Parameters controlling kernel parallelization using OpenMP.""" nesting_depth: int = 0 - """Nesting depth of the loop that should be parallelized""" + """Nesting depth of the loop that should be parallelized. Must be a nonnegative number.""" collapse: int = 0 """Argument to the OpenMP ``collapse`` clause""" @@ -53,12 +38,12 @@ class OpenMpParams: @dataclass class CpuOptimConfig: """Configuration for the CPU optimizer. - + If any flag in this configuration is set to a value not supported by the CPU specified in `CreateKernelConfig.target`, an error will be raised. """ - - openmp: bool | OpenMpParams = False + + openmp: bool | OpenMpConfig = False """Enable OpenMP parallelization. If set to `True`, the kernel will be parallelized using OpenMP according to the default settings in `OpenMpParams`. @@ -89,22 +74,11 @@ class CpuOptimConfig: to produce cacheline zeroing instructions where possible. """ - insert_loop_pragmas: Sequence[LoopPragma] = () - """Insert pragmas before loops. - - Each pragma is annotated with a nesting depth and will be prepended to all loops of that depth; - the value ``-1`` indicates the innermost loop. - - The relative order of pragmas with the (exact) same nesting depth is preserved; - however, the relative order of pragmas inserted at ``-1`` and the actual depth of the deepest loop - is undefined. - """ - @dataclass class VectorizationConfig: """Configuration for the auto-vectorizer. - + If any flag in this configuration is set to a value not supported by the CPU specified in `CreateKernelConfig.target`, an error will be raised. """ @@ -228,19 +202,27 @@ class CreateKernelConfig: raise PsOptionsError( "Only fields with `field_type == FieldType.INDEXED` can be specified as `index_field`" ) - + # Check optim if self.cpu_optim is not None: if not self.target.is_cpu(): - raise PsOptionsError(f"`cpu_optim` cannot be set for non-CPU target {self.target}") - - if self.cpu_optim.vectorize is not False and not self.target.is_vector_cpu(): - raise PsOptionsError(f"Cannot enable auto-vectorization for non-vector CPU target {self.target}") + raise PsOptionsError( + f"`cpu_optim` cannot be set for non-CPU target {self.target}" + ) + + if ( + self.cpu_optim.vectorize is not False + and not self.target.is_vector_cpu() + ): + raise PsOptionsError( + f"Cannot enable auto-vectorization for non-vector CPU target {self.target}" + ) # Infer JIT if self.jit is None: if self.target.is_cpu(): from .backend.jit import LegacyCpuJit + self.jit = LegacyCpuJit() else: raise NotImplementedError( diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 0f6941cf52ca1c3745b5af9a1f0478f269fc8e4a..66c2a0d6c16e291ba5f6315478406668e7e91069 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -104,6 +104,7 @@ def create_kernel( # Target-Specific optimizations if config.target.is_cpu(): from .backend.kernelcreation import optimize_cpu + kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) erase_anons = EraseAnonymousStructTypes(ctx) diff --git a/tests/nbackend/kernelcreation/test_openmp.py b/tests/nbackend/kernelcreation/test_openmp.py new file mode 100644 index 0000000000000000000000000000000000000000..d7be8eb98cd29bea370bc6279013ef973e621370 --- /dev/null +++ b/tests/nbackend/kernelcreation/test_openmp.py @@ -0,0 +1,61 @@ +import pytest +from pystencils import ( + fields, + Assignment, + create_kernel, + CreateKernelConfig, + CpuOptimConfig, + OpenMpConfig, + Target, +) + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsLoop, PsPragma + + +@pytest.mark.parametrize("nesting_depth", range(3)) +@pytest.mark.parametrize("schedule", ["static", "static,16", "dynamic", "auto"]) +@pytest.mark.parametrize("collapse", range(3)) +@pytest.mark.parametrize("omit_parallel_construct", range(3)) +def test_openmp(nesting_depth, schedule, collapse, omit_parallel_construct): + f, g = fields("f, g: [3D]") + asm = Assignment(f.center(0), g.center(0)) + + omp = OpenMpConfig( + nesting_depth=nesting_depth, + schedule=schedule, + collapse=collapse, + omit_parallel_construct=omit_parallel_construct, + ) + gen_config = CreateKernelConfig( + target=Target.CPU, cpu_optim=CpuOptimConfig(openmp=omp) + ) + + kernel = create_kernel(asm, gen_config) + ast = kernel.body + + def find_omp_pragma(ast) -> PsPragma: + num_loops = 0 + generator = dfs_preorder(ast) + for node in generator: + match node: + case PsLoop(): + num_loops += 1 + case PsPragma(): + loop = next(generator) + assert isinstance(loop, PsLoop) + assert num_loops == nesting_depth + return node + + pytest.fail("No OpenMP pragma found") + + pragma = find_omp_pragma(ast) + tokens = set(pragma.text.split()) + + expected_tokens = {"omp", "for", f"schedule({omp.schedule})"} + if not omp.omit_parallel_construct: + expected_tokens.add("parallel") + if omp.collapse > 0: + expected_tokens.add(f"collapse({omp.collapse})") + + assert tokens == expected_tokens diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py index b1a220f213e15acc44c47a12d79b30cc81fc90bd..09c63a5572f7873648712dbd3279d16f3c342458 100644 --- a/tests/nbackend/test_ast.py +++ b/tests/nbackend/test_ast.py @@ -56,7 +56,7 @@ def test_cloning(): PsComment("Loop body"), PsAssignment(x, y), PsAssignment(x, y), - PsPragma("anyone has an idea for a nice pragma?"), + PsPragma("#pragma clang loop vectorize(enable)"), PsStatement( PsDeref(PsCast(Ptr(Fp(32)), z)) + PsSubscript(z, one + one + one) diff --git a/tests/nbackend/transformations/test_add_pragmas.py b/tests/nbackend/transformations/test_add_pragmas.py new file mode 100644 index 0000000000000000000000000000000000000000..1d8dd1ded148697dbe7acde3648a292dc5a7fcac --- /dev/null +++ b/tests/nbackend/transformations/test_add_pragmas.py @@ -0,0 +1,54 @@ +import sympy as sp +from itertools import product + +from pystencils import make_slice, fields, Assignment +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + AstFactory, + FullIterationSpace, +) + +from pystencils.backend.ast import dfs_preorder +from pystencils.backend.ast.structural import PsBlock, PsPragma, PsLoop +from pystencils.backend.transformations import InsertPragmasAtLoops, LoopPragma + +def test_insert_pragmas(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + f, g = fields("f, g: [3D]") + ispace = FullIterationSpace.create_from_slice( + ctx, make_slice[:, :, :], archetype_field=f + ) + ctx.set_iteration_space(ispace) + + stencil = list(product([-1, 0, 1], [-1, 0, 1], [-1, 0, 1])) + loop_body = PsBlock([ + factory.parse_sympy(Assignment(f.center(0), sum(g.neighbors(stencil)))) + ]) + loops = factory.loops_from_ispace(ispace, loop_body) + + pragmas = ( + LoopPragma("omp parallel for", 0), + LoopPragma("some nonsense pragma", 1), + LoopPragma("omp simd", -1), + ) + add_pragmas = InsertPragmasAtLoops(ctx, pragmas) + ast = add_pragmas(loops) + + assert isinstance(ast, PsBlock) + + first_pragma = ast.statements[0] + assert isinstance(first_pragma, PsPragma) + assert first_pragma.text == pragmas[0].text + + assert ast.statements[1] == loops + second_pragma = loops.body.statements[0] + assert isinstance(second_pragma, PsPragma) + assert second_pragma.text == pragmas[1].text + + second_loop = list(dfs_preorder(ast, lambda node: isinstance(node, PsLoop)))[1] + assert isinstance(second_loop, PsLoop) + third_pragma = second_loop.body.statements[0] + assert isinstance(third_pragma, PsPragma) + assert third_pragma.text == pragmas[2].text