Skip to content
Snippets Groups Projects
Commit a3025645 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix min/max reductions

parent bb984679
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -193,29 +193,31 @@ class FreezeExpressions: ...@@ -193,29 +193,31 @@ class FreezeExpressions:
assert isinstance(rhs, PsExpression) assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr) assert isinstance(lhs, PsSymbolExpr)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
new_rhs: PsExpression
match expr.op: match expr.op:
case "+": case "+":
op = add
init_val = PsConstant(0) init_val = PsConstant(0)
new_rhs = add(lhs.clone(), rhs)
case "-": case "-":
op = sub
init_val = PsConstant(0) init_val = PsConstant(0)
new_rhs = sub(lhs.clone(), rhs)
case "*": case "*":
op = mul
init_val = PsConstant(1) init_val = PsConstant(1)
# TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards new_rhs = mul(lhs.clone(), rhs)
case "min": case "min":
op = sp.Min
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs])
case "max": case "max":
op = sp.Max
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs])
case _: case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
# set reduction symbol property in context
self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val)) self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val))
return PsAssignment(lhs, op(lhs.clone(), rhs)) return PsAssignment(lhs, new_rhs)
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name) symb = self._ctx.get_symbol(spsym.name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment