Skip to content
Snippets Groups Projects
Commit 6d048af1 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

various minor fixes and refactorings

parent ee64e7e1
Branches
Tags
No related merge requests found
Pipeline #63251 failed
......@@ -156,6 +156,7 @@ class PsArrayAssocSymbol(PsSymbol, ABC):
Instances of this class represent pointers and indexing information bound
to a particular array.
"""
__match_args__ = ("name", "dtype", "array")
def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray):
......@@ -214,6 +215,7 @@ class PsArrayStrideSymbol(PsArrayAssocSymbol):
Do not instantiate this class yourself, but only use its instances
as provided by `PsLinearizedArray.strides`.
"""
__match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
......
......@@ -70,7 +70,9 @@ class UndefinedSymbolsCollector:
return {symb}
case _:
return reduce(
set.union, (self.visit_expr(cast(PsExpression, c)) for c in expr.children), set()
set.union,
(self.visit_expr(cast(PsExpression, c)) for c in expr.children),
set(),
)
def declared_variables(self, node: PsAstNode) -> set[PsSymbol]:
......
......@@ -85,7 +85,7 @@ class Ops(Enum):
class PrinterCtx:
def __init__(self) -> None:
self.operator_stack = [Ops.Weakest]
self.branch_stack: list[LR] = []
self.branch_stack = [LR.Middle]
self.indent_level = 0
def push_op(self, operator: Ops, branch: LR):
......
from __future__ import annotations
from typing import Iterable, Iterator
from itertools import chain
from types import EllipsisType
......@@ -24,6 +25,14 @@ class FieldsInKernel:
self.custom_fields: set[Field] = set()
self.buffer_fields: set[Field] = set()
def __iter__(self) -> Iterator:
return chain(
self.domain_fields,
self.index_fields,
self.custom_fields,
self.buffer_fields,
)
class KernelCreationContext:
"""Manages the translation process from the SymPy frontend to the backend AST, and collects
......@@ -80,6 +89,7 @@ class KernelCreationContext:
return tuple(self._constraints)
# Symbols
def get_symbol(self, name: str, dtype: PsAbstractType | None = None) -> PsSymbol:
if name not in self._symbols:
symb = PsSymbol(name, None)
......@@ -109,6 +119,10 @@ class KernelCreationContext:
self._symbols[old.name] = new
@property
def symbols(self) -> Iterable[PsSymbol]:
return self._symbols.values()
# Fields and Arrays
@property
......@@ -214,6 +228,10 @@ class KernelCreationContext:
if isinstance(symb, PsSymbol):
self.add_symbol(symb)
@property
def arrays(self) -> Iterable[PsLinearizedArray]:
return self._field_arrays.values()
def get_array(self, field: Field) -> PsLinearizedArray:
"""Retrieve the underlying array for a given field.
......
......@@ -38,9 +38,6 @@ class GenericCpu(Platform):
else:
assert False, "unreachable code"
def optimize(self, kernel: PsBlock) -> PsBlock:
return kernel
# Internals
def _create_domain_loops(
......
......@@ -28,7 +28,3 @@ class Platform(ABC):
self, block: PsBlock, ispace: IterationSpace
) -> PsBlock:
pass
@abstractmethod
def optimize(self, kernel: PsBlock) -> PsBlock:
pass
......@@ -3,7 +3,12 @@ from enum import Enum
from functools import cache
from typing import Sequence
from ..ast.expressions import PsExpression, PsVectorArrayAccess, PsAddressOf, PsSubscript
from ..ast.expressions import (
PsExpression,
PsVectorArrayAccess,
PsAddressOf,
PsSubscript,
)
from ..transformations.vector_intrinsics import IntrinsicOps
from ..types import PsCustomType, PsVectorType
from ..constants import PsConstant
......@@ -135,14 +140,19 @@ class X86VectorCpu(GenericVectorCpu):
def vector_load(self, acc: PsVectorArrayAccess) -> PsExpression:
if acc.stride == 1:
load_func = _x86_packed_load(self._vector_arch, acc.dtype, False)
return load_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)))
return load_func(
PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index))
)
else:
raise NotImplementedError("Gather loads not implemented yet.")
def vector_store(self, acc: PsVectorArrayAccess, arg: PsExpression) -> PsExpression:
if acc.stride == 1:
store_func = _x86_packed_store(self._vector_arch, acc.dtype, False)
return store_func(PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)), arg)
return store_func(
PsAddressOf(PsSubscript(PsExpression.make(acc.base_ptr), acc.index)),
arg,
)
else:
raise NotImplementedError("Scatter stores not implemented yet.")
......
......@@ -32,6 +32,13 @@ class EraseAnonymousStructTypes:
def __call__(self, node: PsAstNode) -> PsAstNode:
self._substitutions = dict()
# Check if AST traversal is even necessary
if not any(
(isinstance(arr.element_type, PsStructType) and arr.element_type.anonymous)
for arr in self._ctx.arrays
):
return node
node = self.visit(node)
for old, new in self._substitutions.items():
......
from typing import cast
from .enums import Target
from .config import CreateKernelConfig
from .backend.ast import PsKernelFunction
from .backend.ast.structural import PsBlock
from .backend.kernelcreation import (
KernelCreationContext,
KernelAnalysis,
......@@ -15,7 +18,6 @@ from .backend.kernelcreation.iteration_space import (
from .backend.ast.analysis import collect_required_headers
from .backend.transformations import EraseAnonymousStructTypes
from .enums import Target
from .sympyextensions import AssignmentCollection, Assignment
......@@ -66,13 +68,12 @@ def create_kernel(
raise NotImplementedError("Target platform not implemented")
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
kernel_ast = EraseAnonymousStructTypes(ctx)(kernel_ast)
kernel_ast = cast(PsBlock, EraseAnonymousStructTypes(ctx)(kernel_ast))
# 7. Apply optimizations
# - Vectorization
# - OpenMP
# - Loop Splitting, Tiling, Blocking
kernel_ast = platform.optimize(kernel_ast)
assert config.jit is not None
req_headers = collect_required_headers(kernel_ast) | platform.required_headers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment