Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
5 files
+ 59
38
Preferences
Compare changes
Files
5
from typing import overload, cast, Any
from functools import reduce
from operator import add, mul, sub, truediv
from operator import add, mul, sub
import sympy as sp
import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
@@ -14,6 +13,7 @@ from ...sympyextensions import (
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.binop_mapping import binop_str_to_expr
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc
from ...sympyextensions.reduction import ReducedAssignment
@@ -173,19 +173,7 @@ class FreezeExpressions:
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
match expr.op:
case "+=":
op = add
case "-=":
op = sub
case "*=":
op = mul
case "/=":
op = truediv
case _:
raise FreezeError(f"Unsupported augmented assignment: {expr.op}.")
return PsAssignment(lhs, op(lhs.clone(), rhs))
return PsAssignment(lhs, binop_str_to_expr(expr.op[0], lhs.clone(), rhs))
def map_ReducedAssignment(self, expr: ReducedAssignment):
lhs = self.visit(expr.lhs)
@@ -195,7 +183,7 @@ class FreezeExpressions:
assert isinstance(lhs, PsSymbolExpr)
orig_lhs_symb = lhs.symbol
dtype = rhs.dtype # TODO: kernel with (implicit) up/downcasts?
dtype = rhs.dtype # TODO: kernel with (implicit) up/downcasts?
# replace original symbol with pointer-based type used for export
orig_lhs_symb_as_ptr = PsSymbol(orig_lhs_symb.name, PsPointerType(dtype))
@@ -204,27 +192,25 @@ class FreezeExpressions:
new_lhs_symb = PsSymbol(f"{orig_lhs_symb.name}_local", dtype)
new_lhs = PsSymbolExpr(new_lhs_symb)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
# get new rhs from augmented assignment
new_rhs: PsExpression = binop_str_to_expr(expr.op, new_lhs.clone(), rhs)
# match for reduction operation and set neutral init_val
new_rhs: PsExpression
init_val: PsExpression
match expr.op:
case "+":
init_val = PsConstantExpr(PsConstant(0, dtype))
new_rhs = add(new_lhs.clone(), rhs)
case "-":
init_val = PsConstantExpr(PsConstant(0, dtype))
new_rhs = sub(new_lhs.clone(), rhs)
case "*":
init_val = PsConstantExpr(PsConstant(1, dtype))
new_rhs = mul(new_lhs.clone(), rhs)
case "min":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
init_val.dtype = dtype
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
init_val.dtype = dtype
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs])
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")