Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
13 unresolved threads
Compare and Show latest version
4 files
+ 44
18
Preferences
Compare changes
Files
4
@@ -7,6 +7,7 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ..memory import PsSymbol
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...sympyextensions import (
@@ -60,7 +61,7 @@ from ..ast.expressions import (
from ..ast.vector import PsVecMemAcc
from ..constants import PsConstant
from ...types import PsNumericType, PsStructType, PsType
from ...types import PsNumericType, PsStructType, PsType, PsPointerType
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions
from ..exceptions import FreezeError
@@ -193,32 +194,42 @@ class FreezeExpressions:
assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr)
# create kernel-local copy of lhs symbol to work with
new_lhs_symb = PsSymbol(f"{lhs.symbol.name}_local", rhs.dtype)
new_lhs = PsSymbolExpr(new_lhs_symb)
self._ctx.add_symbol(new_lhs_symb)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
new_rhs: PsExpression
init_val: Any # TODO: type?
init_val: PsExpression
match expr.op:
case "+":
init_val = PsConstant(0)
new_rhs = add(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(0))
new_rhs = add(new_lhs.clone(), rhs)
case "-":
init_val = PsConstant(0)
new_rhs = sub(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(0))
new_rhs = sub(new_lhs.clone(), rhs)
case "*":
init_val = PsConstant(1)
new_rhs = mul(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(1))
new_rhs = mul(new_lhs.clone(), rhs)
case "min":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs])
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
# replace original symbol with pointer-based type used for export
orig_symbol_as_ptr = PsSymbol(lhs.symbol.name, PsPointerType(rhs.dtype))
self._ctx.replace_symbol(lhs.symbol, orig_symbol_as_ptr)
# set reduction symbol property in context
self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val))
init_val.dtype = rhs.dtype
self._ctx.add_reduction_to_symbol(new_lhs_symb, ReductionSymbolProperty(expr.op, init_val, orig_symbol_as_ptr))
return PsAssignment(lhs, new_rhs)
return PsAssignment(new_lhs, new_rhs)
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)