From 0ff6bb94ca1539156021bc89375f94af060b0d9e Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Thu, 20 Jun 2024 13:22:44 +0200 Subject: [PATCH] add foreign_ast and CppMethodCall. Add proper index translation for SYCL. --- src/pystencils/__init__.py | 2 + src/pystencils/backend/emission.py | 8 ++++ src/pystencils/backend/extensions/__init__.py | 0 src/pystencils/backend/extensions/cpp.py | 41 +++++++++++++++++ .../backend/extensions/foreign_ast.py | 45 +++++++++++++++++++ src/pystencils/backend/platforms/sycl.py | 28 ++++++++---- 6 files changed, 115 insertions(+), 9 deletions(-) create mode 100644 src/pystencils/backend/extensions/__init__.py create mode 100644 src/pystencils/backend/extensions/cpp.py create mode 100644 src/pystencils/backend/extensions/foreign_ast.py diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 3d3b7846a..051e02e7f 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -13,6 +13,7 @@ from .config import ( CpuOptimConfig, VectorizationConfig, OpenMpConfig, + GpuIndexingConfig, ) from .kernel_decorator import kernel, kernel_config from .kernelcreation import create_kernel @@ -45,6 +46,7 @@ __all__ = [ "CreateKernelConfig", "CpuOptimConfig", "VectorizationConfig", + "GpuIndexingConfig", "OpenMpConfig", "create_kernel", "KernelFunction", diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index e8fc2a662..fdc81a47c 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -48,6 +48,8 @@ from .ast.expressions import ( PsLe, ) +from .extensions.foreign_ast import PsForeignExpression + from .symbols import PsSymbol from ..types import PsScalarType, PsArrayType @@ -334,6 +336,12 @@ class CAstPrinter: items_str = ", ".join(self.visit(item, pc) for item in items) pc.pop_op() return "{ " + items_str + " }" + + case PsForeignExpression(children): + pc.push_op(Ops.Weakest, LR.Middle) + foreign_code = node.get_code(self.visit(c, pc) for c in children) + pc.pop_op() + return foreign_code case _: raise NotImplementedError(f"Don't know how to print {node}") diff --git a/src/pystencils/backend/extensions/__init__.py b/src/pystencils/backend/extensions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pystencils/backend/extensions/cpp.py b/src/pystencils/backend/extensions/cpp.py new file mode 100644 index 000000000..1055b79e9 --- /dev/null +++ b/src/pystencils/backend/extensions/cpp.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from typing import Iterable, cast + +from pystencils.backend.ast.astnode import PsAstNode + +from ..ast.expressions import PsExpression +from .foreign_ast import PsForeignExpression +from ...types import PsType + + +class CppMethodCall(PsForeignExpression): + """C++ method call on an expression.""" + + def __init__( + self, obj: PsExpression, method: str, return_type: PsType, args: Iterable = () + ): + self._method = method + self._return_type = return_type + children = [obj] + list(args) + super().__init__(children, return_type) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, CppMethodCall): + return False + + return super().structurally_equal(other) and self._method == other._method + + def clone(self) -> CppMethodCall: + return CppMethodCall( + cast(PsExpression, self.children[0]), + self._method, + self._return_type, + self.children[1:], + ) + + def get_code(self, children_code: Iterable[str]) -> str: + cs = list(children_code) + obj_code = cs[0] + args_code = cs[1:] + args = ", ".join(args_code) + return f"({obj_code}).{self._method}({args})" diff --git a/src/pystencils/backend/extensions/foreign_ast.py b/src/pystencils/backend/extensions/foreign_ast.py new file mode 100644 index 000000000..06735a65a --- /dev/null +++ b/src/pystencils/backend/extensions/foreign_ast.py @@ -0,0 +1,45 @@ +from __future__ import annotations +from typing import Iterable +from abc import ABC, abstractmethod + +from pystencils.backend.ast.astnode import PsAstNode + +from ..ast.expressions import PsExpression +from ..ast.util import failing_cast +from ...types import PsType + + +class PsForeignExpression(PsExpression, ABC): + """Base class for foreign expressions. + + Foreign expressions are expressions whose properties are not modelled by the pystencils AST, + and which pystencils therefore does not understand. + Support for foreign expressions by the code generator is therefore very limited; + as a rule of thumb, only printing is supported. + Type checking and most transformations will fail when encountering a `PsForeignExpression`. + + There are many situations where non-supported expressions are needed; + the most common use case is C++ syntax. + """ + + __match_args__ = ("children",) + + def __init__( + self, children: Iterable[PsExpression], dtype: PsType | None = None + ): + self._children = list(children) + super().__init__(dtype) + + @abstractmethod + def get_code(self, children_code: Iterable[str]) -> str: + """Print this expression, with the given code for each of its children.""" + pass + + def get_children(self) -> tuple[PsAstNode, ...]: + return tuple(self._children) + + def set_child(self, idx: int, c: PsAstNode): + self._children[idx] = failing_cast(PsExpression, c) + + def __repr__(self) -> str: + return f"{type(self)}({self._children})" diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 79de61ffa..1e4cbcaa0 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -8,8 +8,10 @@ from ..kernelcreation.iteration_space import ( from ..ast.structural import PsDeclaration from ..ast.expressions import ( PsExpression, + PsSymbolExpr, PsSubscript, ) +from ..extensions.cpp import CppMethodCall from ..kernelcreation.context import KernelCreationContext from ..constants import PsConstant @@ -25,9 +27,6 @@ class SyclPlatform(Platform): super().__init__(ctx) self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig() - if not self._cfg.sycl_automatic_block_size: - raise ValueError("The SYCL code generator supports only automatic block sizes at the moment.") - @property def required_headers(self) -> set[str]: return {"<sycl/sycl.hpp>"} @@ -68,8 +67,9 @@ class SyclPlatform(Platform): self, body: PsBlock, ispace: FullIterationSpace ) -> PsBlock: rank = ispace.rank - id_type = PsCustomType(f"sycl::id< {rank} >", const=True) + id_type = self._id_type(rank) id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) + id_decl = self._id_declaration(rank, id_symbol) # Determine loop order by permuting dimensions archetype_field = ispace.archetype_field @@ -80,7 +80,7 @@ class SyclPlatform(Platform): else: dimensions = ispace.dimensions - unpackings = [] + unpackings = [id_decl] for i, dim in enumerate(dimensions[::-1]): coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype)) work_item_idx = PsSubscript(id_symbol, coord) @@ -112,12 +112,22 @@ class SyclPlatform(Platform): return body - def _id_type(self, rank: int): + def _item_type(self, rank: int): if not self._cfg.sycl_automatic_block_size: return PsCustomType(f"sycl::nd_item< {rank} >", const=True) else: return PsCustomType(f"sycl::item< {rank} >", const=True) + + def _id_type(self, rank: int): + return PsCustomType(f"sycl::id< {rank} >", const=True) + + def _id_declaration(self, rank: int, id: PsSymbolExpr) -> PsDeclaration: + item_type = self._item_type(rank) + item = PsExpression.make(self._ctx.get_symbol("sycl_item", item_type)) + + if not self._cfg.sycl_automatic_block_size: + rhs = CppMethodCall(item, "get_global_id", self._id_type(rank)) + else: + rhs = CppMethodCall(item, "get_id", self._id_type(rank)) - def _get_id(self): - # TODO: Need AST support for member functions to model `get_id()` / `get_global_id()` - raise NotImplementedError() + return PsDeclaration(id, rhs) -- GitLab