diff --git a/src/pystencils/backend/ast/vector.py b/src/pystencils/backend/ast/vector.py index 705d250949f3662695d506feeff30c20649eb1c5..8ff1ff8a0847e6c5d18bcdece9881b3b5d14766f 100644 --- a/src/pystencils/backend/ast/vector.py +++ b/src/pystencils/backend/ast/vector.py @@ -5,6 +5,7 @@ from typing import cast from .astnode import PsAstNode from .expressions import PsExpression, PsLvalue, PsUnOp from .util import failing_cast +from ...sympyextensions import ReductionOp from ...types import PsVectorType @@ -42,6 +43,45 @@ class PsVecBroadcast(PsUnOp, PsVectorOp): ) +class PsVecHorizontal(PsUnOp, PsVectorOp): + """Extracts scalar value from N vector lanes.""" + + __match_args__ = ("lanes", "operand", "operation") + + def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp): + super().__init__(operand) + self._lanes = lanes + self._reduction_operation = reduction_op + + @property + def lanes(self) -> int: + return self._lanes + + @lanes.setter + def lanes(self, n: int): + self._lanes = n + + @property + def reduction_operation(self) -> ReductionOp: + return self._reduction_operation + + @reduction_operation.setter + def reduction_operation(self, op: ReductionOp): + self._reduction_operation = op + + def _clone_expr(self) -> PsVecHorizontal: + return PsVecHorizontal(self._lanes, self._operand.clone(), self._operation.clone()) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsVecHorizontal): + return False + return ( + super().structurally_equal(other) + and self._lanes == other._lanes + and self._operation == other._operation + ) + + class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp): """Pointer-based vectorized memory access. diff --git a/src/pystencils/backend/emission/ir_printer.py b/src/pystencils/backend/emission/ir_printer.py index ffb65181ccd71ff95dffd6d006617dadc6809eea..04084dd3bfa3b7bca02e173b6b477adffb11a7a7 100644 --- a/src/pystencils/backend/emission/ir_printer.py +++ b/src/pystencils/backend/emission/ir_printer.py @@ -10,7 +10,7 @@ from .base_printer import BasePrinter, Ops, LR from ..ast import PsAstNode from ..ast.expressions import PsBufferAcc -from ..ast.vector import PsVecMemAcc, PsVecBroadcast +from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal if TYPE_CHECKING: from ...codegen import Kernel @@ -77,6 +77,15 @@ class IRAstPrinter(BasePrinter): f"vec_broadcast<{lanes}>({operand_code})", Ops.Weakest ) + case PsVecHorizontal(lanes, operand, reduction_op): + pc.push_op(Ops.Weakest, LR.Middle) + operand_code = self.visit(operand, pc) + pc.pop_op() + + return pc.parenthesize( + f"vec_horizontal_{reduction_op.name.lower()}<{lanes}>({operand_code})", Ops.Weakest + ) + case _: return super().visit(node, pc) diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 059817bfda92d4714896a86a110cb257ca4cb823..25fb55a0b40010c7ae597d8e1cfb32086b1d0abd 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -49,7 +49,7 @@ from ..ast.expressions import ( PsNeg, PsNot, ) -from ..ast.vector import PsVecBroadcast, PsVecMemAcc +from ..ast.vector import PsVecBroadcast, PsVecMemAcc, PsVecHorizontal from ..functions import PsMathFunction, CFunction, PsReductionFunction from ..ast.util import determine_memory_object from ..exceptions import TypificationError @@ -640,6 +640,22 @@ class Typifier: tc.apply_dtype(PsVectorType(op_tc.target_type, lanes), expr) + case PsVecHorizontal(): + op_tc = TypeContext() + self.visit_expr(expr.operand, op_tc) + + if op_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of argument to vector horizontal: {expr.operand}" + ) + + if not isinstance(op_tc.target_type, PsVectorType): + raise TypificationError( + f"Illegal type in argument to vector horizontal: {op_tc.target_type}" + ) + + tc.apply_dtype(op_tc.target_type.scalar_type, expr) + case _: raise NotImplementedError(f"Can't typify {expr}") diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 7d2fe650fc23a54ea301817d57a3816b6780bd85..acd3971551a656cc9e06f5b052812422c967cd14 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -17,8 +17,8 @@ from ..ast.expressions import ( PsCast, PsCall, ) -from ..ast.vector import PsVecMemAcc, PsVecBroadcast -from ...types import PsCustomType, PsVectorType, PsPointerType +from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal +from ...types import PsCustomType, PsVectorType, PsPointerType, PsType from ..constants import PsConstant from ..exceptions import MaterializationError @@ -160,7 +160,14 @@ class X86VectorCpu(GenericVectorCpu): ) -> PsExpression: match expr: case PsUnOp() | PsBinOp(): - func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype()) + vtype: PsType + if isinstance(expr, PsVecHorizontal): + # expression itself is scalar, but argument is a vector + vtype = expr.operand.get_dtype() + else: + vtype = expr.get_dtype() + + func = _x86_op_intrin(self._vector_arch, expr, vtype) intrinsic = func(*operands) intrinsic.dtype = func.return_type return intrinsic @@ -343,6 +350,9 @@ def _x86_op_intrin( if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4: suffix += "x" atype = vtype.scalar_type + case PsVecHorizontal(): + opstr = f"horizontal_{op.reduction_operation.name.lower()}" + rtype = vtype.scalar_type case PsAdd(): opstr = "add" case PsSub(): diff --git a/src/pystencils/backend/transformations/loop_vectorizer.py b/src/pystencils/backend/transformations/loop_vectorizer.py index e1e4fea502c08de86e13de5e3c251f1b7a7d0ee6..39d72adb4a6e7e4481072252be54ab13997d2f32 100644 --- a/src/pystencils/backend/transformations/loop_vectorizer.py +++ b/src/pystencils/backend/transformations/loop_vectorizer.py @@ -7,9 +7,9 @@ from ...types import PsVectorType, PsScalarType from ..kernelcreation import KernelCreationContext from ..constants import PsConstant from ..ast import PsAstNode -from ..ast.structural import PsLoop, PsBlock, PsDeclaration -from ..ast.expressions import PsExpression, PsTernary, PsGt -from ..ast.vector import PsVecBroadcast +from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment +from ..ast.expressions import PsExpression, PsTernary, PsGt, PsSymbolExpr +from ..ast.vector import PsVecBroadcast, PsVecHorizontal from ..ast.analysis import collect_undefined_symbols from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer @@ -134,6 +134,21 @@ class LoopVectorizer: # Prepare vectorization context vc = VectorizationContext(self._ctx, self._lanes, axis) + # Prepare reductions + simd_init_local_reduction_vars = [] + simd_writeback_local_reduction_vars = [] + for symb, reduction_info in self._ctx.symbols_reduction_info.items(): + # Vectorize symbol for local copy + vector_symb = vc.vectorize_symbol(symb) + + # Declare and init vector + simd_init_local_reduction_vars += [self._type_fold(PsDeclaration( + PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb))))] + + # Write back vectorization result + simd_writeback_local_reduction_vars += [self._type_fold(PsAssignment( + PsSymbolExpr(symb), PsVecHorizontal(self._lanes, PsSymbolExpr(vector_symb), reduction_info.op)))] + # Generate vectorized loop body simd_body = self._vectorize_ast(loop.body, vc) @@ -224,10 +239,14 @@ class LoopVectorizer: ) return PsBlock( + simd_init_local_reduction_vars + [ simd_stop_decl, simd_step_decl, - simd_loop, + simd_loop + ] + + simd_writeback_local_reduction_vars + + [ trailing_start_decl, trailing_loop, ] @@ -238,11 +257,13 @@ class LoopVectorizer: case LoopVectorizer.TrailingItersTreatment.NONE: return PsBlock( + simd_init_local_reduction_vars + [ simd_stop_decl, simd_step_decl, simd_loop, - ] + ] + + simd_writeback_local_reduction_vars ) @overload diff --git a/src/pystencils/backend/transformations/select_intrinsics.py b/src/pystencils/backend/transformations/select_intrinsics.py index 060192810a7ccb9ab9ed13f64dd7948791078ea4..7a03e293a5a1d939f7a086c8c6468b704756ca76 100644 --- a/src/pystencils/backend/transformations/select_intrinsics.py +++ b/src/pystencils/backend/transformations/select_intrinsics.py @@ -7,7 +7,7 @@ from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement from ..ast.expressions import PsExpression, PsCall, PsCast, PsLiteral from ...types import PsCustomType, PsVectorType, constify, deconstify from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp -from ..ast.vector import PsVecMemAcc +from ..ast.vector import PsVecMemAcc, PsVecHorizontal from ..exceptions import MaterializationError from ..functions import CFunction, PsMathFunction @@ -86,6 +86,10 @@ class SelectIntrinsics: new_rhs = self.visit_expr(rhs, sc) return PsStatement(self._platform.vector_store(lhs, new_rhs)) + case PsAssignment(lhs, rhs) if isinstance(rhs, PsVecHorizontal): + new_rhs = self.visit_expr(rhs, sc) + return PsAssignment(lhs, new_rhs) + case _: node.children = [self.visit(c, sc) for c in node.children] @@ -93,7 +97,13 @@ class SelectIntrinsics: def visit_expr(self, expr: PsExpression, sc: SelectionContext) -> PsExpression: if not isinstance(expr.dtype, PsVectorType): - return expr + # special case: result type of horizontal reduction is scalar + if isinstance(expr, PsVecHorizontal): + op = self.visit_expr(expr.operand, sc) + print(op) + return self._platform.op_intrinsic(expr, [op]) + else: + return expr match expr: case PsSymbolExpr(symb): diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py index be2589912f458306a06c7f96b8d94e6ce176c8f0..f64ba154aeff3c40c9ee02eab872710545611e75 100644 --- a/tests/kernelcreation/test_reduction.py +++ b/tests/kernelcreation/test_reduction.py @@ -29,7 +29,10 @@ def test_reduction(target, dtype, op): red_assign = reduction_assignment_from_str(w, op, x.center()) - config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True) + vectorize_info = {'instruction_set': 'avx', 'assume_inner_stride_one': True} + + config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail \ + else ps.CreateKernelConfig(cpu_openmp=True, cpu_vectorize_info=vectorize_info) ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype) ps.show_code(ast_reduction)