Skip to content
Snippets Groups Projects
Commit d4b7e78f authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Add initial implementation for horizontal reductions for vectorization

parent a2a59d40
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -5,6 +5,7 @@ from typing import cast ...@@ -5,6 +5,7 @@ from typing import cast
from .astnode import PsAstNode from .astnode import PsAstNode
from .expressions import PsExpression, PsLvalue, PsUnOp from .expressions import PsExpression, PsLvalue, PsUnOp
from .util import failing_cast from .util import failing_cast
from ...sympyextensions import ReductionOp
from ...types import PsVectorType from ...types import PsVectorType
...@@ -42,6 +43,45 @@ class PsVecBroadcast(PsUnOp, PsVectorOp): ...@@ -42,6 +43,45 @@ class PsVecBroadcast(PsUnOp, PsVectorOp):
) )
class PsVecHorizontal(PsUnOp, PsVectorOp):
"""Extracts scalar value from N vector lanes."""
__match_args__ = ("lanes", "operand", "operation")
def __init__(self, lanes: int, operand: PsExpression, reduction_op: ReductionOp):
super().__init__(operand)
self._lanes = lanes
self._reduction_operation = reduction_op
@property
def lanes(self) -> int:
return self._lanes
@lanes.setter
def lanes(self, n: int):
self._lanes = n
@property
def reduction_operation(self) -> ReductionOp:
return self._reduction_operation
@reduction_operation.setter
def reduction_operation(self, op: ReductionOp):
self._reduction_operation = op
def _clone_expr(self) -> PsVecHorizontal:
return PsVecHorizontal(self._lanes, self._operand.clone(), self._operation.clone())
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVecHorizontal):
return False
return (
super().structurally_equal(other)
and self._lanes == other._lanes
and self._operation == other._operation
)
class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp): class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp):
"""Pointer-based vectorized memory access. """Pointer-based vectorized memory access.
......
...@@ -10,7 +10,7 @@ from .base_printer import BasePrinter, Ops, LR ...@@ -10,7 +10,7 @@ from .base_printer import BasePrinter, Ops, LR
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.expressions import PsBufferAcc from ..ast.expressions import PsBufferAcc
from ..ast.vector import PsVecMemAcc, PsVecBroadcast from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal
if TYPE_CHECKING: if TYPE_CHECKING:
from ...codegen import Kernel from ...codegen import Kernel
...@@ -77,6 +77,15 @@ class IRAstPrinter(BasePrinter): ...@@ -77,6 +77,15 @@ class IRAstPrinter(BasePrinter):
f"vec_broadcast<{lanes}>({operand_code})", Ops.Weakest f"vec_broadcast<{lanes}>({operand_code})", Ops.Weakest
) )
case PsVecHorizontal(lanes, operand, reduction_op):
pc.push_op(Ops.Weakest, LR.Middle)
operand_code = self.visit(operand, pc)
pc.pop_op()
return pc.parenthesize(
f"vec_horizontal_{reduction_op.name.lower()}<{lanes}>({operand_code})", Ops.Weakest
)
case _: case _:
return super().visit(node, pc) return super().visit(node, pc)
......
...@@ -49,7 +49,7 @@ from ..ast.expressions import ( ...@@ -49,7 +49,7 @@ from ..ast.expressions import (
PsNeg, PsNeg,
PsNot, PsNot,
) )
from ..ast.vector import PsVecBroadcast, PsVecMemAcc from ..ast.vector import PsVecBroadcast, PsVecMemAcc, PsVecHorizontal
from ..functions import PsMathFunction, CFunction, PsReductionFunction from ..functions import PsMathFunction, CFunction, PsReductionFunction
from ..ast.util import determine_memory_object from ..ast.util import determine_memory_object
from ..exceptions import TypificationError from ..exceptions import TypificationError
...@@ -640,6 +640,22 @@ class Typifier: ...@@ -640,6 +640,22 @@ class Typifier:
tc.apply_dtype(PsVectorType(op_tc.target_type, lanes), expr) tc.apply_dtype(PsVectorType(op_tc.target_type, lanes), expr)
case PsVecHorizontal():
op_tc = TypeContext()
self.visit_expr(expr.operand, op_tc)
if op_tc.target_type is None:
raise TypificationError(
f"Unable to determine type of argument to vector horizontal: {expr.operand}"
)
if not isinstance(op_tc.target_type, PsVectorType):
raise TypificationError(
f"Illegal type in argument to vector horizontal: {op_tc.target_type}"
)
tc.apply_dtype(op_tc.target_type.scalar_type, expr)
case _: case _:
raise NotImplementedError(f"Can't typify {expr}") raise NotImplementedError(f"Can't typify {expr}")
......
...@@ -17,8 +17,8 @@ from ..ast.expressions import ( ...@@ -17,8 +17,8 @@ from ..ast.expressions import (
PsCast, PsCast,
PsCall, PsCall,
) )
from ..ast.vector import PsVecMemAcc, PsVecBroadcast from ..ast.vector import PsVecMemAcc, PsVecBroadcast, PsVecHorizontal
from ...types import PsCustomType, PsVectorType, PsPointerType from ...types import PsCustomType, PsVectorType, PsPointerType, PsType
from ..constants import PsConstant from ..constants import PsConstant
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
...@@ -160,7 +160,14 @@ class X86VectorCpu(GenericVectorCpu): ...@@ -160,7 +160,14 @@ class X86VectorCpu(GenericVectorCpu):
) -> PsExpression: ) -> PsExpression:
match expr: match expr:
case PsUnOp() | PsBinOp(): case PsUnOp() | PsBinOp():
func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype()) vtype: PsType
if isinstance(expr, PsVecHorizontal):
# expression itself is scalar, but argument is a vector
vtype = expr.operand.get_dtype()
else:
vtype = expr.get_dtype()
func = _x86_op_intrin(self._vector_arch, expr, vtype)
intrinsic = func(*operands) intrinsic = func(*operands)
intrinsic.dtype = func.return_type intrinsic.dtype = func.return_type
return intrinsic return intrinsic
...@@ -343,6 +350,9 @@ def _x86_op_intrin( ...@@ -343,6 +350,9 @@ def _x86_op_intrin(
if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4: if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4:
suffix += "x" suffix += "x"
atype = vtype.scalar_type atype = vtype.scalar_type
case PsVecHorizontal():
opstr = f"horizontal_{op.reduction_operation.name.lower()}"
rtype = vtype.scalar_type
case PsAdd(): case PsAdd():
opstr = "add" opstr = "add"
case PsSub(): case PsSub():
......
...@@ -7,9 +7,9 @@ from ...types import PsVectorType, PsScalarType ...@@ -7,9 +7,9 @@ from ...types import PsVectorType, PsScalarType
from ..kernelcreation import KernelCreationContext from ..kernelcreation import KernelCreationContext
from ..constants import PsConstant from ..constants import PsConstant
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.structural import PsLoop, PsBlock, PsDeclaration from ..ast.structural import PsLoop, PsBlock, PsDeclaration, PsAssignment
from ..ast.expressions import PsExpression, PsTernary, PsGt from ..ast.expressions import PsExpression, PsTernary, PsGt, PsSymbolExpr
from ..ast.vector import PsVecBroadcast from ..ast.vector import PsVecBroadcast, PsVecHorizontal
from ..ast.analysis import collect_undefined_symbols from ..ast.analysis import collect_undefined_symbols
from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer
...@@ -134,6 +134,21 @@ class LoopVectorizer: ...@@ -134,6 +134,21 @@ class LoopVectorizer:
# Prepare vectorization context # Prepare vectorization context
vc = VectorizationContext(self._ctx, self._lanes, axis) vc = VectorizationContext(self._ctx, self._lanes, axis)
# Prepare reductions
simd_init_local_reduction_vars = []
simd_writeback_local_reduction_vars = []
for symb, reduction_info in self._ctx.symbols_reduction_info.items():
# Vectorize symbol for local copy
vector_symb = vc.vectorize_symbol(symb)
# Declare and init vector
simd_init_local_reduction_vars += [self._type_fold(PsDeclaration(
PsSymbolExpr(vector_symb), PsVecBroadcast(self._lanes, PsSymbolExpr(symb))))]
# Write back vectorization result
simd_writeback_local_reduction_vars += [self._type_fold(PsAssignment(
PsSymbolExpr(symb), PsVecHorizontal(self._lanes, PsSymbolExpr(vector_symb), reduction_info.op)))]
# Generate vectorized loop body # Generate vectorized loop body
simd_body = self._vectorize_ast(loop.body, vc) simd_body = self._vectorize_ast(loop.body, vc)
...@@ -224,10 +239,14 @@ class LoopVectorizer: ...@@ -224,10 +239,14 @@ class LoopVectorizer:
) )
return PsBlock( return PsBlock(
simd_init_local_reduction_vars +
[ [
simd_stop_decl, simd_stop_decl,
simd_step_decl, simd_step_decl,
simd_loop, simd_loop
] +
simd_writeback_local_reduction_vars +
[
trailing_start_decl, trailing_start_decl,
trailing_loop, trailing_loop,
] ]
...@@ -238,11 +257,13 @@ class LoopVectorizer: ...@@ -238,11 +257,13 @@ class LoopVectorizer:
case LoopVectorizer.TrailingItersTreatment.NONE: case LoopVectorizer.TrailingItersTreatment.NONE:
return PsBlock( return PsBlock(
simd_init_local_reduction_vars +
[ [
simd_stop_decl, simd_stop_decl,
simd_step_decl, simd_step_decl,
simd_loop, simd_loop,
] ] +
simd_writeback_local_reduction_vars
) )
@overload @overload
......
...@@ -7,7 +7,7 @@ from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement ...@@ -7,7 +7,7 @@ from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement
from ..ast.expressions import PsExpression, PsCall, PsCast, PsLiteral from ..ast.expressions import PsExpression, PsCall, PsCast, PsLiteral
from ...types import PsCustomType, PsVectorType, constify, deconstify from ...types import PsCustomType, PsVectorType, constify, deconstify
from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp
from ..ast.vector import PsVecMemAcc from ..ast.vector import PsVecMemAcc, PsVecHorizontal
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from ..functions import CFunction, PsMathFunction from ..functions import CFunction, PsMathFunction
...@@ -86,6 +86,10 @@ class SelectIntrinsics: ...@@ -86,6 +86,10 @@ class SelectIntrinsics:
new_rhs = self.visit_expr(rhs, sc) new_rhs = self.visit_expr(rhs, sc)
return PsStatement(self._platform.vector_store(lhs, new_rhs)) return PsStatement(self._platform.vector_store(lhs, new_rhs))
case PsAssignment(lhs, rhs) if isinstance(rhs, PsVecHorizontal):
new_rhs = self.visit_expr(rhs, sc)
return PsAssignment(lhs, new_rhs)
case _: case _:
node.children = [self.visit(c, sc) for c in node.children] node.children = [self.visit(c, sc) for c in node.children]
...@@ -93,7 +97,13 @@ class SelectIntrinsics: ...@@ -93,7 +97,13 @@ class SelectIntrinsics:
def visit_expr(self, expr: PsExpression, sc: SelectionContext) -> PsExpression: def visit_expr(self, expr: PsExpression, sc: SelectionContext) -> PsExpression:
if not isinstance(expr.dtype, PsVectorType): if not isinstance(expr.dtype, PsVectorType):
return expr # special case: result type of horizontal reduction is scalar
if isinstance(expr, PsVecHorizontal):
op = self.visit_expr(expr.operand, sc)
print(op)
return self._platform.op_intrinsic(expr, [op])
else:
return expr
match expr: match expr:
case PsSymbolExpr(symb): case PsSymbolExpr(symb):
......
...@@ -29,7 +29,10 @@ def test_reduction(target, dtype, op): ...@@ -29,7 +29,10 @@ def test_reduction(target, dtype, op):
red_assign = reduction_assignment_from_str(w, op, x.center()) red_assign = reduction_assignment_from_str(w, op, x.center())
config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True) vectorize_info = {'instruction_set': 'avx', 'assume_inner_stride_one': True}
config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail \
else ps.CreateKernelConfig(cpu_openmp=True, cpu_vectorize_info=vectorize_info)
ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype) ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype)
ps.show_code(ast_reduction) ps.show_code(ast_reduction)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment