Skip to content
Snippets Groups Projects
Commit 325ca386 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Move checks and init value determination for ReductionAssignments to add_reduction_info

parent 0c8654e3
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -6,7 +6,8 @@ from itertools import chain, count
from collections import namedtuple, defaultdict
import re
from ..ast.expressions import PsExpression
from ..ast.expressions import PsExpression, PsConstantExpr, PsCall
from ..functions import NumericLimitsFunctions, PsMathFunction
from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions import ReductionOp
......@@ -208,10 +209,16 @@ class KernelCreationContext:
lhs_name: str,
lhs_dtype: PsType,
reduction_op: ReductionOp,
init_value: PsExpression,
):
"""Create ReductionInfo instance and add to its corresponding lookup table for a given symbol name."""
# make sure that lhs symbol never occurred before ReductionAssignment
if self.find_symbol(lhs_name):
raise KernelConstraintsError(
f"Left-hand side {lhs_name} of ReductionAssignment already exists in symbol table. "
f"Make sure that it is only used once in a kernel's ReductionAssignment."
)
# replace datatype of lhs symbol with pointer datatype for write-back mechanism
symb = self.get_symbol(lhs_name, lhs_dtype)
pointer_symb = PsSymbol(lhs_name, PsPointerType(lhs_dtype))
......@@ -221,9 +228,27 @@ class KernelCreationContext:
local_symb = PsSymbol(f"{lhs_name}_local", lhs_dtype)
self.add_symbol(local_symb)
# match for reduction operation and set neutral init_val
init_val: PsExpression
match reduction_op:
case ReductionOp.Add:
init_val = PsConstantExpr(PsConstant(0))
case ReductionOp.Sub:
init_val = PsConstantExpr(PsConstant(0))
case ReductionOp.Mul:
init_val = PsConstantExpr(PsConstant(1))
case ReductionOp.Min:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
case ReductionOp.Max:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
case _:
raise PsInternalCompilerError(
f"Unsupported kind of reduction assignment: {reduction_op}."
)
# create reduction info and add to set
reduction_info = ReductionInfo(
reduction_op, init_value, local_symb, pointer_symb
reduction_op, init_val, local_symb, pointer_symb
)
self._reduction_data[lhs_name] = reduction_info
......
......@@ -189,21 +189,12 @@ class FreezeExpressions:
def map_ReductionAssignment(self, expr: ReductionAssignment):
assert isinstance(expr.lhs, TypedSymbol)
# make sure that lhs symbol never occurred before ReductionAssignment
if self._ctx.find_symbol(expr.lhs.name):
raise FreezeError(
f"Left-hand side {expr.lhs} of ReductionAssignment already exists in symbol table. "
f"Make sure that it is only used once in a kernel's ReductionAssignment."
)
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr)
reduction_op = expr.reduction_op
lhs_symbol = lhs.symbol
lhs_symbol = expr.lhs
lhs_dtype = lhs_symbol.dtype
lhs_name = lhs_symbol.name
......@@ -211,27 +202,9 @@ class FreezeExpressions:
lhs_dtype, PsNumericType
), "Reduction assignments require type information of the lhs symbol."
# match for reduction operation and set neutral init_val
init_val: PsExpression
match reduction_op:
case ReductionOp.Add:
init_val = PsConstantExpr(PsConstant(0))
case ReductionOp.Sub:
init_val = PsConstantExpr(PsConstant(0))
case ReductionOp.Mul:
init_val = PsConstantExpr(PsConstant(1))
case ReductionOp.Min:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
case ReductionOp.Max:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
case _:
raise FreezeError(
f"Unsupported kind of reduction assignment: {reduction_op}."
)
# get reduction info from context
reduction_info = self._ctx.add_reduction_info(
lhs_name, lhs_dtype, reduction_op, init_val
lhs_name, lhs_dtype, reduction_op
)
# create new lhs from newly created local lhs symbol
......@@ -330,16 +303,6 @@ class FreezeExpressions:
def map_TypedSymbol(self, expr: TypedSymbol):
dtype = self._ctx.resolve_dynamic_type(expr.dtype)
# check if symbol is referenced after freezing a ReductionAssignment
if self._ctx.find_reduction_info(expr.name):
# check if types do not align since a ReductionAssignment modifies
# the symbol's type to PsPointerType in the context's symbol table
if (symbol := self._ctx.find_symbol(expr.name)) and symbol.dtype != dtype:
raise FreezeError(
f"Illegal access to reduction symbol {symbol.name} after freezing a kernel's ReductionAssignment. "
)
symb = self._ctx.get_symbol(expr.name, dtype)
return PsSymbolExpr(symb)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment