diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 9a34303e21071137c6e929d423660837fffdd6d0..64230203f90fe23bdf9e048e270545808db82bac 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -193,29 +193,31 @@ class FreezeExpressions: assert isinstance(rhs, PsExpression) assert isinstance(lhs, PsSymbolExpr) + # match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment) + new_rhs: PsExpression match expr.op: case "+": - op = add init_val = PsConstant(0) + new_rhs = add(lhs.clone(), rhs) case "-": - op = sub init_val = PsConstant(0) + new_rhs = sub(lhs.clone(), rhs) case "*": - op = mul init_val = PsConstant(1) - # TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards + new_rhs = mul(lhs.clone(), rhs) case "min": - op = sp.Min init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) + new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs]) case "max": - op = sp.Max init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) + new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs]) case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") + # set reduction symbol property in context self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val)) - return PsAssignment(lhs, op(lhs.clone(), rhs)) + return PsAssignment(lhs, new_rhs) def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name)