Skip to content
Snippets Groups Projects
Commit 128abb90 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Basic structure for intrinsic-based vectorization

parent 1328db5a
Branches
Tags
No related merge requests found
Showing
with 361 additions and 41 deletions
......@@ -52,6 +52,9 @@ from .types import (
PsIntegerType,
PsUnsignedIntegerType,
PsSignedIntegerType,
PsScalarType,
PsVectorType,
PsTypeError
)
from .typed_expressions import PsTypedVariable, ExprOrConstant, PsTypedConstant
......@@ -293,3 +296,39 @@ class PsArrayAccess(pb.Subscript):
def dtype(self) -> PsAbstractType:
"""Data type of this expression, i.e. the element type of the underlying array"""
return self._base_ptr.array.element_type
class PsVectorArrayAccess(pb.AlgebraicLeaf):
mapper_method = intern("map_vector_array_access")
def __init__(self, base_ptr: PsArrayBasePointer, base_index: ExprOrConstant, vector_width: int, stride: int = 1):
element_type = base_ptr.array.element_type
if not isinstance(element_type, PsScalarType):
raise PsTypeError("Cannot generate vector accesses to arrays with non-scalar elements")
self._base_ptr = base_ptr
self._base_index = base_index
self._vector_type = PsVectorType(element_type, vector_width, const=element_type.const)
self._stride = stride
@property
def base_ptr(self) -> PsArrayBasePointer:
return self._base_ptr
@property
def array(self) -> PsLinearizedArray:
return self._base_ptr.array
@property
def base_index(self) -> ExprOrConstant:
return self._base_index
@property
def dtype(self) -> PsVectorType:
"""Data type of this expression, i.e. the resulting generic vector type"""
return self._vector_type
@property
def stride(self) -> int:
return self._stride
......@@ -4,6 +4,7 @@ from .nodes import (
PsExpression,
PsLvalueExpr,
PsSymbolExpr,
PsStatement,
PsAssignment,
PsDeclaration,
PsLoop,
......@@ -23,6 +24,7 @@ __all__ = [
"PsExpression",
"PsLvalueExpr",
"PsSymbolExpr",
"PsStatement",
"PsAssignment",
"PsDeclaration",
"PsLoop",
......
......@@ -8,6 +8,7 @@ from pymbolic.mapper.dependency import DependencyMapper
from .nodes import PsAstNode, PsBlock, failing_cast
from ..constraints import PsKernelConstraint
from ..platforms import Platform
from ..typed_expressions import PsTypedVariable
from ..arrays import PsLinearizedArray, PsArrayBasePointer, PsArrayAssocVar
from ..jit import JitBase, no_jit
......@@ -69,13 +70,14 @@ class PsKernelFunction(PsAstNode):
__match_args__ = ("body",)
def __init__(
self, body: PsBlock, target: Target, name: str = "kernel", jit: JitBase = no_jit
self, body: PsBlock, target: Target, name: str, required_headers: set[str], jit: JitBase = no_jit
):
self._body: PsBlock = body
self._target = target
self._name = name
self._jit = jit
self._required_headers = required_headers
self._constraints: list[PsKernelConstraint] = []
@property
......@@ -108,6 +110,10 @@ class PsKernelFunction(PsAstNode):
def instruction_set(self) -> str | None:
"""For backward compatibility"""
return None
@property
def required_headers(self) -> set[str]:
return self._required_headers
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._body,)
......@@ -136,11 +142,5 @@ class PsKernelFunction(PsAstNode):
tuple(params_list), tuple(arrays), tuple(self._constraints)
)
def get_required_headers(self) -> set[str]:
# To Do: Headers from target/instruction set/...
from .collectors import collect_required_headers
return collect_required_headers(self)
def compile(self) -> Callable[..., None]:
return self._jit.compile(self)
......@@ -135,6 +135,29 @@ class PsSymbolExpr(PsLvalueExpr):
self._expr = symbol
class PsStatement(PsAstNode):
__match_args__ = ("expression")
def __init__(self, expr: PsExpression):
self._expression = expr
@property
def expression(self) -> PsExpression:
return self._expression
@expression.setter
def expression(self, expr: PsExpression):
self._expression = expr
def get_children(self) -> tuple[PsAstNode, ...]:
return (self._expression,)
def set_child(self, idx: int, c: PsAstNode):
idx = [0][idx]
assert idx == 0
self._expression = failing_cast(PsExpression, c)
PsLvalue: TypeAlias = Variable | PsArrayAccess
"""Types of expressions that may occur on the left-hand side of assignments."""
......
......@@ -7,6 +7,7 @@ from .ast import (
PsAstNode,
PsBlock,
PsExpression,
PsStatement,
PsDeclaration,
PsAssignment,
PsLoop,
......@@ -76,6 +77,10 @@ class CAstPrinter:
@visit.case(PsExpression)
def pymb_expression(self, expr: PsExpression):
return self._expr_printer(expr.expression)
@visit.case(PsStatement)
def statement(self, stmt: PsStatement):
return self.indent(f"{self.visit(stmt.expression)};")
@visit.case(PsDeclaration)
def declaration(self, decl: PsDeclaration):
......@@ -90,7 +95,7 @@ class CAstPrinter:
def assignment(self, asm: PsAssignment):
lhs_code = self.visit(asm.lhs)
rhs_code = self.visit(asm.rhs)
return self.indent(f"{lhs_code} = {rhs_code};\n")
return self.indent(f"{lhs_code} = {rhs_code};")
@visit.case(PsLoop)
def loop(self, loop: PsLoop):
......
......@@ -63,9 +63,9 @@ class PsKernelExtensioNModule:
code = ""
# Collect headers
headers = {"<math.h>", "<stdint.h>"}
headers = {"<stdint.h>"}
for kernel in self._kernels.values():
headers |= kernel.get_required_headers()
headers |= kernel.required_headers
header_list = sorted(headers)
header_list.insert(0, '"Python.h"')
......
from .basic_cpu import BasicCpu
from .platform import Platform
from .generic_cpu import GenericCpu, GenericVectorCpu
__all__ = ["BasicCpu"]
__all__ = ["Platform", "GenericCpu", "GenericVectorCpu"]
from typing import Sequence
from abc import ABC, abstractmethod
import pymbolic.primitives as pb
from .platform import Platform
from ..kernelcreation.iteration_space import (
......@@ -7,11 +12,18 @@ from ..kernelcreation.iteration_space import (
)
from ..ast import PsDeclaration, PsSymbolExpr, PsExpression, PsLoop, PsBlock
from ..types import PsVectorType, PsCustomType
from ..typed_expressions import PsTypedConstant
from ..arrays import PsArrayAccess
from ..arrays import PsArrayAccess, PsVectorArrayAccess
from ..transformations.vector_intrinsics import IntrinsicOps
class GenericCpu(Platform):
@property
def required_headers(self) -> set[str]:
return {"<math.h>"}
class BasicCpu(Platform):
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> PsBlock:
......@@ -69,3 +81,40 @@ class BasicCpu(Platform):
)
return PsBlock([loop])
class IntrinsicsError(Exception):
"""Exception indicating a fatal error during intrinsic materialization."""
class GenericVectorCpu(GenericCpu, ABC):
@abstractmethod
def type_intrinsic(self, vector_type: PsVectorType) -> PsCustomType:
"""Return the intrinsic vector type for the given generic vector type,
or raise an `IntrinsicsError` if type is not supported."""
@abstractmethod
def constant_vector(self, c: PsTypedConstant) -> pb.Expression:
"""Return an expression that initializes a constant vector,
or raise an `IntrinsicsError` if not supported."""
@abstractmethod
def op_intrinsic(
self, op: IntrinsicOps, vtype: PsVectorType, args: Sequence[pb.Expression]
) -> pb.Expression:
"""Return an expression intrinsically invoking the given operation
on the given arguments with the given vector type,
or raise an `IntrinsicsError` if not supported."""
@abstractmethod
def vector_load(self, acc: PsVectorArrayAccess) -> pb.Expression:
"""Return an expression intrinsically performing a vector load,
or raise an `IntrinsicsError` if not supported."""
@abstractmethod
def vector_store(
self, acc: PsVectorArrayAccess, arg: pb.Expression
) -> pb.Expression:
"""Return an expression intrinsically performing a vector store,
or raise an `IntrinsicsError` if not supported."""
......@@ -18,6 +18,11 @@ class Platform(ABC):
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
@property
@abstractmethod
def required_headers(self) -> set[str]:
pass
@abstractmethod
def materialize_iteration_space(
self, block: PsBlock, ispace: IterationSpace
......
from .erase_anonymous_structs import EraseAnonymousStructTypes
__all__ = [
"EraseAnonymousStructTypes"
]
\ No newline at end of file
......@@ -5,7 +5,7 @@ from typing import TypeVar
import pymbolic.primitives as pb
from pymbolic.mapper import IdentityMapper
from .context import KernelCreationContext
from ..kernelcreation.context import KernelCreationContext
from ..ast import PsAstNode, PsExpression
from ..arrays import PsArrayAccess, TypeErasedBasePointer
......
from typing import TypeVar, TYPE_CHECKING
from enum import Enum, auto
import pymbolic.primitives as pb
from pymbolic.mapper import IdentityMapper
from ..ast import PsAstNode, PsExpression, PsAssignment, PsStatement
from ..types import PsVectorType
from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from ..arrays import PsVectorArrayAccess
from ..exceptions import PsInternalCompilerError
if TYPE_CHECKING:
from ..platforms import GenericVectorCpu
__all__ = ["IntrinsicOps", "MaterializeVectorIntrinsics"]
NodeT = TypeVar("NodeT", bound=PsAstNode)
class IntrinsicOps(Enum):
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
FMA = auto()
class VectorizationError(Exception):
"""Exception indicating a fatal error during vectorization."""
class VecTypeCtx:
def __init__(self):
self._dtype: None | PsVectorType = None
def get(self) -> PsVectorType | None:
return self._dtype
def set(self, dtype: PsVectorType):
if self._dtype is not None:
raise PsInternalCompilerError("Ambiguous vector types.")
self._dtype = dtype
def reset(self):
self._dtype = None
class MaterializeVectorIntrinsics(IdentityMapper):
def __init__(self, platform: GenericVectorCpu):
self._platform = platform
def __call__(self, node: PsAstNode) -> PsAstNode:
match node:
case PsExpression(expr):
# descend into expr
node.expression = self.rec(expr, VecTypeCtx())
return node
case PsAssignment(lhs, rhs) if isinstance(lhs.expression, PsVectorArrayAccess):
vc = VecTypeCtx()
vc.set(lhs.expression.dtype)
store_arg = self.rec(rhs.expression, vc)
return PsStatement(PsExpression(self._platform.vector_store(lhs.expression, store_arg)))
case other:
for c in other.children:
self(c)
return node
def map_typed_variable(self, tv: PsTypedVariable, vc: VecTypeCtx) -> PsTypedVariable:
if isinstance(tv.dtype, PsVectorType):
intrin_type = self._platform.type_intrinsic(tv.dtype)
vc.set(tv.dtype)
return PsTypedVariable(tv.name, intrin_type)
else:
return tv
def map_constant(self, c: PsTypedConstant, vc: VecTypeCtx) -> ExprOrConstant:
if isinstance(c.dtype, PsVectorType):
vc.set(c.dtype)
return self._platform.constant_vector(c)
else:
return c
def map_vector_array_access(self, acc: PsVectorArrayAccess, vc: VecTypeCtx) -> pb.Expression:
vc.set(acc.dtype)
return self._platform.vector_load(acc)
def map_sum(self, expr: pb.Sum, vc: VecTypeCtx) -> pb.Expression:
args = [self.rec(arg, vc) for arg in expr.children]
vtype = vc.get()
if vtype is not None:
if len(args) != 2:
raise VectorizationError("Cannot vectorize non-binary sums")
return self._platform.op_intrinsic(IntrinsicOps.ADD, vtype, args)
else:
return expr
def map_product(self, expr: pb.Product, vc: VecTypeCtx) -> pb.Expression:
args = [self.rec(arg, vc) for arg in expr.children]
vtype = vc.get()
if vtype is not None:
if len(args) != 2:
raise VectorizationError("Cannot vectorize non-binary products")
return self._platform.op_intrinsic(IntrinsicOps.MUL, vtype, args)
else:
return expr
......@@ -4,6 +4,7 @@ from .basic_types import (
PsStructType,
PsNumericType,
PsScalarType,
PsVectorType,
PsPointerType,
PsIntegerType,
PsUnsignedIntegerType,
......@@ -24,6 +25,7 @@ __all__ = [
"PsPointerType",
"PsNumericType",
"PsScalarType",
"PsVectorType",
"PsIntegerType",
"PsUnsignedIntegerType",
"PsSignedIntegerType",
......
......@@ -66,23 +66,20 @@ class PsAbstractType(ABC):
return "const " if self._const else ""
@abstractmethod
def c_string(self) -> str:
...
def c_string(self) -> str: ...
# -------------------------------------------------------------------------------------------
# Dunder Methods
# -------------------------------------------------------------------------------------------
@abstractmethod
def __eq__(self, other: object) -> bool:
...
def __eq__(self, other: object) -> bool: ...
def __str__(self) -> str:
return self.c_string()
@abstractmethod
def __hash__(self) -> int:
...
def __hash__(self) -> int: ...
class PsCustomType(PsAbstractType):
......@@ -274,6 +271,22 @@ class PsNumericType(PsAbstractType, ABC):
PsTypeError: If the given value cannot be interpreted in this type.
"""
@abstractmethod
def is_int(self) -> bool: ...
@abstractmethod
def is_sint(self) -> bool: ...
@abstractmethod
def is_uint(self) -> bool: ...
@abstractmethod
def is_float(self) -> bool: ...
class PsScalarType(PsNumericType, ABC):
"""Class to model scalar numeric types."""
@abstractmethod
def create_literal(self, value: Any) -> str:
"""Create a C numerical literal for a constant of this type.
......@@ -282,37 +295,103 @@ class PsNumericType(PsAbstractType, ABC):
PsTypeError: If the given value's type is not the numeric type's compiler-internal representation.
"""
@abstractmethod
def is_int(self) -> bool:
...
return isinstance(self, PsIntegerType)
@abstractmethod
def is_sint(self) -> bool:
...
return isinstance(self, PsIntegerType) and self.signed
@abstractmethod
def is_uint(self) -> bool:
...
return isinstance(self, PsIntegerType) and not self.signed
@abstractmethod
def is_float(self) -> bool:
...
return isinstance(self, PsIeeeFloatType)
class PsScalarType(PsNumericType, ABC):
"""Class to model scalar numeric types."""
class PsVectorType(PsNumericType):
"""Class to model packed vectors of numeric type.
Args:
element_type: Underlying scalar data type
num_entries: Number of entries in the vector
"""
def __init__(
self, scalar_type: PsScalarType, vector_width: int, const: bool = False
):
super().__init__(const)
self._vector_width = vector_width
self._scalar_type = constify(scalar_type) if const else deconstify(scalar_type)
@property
def scalar_type(self) -> PsScalarType:
return self._scalar_type
@property
def vector_width(self) -> int:
return self._vector_width
def is_int(self) -> bool:
return isinstance(self, PsIntegerType)
return self._scalar_type.is_int()
def is_sint(self) -> bool:
return isinstance(self, PsIntegerType) and self.signed
return self._scalar_type.is_sint()
def is_uint(self) -> bool:
return isinstance(self, PsIntegerType) and not self.signed
return self._scalar_type.is_uint()
def is_float(self) -> bool:
return isinstance(self, PsIeeeFloatType)
return self._scalar_type.is_float()
@property
def itemsize(self) -> int | None:
if self._scalar_type.itemsize is None:
return None
else:
return self._vector_width * self._scalar_type.itemsize
@property
def numpy_dtype(self):
return np.dtype((self._scalar_type.numpy_dtype, (self._vector_width,)))
def create_constant(self, value: Any) -> Any:
if (
isinstance(value, np.ndarray)
and value.dtype == self.scalar_type.numpy_dtype
and value.shape == (self._vector_width,)
):
return value.copy()
element = self._scalar_type.create_constant(value)
return np.array(
[element] * self._vector_width, dtype=self.scalar_type.numpy_dtype
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsVectorType):
return False
return (
self._base_equal(other)
and self._scalar_type == other._scalar_type
and self._vector_width == other._vector_width
)
def __hash__(self) -> int:
return hash(
("PsVectorType", self._scalar_type, self._vector_width, self._const)
)
def c_string(self) -> str:
raise PsInternalCompilerError(
"Cannot retrieve C type string for generic vector types."
)
def __str__(self) -> str:
return f"vector[{self._scalar_type}, {self._vector_width}]"
def __repr__(self) -> str:
return f"PsVectorType( scalar_type={repr(self._scalar_type)}, vector_width={self._vector_width}, const={self.const} )"
class PsIntegerType(PsScalarType, ABC):
......
......@@ -11,7 +11,9 @@ from .backend.kernelcreation.iteration_space import (
create_sparse_iteration_space,
create_full_iteration_space,
)
from .backend.kernelcreation.transformations import EraseAnonymousStructTypes
from .backend.ast.collectors import collect_required_headers
from .backend.transformations import EraseAnonymousStructTypes
from .enums import Target
from .sympyextensions import AssignmentCollection, Assignment
......@@ -54,10 +56,10 @@ def create_kernel(
match config.target:
case Target.CPU:
from .backend.platforms import BasicCpu
from .backend.platforms import GenericCpu
# TODO: CPU platform should incorporate instruction set info, OpenMP, etc.
platform = BasicCpu(ctx)
platform = GenericCpu(ctx)
case _:
# TODO: CUDA/HIP platform
# TODO: SYCL platform (?)
......@@ -73,8 +75,9 @@ def create_kernel(
kernel_ast = platform.optimize(kernel_ast)
assert config.jit is not None
req_headers = collect_required_headers(kernel_ast) | platform.required_headers
function = PsKernelFunction(
kernel_ast, config.target, name=config.function_name, jit=config.jit
kernel_ast, config.target, config.function_name, req_headers, jit=config.jit
)
function.add_constraints(*ctx.constraints)
......
......@@ -9,14 +9,14 @@ from pystencils.backend.kernelcreation import (
from pystencils.backend.ast import PsBlock, PsLoop, PsComment, dfs_preorder
from pystencils.backend.platforms import BasicCpu
from pystencils.backend.platforms import GenericCpu
@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"])
def test_loop_nest(layout):
ctx = KernelCreationContext()
body = PsBlock([PsComment("Loop body goes here")])
platform = BasicCpu(ctx)
platform = GenericCpu(ctx)
# FZYX Order
archetype_field = Field.create_generic("fzyx_field", spatial_dimensions=3, layout=layout)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment