Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
12 files
+ 204
193
Preferences
Compare changes
Files
12
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Iterator, Any
from itertools import chain, count
from collections import namedtuple, defaultdict
import re
from ..ast.expressions import PsExpression
from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions import ReductionOp
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ...codegen.properties import LocalReductionVariable, ReductionPointerVariable
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ...types import (
@@ -46,6 +47,16 @@ class FieldsInKernel:
FieldArrayPair = namedtuple("FieldArrayPair", ("field", "array"))
@dataclass(frozen=True)
class ReductionInfo:
op: ReductionOp
init_val: PsExpression
orig_symbol: PsSymbol
ptr_symbol: PsSymbol
class KernelCreationContext:
"""Manages the translation process from the SymPy frontend to the backend AST, and collects
all necessary information for the translation:
@@ -77,8 +88,7 @@ class KernelCreationContext:
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._local_reduction_symbols: dict[PsSymbol, LocalReductionVariable] = dict()
self._reduction_ptr_symbols: dict[PsSymbol, ReductionPointerVariable] = dict()
self._symbols_reduction_info: dict[PsSymbol, ReductionInfo] = dict()
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
@@ -173,41 +183,17 @@ class KernelCreationContext:
self._symbols[old.name] = new
def add_local_reduction_symbol(self, local_symb: PsSymbol, local_var_prop: LocalReductionVariable):
"""Adds entry for a symbol and its property to the lookup table for local reduction variables.
def add_symbol_reduction_info(self, local_symb: PsSymbol, reduction_info: ReductionInfo):
"""Adds entry for a symbol and its reduction info to its corresponding lookup table.
The symbol ``symbol`` should not have a 'LocalReductionSymbol' property and shall not exist in the symbol table.
The symbol ``symbol`` shall not exist in the symbol table already.
"""
if self.find_symbol(local_symb.name) is not None:
raise PsInternalCompilerError(
f"add_local_reduction_symbol: {local_symb.name} already exist in the symbol table"
)
self.add_symbol(local_symb)
if local_symb not in self._local_reduction_symbols and not local_symb.get_properties(LocalReductionVariable):
local_symb.add_property(local_var_prop)
self._local_reduction_symbols[local_symb] = local_var_prop
else:
if local_symb in self._symbols_reduction_info:
raise PsInternalCompilerError(
f"add_local_reduction_symbol: {local_symb.name} already exists in local reduction table"
f"add_symbol_reduction_info: {local_symb.name} already exist in the symbol table"
)
def add_reduction_ptr_symbol(self, orig_symb: PsSymbol, ptr_symb: PsSymbol, ptr_var_prop: ReductionPointerVariable):
"""Replaces reduction symbol with a pointer-based counterpart used for export
and adds the new symbol and its property to the lookup table for pointer-based reduction variables
The symbol ``ptr_symbol`` should not exist in the symbol table.
"""
self.replace_symbol(orig_symb, ptr_symb)
if ptr_symb not in self._reduction_ptr_symbols and not ptr_symb.get_properties(
ReductionPointerVariable):
ptr_symb.add_property(ptr_var_prop)
self._reduction_ptr_symbols[ptr_symb] = ptr_var_prop
else:
raise PsInternalCompilerError(
f"add_reduction_ptr_symbol: {ptr_symb.name} already exists in pointer-based reduction variable table "
)
self._symbols_reduction_info[local_symb] = reduction_info
def duplicate_symbol(
self, symb: PsSymbol, new_dtype: PsType | None = None
@@ -245,14 +231,9 @@ class KernelCreationContext:
return self._symbols.values()
@property
def local_reduction_symbols(self) -> dict[PsSymbol, LocalReductionVariable]:
def symbols_reduction_info(self) -> dict[PsSymbol, ReductionInfo]:
"""Return a dictionary holding kernel-local reduction symbols and their reduction properties."""
return self._local_reduction_symbols
@property
def reduction_pointer_symbols(self) -> dict[PsSymbol, ReductionPointerVariable]:
"""Return a dictionary holding pointer-based reduction symbols and their reduction properties."""
return self._reduction_ptr_symbols
return self._symbols_reduction_info
# Fields and Arrays