Skip to content
Snippets Groups Projects
Commit 8c4f5a87 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

first version of the full translation pass

parent e5e1a95c
No related branches found
No related tags found
No related merge requests found
Pipeline #61148 failed
......@@ -231,7 +231,7 @@ class PsArrayStrideVar(PsArrayAssocVar):
__match_args__ = ("array", "coordinate", "dtype")
def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType):
name = f"{array.name}_size{coordinate}"
name = f"{array.name}_stride{coordinate}"
super().__init__(name, dtype, array)
self._coordinate = coordinate
......
from typing import overload, cast
import sympy as sp
import pymbolic.primitives as pb
from pymbolic.interop.sympy import SympyToPymbolicMapper
from itertools import chain
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...field import Field, FieldType
from ...typing import BasicType
from .context import KernelCreationContext
from ..ast.nodes import (
PsBlock,
PsAssignment,
PsDeclaration,
PsSymbolExpr,
......@@ -16,13 +23,36 @@ from ..ast.nodes import (
from ..types import constify, make_type
from ..typed_expressions import PsTypedVariable
from ..arrays import PsArrayAccess
from ..exceptions import PsInputError
class FreezeExpressions(SympyToPymbolicMapper):
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
def map_Assignment(self, expr): # noqa
@overload
def __call__(self, asms: AssignmentCollection) -> PsBlock:
...
@overload
def __call__(self, expr: Assignment) -> PsAssignment:
...
@overload
def __call__(self, expr: sp.Basic) -> pb.Expression:
...
def __call__(self, obj):
if isinstance(obj, AssignmentCollection):
return PsBlock([self.rec(asm) for asm in obj.all_assignments])
elif isinstance(obj, Assignment):
return cast(PsAssignment, self.rec(obj))
elif isinstance(obj, sp.Basic):
return cast(pb.Expression, self.rec(obj))
else:
raise PsInputError(f"Don't know how to freeze {obj}")
def map_Assignment(self, expr: Assignment): # noqa
lhs = self.rec(expr.lhs)
rhs = self.rec(expr.rhs)
......
......@@ -48,6 +48,14 @@ class IterationSpace(ABC):
class FullIterationSpace(IterationSpace):
"""N-dimensional full iteration space.
Each dimension of the full iteration space is represented by an instance of `FullIterationSpace.Dimension`.
Dimensions are ordered slowest-to-fastest: The first dimension corresponds to the slowest coordinate,
translates to the outermost loop, while the last dimension is the fastest coordinate and translates
to the innermost loop.
"""
@dataclass
class Dimension:
start: VarOrConstant
......@@ -67,7 +75,7 @@ class FullIterationSpace(IterationSpace):
dim = archetype_field.spatial_dimensions
counters = [
PsTypedVariable(name, ctx.index_dtype)
for name in Defaults.spatial_counter_names
for name in Defaults.spatial_counter_names[:dim]
]
if isinstance(ghost_layers, int):
......@@ -89,13 +97,17 @@ class FullIterationSpace(IterationSpace):
for (gl_left, gl_right) in ghost_layers_spec
]
spatial_shape = archetype_array.shape[:dim]
dimensions = [
FullIterationSpace.Dimension(gl_left, shape - gl_right, one, ctr)
for (gl_left, gl_right), shape, ctr in zip(
ghost_layer_exprs, archetype_array.shape, counters, strict=True
ghost_layer_exprs, spatial_shape, counters, strict=True
)
]
# TODO: Reorder dimensions according to optimal loop layout (?)
return FullIterationSpace(ctx, dimensions)
def __init__(self, ctx: KernelCreationContext, dimensions: Sequence[Dimension]):
......@@ -132,6 +144,7 @@ class FullIterationSpace(IterationSpace):
class SparseIterationSpace(IterationSpace):
# TODO: To properly implement sparse iteration, we still need struct data types
def __init__(
self,
spatial_indices: tuple[PsTypedVariable, ...],
......
from itertools import chain
from ...simp import AssignmentCollection
from ..ast import PsBlock
from ..ast import PsBlock, PsKernelFunction
from ...enums import Target
from .context import KernelCreationContext
from .analysis import KernelAnalysis
......@@ -15,8 +14,6 @@ from .iteration_space import (
create_full_iteration_space,
)
# flake8: noqa
def create_kernel(assignments: AssignmentCollection, options: KernelCreationOptions):
ctx = KernelCreationContext(options)
......@@ -39,16 +36,28 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
kernel_body = typify(kernel_body)
# Up to this point, all was target-agnostic, but now the target becomes relevant.
# Here we might hand off the compilation to a target-specific part of the compiler
# (CPU/CUDA/...), since these will likely also apply very different optimizations.
match options.target:
case Target.CPU:
from .platform import BasicCpu
# TODO: CPU platform should incorporate instruction set info, OpenMP, etc.
platform = BasicCpu(ctx)
case _:
# TODO: CUDA/HIP platform
# TODO: SYCL platform (?)
raise NotImplementedError("Target platform not implemented")
# 6. Add loops or device indexing
# This step translates the iteration space to actual index calculation code and is once again
# different in indexed and domain kernels.
kernel_ast = platform.apply_iteration_space(kernel_body, ispace)
# 7. Apply optimizations
# - Vectorization
# - OpenMP
# - Loop Splitting, Tiling, Blocking
kernel_ast = platform.optimize(kernel_ast)
# 8. Create and return kernel function.
function = PsKernelFunction(kernel_ast, options.target, name=options.function_name)
function.add_constraints(*ctx.constraints)
return function
from .basic_cpu import BasicCpu
__all__ = [
'BasicCpu'
]
from pystencils.nbackend.ast import PsBlock, PsLoop, PsSymbolExpr, PsExpression
from pystencils.nbackend.kernelcreation.iteration_space import (
IterationSpace,
FullIterationSpace,
)
from .platform import Platform
class BasicCpu(Platform):
def apply_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock:
if isinstance(ispace, FullIterationSpace):
return self._create_domain_loops(block, ispace)
else:
raise NotImplementedError("Iteration space not supported yet.")
def optimize(self, kernel: PsBlock) -> PsBlock:
return kernel
# Internals
def _create_domain_loops(
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
dimensions = ispace.dimensions
outer_block = body
for dimension in dimensions[::-1]:
loop = PsLoop(
PsSymbolExpr(dimension.counter),
PsExpression(dimension.start),
PsExpression(dimension.stop),
PsExpression(dimension.step),
outer_block,
)
outer_block = PsBlock([loop])
return outer_block
from abc import ABC, abstractmethod
from ...ast import PsBlock
from ..context import KernelCreationContext
from ..iteration_space import IterationSpace
class Platform(ABC):
"""Abstract base class for all supported platforms.
The platform performs all target-dependent tasks during code generation:
- Translation of the iteration space to an index source (loop nest, GPU indexing, ...)
- Platform-specific optimizations (e.g. vectorization, OpenMP)
"""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
@abstractmethod
def apply_iteration_space(self, block: PsBlock, ispace: IterationSpace) -> PsBlock:
...
@abstractmethod
def optimize(self, kernel: PsBlock) -> PsBlock:
...
......@@ -9,7 +9,7 @@ from .context import KernelCreationContext
from ..types import PsAbstractType, PsNumericType
from ..typed_expressions import PsTypedVariable, PsTypedConstant, ExprOrConstant
from ..arrays import PsArrayAccess
from ..ast import PsAstNode, PsExpression, PsAssignment
from ..ast import PsAstNode, PsBlock, PsExpression, PsAssignment
class TypificationError(Exception):
......@@ -42,6 +42,9 @@ class Typifier(Mapper):
def __call__(self, node: NodeT) -> NodeT:
match node:
case PsBlock([*statements]):
node.statements = [self(s) for s in statements]
case PsExpression(expr):
node.expression, _ = self.rec(expr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment