From 99a3335135baf5da92ec56713f96beae37798b60 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 16 Jan 2025 15:52:05 +0100
Subject: [PATCH] Add omp reduction clauses for reduced symbols

---
 src/pystencils/backend/kernelcreation/context.py      | 5 +++++
 src/pystencils/backend/transformations/add_pragmas.py | 9 +++++++++
 2 files changed, 14 insertions(+)

diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py
index e41f8371c..a8728e6ac 100644
--- a/src/pystencils/backend/kernelcreation/context.py
+++ b/src/pystencils/backend/kernelcreation/context.py
@@ -223,6 +223,11 @@ class KernelCreationContext:
         """Return an iterable of all symbols listed in the symbol table."""
         return self._symbols.values()
 
+    @property
+    def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]:
+        """Return a dictionary holding symbols and their reduction property."""
+        return self._symbols_with_reduction
+
     #   Fields and Arrays
 
     @property
diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py
index 78e721f38..6d72e1550 100644
--- a/src/pystencils/backend/transformations/add_pragmas.py
+++ b/src/pystencils/backend/transformations/add_pragmas.py
@@ -10,6 +10,8 @@ from ..ast import PsAstNode
 from ..ast.structural import PsBlock, PsLoop, PsPragma
 from ..ast.expressions import PsExpression
 
+from ...types import PsScalarType
+
 if TYPE_CHECKING:
     from ...codegen.config import OpenMpConfig
 
@@ -110,6 +112,13 @@ class AddOpenMP:
         pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
         pragma_text += f" for schedule({omp_params.schedule})"
 
+        if bool(ctx.symbols_with_reduction):
+            for symbol, reduction in ctx.symbols_with_reduction.items():
+                if isinstance(symbol.dtype, PsScalarType):
+                    pragma_text += f" reduction({reduction.op}: {symbol.name})"
+                else:
+                    NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.")
+
         if omp_params.num_threads is not None:
             pragma_text += f" num_threads({str(omp_params.num_threads)})"
 
-- 
GitLab