Skip to content
Snippets Groups Projects

Reduction Support

Merged Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and
11 files
+ 195
7
Compare changes
  • Side-by-side
  • Inline
Files
11
@@ -9,6 +9,8 @@ from ...defaults import DEFAULTS
@@ -9,6 +9,8 @@ from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
 
from ...codegen.properties import ReductionSymbolProperty
 
from ..memory import PsSymbol, PsBuffer
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ..constants import PsConstant
from ...types import (
from ...types import (
@@ -75,6 +77,8 @@ class KernelCreationContext:
@@ -75,6 +77,8 @@ class KernelCreationContext:
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
 
self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict()
 
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
self._fields_collection = FieldsInKernel()
@@ -168,6 +172,22 @@ class KernelCreationContext:
@@ -168,6 +172,22 @@ class KernelCreationContext:
self._symbols[old.name] = new
self._symbols[old.name] = new
 
def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty):
 
"""Adds a reduction property to a symbol.
 
 
The symbol ``symbol`` should not have a reduction property and must exist in the symbol table.
 
"""
 
if self.find_symbol(symbol.name) is None:
 
raise PsInternalCompilerError(
 
f"add_reduction_to_symbol: {symbol.name} does not exist in the symbol table"
 
)
 
 
if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty):
 
symbol.add_property(reduction)
 
self._symbols_with_reduction[symbol] = reduction
 
else:
 
raise PsInternalCompilerError(f"add_reduction_to_symbol: {symbol.name} already has a reduction property")
 
def duplicate_symbol(
def duplicate_symbol(
self, symb: PsSymbol, new_dtype: PsType | None = None
self, symb: PsSymbol, new_dtype: PsType | None = None
) -> PsSymbol:
) -> PsSymbol:
@@ -199,6 +219,11 @@ class KernelCreationContext:
@@ -199,6 +219,11 @@ class KernelCreationContext:
"""Return an iterable of all symbols listed in the symbol table."""
"""Return an iterable of all symbols listed in the symbol table."""
return self._symbols.values()
return self._symbols.values()
 
@property
 
def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]:
 
"""Return a dictionary holding symbols and their reduction property."""
 
return self._symbols_with_reduction
 
# Fields and Arrays
# Fields and Arrays
@property
@property
Loading