From 99a3335135baf5da92ec56713f96beae37798b60 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Thu, 16 Jan 2025 15:52:05 +0100 Subject: [PATCH] Add omp reduction clauses for reduced symbols --- src/pystencils/backend/kernelcreation/context.py | 5 +++++ src/pystencils/backend/transformations/add_pragmas.py | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index e41f8371c..a8728e6ac 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -223,6 +223,11 @@ class KernelCreationContext: """Return an iterable of all symbols listed in the symbol table.""" return self._symbols.values() + @property + def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]: + """Return a dictionary holding symbols and their reduction property.""" + return self._symbols_with_reduction + # Fields and Arrays @property diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 78e721f38..6d72e1550 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -10,6 +10,8 @@ from ..ast import PsAstNode from ..ast.structural import PsBlock, PsLoop, PsPragma from ..ast.expressions import PsExpression +from ...types import PsScalarType + if TYPE_CHECKING: from ...codegen.config import OpenMpConfig @@ -110,6 +112,13 @@ class AddOpenMP: pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" pragma_text += f" for schedule({omp_params.schedule})" + if bool(ctx.symbols_with_reduction): + for symbol, reduction in ctx.symbols_with_reduction.items(): + if isinstance(symbol.dtype, PsScalarType): + pragma_text += f" reduction({reduction.op}: {symbol.name})" + else: + NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.") + if omp_params.num_threads is not None: pragma_text += f" num_threads({str(omp_params.num_threads)})" -- GitLab