From eb7823a5b28589a45f968ac6415ee2a02543a3ab Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Wed, 12 Feb 2025 14:13:05 +0100
Subject: [PATCH] Minor refactor of reduction ops

---
 src/pystencils/backend/ast/vector.py           | 16 ++++++++--------
 src/pystencils/backend/functions.py            | 10 +++++-----
 .../backend/kernelcreation/freeze.py           |  9 +++++----
 .../backend/platforms/generic_cpu.py           |  2 +-
 src/pystencils/backend/platforms/x86.py        |  2 +-
 src/pystencils/compound_op_mapping.py          | 10 +++++-----
 src/pystencils/sympyextensions/reduction.py    | 18 +++++++++---------
 7 files changed, 34 insertions(+), 33 deletions(-)

diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py
index 8ff1ff8a0..4e6b2ff00 100644
--- a/src/pystencils/backend/ast/vector.py
+++ b/src/pystencils/backend/ast/vector.py
@@ -51,7 +51,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp):
     def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp):
         super().__init__(operand)
         self._lanes = lanes
-        self._reduction_operation = reduction_op
+        self._reduction_op = reduction_op
 
     @property
     def lanes(self) -> int:
@@ -62,15 +62,15 @@ class PsVecHorizontal(PsUnOp, PsVectorOp):
         self._lanes = n
 
     @property
-    def reduction_operation(self) -> ReductionOp:
-        return self._reduction_operation
+    def reduction_op(self) -> ReductionOp:
+        return self._reduction_op
 
-    @reduction_operation.setter
-    def reduction_operation(self, op: ReductionOp):
-        self._reduction_operation = op
+    @reduction_op.setter
+    def reduction_op(self, op: ReductionOp):
+        self._reduction_op = op
 
     def _clone_expr(self) -> PsVecHorizontal:
-        return PsVecHorizontal(self._lanes, self._operand.clone(), self._operation.clone())
+        return PsVecHorizontal(self._lanes, self._operand.clone(), self._reduction_op)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
         if not isinstance(other, PsVecHorizontal):
@@ -78,7 +78,7 @@ class PsVecHorizontal(PsUnOp, PsVectorOp):
         return (
                 super().structurally_equal(other)
                 and self._lanes == other._lanes
-                and self._operation == other._operation
+                and self._reduction_op == other._reduction_op
         )
 
 
diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index e1f742386..d28ef5f44 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -152,18 +152,18 @@ class ReductionFunctions(Enum):
 
 class PsReductionFunction(PsFunction):
 
-    def __init__(self, func: ReductionFunctions, op: ReductionOp) -> None:
+    def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None:
         super().__init__(func.function_name, func.num_args)
         self._func = func
-        self._op = op
+        self._reduction_op = reduction_op
 
     @property
     def func(self) -> ReductionFunctions:
         return self._func
 
     @property
-    def op(self) -> ReductionOp:
-        return self._op
+    def reduction_op(self) -> ReductionOp:
+        return self._reduction_op
 
     def __str__(self) -> str:
         return f"{self._func.function_name}"
@@ -172,7 +172,7 @@ class PsReductionFunction(PsFunction):
         if not isinstance(other, PsReductionFunction):
             return False
 
-        return self._func == other._func
+        return self._func == other._func and self._reduction_op == other._reduction_op
 
     def __hash__(self) -> int:
         return hash(self._func)
diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py
index 675655802..ce65cd85d 100644
--- a/src/pystencils/backend/kernelcreation/freeze.py
+++ b/src/pystencils/backend/kernelcreation/freeze.py
@@ -189,6 +189,7 @@ class FreezeExpressions:
         assert isinstance(rhs, PsExpression)
         assert isinstance(lhs, PsSymbolExpr)
 
+        op = expr.reduction_op
         orig_lhs_symb = lhs.symbol
         dtype = lhs.dtype
 
@@ -202,11 +203,11 @@ class FreezeExpressions:
         new_lhs = PsSymbolExpr(new_lhs_symb)
 
         # get new rhs from augmented assignment
-        new_rhs: PsExpression = compound_op_to_expr(expr.op, new_lhs.clone(), rhs)
+        new_rhs: PsExpression = compound_op_to_expr(op, new_lhs.clone(), rhs)
 
         # match for reduction operation and set neutral init_val
         init_val: PsExpression
-        match expr.op:
+        match op:
             case ReductionOp.Add:
                 init_val = PsConstantExpr(PsConstant(0))
             case ReductionOp.Sub:
@@ -218,9 +219,9 @@ class FreezeExpressions:
             case ReductionOp.Max:
                 init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
             case _:
-                raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
+                raise FreezeError(f"Unsupported reduced assignment: {op}.")
 
-        reduction_info = ReductionInfo(expr.op, init_val, orig_lhs_symb_as_ptr)
+        reduction_info = ReductionInfo(op, init_val, orig_lhs_symb_as_ptr)
 
         # add new symbol for local copy, replace original copy with pointer counterpart and add reduction info
         self._ctx.add_symbol(new_lhs_symb)
diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py
index aa6e22b85..7655572de 100644
--- a/src/pystencils/backend/platforms/generic_cpu.py
+++ b/src/pystencils/backend/platforms/generic_cpu.py
@@ -74,7 +74,7 @@ class GenericCpu(Platform):
                 return PsDeclaration(symbol_expr, init_val)
             case ReductionFunctions.WriteBackToPtr:
                 ptr_expr, symbol_expr = call.args
-                op = call.function.op
+                op = call.function.reduction_op
 
                 assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
                 assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py
index 0727b65b9..59c3a178f 100644
--- a/src/pystencils/backend/platforms/x86.py
+++ b/src/pystencils/backend/platforms/x86.py
@@ -354,7 +354,7 @@ def _x86_op_intrin(
                 suffix += "x"
             atype = vtype.scalar_type
         case PsVecHorizontal():
-            opstr = f"horizontal_{op.reduction_operation.name.lower()}"
+            opstr = f"horizontal_{op.reduction_op.name.lower()}"
             rtype = vtype.scalar_type
         case PsAdd():
             opstr = "add"
diff --git a/src/pystencils/compound_op_mapping.py b/src/pystencils/compound_op_mapping.py
index 1eadfa6f0..2dd88fc94 100644
--- a/src/pystencils/compound_op_mapping.py
+++ b/src/pystencils/compound_op_mapping.py
@@ -1,6 +1,6 @@
 from operator import truediv, mul, sub, add
 
-from .backend.ast.expressions import PsExpression, PsCall
+from .backend.ast.expressions import PsExpression, PsCall, PsAdd, PsSub, PsMul, PsDiv
 from .backend.exceptions import FreezeError
 from .backend.functions import PsMathFunction, MathFunctions
 from .sympyextensions.reduction import ReductionOp
@@ -12,13 +12,13 @@ def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression:
     if op in _available_operator_interface:
         match op:
             case ReductionOp.Add:
-                operator = add
+                operator = PsAdd
             case ReductionOp.Sub:
-                operator = sub
+                operator = PsSub
             case ReductionOp.Mul:
-                operator = mul
+                operator = PsMul
             case ReductionOp.Div:
-                operator = truediv
+                operator = PsDiv
             case _:
                 raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.")
         return operator(op1, op2)
diff --git a/src/pystencils/sympyextensions/reduction.py b/src/pystencils/sympyextensions/reduction.py
index 9d8aecb5b..25ae5c0ac 100644
--- a/src/pystencils/sympyextensions/reduction.py
+++ b/src/pystencils/sympyextensions/reduction.py
@@ -22,36 +22,36 @@ class ReductionAssignment(AssignmentBase):
     binop : CompoundOp
        Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc.
     """
-    binop = None  # type: ReductionOp
+    reduction_op = None  # type: ReductionOp
 
     @property
-    def op(self):
-        return self.binop
+    def reduction_op(self):
+        return self.reduction_op
 
 
 class AddReductionAssignment(ReductionAssignment):
-    binop = ReductionOp.Add
+    reduction_op = ReductionOp.Add
 
 
 class SubReductionAssignment(ReductionAssignment):
-    binop = ReductionOp.Sub
+    reduction_op = ReductionOp.Sub
 
 
 class MulReductionAssignment(ReductionAssignment):
-    binop = ReductionOp.Mul
+    reduction_op = ReductionOp.Mul
 
 
 class MinReductionAssignment(ReductionAssignment):
-    binop = ReductionOp.Min
+    reduction_op = ReductionOp.Min
 
 
 class MaxReductionAssignment(ReductionAssignment):
-    binop = ReductionOp.Max
+    reduction_op = ReductionOp.Max
 
 
 # Mapping from ReductionOp enum to ReductionAssigment classes
 _reduction_assignment_classes = {
-    cls.binop: cls for cls in [
+    cls.reduction_op: cls for cls in [
         AddReductionAssignment, SubReductionAssignment, MulReductionAssignment,
         MinReductionAssignment, MaxReductionAssignment
     ]
-- 
GitLab