Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
2 files
+ 20
5
Preferences
Compare changes
Files
2
@@ -9,6 +9,8 @@ from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ...codegen.properties import ReductionSymbolProperty
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ...types import (
@@ -75,7 +77,7 @@ class KernelCreationContext:
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
# TODO: add list of reduction symbols
self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict()
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
@@ -170,6 +172,21 @@ class KernelCreationContext:
self._symbols[old.name] = new
def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty):
"""Adds a reduction property to a symbol.
The symbol ``symbol`` should not have a reduction property and must exist in the symbol table.
"""
if self.find_symbol(symbol.name) is None:
raise PsInternalCompilerError(
"add_reduction_to_symbol: Symbol does not exist in the symbol table"
)
if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty):
self._symbols_with_reduction[symbol] = reduction
else:
raise PsInternalCompilerError(f"add_reduction_to_symbol: Symbol {symbol.name} already has a reduction property")
def duplicate_symbol(
self, symb: PsSymbol, new_dtype: PsType | None = None
) -> PsSymbol: