Skip to content
Snippets Groups Projects
Commit 0ff6bb94 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add foreign_ast and CppMethodCall. Add proper index translation for SYCL.

parent 33df1906
No related branches found
No related tags found
1 merge request!384Fundamental GPU Support
Pipeline #67000 failed
......@@ -13,6 +13,7 @@ from .config import (
CpuOptimConfig,
VectorizationConfig,
OpenMpConfig,
GpuIndexingConfig,
)
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
......@@ -45,6 +46,7 @@ __all__ = [
"CreateKernelConfig",
"CpuOptimConfig",
"VectorizationConfig",
"GpuIndexingConfig",
"OpenMpConfig",
"create_kernel",
"KernelFunction",
......
......@@ -48,6 +48,8 @@ from .ast.expressions import (
PsLe,
)
from .extensions.foreign_ast import PsForeignExpression
from .symbols import PsSymbol
from ..types import PsScalarType, PsArrayType
......@@ -334,6 +336,12 @@ class CAstPrinter:
items_str = ", ".join(self.visit(item, pc) for item in items)
pc.pop_op()
return "{ " + items_str + " }"
case PsForeignExpression(children):
pc.push_op(Ops.Weakest, LR.Middle)
foreign_code = node.get_code(self.visit(c, pc) for c in children)
pc.pop_op()
return foreign_code
case _:
raise NotImplementedError(f"Don't know how to print {node}")
......
from __future__ import annotations
from typing import Iterable, cast
from pystencils.backend.ast.astnode import PsAstNode
from ..ast.expressions import PsExpression
from .foreign_ast import PsForeignExpression
from ...types import PsType
class CppMethodCall(PsForeignExpression):
"""C++ method call on an expression."""
def __init__(
self, obj: PsExpression, method: str, return_type: PsType, args: Iterable = ()
):
self._method = method
self._return_type = return_type
children = [obj] + list(args)
super().__init__(children, return_type)
def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, CppMethodCall):
return False
return super().structurally_equal(other) and self._method == other._method
def clone(self) -> CppMethodCall:
return CppMethodCall(
cast(PsExpression, self.children[0]),
self._method,
self._return_type,
self.children[1:],
)
def get_code(self, children_code: Iterable[str]) -> str:
cs = list(children_code)
obj_code = cs[0]
args_code = cs[1:]
args = ", ".join(args_code)
return f"({obj_code}).{self._method}({args})"
from __future__ import annotations
from typing import Iterable
from abc import ABC, abstractmethod
from pystencils.backend.ast.astnode import PsAstNode
from ..ast.expressions import PsExpression
from ..ast.util import failing_cast
from ...types import PsType
class PsForeignExpression(PsExpression, ABC):
"""Base class for foreign expressions.
Foreign expressions are expressions whose properties are not modelled by the pystencils AST,
and which pystencils therefore does not understand.
Support for foreign expressions by the code generator is therefore very limited;
as a rule of thumb, only printing is supported.
Type checking and most transformations will fail when encountering a `PsForeignExpression`.
There are many situations where non-supported expressions are needed;
the most common use case is C++ syntax.
"""
__match_args__ = ("children",)
def __init__(
self, children: Iterable[PsExpression], dtype: PsType | None = None
):
self._children = list(children)
super().__init__(dtype)
@abstractmethod
def get_code(self, children_code: Iterable[str]) -> str:
"""Print this expression, with the given code for each of its children."""
pass
def get_children(self) -> tuple[PsAstNode, ...]:
return tuple(self._children)
def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = failing_cast(PsExpression, c)
def __repr__(self) -> str:
return f"{type(self)}({self._children})"
......@@ -8,8 +8,10 @@ from ..kernelcreation.iteration_space import (
from ..ast.structural import PsDeclaration
from ..ast.expressions import (
PsExpression,
PsSymbolExpr,
PsSubscript,
)
from ..extensions.cpp import CppMethodCall
from ..kernelcreation.context import KernelCreationContext
from ..constants import PsConstant
......@@ -25,9 +27,6 @@ class SyclPlatform(Platform):
super().__init__(ctx)
self._cfg = indexing_cfg if indexing_cfg is not None else GpuIndexingConfig()
if not self._cfg.sycl_automatic_block_size:
raise ValueError("The SYCL code generator supports only automatic block sizes at the moment.")
@property
def required_headers(self) -> set[str]:
return {"<sycl/sycl.hpp>"}
......@@ -68,8 +67,9 @@ class SyclPlatform(Platform):
self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock:
rank = ispace.rank
id_type = PsCustomType(f"sycl::id< {rank} >", const=True)
id_type = self._id_type(rank)
id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
id_decl = self._id_declaration(rank, id_symbol)
# Determine loop order by permuting dimensions
archetype_field = ispace.archetype_field
......@@ -80,7 +80,7 @@ class SyclPlatform(Platform):
else:
dimensions = ispace.dimensions
unpackings = []
unpackings = [id_decl]
for i, dim in enumerate(dimensions[::-1]):
coord = PsExpression.make(PsConstant(i, self._ctx.index_dtype))
work_item_idx = PsSubscript(id_symbol, coord)
......@@ -112,12 +112,22 @@ class SyclPlatform(Platform):
return body
def _id_type(self, rank: int):
def _item_type(self, rank: int):
if not self._cfg.sycl_automatic_block_size:
return PsCustomType(f"sycl::nd_item< {rank} >", const=True)
else:
return PsCustomType(f"sycl::item< {rank} >", const=True)
def _id_type(self, rank: int):
return PsCustomType(f"sycl::id< {rank} >", const=True)
def _id_declaration(self, rank: int, id: PsSymbolExpr) -> PsDeclaration:
item_type = self._item_type(rank)
item = PsExpression.make(self._ctx.get_symbol("sycl_item", item_type))
if not self._cfg.sycl_automatic_block_size:
rhs = CppMethodCall(item, "get_global_id", self._id_type(rank))
else:
rhs = CppMethodCall(item, "get_id", self._id_type(rank))
def _get_id(self):
# TODO: Need AST support for member functions to model `get_id()` / `get_global_id()`
raise NotImplementedError()
return PsDeclaration(id, rhs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment