Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
...@@ -57,7 +57,7 @@ class GenericCpu(Platform): ...@@ -57,7 +57,7 @@ class GenericCpu(Platform):
def select_function(self, call: PsCall) -> PsExpression: def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction) assert isinstance(call.function, PsMathFunction)
func = call.function.func func = call.function.func
dtype = call.get_dtype() dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args arg_types = (dtype,) * func.num_args
...@@ -87,13 +87,13 @@ class GenericCpu(Platform): ...@@ -87,13 +87,13 @@ class GenericCpu(Platform):
call.function = cfunc call.function = cfunc
return call return call
if isinstance(dtype, PsIntegerType): if isinstance(dtype, PsIntegerType):
match func: match func:
case MathFunctions.Abs: case MathFunctions.Abs:
zero = PsExpression.make(PsConstant(0, dtype)) zero = PsExpression.make(PsConstant(0, dtype))
arg = call.args[0] arg = call.args[0]
return PsTernary(PsGe(arg, zero), arg, - arg) return PsTernary(PsGe(arg, zero), arg, -arg)
case MathFunctions.Min: case MathFunctions.Min:
arg1, arg2 = call.args arg1, arg2 = call.args
return PsTernary(PsLe(arg1, arg2), arg1, arg2) return PsTernary(PsLe(arg1, arg2), arg1, arg2)
...@@ -131,7 +131,10 @@ class GenericCpu(Platform): ...@@ -131,7 +131,10 @@ class GenericCpu(Platform):
PsLookup( PsLookup(
PsBufferAcc( PsBufferAcc(
ispace.index_list.base_pointer, ispace.index_list.base_pointer,
(PsExpression.make(ispace.sparse_counter), factory.parse_index(0)), (
PsExpression.make(ispace.sparse_counter),
factory.parse_index(0),
),
), ),
coord.name, coord.name,
), ),
...@@ -158,26 +161,33 @@ class GenericVectorCpu(GenericCpu, ABC): ...@@ -158,26 +161,33 @@ class GenericVectorCpu(GenericCpu, ABC):
@abstractmethod @abstractmethod
def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType: def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType:
"""Return the intrinsic vector type for the given generic vector type, """Return the intrinsic vector type for the given generic vector type,
or raise an `MaterializationError` if type is not supported.""" or raise a `MaterializationError` if type is not supported."""
@abstractmethod @abstractmethod
def constant_intrinsic(self, c: PsConstant) -> PsExpression: def constant_intrinsic(self, c: PsConstant) -> PsExpression:
"""Return an expression that initializes a constant vector, """Return an expression that initializes a constant vector,
or raise an `MaterializationError` if not supported.""" or raise a `MaterializationError` if not supported."""
@abstractmethod @abstractmethod
def op_intrinsic( def op_intrinsic(
self, expr: PsExpression, operands: Sequence[PsExpression] self, expr: PsExpression, operands: Sequence[PsExpression]
) -> PsExpression: ) -> PsExpression:
"""Return an expression intrinsically invoking the given operation """Return an expression intrinsically invoking the given operation
or raise an `MaterializationError` if not supported.""" or raise a `MaterializationError` if not supported."""
@abstractmethod
def math_func_intrinsic(
self, expr: PsCall, operands: Sequence[PsExpression]
) -> PsExpression:
"""Return an expression intrinsically invoking the given mathematical
function or raise a `MaterializationError` if not supported."""
@abstractmethod @abstractmethod
def vector_load(self, acc: PsVecMemAcc) -> PsExpression: def vector_load(self, acc: PsVecMemAcc) -> PsExpression:
"""Return an expression intrinsically performing a vector load, """Return an expression intrinsically performing a vector load,
or raise an `MaterializationError` if not supported.""" or raise a `MaterializationError` if not supported."""
@abstractmethod @abstractmethod
def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression: def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression:
"""Return an expression intrinsically performing a vector store, """Return an expression intrinsically performing a vector store,
or raise an `MaterializationError` if not supported.""" or raise a `MaterializationError` if not supported."""
...@@ -13,7 +13,9 @@ from ..ast.expressions import ( ...@@ -13,7 +13,9 @@ from ..ast.expressions import (
PsSub, PsSub,
PsMul, PsMul,
PsDiv, PsDiv,
PsConstantExpr PsConstantExpr,
PsCast,
PsCall,
) )
from ..ast.vector import PsVecMemAcc, PsVecBroadcast from ..ast.vector import PsVecMemAcc, PsVecBroadcast
from ...types import PsCustomType, PsVectorType, PsPointerType from ...types import PsCustomType, PsVectorType, PsPointerType
...@@ -23,15 +25,15 @@ from ..exceptions import MaterializationError ...@@ -23,15 +25,15 @@ from ..exceptions import MaterializationError
from .generic_cpu import GenericVectorCpu from .generic_cpu import GenericVectorCpu
from ..kernelcreation import KernelCreationContext from ..kernelcreation import KernelCreationContext
from ...types.quick import Fp, SInt from ...types.quick import Fp, UInt, SInt
from ..functions import CFunction from ..functions import CFunction, PsMathFunction, MathFunctions
class X86VectorArch(Enum): class X86VectorArch(Enum):
SSE = 128 SSE = 128
AVX = 256 AVX = 256
AVX512 = 512 AVX512 = 512
AVX512_FP16 = AVX512 + 1 # TODO improve modelling? AVX512_FP16 = AVX512 + 1 # TODO improve modelling?
def __ge__(self, other: X86VectorArch) -> bool: def __ge__(self, other: X86VectorArch) -> bool:
return self.value >= other.value return self.value >= other.value
...@@ -78,7 +80,7 @@ class X86VectorArch(Enum): ...@@ -78,7 +80,7 @@ class X86VectorArch(Enum):
) )
return suffix return suffix
def intrin_type(self, vtype: PsVectorType): def intrin_type(self, vtype: PsVectorType):
scalar_type = vtype.scalar_type scalar_type = vtype.scalar_type
match scalar_type: match scalar_type:
...@@ -96,9 +98,7 @@ class X86VectorArch(Enum): ...@@ -96,9 +98,7 @@ class X86VectorArch(Enum):
) )
if vtype.width > self.max_vector_width: if vtype.width > self.max_vector_width:
raise MaterializationError( raise MaterializationError(f"x86/{self} does not support {vtype}")
f"x86/{self} does not support {vtype}"
)
return PsCustomType(f"__m{vtype.width}{suffix}") return PsCustomType(f"__m{vtype.width}{suffix}")
...@@ -161,26 +161,124 @@ class X86VectorCpu(GenericVectorCpu): ...@@ -161,26 +161,124 @@ class X86VectorCpu(GenericVectorCpu):
match expr: match expr:
case PsUnOp() | PsBinOp(): case PsUnOp() | PsBinOp():
func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype()) func = _x86_op_intrin(self._vector_arch, expr, expr.get_dtype())
return func(*operands) intrinsic = func(*operands)
intrinsic.dtype = func.return_type
return intrinsic
case _: case _:
raise MaterializationError(f"Cannot map {type(expr)} to x86 intrinsic") raise MaterializationError(f"Cannot map {type(expr)} to x86 intrinsic")
def math_func_intrinsic(
self, expr: PsCall, operands: Sequence[PsExpression]
) -> PsExpression:
assert isinstance(expr.function, PsMathFunction)
vtype = expr.get_dtype()
assert isinstance(vtype, PsVectorType)
prefix = self._vector_arch.intrin_prefix(vtype)
suffix = self._vector_arch.intrin_suffix(vtype)
rtype = atype = self._vector_arch.intrin_type(vtype)
match expr.function.func:
case (
MathFunctions.Exp
| MathFunctions.Log
| MathFunctions.Sin
| MathFunctions.Cos
| MathFunctions.Tan
| MathFunctions.Sinh
| MathFunctions.Cosh
| MathFunctions.ASin
| MathFunctions.ACos
| MathFunctions.ATan
| MathFunctions.ATan2
| MathFunctions.Pow
):
raise MaterializationError(
"Trigonometry, exp, log, and pow require SVML."
)
case MathFunctions.Floor | MathFunctions.Ceil if vtype.is_float():
opstr = expr.function.func.function_name
if vtype.width > 256:
raise MaterializationError("512bit ceil/floor require SVML.")
case MathFunctions.Min | MathFunctions.Max:
opstr = expr.function.func.function_name
if (
vtype.is_int()
and vtype.scalar_type.width == 64
and self._vector_arch < X86VectorArch.AVX512
):
raise MaterializationError(
"64bit integer (signed and unsigned) min/max intrinsics require AVX512."
)
case MathFunctions.Abs:
assert len(operands) == 1, "abs takes exactly one argument."
op = operands[0]
match vtype.scalar_type:
case UInt():
return op
case SInt(width):
opstr = expr.function.func.function_name
if width == 64 and self._vector_arch < X86VectorArch.AVX512:
raise MaterializationError(
"64bit integer abs intrinsic requires AVX512."
)
case Fp():
neg_zero = self.constant_intrinsic(PsConstant(-0.0, vtype))
opstr = "andnot"
func = CFunction(
f"{prefix}_{opstr}_{suffix}", (atype,) * 2, rtype
)
return func(neg_zero, op)
case _:
raise MaterializationError(
f"x86/{self} does not support {expr.function.func.function_name} on type {vtype}."
)
if expr.function.func in [
MathFunctions.ATan2,
MathFunctions.Min,
MathFunctions.Max,
]:
num_args = 2
else:
num_args = 1
func = CFunction(f"{prefix}_{opstr}_{suffix}", (atype,) * num_args, rtype)
return func(*operands)
def vector_load(self, acc: PsVecMemAcc) -> PsExpression: def vector_load(self, acc: PsVecMemAcc) -> PsExpression:
if acc.stride is None: if acc.stride is None:
load_func = _x86_packed_load(self._vector_arch, acc.dtype, False) load_func, addr_type = _x86_packed_load(self._vector_arch, acc.dtype, False)
return load_func( addr: PsExpression = PsAddressOf(PsMemAcc(acc.pointer, acc.offset))
PsAddressOf(PsMemAcc(acc.pointer, acc.offset)) if addr_type:
) addr = PsCast(addr_type, addr)
intrinsic = load_func(addr)
intrinsic.dtype = load_func.return_type
return intrinsic
else: else:
raise NotImplementedError("Gather loads not implemented yet.") raise NotImplementedError("Gather loads not implemented yet.")
def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression: def vector_store(self, acc: PsVecMemAcc, arg: PsExpression) -> PsExpression:
if acc.stride is None: if acc.stride is None:
store_func = _x86_packed_store(self._vector_arch, acc.dtype, False) store_func, addr_type = _x86_packed_store(
return store_func( self._vector_arch, acc.dtype, False
PsAddressOf(PsMemAcc(acc.pointer, acc.offset)),
arg,
) )
addr: PsExpression = PsAddressOf(PsMemAcc(acc.pointer, acc.offset))
if addr_type:
addr = PsCast(addr_type, addr)
intrinsic = store_func(addr, arg)
intrinsic.dtype = store_func.return_type
return intrinsic
else: else:
raise NotImplementedError("Scatter stores not implemented yet.") raise NotImplementedError("Scatter stores not implemented yet.")
...@@ -188,26 +286,46 @@ class X86VectorCpu(GenericVectorCpu): ...@@ -188,26 +286,46 @@ class X86VectorCpu(GenericVectorCpu):
@cache @cache
def _x86_packed_load( def _x86_packed_load(
varch: X86VectorArch, vtype: PsVectorType, aligned: bool varch: X86VectorArch, vtype: PsVectorType, aligned: bool
) -> CFunction: ) -> tuple[CFunction, PsPointerType | None]:
prefix = varch.intrin_prefix(vtype) prefix = varch.intrin_prefix(vtype)
suffix = varch.intrin_suffix(vtype)
ptr_type = PsPointerType(vtype.scalar_type, const=True) ptr_type = PsPointerType(vtype.scalar_type, const=True)
return CFunction(
f"{prefix}_load{'' if aligned else 'u'}_{suffix}", (ptr_type,), vtype if isinstance(vtype.scalar_type, SInt):
suffix = f"si{vtype.width}"
addr_type = PsPointerType(varch.intrin_type(vtype))
else:
suffix = varch.intrin_suffix(vtype)
addr_type = None
return (
CFunction(
f"{prefix}_load{'' if aligned else 'u'}_{suffix}", (ptr_type,), vtype
),
addr_type,
) )
@cache @cache
def _x86_packed_store( def _x86_packed_store(
varch: X86VectorArch, vtype: PsVectorType, aligned: bool varch: X86VectorArch, vtype: PsVectorType, aligned: bool
) -> CFunction: ) -> tuple[CFunction, PsPointerType | None]:
prefix = varch.intrin_prefix(vtype) prefix = varch.intrin_prefix(vtype)
suffix = varch.intrin_suffix(vtype)
ptr_type = PsPointerType(vtype.scalar_type, const=True) ptr_type = PsPointerType(vtype.scalar_type, const=True)
return CFunction(
f"{prefix}_store{'' if aligned else 'u'}_{suffix}", if isinstance(vtype.scalar_type, SInt):
(ptr_type, vtype), suffix = f"si{vtype.width}"
PsCustomType("void"), addr_type = PsPointerType(varch.intrin_type(vtype))
else:
suffix = varch.intrin_suffix(vtype)
addr_type = None
return (
CFunction(
f"{prefix}_store{'' if aligned else 'u'}_{suffix}",
(ptr_type, vtype),
PsCustomType("void"),
),
addr_type,
) )
...@@ -223,7 +341,7 @@ def _x86_op_intrin( ...@@ -223,7 +341,7 @@ def _x86_op_intrin(
case PsVecBroadcast(): case PsVecBroadcast():
opstr = "set1" opstr = "set1"
if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4: if vtype.scalar_type == SInt(64) and vtype.vector_entries <= 4:
suffix += "x" suffix += "x"
atype = vtype.scalar_type atype = vtype.scalar_type
case PsAdd(): case PsAdd():
opstr = "add" opstr = "add"
...@@ -239,8 +357,52 @@ def _x86_op_intrin( ...@@ -239,8 +357,52 @@ def _x86_op_intrin(
opstr = "mul" opstr = "mul"
case PsDiv(): case PsDiv():
opstr = "div" opstr = "div"
case PsCast(target_type, arg):
atype = arg.dtype
widest_type = max(vtype, atype, key=lambda t: t.width)
assert target_type == vtype, "type mismatch"
assert isinstance(atype, PsVectorType)
def panic(detail: str = ""):
raise MaterializationError(
f"Unable to select intrinsic for type conversion: "
f"{varch.name} does not support packed conversion from {atype} to {target_type}. {detail}\n"
f" at: {op}"
)
if atype == vtype:
panic("Use `EliminateConstants` to eliminate trivial casts.")
match (atype.scalar_type, vtype.scalar_type):
# Not supported: cvtepi8_pX, cvtpX_epi8, cvtepi16_p[sd], cvtp[sd]_epi16
case (
(SInt(8), Fp())
| (Fp(), SInt(8))
| (SInt(16), Fp(32))
| (SInt(16), Fp(64))
| (Fp(32), SInt(16))
| (Fp(64), SInt(16))
):
panic()
# AVX512 only: cvtepi64_pX, cvtpX_epi64
case (SInt(64), Fp()) | (
Fp(),
SInt(64),
) if varch < X86VectorArch.AVX512:
panic()
# AVX512 only: cvtepiA_epiT if A > T
case (SInt(a), SInt(t)) if a > t and varch < X86VectorArch.AVX512:
panic()
case _:
prefix = varch.intrin_prefix(widest_type)
opstr = f"cvt{varch.intrin_suffix(atype)}"
case _: case _:
raise MaterializationError(f"Unable to select operation intrinsic for {type(op)}") raise MaterializationError(
f"Unable to select operation intrinsic for {type(op)}"
)
num_args = 1 if isinstance(op, PsUnOp) else 2 num_args = 1 if isinstance(op, PsUnOp) else 2
return CFunction(f"{prefix}_{opstr}_{suffix}", (atype,) * num_args, rtype) return CFunction(f"{prefix}_{opstr}_{suffix}", (atype,) * num_args, rtype)
from __future__ import annotations from __future__ import annotations
from typing import overload from textwrap import indent
from typing import cast, overload
from dataclasses import dataclass from dataclasses import dataclass
...@@ -11,7 +12,13 @@ from ..constants import PsConstant ...@@ -11,7 +12,13 @@ from ..constants import PsConstant
from ..functions import PsMathFunction from ..functions import PsMathFunction
from ..ast import PsAstNode from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsDeclaration, PsAssignment from ..ast.structural import (
PsBlock,
PsDeclaration,
PsAssignment,
PsLoop,
PsEmptyLeafMixIn,
)
from ..ast.expressions import ( from ..ast.expressions import (
PsExpression, PsExpression,
PsAddressOf, PsAddressOf,
...@@ -24,6 +31,7 @@ from ..ast.expressions import ( ...@@ -24,6 +31,7 @@ from ..ast.expressions import (
PsCall, PsCall,
PsMemAcc, PsMemAcc,
PsBufferAcc, PsBufferAcc,
PsSubscript,
PsAdd, PsAdd,
PsMul, PsMul,
PsSub, PsSub,
...@@ -148,6 +156,10 @@ class VectorizationContext: ...@@ -148,6 +156,10 @@ class VectorizationContext:
) )
return PsVectorType(scalar_type, self._lanes) return PsVectorType(scalar_type, self._lanes)
def axis_ctr_dependees(self, symbols: set[PsSymbol]) -> set[PsSymbol]:
"""Returns all symbols in `symbols` that depend on the axis counter."""
return symbols & (self.vectorized_symbols.keys() | {self.axis.counter})
@dataclass @dataclass
class Affine: class Affine:
...@@ -246,7 +258,7 @@ class AstVectorizer: ...@@ -246,7 +258,7 @@ class AstVectorizer:
def __call__(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode: def __call__(self, node: PsAstNode, vc: VectorizationContext) -> PsAstNode:
"""Perform subtree vectorization. """Perform subtree vectorization.
Args: Args:
node: Root of the subtree that should be vectorized node: Root of the subtree that should be vectorized
vc: Object describing the current vectorization context vc: Object describing the current vectorization context
...@@ -273,6 +285,14 @@ class AstVectorizer: ...@@ -273,6 +285,14 @@ class AstVectorizer:
return PsDeclaration(vec_lhs, vec_rhs) return PsDeclaration(vec_lhs, vec_rhs)
case PsAssignment(lhs, rhs): case PsAssignment(lhs, rhs):
if (
isinstance(lhs, PsSymbolExpr)
and lhs.symbol in vc.vectorized_symbols
):
return PsAssignment(
self.visit_expr(lhs, vc), self.visit_expr(rhs, vc)
)
if not isinstance(lhs, (PsMemAcc, PsBufferAcc)): if not isinstance(lhs, (PsMemAcc, PsBufferAcc)):
raise VectorizationError(f"Unable to vectorize assignment to {lhs}") raise VectorizationError(f"Unable to vectorize assignment to {lhs}")
...@@ -286,6 +306,29 @@ class AstVectorizer: ...@@ -286,6 +306,29 @@ class AstVectorizer:
rhs_vec = self.visit_expr(rhs, vc) rhs_vec = self.visit_expr(rhs, vc)
return PsAssignment(lhs_vec, rhs_vec) return PsAssignment(lhs_vec, rhs_vec)
case PsLoop(counter, start, stop, step, body):
# Check that loop bounds are lane-invariant
free_symbols = (
self._collect_symbols(start)
| self._collect_symbols(stop)
| self._collect_symbols(step)
)
vec_dependencies = vc.axis_ctr_dependees(free_symbols)
if vec_dependencies:
raise VectorizationError(
"Unable to vectorize loop depending on vectorized symbols:\n"
f" Offending dependencies:\n"
f" {vec_dependencies}\n"
f" Found in loop:\n"
f"{indent(str(node), ' ')}"
)
vectorized_body = cast(PsBlock, self.visit(body, vc))
return PsLoop(counter, start, stop, step, vectorized_body)
case PsEmptyLeafMixIn():
return node
case _: case _:
raise NotImplementedError(f"Vectorization of {node} is not implemented") raise NotImplementedError(f"Vectorization of {node} is not implemented")
...@@ -426,6 +469,23 @@ class AstVectorizer: ...@@ -426,6 +469,23 @@ class AstVectorizer:
# Buffer access is lane-invariant # Buffer access is lane-invariant
vec_expr = PsVecBroadcast(vc.lanes, expr.clone()) vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
case PsSubscript(array, index):
# Check that array expression and indices are lane-invariant
free_symbols = self._collect_symbols(array).union(
*[self._collect_symbols(i) for i in index]
)
vec_dependencies = vc.axis_ctr_dependees(free_symbols)
if vec_dependencies:
raise VectorizationError(
"Unable to vectorize array subscript depending on vectorized symbols:\n"
f" Offending dependencies:\n"
f" {vec_dependencies}\n"
f" Found in expression:\n"
f"{indent(str(expr), ' ')}"
)
vec_expr = PsVecBroadcast(vc.lanes, expr.clone())
case _: case _:
raise NotImplementedError( raise NotImplementedError(
f"Vectorization of {type(expr)} is not implemented" f"Vectorization of {type(expr)} is not implemented"
......
...@@ -40,13 +40,7 @@ from ..ast.util import AstEqWrapper ...@@ -40,13 +40,7 @@ from ..ast.util import AstEqWrapper
from ..constants import PsConstant from ..constants import PsConstant
from ..memory import PsSymbol from ..memory import PsSymbol
from ..functions import PsMathFunction from ..functions import PsMathFunction
from ...types import ( from ...types import PsNumericType, PsBoolType, PsScalarType, PsVectorType, constify
PsNumericType,
PsBoolType,
PsScalarType,
PsVectorType,
constify
)
__all__ = ["EliminateConstants"] __all__ = ["EliminateConstants"]
...@@ -261,6 +255,9 @@ class EliminateConstants: ...@@ -261,6 +255,9 @@ class EliminateConstants:
assert isinstance(target_type, PsNumericType) assert isinstance(target_type, PsNumericType)
return PsConstantExpr(c.reinterpret_as(target_type)), True return PsConstantExpr(c.reinterpret_as(target_type)), True
case PsCast(target_type, op) if target_type == op.get_dtype():
return op, all(subtree_constness)
case PsVecBroadcast(lanes, PsConstantExpr(c)): case PsVecBroadcast(lanes, PsConstantExpr(c)):
scalar_type = c.get_dtype() scalar_type = c.get_dtype()
assert isinstance(scalar_type, PsScalarType) assert isinstance(scalar_type, PsScalarType)
...@@ -358,7 +355,10 @@ class EliminateConstants: ...@@ -358,7 +355,10 @@ class EliminateConstants:
from ...utils import c_intdiv from ...utils import c_intdiv
folded = PsConstant(c_intdiv(v1, v2), dtype) folded = PsConstant(c_intdiv(v1, v2), dtype)
elif isinstance(dtype, PsNumericType) and dtype.is_float(): elif (
isinstance(dtype, PsNumericType)
and dtype.is_float()
):
folded = PsConstant(v1 / v2, dtype) folded = PsConstant(v1 / v2, dtype)
if folded is not None: if folded is not None:
......
...@@ -4,11 +4,12 @@ from typing import cast ...@@ -4,11 +4,12 @@ from typing import cast
from ..kernelcreation import KernelCreationContext from ..kernelcreation import KernelCreationContext
from ..memory import PsSymbol from ..memory import PsSymbol
from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement
from ..ast.expressions import PsExpression from ..ast.expressions import PsExpression, PsCall, PsCast, PsLiteral
from ...types import PsVectorType, constify, deconstify from ...types import PsCustomType, PsVectorType, constify, deconstify
from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp
from ..ast.vector import PsVecMemAcc from ..ast.vector import PsVecMemAcc
from ..exceptions import MaterializationError from ..exceptions import MaterializationError
from ..functions import CFunction, PsMathFunction
from ..platforms import GenericVectorCpu from ..platforms import GenericVectorCpu
...@@ -39,22 +40,34 @@ class SelectionContext: ...@@ -39,22 +40,34 @@ class SelectionContext:
class SelectIntrinsics: class SelectIntrinsics:
"""Lower IR vector types to intrinsic vector types, and IR vector operations to intrinsic vector operations. """Lower IR vector types to intrinsic vector types, and IR vector operations to intrinsic vector operations.
This transformation will replace all vectorial IR elements by conforming implementations using This transformation will replace all vectorial IR elements by conforming implementations using
compiler intrinsics for the given execution platform. compiler intrinsics for the given execution platform.
Args: Args:
ctx: The current kernel creation context ctx: The current kernel creation context
platform: Platform object representing the target hardware, which provides the intrinsics platform: Platform object representing the target hardware, which provides the intrinsics
use_builtin_convertvector: If `True`, type conversions between SIMD
vectors use the compiler builtin ``__builtin_convertvector``
instead of instrinsics. It is supported by Clang >= 3.7, GCC >= 9.1,
and ICX. Not supported by ICC or MSVC. Activate if you need type
conversions not natively supported by your CPU, e.g. conversion from
64bit integer to double on an x86 AVX machine. Defaults to `False`.
Raises: Raises:
MaterializationError: If a vector type or operation cannot be represented by intrinsics MaterializationError: If a vector type or operation cannot be represented by intrinsics
on the given platform on the given platform
""" """
def __init__(self, ctx: KernelCreationContext, platform: GenericVectorCpu): def __init__(
self,
ctx: KernelCreationContext,
platform: GenericVectorCpu,
use_builtin_convertvector: bool = False,
):
self._ctx = ctx self._ctx = ctx
self._platform = platform self._platform = platform
self._use_builtin_convertvector = use_builtin_convertvector
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node, SelectionContext(self._ctx, self._platform)) return self.visit(node, SelectionContext(self._ctx, self._platform))
...@@ -68,11 +81,11 @@ class SelectIntrinsics: ...@@ -68,11 +81,11 @@ class SelectIntrinsics:
lhs_new = cast(PsSymbolExpr, self.visit_expr(lhs, sc)) lhs_new = cast(PsSymbolExpr, self.visit_expr(lhs, sc))
rhs_new = self.visit_expr(rhs, sc) rhs_new = self.visit_expr(rhs, sc)
return PsDeclaration(lhs_new, rhs_new) return PsDeclaration(lhs_new, rhs_new)
case PsAssignment(lhs, rhs) if isinstance(lhs, PsVecMemAcc): case PsAssignment(lhs, rhs) if isinstance(lhs, PsVecMemAcc):
new_rhs = self.visit_expr(rhs, sc) new_rhs = self.visit_expr(rhs, sc)
return PsStatement(self._platform.vector_store(lhs, new_rhs)) return PsStatement(self._platform.vector_store(lhs, new_rhs))
case _: case _:
node.children = [self.visit(c, sc) for c in node.children] node.children = [self.visit(c, sc) for c in node.children]
...@@ -89,6 +102,22 @@ class SelectIntrinsics: ...@@ -89,6 +102,22 @@ class SelectIntrinsics:
case PsConstantExpr(c): case PsConstantExpr(c):
return self._platform.constant_intrinsic(c) return self._platform.constant_intrinsic(c)
case PsCast(target_type, operand) if self._use_builtin_convertvector:
assert isinstance(target_type, PsVectorType)
op = self.visit_expr(operand, sc)
rtype = PsCustomType(
f"{target_type.scalar_type.c_string()} __attribute__((__vector_size__({target_type.itemsize})))"
)
target_type_literal = PsExpression.make(PsLiteral(rtype.name, rtype))
func = CFunction(
"__builtin_convertvector", (op.get_dtype(), rtype), target_type
)
intrinsic = func(op, target_type_literal)
intrinsic.dtype = func.return_type
return intrinsic
case PsUnOp(operand): case PsUnOp(operand):
op = self.visit_expr(operand, sc) op = self.visit_expr(operand, sc)
return self._platform.op_intrinsic(expr, [op]) return self._platform.op_intrinsic(expr, [op])
...@@ -102,6 +131,10 @@ class SelectIntrinsics: ...@@ -102,6 +131,10 @@ class SelectIntrinsics:
case PsVecMemAcc(): case PsVecMemAcc():
return self._platform.vector_load(expr) return self._platform.vector_load(expr)
case PsCall(function, args) if isinstance(function, PsMathFunction):
arguments = [self.visit_expr(a, sc) for a in args]
return self._platform.math_func_intrinsic(expr, arguments)
case _: case _:
raise MaterializationError( raise MaterializationError(
f"Unable to select intrinsic implementation for {expr}" f"Unable to select intrinsic implementation for {expr}"
......
...@@ -115,6 +115,18 @@ class Target(Flag): ...@@ -115,6 +115,18 @@ class Target(Flag):
return avail_targets.pop() return avail_targets.pop()
else: else:
return Target.GenericCPU return Target.GenericCPU
@staticmethod
def available_targets() -> list[Target]:
targets = [Target.GenericCPU]
try:
import cupy # noqa: F401
targets.append(Target.CUDA)
except ImportError:
pass
targets += Target.available_vector_cpu_targets()
return targets
@staticmethod @staticmethod
def available_vector_cpu_targets() -> list[Target]: def available_vector_cpu_targets() -> list[Target]:
......
...@@ -31,11 +31,13 @@ except ImportError: ...@@ -31,11 +31,13 @@ except ImportError:
AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets() AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets()
TARGET_IDS = [t.name for t in AVAILABLE_TARGETS] TARGET_IDS = [t.name for t in AVAILABLE_TARGETS]
@pytest.fixture(params=AVAILABLE_TARGETS, ids=TARGET_IDS) @pytest.fixture(params=AVAILABLE_TARGETS, ids=TARGET_IDS)
def target(request) -> ps.Target: def target(request) -> ps.Target:
"""Provides all code generation targets available on the current hardware""" """Provides all code generation targets available on the current hardware"""
return request.param return request.param
@pytest.fixture @pytest.fixture
def gen_config(target: ps.Target): def gen_config(target: ps.Target):
"""Default codegen configuration for the current target. """Default codegen configuration for the current target.
...@@ -56,6 +58,7 @@ def gen_config(target: ps.Target): ...@@ -56,6 +58,7 @@ def gen_config(target: ps.Target):
return gen_config return gen_config
@pytest.fixture() @pytest.fixture()
def xp(target: ps.Target) -> ModuleType: def xp(target: ps.Target) -> ModuleType:
"""Primary array module for the current target. """Primary array module for the current target.
......
import sympy as sp import sympy as sp
import numpy as np import numpy as np
import pytest import pytest
from dataclasses import replace
from pystencils import fields, create_kernel, CreateKernelConfig, Target, Assignment, Field from itertools import product
from pystencils import (
fields,
create_kernel,
CreateKernelConfig,
Target,
Assignment,
Field,
)
from pystencils.backend.ast import dfs_preorder from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.expressions import PsCall from pystencils.backend.ast.expressions import PsCall
...@@ -34,38 +43,42 @@ def binary_function(name, xp): ...@@ -34,38 +43,42 @@ def binary_function(name, xp):
}[name] }[name]
@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) AVAIL_TARGETS = Target.available_targets()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"function_name", "function_name, target",
( list(
"exp", product(
"log", (
"sin", "exp",
"cos", "log",
"tan", "sin",
"sinh", "cos",
"cosh", "tan",
"asin", "sinh",
"acos", "cosh",
"atan", "asin",
"abs", "acos",
"floor", "atan",
"ceil", ),
), [t for t in AVAIL_TARGETS if Target._X86 not in t],
)
)
+ list(
product(
["floor", "ceil"], [t for t in AVAIL_TARGETS if Target._AVX512 not in t]
)
)
+ list(product(["abs"], AVAIL_TARGETS)),
) )
@pytest.mark.parametrize("dtype", (np.float32, np.float64)) @pytest.mark.parametrize("dtype", (np.float32, np.float64))
def test_unary_functions(target, function_name, dtype): def test_unary_functions(gen_config, xp, function_name, dtype):
if target == Target.CUDA:
xp = pytest.importorskip("cupy")
else:
xp = np
sp_func, xp_func = unary_function(function_name, xp) sp_func, xp_func = unary_function(function_name, xp)
resolution = np.finfo(dtype).resolution resolution = np.finfo(dtype).resolution
inp = xp.array( # Array size should be larger than eight, such that vectorized kernels don't just run their remainder loop
[[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [xp.pi, xp.e, 0.0]], dtype=dtype inp = xp.array([0.1, 0.2, 0.0, -0.8, -1.6, -12.592, xp.pi, xp.e, -0.3], dtype=dtype)
)
outp = xp.zeros_like(inp) outp = xp.zeros_like(inp)
reference = xp_func(inp) reference = xp_func(inp)
...@@ -74,7 +87,7 @@ def test_unary_functions(target, function_name, dtype): ...@@ -74,7 +87,7 @@ def test_unary_functions(target, function_name, dtype):
outp_field = inp_field.new_field_with_different_name("outp") outp_field = inp_field.new_field_with_different_name("outp")
asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))] asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))]
gen_config = CreateKernelConfig(target=target, default_dtype=dtype) gen_config = replace(gen_config, default_dtype=dtype)
kernel = create_kernel(asms, gen_config) kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile() kfunc = kernel.compile()
...@@ -83,28 +96,26 @@ def test_unary_functions(target, function_name, dtype): ...@@ -83,28 +96,26 @@ def test_unary_functions(target, function_name, dtype):
xp.testing.assert_allclose(outp, reference, rtol=resolution) xp.testing.assert_allclose(outp, reference, rtol=resolution)
@pytest.mark.parametrize("target", (Target.GenericCPU, Target.CUDA)) @pytest.mark.parametrize(
@pytest.mark.parametrize("function_name", ("min", "max", "pow", "atan2")) "function_name,target",
list(product(["min", "max"], AVAIL_TARGETS))
+ list(
product(["pow", "atan2"], [t for t in AVAIL_TARGETS if Target._X86 not in t])
),
)
@pytest.mark.parametrize("dtype", (np.float32, np.float64)) @pytest.mark.parametrize("dtype", (np.float32, np.float64))
def test_binary_functions(target, function_name, dtype): def test_binary_functions(gen_config, xp, function_name, dtype):
if target == Target.CUDA: sp_func, xp_func = binary_function(function_name, xp)
xp = pytest.importorskip("cupy")
else:
xp = np
sp_func, np_func = binary_function(function_name, xp)
resolution: dtype = np.finfo(dtype).resolution resolution: dtype = np.finfo(dtype).resolution
inp = xp.array( inp = xp.array([0.1, 0.2, 0.3, -0.8, -1.6, -12.592, xp.pi, xp.e, 0.0], dtype=dtype)
[[0.1, 0.2, 0.3], [-0.8, -1.6, -12.592], [xp.pi, xp.e, 0.0]], dtype=dtype
)
inp2 = xp.array( inp2 = xp.array(
[[3.1, -0.5, 21.409], [11.0, 1.0, -14e3], [2.0 * xp.pi, -xp.e, 0.0]], [3.1, -0.5, 21.409, 11.0, 1.0, -14e3, 2.0 * xp.pi, -xp.e, 0.0],
dtype=dtype, dtype=dtype,
) )
outp = xp.zeros_like(inp) outp = xp.zeros_like(inp)
reference = np_func(inp, inp2) reference = xp_func(inp, inp2)
inp_field = Field.create_from_numpy_array("inp", inp) inp_field = Field.create_from_numpy_array("inp", inp)
inp2_field = Field.create_from_numpy_array("inp2", inp) inp2_field = Field.create_from_numpy_array("inp2", inp)
...@@ -115,7 +126,7 @@ def test_binary_functions(target, function_name, dtype): ...@@ -115,7 +126,7 @@ def test_binary_functions(target, function_name, dtype):
outp_field.center(), sp_func(inp_field.center(), inp2_field.center()) outp_field.center(), sp_func(inp_field.center(), inp2_field.center())
) )
] ]
gen_config = CreateKernelConfig(target=target, default_dtype=dtype) gen_config = replace(gen_config, default_dtype=dtype)
kernel = create_kernel(asms, gen_config) kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile() kfunc = kernel.compile()
...@@ -124,26 +135,107 @@ def test_binary_functions(target, function_name, dtype): ...@@ -124,26 +135,107 @@ def test_binary_functions(target, function_name, dtype):
xp.testing.assert_allclose(outp, reference, rtol=resolution) xp.testing.assert_allclose(outp, reference, rtol=resolution)
@pytest.mark.parametrize('a', [sp.Symbol('a'), fields('a: float64[2d]').center]) dtype_and_target_for_integer_funcs = pytest.mark.parametrize(
"dtype, target",
list(product([np.int32], [t for t in AVAIL_TARGETS if t is not Target.CUDA]))
+ list(
product(
[np.int64],
[
t
for t in AVAIL_TARGETS
if t not in (Target.X86_SSE, Target.X86_AVX, Target.CUDA)
],
)
),
)
@dtype_and_target_for_integer_funcs
def test_integer_abs(gen_config, xp, dtype):
sp_func, xp_func = unary_function("abs", xp)
smallest = np.iinfo(dtype).min
largest = np.iinfo(dtype).max
inp = xp.array([-1, 0, 1, 3, -5, -312, smallest + 1, largest], dtype=dtype)
outp = xp.zeros_like(inp)
reference = xp_func(inp)
inp_field = Field.create_from_numpy_array("inp", inp)
outp_field = inp_field.new_field_with_different_name("outp")
asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))]
gen_config = replace(gen_config, default_dtype=dtype)
kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile()
kfunc(inp=inp, outp=outp)
xp.testing.assert_array_equal(outp, reference)
@pytest.mark.parametrize("function_name", ("min", "max"))
@dtype_and_target_for_integer_funcs
def test_integer_binary_functions(gen_config, xp, function_name, dtype):
sp_func, xp_func = binary_function(function_name, xp)
smallest = np.iinfo(dtype).min
largest = np.iinfo(dtype).max
inp1 = xp.array([-1, 0, 1, 3, -5, -312, smallest + 1, largest], dtype=dtype)
inp2 = xp.array([3, -5, 1, 12, 1, 11, smallest + 42, largest - 3], dtype=dtype)
outp = xp.zeros_like(inp1)
reference = xp_func(inp1, inp2)
inp_field = Field.create_from_numpy_array("inp1", inp1)
inp2_field = Field.create_from_numpy_array("inp2", inp2)
outp_field = inp_field.new_field_with_different_name("outp")
asms = [
Assignment(
outp_field.center(), sp_func(inp_field.center(), inp2_field.center())
)
]
gen_config = replace(gen_config, default_dtype=dtype)
kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile()
kfunc(inp1=inp1, inp2=inp2, outp=outp)
xp.testing.assert_array_equal(outp, reference)
@pytest.mark.parametrize("a", [sp.Symbol("a"), fields("a: float64[2d]").center])
def test_avoid_pow(a): def test_avoid_pow(a):
x = fields('x: float64[2d]') x = fields("x: float64[2d]")
up = Assignment(x.center_vector[0], 2 * a ** 2 / 3) up = Assignment(x.center_vector[0], 2 * a**2 / 3)
func = create_kernel(up) func = create_kernel(up)
powers = list(dfs_preorder(func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name)) powers = list(
dfs_preorder(
func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name
)
)
assert not powers assert not powers
@pytest.mark.xfail(reason="fast_div not available yet") @pytest.mark.xfail(reason="fast_div not available yet")
def test_avoid_pow_fast_div(): def test_avoid_pow_fast_div():
x = fields('x: float64[2d]') x = fields("x: float64[2d]")
a = fields('a: float64[2d]').center a = fields("a: float64[2d]").center
up = Assignment(x.center_vector[0], fast_division(1, (a**2))) up = Assignment(x.center_vector[0], fast_division(1, (a**2)))
func = create_kernel(up, config=CreateKernelConfig(target=Target.GPU)) func = create_kernel(up, config=CreateKernelConfig(target=Target.GPU))
powers = list(dfs_preorder(func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name)) powers = list(
dfs_preorder(
func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name
)
)
assert not powers assert not powers
...@@ -151,14 +243,23 @@ def test_avoid_pow_move_constants(): ...@@ -151,14 +243,23 @@ def test_avoid_pow_move_constants():
# At the end of the kernel creation the function move_constants_before_loop will be called # At the end of the kernel creation the function move_constants_before_loop will be called
# This function additionally contains substitutions for symbols with the same value # This function additionally contains substitutions for symbols with the same value
# Thus it simplifies the equations again # Thus it simplifies the equations again
x = fields('x: float64[2d]') x = fields("x: float64[2d]")
a, b, c = sp.symbols("a, b, c") a, b, c = sp.symbols("a, b, c")
up = [Assignment(a, 0.0), up = [
Assignment(b, 0.0), Assignment(a, 0.0),
Assignment(c, 0.0), Assignment(b, 0.0),
Assignment(x.center_vector[0], a**2/18 - a*b/6 - a/18 + b**2/18 + b/18 - c**2/36)] Assignment(c, 0.0),
Assignment(
x.center_vector[0],
a**2 / 18 - a * b / 6 - a / 18 + b**2 / 18 + b / 18 - c**2 / 36,
),
]
func = create_kernel(up) func = create_kernel(up)
powers = list(dfs_preorder(func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name)) powers = list(
dfs_preorder(
func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name
)
)
assert not powers assert not powers
import numpy as np
import pytest
from itertools import product
from pystencils import (
create_kernel,
Target,
Assignment,
Field,
)
from pystencils.sympyextensions.typed_sympy import CastFunc
AVAIL_TARGETS_NO_SSE = [t for t in Target.available_targets() if Target._SSE not in t]
target_and_dtype = pytest.mark.parametrize(
"target, from_type, to_type",
list(
product(
[
t
for t in AVAIL_TARGETS_NO_SSE
if Target._X86 in t and Target._AVX512 not in t
],
[np.int32, np.float32, np.float64],
[np.int32, np.float32, np.float64],
)
)
+ list(
product(
[
t
for t in AVAIL_TARGETS_NO_SSE
if Target._X86 not in t or Target._AVX512 in t
],
[np.int32, np.int64, np.float32, np.float64],
[np.int32, np.int64, np.float32, np.float64],
)
),
)
@target_and_dtype
def test_type_cast(gen_config, xp, from_type, to_type):
if np.issubdtype(from_type, np.floating):
inp = xp.array([-1.25, -0, 1.5, 3, -5, -312, 42, 6.625, -9], dtype=from_type)
else:
inp = xp.array([-1, 0, 1, 3, -5, -312, 42, 6, -9], dtype=from_type)
outp = xp.zeros_like(inp).astype(to_type)
truncated = inp.astype(to_type)
rounded = xp.round(inp).astype(to_type)
inp_field = Field.create_from_numpy_array("inp", inp)
outp_field = Field.create_from_numpy_array("outp", outp)
asms = [Assignment(outp_field.center(), CastFunc(inp_field.center(), to_type))]
kernel = create_kernel(asms, gen_config)
kfunc = kernel.compile()
kfunc(inp=inp, outp=outp)
if np.issubdtype(from_type, np.floating) and not np.issubdtype(
to_type, np.floating
):
# rounding mode depends on platform
try:
xp.testing.assert_array_equal(outp, truncated)
except AssertionError:
xp.testing.assert_array_equal(outp, rounded)
else:
xp.testing.assert_array_equal(outp, truncated)
import sympy as sp import sympy as sp
import pytest import pytest
from pystencils import Assignment, TypedSymbol, fields, FieldType from pystencils import Assignment, TypedSymbol, fields, FieldType, make_slice
from pystencils.sympyextensions import CastFunc, mem_acc from pystencils.sympyextensions import CastFunc, mem_acc
from pystencils.sympyextensions.pointers import AddressOf from pystencils.sympyextensions.pointers import AddressOf
...@@ -18,19 +18,25 @@ from pystencils.backend.transformations import ( ...@@ -18,19 +18,25 @@ from pystencils.backend.transformations import (
AstVectorizer, AstVectorizer,
) )
from pystencils.backend.ast import dfs_preorder from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import PsBlock, PsDeclaration, PsAssignment from pystencils.backend.ast.structural import (
PsBlock,
PsDeclaration,
PsAssignment,
PsLoop,
)
from pystencils.backend.ast.expressions import ( from pystencils.backend.ast.expressions import (
PsSymbolExpr, PsSymbolExpr,
PsConstantExpr, PsConstantExpr,
PsExpression, PsExpression,
PsCast, PsCast,
PsMemAcc, PsMemAcc,
PsCall PsCall,
PsSubscript,
) )
from pystencils.backend.functions import CFunction from pystencils.backend.functions import CFunction
from pystencils.backend.ast.vector import PsVecBroadcast, PsVecMemAcc from pystencils.backend.ast.vector import PsVecBroadcast, PsVecMemAcc
from pystencils.backend.exceptions import VectorizationError from pystencils.backend.exceptions import VectorizationError
from pystencils.types import PsVectorType, deconstify, create_type from pystencils.types import PsArrayType, PsVectorType, deconstify, create_type
def test_vectorize_expressions(): def test_vectorize_expressions():
...@@ -56,7 +62,9 @@ def test_vectorize_expressions(): ...@@ -56,7 +62,9 @@ def test_vectorize_expressions():
factory.parse_sympy(-x * y + 13 * z - 4 * (x / w) * (x + z)), factory.parse_sympy(-x * y + 13 * z - 4 * (x / w) * (x + z)),
factory.parse_sympy(sp.sin(x + z) - sp.cos(w)), factory.parse_sympy(sp.sin(x + z) - sp.cos(w)),
factory.parse_sympy(y**2 - x**2), factory.parse_sympy(y**2 - x**2),
typify(- factory.parse_sympy(x / (w**2))), # place the negation outside, since SymPy would remove it typify(
-factory.parse_sympy(x / (w**2))
), # place the negation outside, since SymPy would remove it
factory.parse_sympy(13 + (1 / w) - sp.exp(x) * 24), factory.parse_sympy(13 + (1 / w) - sp.exp(x) * 24),
]: ]:
vec_expr = vectorize.visit(expr, vc) vec_expr = vectorize.visit(expr, vc)
...@@ -239,7 +247,31 @@ def test_reject_symbol_assignments(): ...@@ -239,7 +247,31 @@ def test_reject_symbol_assignments():
with pytest.raises(VectorizationError): with pytest.raises(VectorizationError):
_ = vectorize.visit(asm, vc) _ = vectorize.visit(asm, vc)
def test_vectorize_assignments():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
x, y = sp.symbols("x, y")
vectorize = AstVectorizer(ctx)
axis = VectorizationAxis(
ctx.get_symbol("ctr", ctx.index_dtype),
)
vc = VectorizationContext(ctx, 4, axis)
decl = PsDeclaration(factory.parse_sympy(x), factory.parse_sympy(sp.sympify(0)))
asm = PsAssignment(factory.parse_sympy(x), factory.parse_sympy(3 + y))
ast = PsBlock([decl, asm])
vec_ast = vectorize.visit(ast, vc)
vec_asm = vec_ast.statements[1]
assert isinstance(vec_asm, PsAssignment)
assert isinstance(vec_asm.lhs.symbol.dtype, PsVectorType)
def test_vectorize_memory_assignments(): def test_vectorize_memory_assignments():
ctx = KernelCreationContext() ctx = KernelCreationContext()
factory = AstFactory(ctx) factory = AstFactory(ctx)
...@@ -260,7 +292,7 @@ def test_vectorize_memory_assignments(): ...@@ -260,7 +292,7 @@ def test_vectorize_memory_assignments():
asm = typify( asm = typify(
PsAssignment( PsAssignment(
factory.parse_sympy(mem_acc(ptr, 3 * ctr + 2)), factory.parse_sympy(mem_acc(ptr, 3 * ctr + 2)),
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)) factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)),
) )
) )
...@@ -303,7 +335,7 @@ def test_invalid_memory_assignments(): ...@@ -303,7 +335,7 @@ def test_invalid_memory_assignments():
asm = typify( asm = typify(
PsAssignment( PsAssignment(
factory.parse_sympy(mem_acc(ptr, 3 * i + 2)), factory.parse_sympy(mem_acc(ptr, 3 * i + 2)),
factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)) factory.parse_sympy(x + y * mem_acc(ptr, ctr + 3)),
) )
) )
...@@ -376,7 +408,9 @@ def test_vectorize_mem_acc(): ...@@ -376,7 +408,9 @@ def test_vectorize_mem_acc():
assert vec_acc.vector_entries == 4 assert vec_acc.vector_entries == 4
# Even more complex affine # Even more complex affine
idx = - factory.parse_index(ctr) / factory.parse_index(i) - factory.parse_index(ctr) * factory.parse_index(j) idx = -factory.parse_index(ctr) / factory.parse_index(i) - factory.parse_index(
ctr
) * factory.parse_index(j)
acc = typify(PsMemAcc(factory.parse_sympy(ptr), idx)) acc = typify(PsMemAcc(factory.parse_sympy(ptr), idx))
assert isinstance(acc, PsMemAcc) assert isinstance(acc, PsMemAcc)
...@@ -386,11 +420,15 @@ def test_vectorize_mem_acc(): ...@@ -386,11 +420,15 @@ def test_vectorize_mem_acc():
assert vec_acc.pointer.structurally_equal(acc.pointer) assert vec_acc.pointer.structurally_equal(acc.pointer)
assert vec_acc.offset is not acc.offset assert vec_acc.offset is not acc.offset
assert vec_acc.offset.structurally_equal(acc.offset) assert vec_acc.offset.structurally_equal(acc.offset)
assert vec_acc.stride.structurally_equal(factory.parse_index(-1) / factory.parse_index(i) - factory.parse_index(j)) assert vec_acc.stride.structurally_equal(
factory.parse_index(-1) / factory.parse_index(i) - factory.parse_index(j)
)
assert vec_acc.vector_entries == 4 assert vec_acc.vector_entries == 4
# Mixture of strides in affine and axis # Mixture of strides in affine and axis
vc = VectorizationContext(ctx, 4, VectorizationAxis(ctx.get_symbol("ctr"), step=factory.parse_index(3))) vc = VectorizationContext(
ctx, 4, VectorizationAxis(ctx.get_symbol("ctr"), step=factory.parse_index(3))
)
acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr)) acc = factory.parse_sympy(mem_acc(ptr, 3 * i + 5 * ctr))
assert isinstance(acc, PsMemAcc) assert isinstance(acc, PsMemAcc)
...@@ -421,7 +459,9 @@ def test_invalid_mem_acc(): ...@@ -421,7 +459,9 @@ def test_invalid_mem_acc():
ptr = TypedSymbol("ptr", create_type("float64 *")) ptr = TypedSymbol("ptr", create_type("float64 *"))
# Non-symbol pointer # Non-symbol pointer
acc = factory.parse_sympy(mem_acc(AddressOf(mem_acc(ptr, 10)), 3 * i + ctr * (3 + ctr))) acc = factory.parse_sympy(
mem_acc(AddressOf(mem_acc(ptr, 10)), 3 * i + ctr * (3 + ctr))
)
with pytest.raises(VectorizationError): with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc) _ = vectorize.visit(acc, vc)
...@@ -503,3 +543,107 @@ def test_invalid_buffer_acc(): ...@@ -503,3 +543,107 @@ def test_invalid_buffer_acc():
with pytest.raises(VectorizationError): with pytest.raises(VectorizationError):
_ = vectorize.visit(acc, vc) _ = vectorize.visit(acc, vc)
def test_vectorize_subscript():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
acc = PsSubscript(
PsExpression.make(ctx.get_symbol("arr", PsArrayType(ctx.default_dtype, 42))),
[PsExpression.make(ctx.get_symbol("i", ctx.index_dtype))],
) # independent of vectorization axis
vec_acc = vectorize.visit(factory._typify(acc), vc)
assert isinstance(vec_acc, PsVecBroadcast)
assert isinstance(vec_acc.operand, PsSubscript)
def test_invalid_subscript():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("ctr", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
acc = PsSubscript(
PsExpression.make(ctx.get_symbol("arr", PsArrayType(ctx.default_dtype, 42))),
[PsExpression.make(ctr)], # depends on vectorization axis
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(factory._typify(acc), vc)
def test_vectorize_nested_loop():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("i", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
ast = factory.loop_nest(
("i", "j"),
make_slice[:8, :8], # inner loop does not depend on vectorization axis
PsBlock(
[
PsDeclaration(
PsExpression.make(ctx.get_symbol("x", ctx.default_dtype)),
PsExpression.make(PsConstant(42, ctx.default_dtype)),
)
]
),
)
vec_ast = vectorize.visit(ast, vc)
inner_loop = next(
dfs_preorder(
vec_ast,
lambda node: isinstance(node, PsLoop) and node.counter.symbol.name == "j",
)
)
decl = inner_loop.body.statements[0]
assert inner_loop.step.structurally_equal(
PsExpression.make(PsConstant(1, ctx.index_dtype))
)
assert isinstance(decl.lhs.symbol.dtype, PsVectorType)
def test_invalid_nested_loop():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
vectorize = AstVectorizer(ctx)
ctr = ctx.get_symbol("i", ctx.index_dtype)
axis = VectorizationAxis(ctr)
vc = VectorizationContext(ctx, 4, axis)
ast = factory.loop_nest(
("i", "j"),
make_slice[:8, :ctr], # inner loop depends on vectorization axis
PsBlock(
[
PsDeclaration(
PsExpression.make(ctx.get_symbol("x", ctx.default_dtype)),
PsExpression.make(PsConstant(42, ctx.default_dtype)),
)
]
),
)
with pytest.raises(VectorizationError):
_ = vectorize.visit(ast, vc)