Skip to content
Snippets Groups Projects

Reduction Support

Merged Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
5 files
+ 78
37
Compare changes
  • Side-by-side
  • Inline
Files
5
@@ -9,7 +9,7 @@ from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ...codegen.properties import ReductionSymbolProperty
from ...codegen.properties import LocalReductionVariable, ReductionPointerVariable
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
@@ -77,7 +77,8 @@ class KernelCreationContext:
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict()
self._local_reduction_symbols: dict[PsSymbol, LocalReductionVariable] = dict()
self._reduction_ptr_symbols: dict[PsSymbol, ReductionPointerVariable] = dict()
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
@@ -172,21 +173,41 @@ class KernelCreationContext:
self._symbols[old.name] = new
def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty):
"""Adds a reduction property to a symbol.
def add_local_reduction_symbol(self, local_symb: PsSymbol, local_var_prop: LocalReductionVariable):
"""Adds entry for a symbol and its property to the lookup table for local reduction variables.
The symbol ``symbol`` should not have a reduction property and must exist in the symbol table.
The symbol ``symbol`` should not have a 'LocalReductionSymbol' property and shall not exist in the symbol table.
"""
if self.find_symbol(symbol.name) is None:
if self.find_symbol(local_symb.name) is not None:
raise PsInternalCompilerError(
f"add_reduction_to_symbol: {symbol.name} does not exist in the symbol table"
f"add_local_reduction_symbol: {local_symb.name} already exist in the symbol table"
)
self.add_symbol(local_symb)
if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty):
symbol.add_property(reduction)
self._symbols_with_reduction[symbol] = reduction
if local_symb not in self._local_reduction_symbols and not local_symb.get_properties(LocalReductionVariable):
local_symb.add_property(local_var_prop)
self._local_reduction_symbols[local_symb] = local_var_prop
else:
raise PsInternalCompilerError(f"add_reduction_to_symbol: {symbol.name} already has a reduction property")
raise PsInternalCompilerError(
f"add_local_reduction_symbol: {local_symb.name} already exists in local reduction table"
)
def add_reduction_ptr_symbol(self, orig_symb: PsSymbol, ptr_symb: PsSymbol, ptr_var_prop: ReductionPointerVariable):
"""Replaces reduction symbol with a pointer-based counterpart used for export
and adds the new symbol and its property to the lookup table for pointer-based reduction variables
The symbol ``ptr_symbol`` should not exist in the symbol table.
"""
self.replace_symbol(orig_symb, ptr_symb)
if ptr_symb not in self._reduction_ptr_symbols and not ptr_symb.get_properties(
ReductionPointerVariable):
ptr_symb.add_property(ptr_var_prop)
self._reduction_ptr_symbols[ptr_symb] = ptr_var_prop
else:
raise PsInternalCompilerError(
f"add_reduction_ptr_symbol: {ptr_symb.name} already exists in pointer-based reduction variable table "
)
def duplicate_symbol(
self, symb: PsSymbol, new_dtype: PsType | None = None
@@ -224,9 +245,14 @@ class KernelCreationContext:
return self._symbols.values()
@property
def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]:
"""Return a dictionary holding symbols and their reduction property."""
return self._symbols_with_reduction
def local_reduction_symbols(self) -> dict[PsSymbol, LocalReductionVariable]:
"""Return a dictionary holding kernel-local reduction symbols and their reduction properties."""
return self._local_reduction_symbols
@property
def reduction_pointer_symbols(self) -> dict[PsSymbol, ReductionPointerVariable]:
"""Return a dictionary holding pointer-based reduction symbols and their reduction properties."""
return self._reduction_ptr_symbols
# Fields and Arrays
Loading