Skip to content
Snippets Groups Projects
Commit d4b7e78f authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Add initial implementation for horizontal reductions for vectorization

parent a2a59d40
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -5,6 +5,7 @@ from typing import cast
from .astnode import PsAstNode
from .expressions import PsExpression, PsLvalue, PsUnOp
from .util import failing_cast
from ...sympyextensions import ReductionOp
from ...types import PsVectorType
......@@ -42,6 +43,45 @@ class PsVecBroadcast(PsUnOp, PsVectorOp):
)
class PsVecHorizontal(PsUnOp, PsVectorOp):
"""Extracts scalar value from N vector lanes."""
__match_args__ = ("lanes", "operand", "operation")
def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp):
super().__init__(operand)
self._lanes = lanes
self._reduction_operation = reduction_op
@property
def lanes(self) -> int:
return self._lanes
@lanes.setter
def lanes(self, n: int):
self._lanes = n
@property
def reduction_operation(self) -> ReductionOp:
return self._reduction_operation
@reduction_operation.setter
def reduction_operation(self, op: ReductionOp):
self._reduction_operation = op
def _clone_expr(self) -> PsVecHorizontal:
return PsVecHorizontal(self._lanes, self._operand.clone(), self._operation.clone())
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVecHorizontal):
return False
return (
super().structurally_equal(other)
and self._lanes == other._lanes
and self._operation == other._operation
)
class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp):
"""Pointer-based vectorized memory access.
......
......@@ -10,7 +10,7 @@ from .base_printer import BasePrinter, Ops, LR
from ..ast import PsAstNode
from ..ast.expressions import PsBufferAcc
from ..ast.vector import PsVecMemAcc, PsVecBroadcast
from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal
if TYPE_CHECKING:
from ...codegen import Kernel
......@@ -77,6 +77,15 @@ class IRAstPrinter(BasePrinter):
f"vec_broadcast<{lanes}>({operand_code})", Ops.Weakest
)
case PsVecHorizontal(lanes, operand, reduction_op):
pc.push_op(Ops.Weakest, LR.Middle)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(
f"vec_horizontal_{reduction_op.name.lower()}<{lanes}>({operand_code})", Ops.Weakest
)
case _:
return super().visit(node, pc)
......
......@@ -49,7 +49,7 @@ from ..ast.expressions import (
PsNeg,
PsNot,
)
from ..ast.vector import PsVecBroadcast, PsVecMemAcc
from ..ast.vector import PsVecBroadcast, PsVecMemAcc, PsVecHorizontal
from ..functions import PsMathFunction, CFunction, PsReductionFunction
from ..ast.util import determine_memory_object
from ..exceptions import TypificationError
......@@ -640,6 +640,22 @@ class Typifier:
tc.apply_dtype(PsVectorType(op_tc.target_type, lanes), expr)
case PsVecHorizontal():
op_tc = TypeContext()
self.visit_expr(expr.operand, op_tc)
if op_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to vector horizontal: {expr.operand}"
)
if not isinstance(op_tc.target_type, PsVectorType):
raise TypificationError(
f"Illegal type in argument to vector horizontal: {op_tc.target_type}"
)
tc.apply_dtype(op_tc.target_type.scalar_type, expr)
case _:
raise NotImplementedError(f"Can't typify {expr}")
......
......@@ -17,8 +17,8 @@ from ..ast.expressions import (
PsCast,
PsCall,
)
from ..ast.vector import PsVecMemAcc, PsVecBroadcast
from ...types import PsCustomType, PsVectorType, PsPointerType
from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal
from ...types import PsCustomType, PsVectorType, PsPointerType, PsType
from ..constants import PsConstant
from ..exceptions import MaterializationError
......@@ -160,7 +160,14 @@ class X86VectorCpu(GenericVectorCpu):
) -> PsExpression:
match expr:
case PsUnOp() | PsBinOp():
func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype())
vtype: PsType
if isinstance(expr, PsVecHorizontal):
# expression itself is scalar, but argument is a vector
vtype = expr.operand.get_dtype()
else:
vtype = expr.get_dtype()
func = _x86_op_intrin(self._vector_arch, expr, vtype)
intrinsic = func(*operands)
intrinsic.dtype = func.return_type
return intrinsic
......@@ -343,6 +350,9 @@ def _x86_op_intrin(
if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4:
suffix += "x"
atype = vtype.scalar_type
case PsVecHorizontal():
opstr = f"horizontal_{op.reduction_operation.name.lower()}"
rtype = vtype.scalar_type
case PsAdd():
opstr = "add"
case PsSub():
......
......@@ -7,9 +7,9 @@ from ...types import PsVectorType, PsScalarType
from ..kernelcreation import KernelCreationContext
from ..constants import PsConstant
from ..ast import PsAstNode
from ..ast.structural import PsLoop, PsBlock, PsDeclaration
from ..ast.expressions import PsExpression, PsTernary, PsGt
from ..ast.vector import PsVecBroadcast
from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment
from ..ast.expressions import PsExpression, PsTernary, PsGt, PsSymbolExpr
from ..ast.vector import PsVecBroadcast, PsVecHorizontal
from ..ast.analysis import collect_undefined_symbols
from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer
......@@ -134,6 +134,21 @@ class LoopVectorizer:
# Prepare vectorization context
vc = VectorizationContext(self._ctx, self._lanes, axis)
# Prepare reductions
simd_init_local_reduction_vars = []
simd_writeback_local_reduction_vars = []
for symb, reduction_info in self._ctx.symbols_reduction_info.items():
# Vectorize symbol for local copy
vector_symb = vc.vectorize_symbol(symb)
# Declare and init vector
simd_init_local_reduction_vars += [self._type_fold(PsDeclaration(
PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb))))]
# Write back vectorization result
simd_writeback_local_reduction_vars += [self._type_fold(PsAssignment(
PsSymbolExpr(symb), PsVecHorizontal(self._lanes, PsSymbolExpr(vector_symb), reduction_info.op)))]
# Generate vectorized loop body
simd_body = self._vectorize_ast(loop.body, vc)
......@@ -224,10 +239,14 @@ class LoopVectorizer:
)
return PsBlock(
simd_init_local_reduction_vars +
[
simd_stop_decl,
simd_step_decl,
simd_loop,
simd_loop
] +
simd_writeback_local_reduction_vars +
[
trailing_start_decl,
trailing_loop,
]
......@@ -238,11 +257,13 @@ class LoopVectorizer:
case LoopVectorizer.TrailingItersTreatment.NONE:
return PsBlock(
simd_init_local_reduction_vars +
[
simd_stop_decl,
simd_step_decl,
simd_loop,
]
] +
simd_writeback_local_reduction_vars
)
@overload
......
......@@ -7,7 +7,7 @@ from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement
from ..ast.expressions import PsExpression, PsCall, PsCast, PsLiteral
from ...types import PsCustomType, PsVectorType, constify, deconstify
from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp
from ..ast.vector import PsVecMemAcc
from ..ast.vector import PsVecMemAcc, PsVecHorizontal
from ..exceptions import MaterializationError
from ..functions import CFunction, PsMathFunction
......@@ -86,6 +86,10 @@ class SelectIntrinsics:
new_rhs = self.visit_expr(rhs, sc)
return PsStatement(self._platform.vector_store(lhs, new_rhs))
case PsAssignment(lhs, rhs) if isinstance(rhs, PsVecHorizontal):
new_rhs = self.visit_expr(rhs, sc)
return PsAssignment(lhs, new_rhs)
case _:
node.children = [self.visit(c, sc) for c in node.children]
......@@ -93,7 +97,13 @@ class SelectIntrinsics:
def visit_expr(self, expr: PsExpression, sc: SelectionContext) -> PsExpression:
if not isinstance(expr.dtype, PsVectorType):
return expr
# special case: result type of horizontal reduction is scalar
if isinstance(expr, PsVecHorizontal):
op = self.visit_expr(expr.operand, sc)
print(op)
return self._platform.op_intrinsic(expr, [op])
else:
return expr
match expr:
case PsSymbolExpr(symb):
......
......@@ -29,7 +29,10 @@ def test_reduction(target, dtype, op):
red_assign = reduction_assignment_from_str(w, op, x.center())
config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True)
vectorize_info = {'instruction_set': 'avx', 'assume_inner_stride_one': True}
config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail \
else ps.CreateKernelConfig(cpu_openmp=True, cpu_vectorize_info=vectorize_info)
ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype)
ps.show_code(ast_reduction)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment