Skip to content
Snippets Groups Projects
Commit 0c1ef7fb authored by Frederik Hennig's avatar Frederik Hennig
Browse files

introduce basic pattern for function support + some cleanup

parent 85cd82c2
No related merge requests found
Pipeline #61400 failed with stages
in 2 minutes and 42 seconds
"""
Functions supported by pystencils.
Every supported function might require handling logic in the following modules:
- In `freeze.FreezeExpressions`, a case in `map_Function` or a separate mapper method to catch its frontend variant
- In each backend platform, a case in `materialize_functions` to map the function onto a concrete C/C++ implementation
- If very special typing rules apply, a case in `typification.Typifier`.
In most cases, typification of function applications will require no special handling.
TODO: Maybe add a way for the user to register additional functions
TODO: Figure out the best way to describe function signatures and overloads for typing
"""
import pymbolic.primitives as pb
from abc import ABC, abstractmethod
class PsFunction(pb.FunctionSymbol, ABC):
@property
@abstractmethod
def arg_count(self) -> int:
"Number of arguments this function takes"
......@@ -118,3 +118,10 @@ class FreezeExpressions(SympyToPymbolicMapper):
index = summands[0] if len(summands) == 1 else pb.Sum(summands)
return PsArrayAccess(ptr, index)
def map_Function(self, func: sp.Function):
"""Map a SymPy function to a backend-supported function symbol.
SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
"""
raise NotImplementedError()
......@@ -9,7 +9,6 @@ from .freeze import FreezeExpressions
from .typification import Typifier
from .options import KernelCreationOptions
from .iteration_space import (
IterationSpace,
create_sparse_iteration_space,
create_full_iteration_space,
)
......@@ -21,11 +20,10 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
analysis = KernelAnalysis(ctx)
analysis(assignments)
ispace: IterationSpace = (
create_sparse_iteration_space(ctx, assignments)
if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None
else create_full_iteration_space(ctx, assignments)
)
if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None:
ispace = create_sparse_iteration_space(ctx, assignments)
else:
ispace = create_full_iteration_space(ctx, assignments)
ctx.set_iteration_space(ispace)
......@@ -37,22 +35,22 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
match options.target:
case Target.CPU:
from .platform import BasicCpuGen
from .platform import BasicCpu
# TODO: CPU platform should incorporate instruction set info, OpenMP, etc.
platform_generator = BasicCpuGen(ctx)
platform = BasicCpu(ctx)
case _:
# TODO: CUDA/HIP platform
# TODO: SYCL platform (?)
raise NotImplementedError("Target platform not implemented")
kernel_ast = platform_generator.materialize_iteration_space(kernel_body, ispace)
kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
# 7. Apply optimizations
# - Vectorization
# - OpenMP
# - Loop Splitting, Tiling, Blocking
kernel_ast = platform_generator.optimize(kernel_ast)
kernel_ast = platform.optimize(kernel_ast)
function = PsKernelFunction(kernel_ast, options.target, name=options.function_name)
function.add_constraints(*ctx.constraints)
......
from .basic_cpu import BasicCpuGen
from .basic_cpu import BasicCpu
__all__ = [
'BasicCpuGen'
'BasicCpu'
]
from .platform import PlatformGen
from .platform import Platform
from ..iteration_space import (
IterationSpace,
......@@ -11,7 +11,7 @@ from ...typed_expressions import PsTypedConstant
from ...arrays import PsArrayAccess
class BasicCpuGen(PlatformGen):
class BasicCpu(Platform):
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
) -> PsBlock:
......
......@@ -6,7 +6,7 @@ from ..context import KernelCreationContext
from ..iteration_space import IterationSpace
class PlatformGen(ABC):
class Platform(ABC):
"""Abstract base class for all supported platforms.
The platform performs all target-dependent tasks during code generation:
......
......@@ -115,7 +115,9 @@ class Typifier(Mapper):
) -> tuple[PsArrayAccess, PsNumericType]:
self._check_target_type(access, access.dtype, target_type)
index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, deconstify(access.dtype))
return PsArrayAccess(access.base_ptr, index), cast(
PsNumericType, deconstify(access.dtype)
)
# Arithmetic Expressions
......@@ -156,6 +158,16 @@ class Typifier(Mapper):
new_args, dtype = self._homogenize(expr, expr.children, target_type)
return pb.Product(new_args), dtype
def map_call(
self, expr: pb.Call, target_type: PsNumericType | None
) -> tuple[pb.Call, PsNumericType]:
"""
TODO: Figure out the best way to typify functions
- How to propagate target_type in the face of multiple overloads?
"""
raise NotImplementedError()
def _check_target_type(
self,
expr: ExprOrConstant,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment