Skip to content
Snippets Groups Projects
Commit eb7823a5 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Minor refactor of reduction ops

parent b4b105be
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -51,7 +51,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp): ...@@ -51,7 +51,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp):
def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp): def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp):
super().__init__(operand) super().__init__(operand)
self._lanes = lanes self._lanes = lanes
self._reduction_operation = reduction_op self._reduction_op = reduction_op
@property @property
def lanes(self) -> int: def lanes(self) -> int:
...@@ -62,15 +62,15 @@ class PsVecHorizontal(PsUnOp, PsVectorOp): ...@@ -62,15 +62,15 @@ class PsVecHorizontal(PsUnOp, PsVectorOp):
self._lanes = n self._lanes = n
@property @property
def reduction_operation(self) -> ReductionOp: def reduction_op(self) -> ReductionOp:
return self._reduction_operation return self._reduction_op
@reduction_operation.setter @reduction_op.setter
def reduction_operation(self, op: ReductionOp): def reduction_op(self, op: ReductionOp):
self._reduction_operation = op self._reduction_op = op
def _clone_expr(self) -> PsVecHorizontal: def _clone_expr(self) -> PsVecHorizontal:
return PsVecHorizontal(self._lanes, self._operand.clone(), self._operation.clone()) return PsVecHorizontal(self._lanes, self._operand.clone(), self._reduction_op)
def structurally_equal(self, other: PsAstNode) -> bool: def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVecHorizontal): if not isinstance(other, PsVecHorizontal):
...@@ -78,7 +78,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp): ...@@ -78,7 +78,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp):
return ( return (
super().structurally_equal(other) super().structurally_equal(other)
and self._lanes == other._lanes and self._lanes == other._lanes
and self._operation == other._operation and self._reduction_op == other._reduction_op
) )
......
...@@ -152,18 +152,18 @@ class ReductionFunctions(Enum): ...@@ -152,18 +152,18 @@ class ReductionFunctions(Enum):
class PsReductionFunction(PsFunction): class PsReductionFunction(PsFunction):
def __init__(self, func: ReductionFunctions, op: ReductionOp) -> None: def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None:
super().__init__(func.function_name, func.num_args) super().__init__(func.function_name, func.num_args)
self._func = func self._func = func
self._op = op self._reduction_op = reduction_op
@property @property
def func(self) -> ReductionFunctions: def func(self) -> ReductionFunctions:
return self._func return self._func
@property @property
def op(self) -> ReductionOp: def reduction_op(self) -> ReductionOp:
return self._op return self._reduction_op
def __str__(self) -> str: def __str__(self) -> str:
return f"{self._func.function_name}" return f"{self._func.function_name}"
...@@ -172,7 +172,7 @@ class PsReductionFunction(PsFunction): ...@@ -172,7 +172,7 @@ class PsReductionFunction(PsFunction):
if not isinstance(other, PsReductionFunction): if not isinstance(other, PsReductionFunction):
return False return False
return self._func == other._func return self._func == other._func and self._reduction_op == other._reduction_op
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self._func) return hash(self._func)
......
...@@ -189,6 +189,7 @@ class FreezeExpressions: ...@@ -189,6 +189,7 @@ class FreezeExpressions:
assert isinstance(rhs, PsExpression) assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr) assert isinstance(lhs, PsSymbolExpr)
op = expr.reduction_op
orig_lhs_symb = lhs.symbol orig_lhs_symb = lhs.symbol
dtype = lhs.dtype dtype = lhs.dtype
...@@ -202,11 +203,11 @@ class FreezeExpressions: ...@@ -202,11 +203,11 @@ class FreezeExpressions:
new_lhs = PsSymbolExpr(new_lhs_symb) new_lhs = PsSymbolExpr(new_lhs_symb)
# get new rhs from augmented assignment # get new rhs from augmented assignment
new_rhs: PsExpression = compound_op_to_expr(expr.op, new_lhs.clone(), rhs) new_rhs: PsExpression = compound_op_to_expr(op, new_lhs.clone(), rhs)
# match for reduction operation and set neutral init_val # match for reduction operation and set neutral init_val
init_val: PsExpression init_val: PsExpression
match expr.op: match op:
case ReductionOp.Add: case ReductionOp.Add:
init_val = PsConstantExpr(PsConstant(0)) init_val = PsConstantExpr(PsConstant(0))
case ReductionOp.Sub: case ReductionOp.Sub:
...@@ -218,9 +219,9 @@ class FreezeExpressions: ...@@ -218,9 +219,9 @@ class FreezeExpressions:
case ReductionOp.Max: case ReductionOp.Max:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
case _: case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") raise FreezeError(f"Unsupported reduced assignment: {op}.")
reduction_info = ReductionInfo(expr.op, init_val, orig_lhs_symb_as_ptr) reduction_info = ReductionInfo(op, init_val, orig_lhs_symb_as_ptr)
# add new symbol for local copy, replace original copy with pointer counterpart and add reduction info # add new symbol for local copy, replace original copy with pointer counterpart and add reduction info
self._ctx.add_symbol(new_lhs_symb) self._ctx.add_symbol(new_lhs_symb)
......
...@@ -74,7 +74,7 @@ class GenericCpu(Platform): ...@@ -74,7 +74,7 @@ class GenericCpu(Platform):
return PsDeclaration(symbol_expr, init_val) return PsDeclaration(symbol_expr, init_val)
case ReductionFunctions.WriteBackToPtr: case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args ptr_expr, symbol_expr = call.args
op = call.function.op op = call.function.reduction_op
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
......
...@@ -354,7 +354,7 @@ def _x86_op_intrin( ...@@ -354,7 +354,7 @@ def _x86_op_intrin(
suffix += "x" suffix += "x"
atype = vtype.scalar_type atype = vtype.scalar_type
case PsVecHorizontal(): case PsVecHorizontal():
opstr = f"horizontal_{op.reduction_operation.name.lower()}" opstr = f"horizontal_{op.reduction_op.name.lower()}"
rtype = vtype.scalar_type rtype = vtype.scalar_type
case PsAdd(): case PsAdd():
opstr = "add" opstr = "add"
......
from operator import truediv, mul, sub, add from operator import truediv, mul, sub, add
from .backend.ast.expressions import PsExpression, PsCall from .backend.ast.expressions import PsExpression, PsCall, PsAdd, PsSub, PsMul, PsDiv
from .backend.exceptions import FreezeError from .backend.exceptions import FreezeError
from .backend.functions import PsMathFunction, MathFunctions from .backend.functions import PsMathFunction, MathFunctions
from .sympyextensions.reduction import ReductionOp from .sympyextensions.reduction import ReductionOp
...@@ -12,13 +12,13 @@ def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: ...@@ -12,13 +12,13 @@ def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression:
if op in _available_operator_interface: if op in _available_operator_interface:
match op: match op:
case ReductionOp.Add: case ReductionOp.Add:
operator = add operator = PsAdd
case ReductionOp.Sub: case ReductionOp.Sub:
operator = sub operator = PsSub
case ReductionOp.Mul: case ReductionOp.Mul:
operator = mul operator = PsMul
case ReductionOp.Div: case ReductionOp.Div:
operator = truediv operator = PsDiv
case _: case _:
raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.") raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.")
return operator(op1, op2) return operator(op1, op2)
......
...@@ -22,36 +22,36 @@ class ReductionAssignment(AssignmentBase): ...@@ -22,36 +22,36 @@ class ReductionAssignment(AssignmentBase):
binop : CompoundOp binop : CompoundOp
Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc. Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc.
""" """
binop = None # type: ReductionOp reduction_op = None # type: ReductionOp
@property @property
def op(self): def reduction_op(self):
return self.binop return self.reduction_op
class AddReductionAssignment(ReductionAssignment): class AddReductionAssignment(ReductionAssignment):
binop = ReductionOp.Add reduction_op = ReductionOp.Add
class SubReductionAssignment(ReductionAssignment): class SubReductionAssignment(ReductionAssignment):
binop = ReductionOp.Sub reduction_op = ReductionOp.Sub
class MulReductionAssignment(ReductionAssignment): class MulReductionAssignment(ReductionAssignment):
binop = ReductionOp.Mul reduction_op = ReductionOp.Mul
class MinReductionAssignment(ReductionAssignment): class MinReductionAssignment(ReductionAssignment):
binop = ReductionOp.Min reduction_op = ReductionOp.Min
class MaxReductionAssignment(ReductionAssignment): class MaxReductionAssignment(ReductionAssignment):
binop = ReductionOp.Max reduction_op = ReductionOp.Max
# Mapping from ReductionOp enum to ReductionAssigment classes # Mapping from ReductionOp enum to ReductionAssigment classes
_reduction_assignment_classes = { _reduction_assignment_classes = {
cls.binop: cls for cls in [ cls.reduction_op: cls for cls in [
AddReductionAssignment, SubReductionAssignment, MulReductionAssignment, AddReductionAssignment, SubReductionAssignment, MulReductionAssignment,
MinReductionAssignment, MaxReductionAssignment MinReductionAssignment, MaxReductionAssignment
] ]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment