Skip to content
Snippets Groups Projects
Commit 54837529 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Introduce reduction symbol property and add to lhs of reduced symbol

parent b263d752
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -75,6 +75,8 @@ class KernelCreationContext:
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
# TODO: add list of reduction symbols
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
......
......@@ -65,6 +65,9 @@ from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
from ..exceptions import FreezeError
import backend.functions
from codegen.properties import ReductionSymbolProperty
ExprLike = (
sp.Expr
......@@ -188,27 +191,32 @@ class FreezeExpressions:
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr)
match expr.op:
case "+=":
case "+":
op = add
case "-=":
init_val = PsConstant(0)
case "-":
op = sub
case "*=":
init_val = PsConstant(0)
case "*":
op = mul
case "/=":
op = truediv
# TODO: unsure if sp.Min & sp.Max work here
case "min=":
init_val = PsConstant(1)
# TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards
case "min":
op = sp.Min
case "max=":
init_val = backend.functions.NumericLimitsFunctions("min")
case "max":
op = sp.Max
init_val = backend.functions.NumericLimitsFunctions("max")
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
return PsAssignment(lhs, op(lhs.clone(), rhs)) # TODO: PsReducedAssignment?
lhs.symbol.add_property(ReductionSymbolProperty(expr.op, init_val))
return PsAssignment(lhs, op(lhs.clone(), rhs))
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
......
......@@ -3,6 +3,8 @@ from dataclasses import dataclass
from ..field import Field
from backend.ast.expressions import PsExpression
@dataclass(frozen=True)
class PsSymbolProperty:
......@@ -14,6 +16,14 @@ class UniqueSymbolProperty(PsSymbolProperty):
"""Base class for unique properties, of which only one instance may be registered at a time."""
@dataclass(frozen=True)
class ReductionSymbolProperty(UniqueSymbolProperty):
"""Symbol acts as a base pointer to a field."""
op: str
init_val: PsExpression
@dataclass(frozen=True)
class FieldShape(PsSymbolProperty):
"""Symbol acts as a shape parameter to a field."""
......
......@@ -12,13 +12,11 @@ class ReducedAssignment(AssignmentBase):
Symbol for binary operation being applied in the assignment, such as "+",
"*", etc.
"""
binop = None # type: str
# TODO: initial value
binop = None # type: str
@property
def op(self):
return self.binop + '='
return self.binop
class AddReducedAssignment(ReducedAssignment):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment