Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Viewing commit 6bc3cf3f
Show latest version
2 files
+ 14
11
Preferences
Compare changes
Files
2
@@ -61,7 +61,7 @@ from ..ast.expressions import (
from ..ast.vector import PsVecMemAcc
from ..constants import PsConstant
from ...types import PsNumericType, PsStructType, PsType
from ...types import PsNumericType, PsStructType, PsType, PsPointerType
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions
from ..exceptions import FreezeError
@@ -195,9 +195,9 @@ class FreezeExpressions:
assert isinstance(lhs, PsSymbolExpr)
# create kernel-local copy of lhs symbol to work with
new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype)
new_lhs = PsSymbolExpr(new_lhs_symbol)
self._ctx.add_symbol(new_lhs_symbol)
new_lhs_symb = PsSymbol(f"{lhs.symbol.name}_local", rhs.dtype)
new_lhs = PsSymbolExpr(new_lhs_symb)
self._ctx.add_symbol(new_lhs_symb)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
new_rhs: PsExpression
@@ -221,8 +221,13 @@ class FreezeExpressions:
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
# replace original symbol with pointer-based type used for export
orig_symbol_as_ptr = PsSymbol(lhs.symbol.name, PsPointerType(rhs.dtype))
self._ctx.replace_symbol(lhs.symbol, orig_symbol_as_ptr)
# set reduction symbol property in context
self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol))
init_val.dtype = rhs.dtype
self._ctx.add_reduction_to_symbol(new_lhs_symb, ReductionSymbolProperty(expr.op, init_val, orig_symbol_as_ptr))
return PsAssignment(new_lhs, new_rhs)