Skip to content
Snippets Groups Projects
Commit 0c1ef7fb authored by Frederik Hennig's avatar Frederik Hennig
Browse files

introduce basic pattern for function support + some cleanup

parent 85cd82c2
No related branches found
No related tags found
No related merge requests found
Pipeline #61400 failed
"""
Functions supported by pystencils.
Every supported function might require handling logic in the following modules:
- In `freeze.FreezeExpressions`, a case in `map_Function` or a separate mapper method to catch its frontend variant
- In each backend platform, a case in `materialize_functions` to map the function onto a concrete C/C++ implementation
- If very special typing rules apply, a case in `typification.Typifier`.
In most cases, typification of function applications will require no special handling.
TODO: Maybe add a way for the user to register additional functions
TODO: Figure out the best way to describe function signatures and overloads for typing
"""
import pymbolic.primitives as pb
from abc import ABC, abstractmethod
class PsFunction(pb.FunctionSymbol, ABC):
@property
@abstractmethod
def arg_count(self) -> int:
"Number of arguments this function takes"
...@@ -118,3 +118,10 @@ class FreezeExpressions(SympyToPymbolicMapper): ...@@ -118,3 +118,10 @@ class FreezeExpressions(SympyToPymbolicMapper):
index = summands[0] if len(summands) == 1 else pb.Sum(summands) index = summands[0] if len(summands) == 1 else pb.Sum(summands)
return PsArrayAccess(ptr, index) return PsArrayAccess(ptr, index)
def map_Function(self, func: sp.Function):
"""Map a SymPy function to a backend-supported function symbol.
SymPy functions are frozen to an instance of `nbackend.functions.PsFunction`.
"""
raise NotImplementedError()
...@@ -9,7 +9,6 @@ from .freeze import FreezeExpressions ...@@ -9,7 +9,6 @@ from .freeze import FreezeExpressions
from .typification import Typifier from .typification import Typifier
from .options import KernelCreationOptions from .options import KernelCreationOptions
from .iteration_space import ( from .iteration_space import (
IterationSpace,
create_sparse_iteration_space, create_sparse_iteration_space,
create_full_iteration_space, create_full_iteration_space,
) )
...@@ -21,11 +20,10 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti ...@@ -21,11 +20,10 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
analysis = KernelAnalysis(ctx) analysis = KernelAnalysis(ctx)
analysis(assignments) analysis(assignments)
ispace: IterationSpace = ( if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None:
create_sparse_iteration_space(ctx, assignments) ispace = create_sparse_iteration_space(ctx, assignments)
if len(ctx.fields.index_fields) > 0 or ctx.options.index_field is not None else:
else create_full_iteration_space(ctx, assignments) ispace = create_full_iteration_space(ctx, assignments)
)
ctx.set_iteration_space(ispace) ctx.set_iteration_space(ispace)
...@@ -37,22 +35,22 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti ...@@ -37,22 +35,22 @@ def create_kernel(assignments: AssignmentCollection, options: KernelCreationOpti
match options.target: match options.target:
case Target.CPU: case Target.CPU:
from .platform import BasicCpuGen from .platform import BasicCpu
# TODO: CPU platform should incorporate instruction set info, OpenMP, etc. # TODO: CPU platform should incorporate instruction set info, OpenMP, etc.
platform_generator = BasicCpuGen(ctx) platform = BasicCpu(ctx)
case _: case _:
# TODO: CUDA/HIP platform # TODO: CUDA/HIP platform
# TODO: SYCL platform (?) # TODO: SYCL platform (?)
raise NotImplementedError("Target platform not implemented") raise NotImplementedError("Target platform not implemented")
kernel_ast = platform_generator.materialize_iteration_space(kernel_body, ispace) kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
# 7. Apply optimizations # 7. Apply optimizations
# - Vectorization # - Vectorization
# - OpenMP # - OpenMP
# - Loop Splitting, Tiling, Blocking # - Loop Splitting, Tiling, Blocking
kernel_ast = platform_generator.optimize(kernel_ast) kernel_ast = platform.optimize(kernel_ast)
function = PsKernelFunction(kernel_ast, options.target, name=options.function_name) function = PsKernelFunction(kernel_ast, options.target, name=options.function_name)
function.add_constraints(*ctx.constraints) function.add_constraints(*ctx.constraints)
......
from .basic_cpu import BasicCpuGen from .basic_cpu import BasicCpu
__all__ = [ __all__ = [
'BasicCpuGen' 'BasicCpu'
] ]
from .platform import PlatformGen from .platform import Platform
from ..iteration_space import ( from ..iteration_space import (
IterationSpace, IterationSpace,
...@@ -11,7 +11,7 @@ from ...typed_expressions import PsTypedConstant ...@@ -11,7 +11,7 @@ from ...typed_expressions import PsTypedConstant
from ...arrays import PsArrayAccess from ...arrays import PsArrayAccess
class BasicCpuGen(PlatformGen): class BasicCpu(Platform):
def materialize_iteration_space( def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace self, body: PsBlock, ispace: IterationSpace
) -> PsBlock: ) -> PsBlock:
......
...@@ -6,7 +6,7 @@ from ..context import KernelCreationContext ...@@ -6,7 +6,7 @@ from ..context import KernelCreationContext
from ..iteration_space import IterationSpace from ..iteration_space import IterationSpace
class PlatformGen(ABC): class Platform(ABC):
"""Abstract base class for all supported platforms. """Abstract base class for all supported platforms.
The platform performs all target-dependent tasks during code generation: The platform performs all target-dependent tasks during code generation:
......
...@@ -115,7 +115,9 @@ class Typifier(Mapper): ...@@ -115,7 +115,9 @@ class Typifier(Mapper):
) -> tuple[PsArrayAccess, PsNumericType]: ) -> tuple[PsArrayAccess, PsNumericType]:
self._check_target_type(access, access.dtype, target_type) self._check_target_type(access, access.dtype, target_type)
index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype) index, _ = self.rec(access.index_tuple[0], self._ctx.options.index_dtype)
return PsArrayAccess(access.base_ptr, index), cast(PsNumericType, deconstify(access.dtype)) return PsArrayAccess(access.base_ptr, index), cast(
PsNumericType, deconstify(access.dtype)
)
# Arithmetic Expressions # Arithmetic Expressions
...@@ -156,6 +158,16 @@ class Typifier(Mapper): ...@@ -156,6 +158,16 @@ class Typifier(Mapper):
new_args, dtype = self._homogenize(expr, expr.children, target_type) new_args, dtype = self._homogenize(expr, expr.children, target_type)
return pb.Product(new_args), dtype return pb.Product(new_args), dtype
def map_call(
self, expr: pb.Call, target_type: PsNumericType | None
) -> tuple[pb.Call, PsNumericType]:
"""
TODO: Figure out the best way to typify functions
- How to propagate target_type in the face of multiple overloads?
"""
raise NotImplementedError()
def _check_target_type( def _check_target_type(
self, self,
expr: ExprOrConstant, expr: ExprOrConstant,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment