diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index b6bf09dba930760898b20cfa630740354fb5da4c..39205d707fa1bddeee4f5918aead91371f127003 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -75,7 +75,7 @@ class KernelCreationContext: self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) - # TODO: add list of reduction symbols + self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict() self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -170,6 +170,21 @@ class KernelCreationContext: self._symbols[old.name] = new + def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty): + """Adds a reduction property to a symbol. + + The symbol ``symbol`` should not have a reduction property and must exist in the symbol table. + """ + if self.find_symbol(symbol.name) is None: + raise PsInternalCompilerError( + "add_reduction_to_symbol: Symbol does not exist in the symbol table" + ) + + if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty): + self._symbols_with_reduction[symbol] = reduction + else: + raise PsInternalCompilerError(f"add_reduction_to_symbol: Symbol {symbol.name} already has a reduction property") + def duplicate_symbol( self, symb: PsSymbol, new_dtype: PsType | None = None ) -> PsSymbol: