diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py index 55db67e7c38a377c3a55ab4d784cfea2a523c78d..d7ae8d6a94ddf747eeb8da9b8e97d350691fb2a9 100644 --- a/src/pystencils/backend/ast/vector.py +++ b/src/pystencils/backend/ast/vector.py @@ -41,10 +41,9 @@ class PsVecBroadcast(PsUnOp, PsVectorOp): class PsVecHorizontal(PsBinOp, PsVectorOp): - """Represents a binary operation between a scalar and a vector operand. - With the binary operation not being vectorized, a horizontal reduction - along the lanes of the vector operand is required to extract a scalar value. - The result type will be equal to the scalar operand. + """Perform a horizontal reduction across a vector onto a scalar base value. + + **Example:** vec_horizontal_add(s, v)` will compute `s + v[0] + v[1] + ... + v[n-1]`. Args: scalar_operand: Scalar operand diff --git a/src/pystencils/backend/reduction_op_mapping.py b/src/pystencils/backend/reduction_op_mapping.py index 832f5d0bfe96ea70abd39029e96d709d01bd22b3..59273efab4d75bdaaa28a3ffbfd36b0ac0ed640d 100644 --- a/src/pystencils/backend/reduction_op_mapping.py +++ b/src/pystencils/backend/reduction_op_mapping.py @@ -12,27 +12,20 @@ _available_operator_interface: set[ReductionOp] = { def reduction_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: - if op in _available_operator_interface: - match op: - case ReductionOp.Add: - return PsAdd(op1, op2) - case ReductionOp.Sub: - return PsSub(op1, op2) - case ReductionOp.Mul: - return PsMul(op1, op2) - case ReductionOp.Div: - return PsDiv(op1, op2) - case _: - raise FreezeError( - f"Found unsupported operation type for reduction assignments: {op}." - ) - else: - match op: - case ReductionOp.Min: - return PsCall(PsMathFunction(MathFunctions.Min), [op1, op2]) - case ReductionOp.Max: - return PsCall(PsMathFunction(MathFunctions.Max), [op1, op2]) - case _: - raise FreezeError( - f"Found unsupported operation type for reduction assignments: {op}." - ) + match op: + case ReductionOp.Add: + return PsAdd(op1, op2) + case ReductionOp.Sub: + return PsSub(op1, op2) + case ReductionOp.Mul: + return PsMul(op1, op2) + case ReductionOp.Div: + return PsDiv(op1, op2) + case ReductionOp.Min: + return PsCall(PsMathFunction(MathFunctions.Min), [op1, op2]) + case ReductionOp.Max: + return PsCall(PsMathFunction(MathFunctions.Max), [op1, op2]) + case _: + raise FreezeError( + f"Found unsupported operation type for reduction assignments: {op}." + )