Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
4 files
+ 40
16
Preferences
Compare changes
Files
4
@@ -7,6 +7,7 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ..memory import PsSymbol
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...sympyextensions import (
@@ -193,31 +194,37 @@ class FreezeExpressions:
assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr)
# create kernel-local copy of lhs symbol to work with
new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype)
new_lhs = PsSymbolExpr(new_lhs_symbol)
self._ctx.add_symbol(new_lhs_symbol)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
new_rhs: PsExpression
init_val: PsExpression
match expr.op:
case "+":
init_val = PsConstant(0)
new_rhs = add(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(0))
new_rhs = add(new_lhs.clone(), rhs)
case "-":
init_val = PsConstant(0)
new_rhs = sub(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(0))
new_rhs = sub(new_lhs.clone(), rhs)
case "*":
init_val = PsConstant(1)
new_rhs = mul(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(1))
new_rhs = mul(new_lhs.clone(), rhs)
case "min":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs])
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
# set reduction symbol property in context
self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val))
self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol))
return PsAssignment(lhs, new_rhs)
return PsAssignment(new_lhs, new_rhs)
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)