From eb7823a5b28589a45f968ac6415ee2a02543a3ab Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Wed, 12 Feb 2025 14:13:05 +0100 Subject: [PATCH] Minor refactor of reduction ops --- src/pystencils/backend/ast/vector.py | 16 ++++++++-------- src/pystencils/backend/functions.py | 10 +++++----- .../backend/kernelcreation/freeze.py | 9 +++++---- .../backend/platforms/generic_cpu.py | 2 +- src/pystencils/backend/platforms/x86.py | 2 +- src/pystencils/compound_op_mapping.py | 10 +++++----- src/pystencils/sympyextensions/reduction.py | 18 +++++++++--------- 7 files changed, 34 insertions(+), 33 deletions(-) diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py index 8ff1ff8a0..4e6b2ff00 100644 --- a/src/pystencils/backend/ast/vector.py +++ b/src/pystencils/backend/ast/vector.py @@ -51,7 +51,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp): def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp): super().__init__(operand) self._lanes = lanes - self._reduction_operation = reduction_op + self._reduction_op = reduction_op @property def lanes(self) -> int: @@ -62,15 +62,15 @@ class PsVecHorizontal(PsUnOp, PsVectorOp): self._lanes = n @property - def reduction_operation(self) -> ReductionOp: - return self._reduction_operation + def reduction_op(self) -> ReductionOp: + return self._reduction_op - @reduction_operation.setter - def reduction_operation(self, op: ReductionOp): - self._reduction_operation = op + @reduction_op.setter + def reduction_op(self, op: ReductionOp): + self._reduction_op = op def _clone_expr(self) -> PsVecHorizontal: - return PsVecHorizontal(self._lanes, self._operand.clone(), self._operation.clone()) + return PsVecHorizontal(self._lanes, self._operand.clone(), self._reduction_op) def structurally_equal(self, other: PsAstNode) -> bool: if not isinstance(other, PsVecHorizontal): @@ -78,7 +78,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp): return ( super().structurally_equal(other) and self._lanes == other._lanes - and self._operation == other._operation + and self._reduction_op == other._reduction_op ) diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index e1f742386..d28ef5f44 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -152,18 +152,18 @@ class ReductionFunctions(Enum): class PsReductionFunction(PsFunction): - def __init__(self, func: ReductionFunctions, op: ReductionOp) -> None: + def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None: super().__init__(func.function_name, func.num_args) self._func = func - self._op = op + self._reduction_op = reduction_op @property def func(self) -> ReductionFunctions: return self._func @property - def op(self) -> ReductionOp: - return self._op + def reduction_op(self) -> ReductionOp: + return self._reduction_op def __str__(self) -> str: return f"{self._func.function_name}" @@ -172,7 +172,7 @@ class PsReductionFunction(PsFunction): if not isinstance(other, PsReductionFunction): return False - return self._func == other._func + return self._func == other._func and self._reduction_op == other._reduction_op def __hash__(self) -> int: return hash(self._func) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 675655802..ce65cd85d 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -189,6 +189,7 @@ class FreezeExpressions: assert isinstance(rhs, PsExpression) assert isinstance(lhs, PsSymbolExpr) + op = expr.reduction_op orig_lhs_symb = lhs.symbol dtype = lhs.dtype @@ -202,11 +203,11 @@ class FreezeExpressions: new_lhs = PsSymbolExpr(new_lhs_symb) # get new rhs from augmented assignment - new_rhs: PsExpression = compound_op_to_expr(expr.op, new_lhs.clone(), rhs) + new_rhs: PsExpression = compound_op_to_expr(op, new_lhs.clone(), rhs) # match for reduction operation and set neutral init_val init_val: PsExpression - match expr.op: + match op: case ReductionOp.Add: init_val = PsConstantExpr(PsConstant(0)) case ReductionOp.Sub: @@ -218,9 +219,9 @@ class FreezeExpressions: case ReductionOp.Max: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) case _: - raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") + raise FreezeError(f"Unsupported reduced assignment: {op}.") - reduction_info = ReductionInfo(expr.op, init_val, orig_lhs_symb_as_ptr) + reduction_info = ReductionInfo(op, init_val, orig_lhs_symb_as_ptr) # add new symbol for local copy, replace original copy with pointer counterpart and add reduction info self._ctx.add_symbol(new_lhs_symb) diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index aa6e22b85..7655572de 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -74,7 +74,7 @@ class GenericCpu(Platform): return PsDeclaration(symbol_expr, init_val) case ReductionFunctions.WriteBackToPtr: ptr_expr, symbol_expr = call.args - op = call.function.op + op = call.function.reduction_op assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType) assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType) diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 0727b65b9..59c3a178f 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -354,7 +354,7 @@ def _x86_op_intrin( suffix += "x" atype = vtype.scalar_type case PsVecHorizontal(): - opstr = f"horizontal_{op.reduction_operation.name.lower()}" + opstr = f"horizontal_{op.reduction_op.name.lower()}" rtype = vtype.scalar_type case PsAdd(): opstr = "add" diff --git a/src/pystencils/compound_op_mapping.py b/src/pystencils/compound_op_mapping.py index 1eadfa6f0..2dd88fc94 100644 --- a/src/pystencils/compound_op_mapping.py +++ b/src/pystencils/compound_op_mapping.py @@ -1,6 +1,6 @@ from operator import truediv, mul, sub, add -from .backend.ast.expressions import PsExpression, PsCall +from .backend.ast.expressions import PsExpression, PsCall, PsAdd, PsSub, PsMul, PsDiv from .backend.exceptions import FreezeError from .backend.functions import PsMathFunction, MathFunctions from .sympyextensions.reduction import ReductionOp @@ -12,13 +12,13 @@ def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: if op in _available_operator_interface: match op: case ReductionOp.Add: - operator = add + operator = PsAdd case ReductionOp.Sub: - operator = sub + operator = PsSub case ReductionOp.Mul: - operator = mul + operator = PsMul case ReductionOp.Div: - operator = truediv + operator = PsDiv case _: raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.") return operator(op1, op2) diff --git a/src/pystencils/sympyextensions/reduction.py b/src/pystencils/sympyextensions/reduction.py index 9d8aecb5b..25ae5c0ac 100644 --- a/src/pystencils/sympyextensions/reduction.py +++ b/src/pystencils/sympyextensions/reduction.py @@ -22,36 +22,36 @@ class ReductionAssignment(AssignmentBase): binop : CompoundOp Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc. """ - binop = None # type: ReductionOp + reduction_op = None # type: ReductionOp @property - def op(self): - return self.binop + def reduction_op(self): + return self.reduction_op class AddReductionAssignment(ReductionAssignment): - binop = ReductionOp.Add + reduction_op = ReductionOp.Add class SubReductionAssignment(ReductionAssignment): - binop = ReductionOp.Sub + reduction_op = ReductionOp.Sub class MulReductionAssignment(ReductionAssignment): - binop = ReductionOp.Mul + reduction_op = ReductionOp.Mul class MinReductionAssignment(ReductionAssignment): - binop = ReductionOp.Min + reduction_op = ReductionOp.Min class MaxReductionAssignment(ReductionAssignment): - binop = ReductionOp.Max + reduction_op = ReductionOp.Max # Mapping from ReductionOp enum to ReductionAssigment classes _reduction_assignment_classes = { - cls.binop: cls for cls in [ + cls.reduction_op: cls for cls in [ AddReductionAssignment, SubReductionAssignment, MulReductionAssignment, MinReductionAssignment, MaxReductionAssignment ] -- GitLab