Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • fhennig/devel
  • master
  • rangersbach/c-interfacing
  • v0.1a1
  • v0.1a2
  • v0.1a3
  • v0.1a4
7 results

Target

Select target project
  • ob28imeq/pystencils-sfg
  • brendan-waters/pystencils-sfg
  • pycodegen/pystencils-sfg
3 results
Select Git revision
  • frontend-cleanup
  • lbwelding-features
  • master
  • refactor-indexing-params
  • unit_tests
  • v0.1a1
  • v0.1a2
  • v0.1a3
  • v0.1a4
9 results
Show changes
Showing
with 1976 additions and 283 deletions
from __future__ import annotations
from typing import Sequence, Iterable
import warnings
from dataclasses import dataclass
from abc import ABC, abstractmethod
import sympy as sp
from pystencils import Field
from pystencils.types import deconstify, PsType
from pystencils.codegen.properties import FieldBasePtr, FieldShape, FieldStride
from ..exceptions import SfgException
from ..config import CodeStyle
from .call_tree import SfgCallTreeNode, SfgSequence, SfgStatements
from ..lang.expressions import SfgKernelParamVar
from ..lang import (
SfgVar,
SupportsFieldExtraction,
SupportsVectorExtraction,
ExprLike,
AugExpr,
depends,
includes,
)
class PostProcessingContext:
def __init__(self) -> None:
self._live_variables: dict[str, SfgVar] = dict()
@property
def live_variables(self) -> set[SfgVar]:
return set(self._live_variables.values())
def get_live_variable(self, name: str) -> SfgVar | None:
return self._live_variables.get(name)
def _define(self, vars: Iterable[SfgVar], expr: str):
for var in vars:
if var.name in self._live_variables:
live_var = self._live_variables[var.name]
live_var_dtype = live_var.dtype
def_dtype = var.dtype
# A const definition conflicts with a non-const live variable
# A non-const definition is always OK, but then the types must be the same
if (def_dtype.const and not live_var_dtype.const) or (
deconstify(def_dtype) != deconstify(live_var_dtype)
):
warnings.warn(
f"Type conflict at variable definition: Expected type {live_var_dtype}, but got {def_dtype}.\n"
f" * At definition {expr}",
UserWarning,
)
del self._live_variables[var.name]
def _use(self, vars: Iterable[SfgVar]):
for var in vars:
if var.name in self._live_variables:
live_var = self._live_variables[var.name]
if var != live_var:
if var.dtype == live_var.dtype:
# This can only happen if the variables are SymbolLike,
# i.e. wrap a field-associated kernel parameter
# TODO: Once symbol properties are a thing, check and combine them here
warnings.warn(
"Encountered two non-identical variables with same name and data type:\n"
f" {var.name_and_type()}\n"
"and\n"
f" {live_var.name_and_type()}\n"
)
elif deconstify(var.dtype) == deconstify(live_var.dtype):
# Same type, just different constness
# One of them must be non-const -> keep the non-const one
if live_var.dtype.const and not var.dtype.const:
self._live_variables[var.name] = var
else:
raise SfgException(
"Encountered two variables with same name but different data types:\n"
f" {var.name_and_type()}\n"
"and\n"
f" {live_var.name_and_type()}"
)
else:
self._live_variables[var.name] = var
@dataclass(frozen=True)
class PostProcessingResult:
function_params: set[SfgVar]
class CallTreePostProcessing:
def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult:
live_vars = self.get_live_variables(ast)
return PostProcessingResult(live_vars)
def handle_sequence(self, seq: SfgSequence, ppc: PostProcessingContext):
def iter_nested_sequences(seq: SfgSequence):
for i in range(len(seq.children) - 1, -1, -1):
c = seq.children[i]
if isinstance(c, SfgDeferredNode):
c = c.expand(ppc)
seq[i] = c
if isinstance(c, SfgSequence):
iter_nested_sequences(c)
else:
if isinstance(c, SfgStatements):
ppc._define(c.defines, c.code_string)
ppc._use(self.get_live_variables(c))
iter_nested_sequences(seq)
def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]:
match node:
case SfgSequence():
ppc = PostProcessingContext()
self.handle_sequence(node, ppc)
return ppc.live_variables
case SfgDeferredNode():
raise SfgException("Deferred nodes can only occur inside a sequence.")
case _:
return node.depends.union(
*(self.get_live_variables(c) for c in node.children)
)
class SfgDeferredNode(SfgCallTreeNode, ABC):
"""Nodes of this type are inserted as placeholders into the kernel call tree
and need to be expanded at a later time.
Subclasses of SfgDeferredNode correspond to nodes that cannot be created yet
because information required for their construction is not yet known.
"""
@property
def children(self) -> Sequence[SfgCallTreeNode]:
raise SfgException(
"Invalid access into deferred node; deferred nodes must be expanded first."
)
@abstractmethod
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
pass
def get_code(self, cstyle: CodeStyle) -> str:
raise SfgException(
"Invalid access into deferred node; deferred nodes must be expanded first."
)
class SfgDeferredParamSetter(SfgDeferredNode):
def __init__(self, param: SfgVar | sp.Symbol, rhs: ExprLike):
self._lhs = param
self._rhs = rhs
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
live_var = ppc.get_live_variable(self._lhs.name)
if live_var is not None:
code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs};"
return SfgStatements(
code, (live_var,), depends(self._rhs), includes(self._rhs)
)
else:
return SfgSequence([])
class SfgDeferredFieldMapping(SfgDeferredNode):
"""Deferred mapping of a pystencils field to a field data structure."""
# NOTE ON Scalar Fields
#
# pystencils permits explicit (`index_shape = (1,)`) and implicit (`index_shape = ()`)
# scalar fields. In order to handle both equivalently,
# we ignore the trivial explicit scalar dimension in field extraction.
# This makes sure that explicit D-dimensional scalar fields
# can be mapped onto D-dimensional data structures, and do not require that
# D+1st dimension.
def __init__(
self,
psfield: Field,
extraction: SupportsFieldExtraction,
cast_indexing_symbols: bool = True,
):
self._field = psfield
self._extraction = extraction
self._cast_indexing_symbols = cast_indexing_symbols
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
# Find field pointer
ptr: SfgKernelParamVar | None = None
rank: int
if self._field.index_shape == (1,):
# explicit scalar field -> ignore index dimensions
rank = self._field.spatial_dimensions
else:
rank = len(self._field.shape)
shape: list[SfgKernelParamVar | str | None] = [None] * rank
strides: list[SfgKernelParamVar | str | None] = [None] * rank
for param in ppc.live_variables:
if isinstance(param, SfgKernelParamVar):
for prop in param.wrapped.properties:
match prop:
case FieldBasePtr(field) if field == self._field:
ptr = param
case FieldShape(field, coord) if field == self._field: # type: ignore
shape[coord] = param # type: ignore
case FieldStride(field, coord) if field == self._field: # type: ignore
strides[coord] = param # type: ignore
# Find constant or otherwise determined sizes
for coord, s in enumerate(self._field.shape[:rank]):
if shape[coord] is None:
shape[coord] = str(s)
# Find constant or otherwise determined strides
for coord, s in enumerate(self._field.strides[:rank]):
if strides[coord] is None:
strides[coord] = str(s)
# Now we have all the symbols, start extracting them
nodes = []
done: set[SfgKernelParamVar] = set()
if ptr is not None:
expr = self._extraction._extract_ptr()
nodes.append(
SfgStatements(
f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};",
(ptr,),
depends(expr),
includes(expr),
)
)
def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr:
if self._cast_indexing_symbols:
return AugExpr(target_type).bind(
"{}( {} )", deconstify(target_type).c_string(), expr
)
else:
return expr
def get_shape(coord, symb: SfgKernelParamVar | str):
expr = self._extraction._extract_size(coord)
if expr is None:
raise SfgException(
f"Cannot extract shape in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
expr = maybe_cast(expr, symb.dtype)
return SfgStatements(
f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};",
(symb,),
depends(expr),
includes(expr),
)
else:
return SfgStatements(f"/* {expr} == {symb} */", (), ())
def get_stride(coord, symb: SfgKernelParamVar | str):
expr = self._extraction._extract_stride(coord)
if expr is None:
raise SfgException(
f"Cannot extract stride in coordinate {coord} from {self._extraction}"
)
if isinstance(symb, SfgKernelParamVar) and symb not in done:
done.add(symb)
expr = maybe_cast(expr, symb.dtype)
return SfgStatements(
f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};",
(symb,),
depends(expr),
includes(expr),
)
else:
return SfgStatements(f"/* {expr} == {symb} */", (), ())
nodes += [get_shape(c, s) for c, s in enumerate(shape) if s is not None]
nodes += [get_stride(c, s) for c, s in enumerate(strides) if s is not None]
return SfgSequence(nodes)
class SfgDeferredVectorMapping(SfgDeferredNode):
def __init__(
self,
scalars: Sequence[sp.Symbol | SfgVar],
vector: SupportsVectorExtraction,
):
self._scalars = {sc.name: (i, sc) for i, sc in enumerate(scalars)}
self._vector = vector
def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode:
nodes = []
for param in ppc.live_variables:
if param.name in self._scalars:
idx, _ = self._scalars[param.name]
expr = self._vector._extract_component(idx)
nodes.append(
SfgStatements(
f"{param.dtype.c_string()} {param.name} {{ {expr} }};",
(param,),
depends(expr),
includes(expr),
)
)
return SfgSequence(nodes)
from __future__ import annotations
from enum import Enum, auto
from typing import (
Iterable,
TypeVar,
Generic,
)
from ..lang import HeaderFile
from .entities import (
SfgNamespace,
SfgKernelHandle,
SfgFunction,
SfgClassMember,
SfgVisibility,
SfgClass,
)
# =========================================================================================================
#
# SYNTACTICAL ELEMENTS
#
# These classes model *code elements*, which represent the actual syntax objects that populate the output
# files, their namespaces and class bodies.
#
# =========================================================================================================
SourceEntity_T = TypeVar(
"SourceEntity_T",
bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass,
covariant=True,
)
"""Source entities that may have declarations and definitions."""
class SfgEntityDecl(Generic[SourceEntity_T]):
"""Declaration of a function, class, method, or constructor"""
__match_args__ = ("entity",)
def __init__(self, entity: SourceEntity_T) -> None:
self._entity = entity
@property
def entity(self) -> SourceEntity_T:
return self._entity
class SfgEntityDef(Generic[SourceEntity_T]):
"""Definition of a function, class, method, or constructor"""
__match_args__ = ("entity",)
def __init__(self, entity: SourceEntity_T) -> None:
self._entity = entity
@property
def entity(self) -> SourceEntity_T:
return self._entity
SfgClassBodyElement = str | SfgEntityDecl[SfgClassMember] | SfgEntityDef[SfgClassMember]
"""Elements that may be placed in the visibility blocks of a class body."""
class SfgVisibilityBlock:
"""Visibility-qualified block inside a class definition body.
Visibility blocks host the code elements placed inside a class body:
method and constructor declarations,
in-class method and constructor definitions,
as well as variable declarations and definitions.
Args:
visibility: The visibility qualifier of this block
"""
__match_args__ = ("visibility", "elements")
def __init__(self, visibility: SfgVisibility) -> None:
self._vis = visibility
self._elements: list[SfgClassBodyElement] = []
self._cls: SfgClass | None = None
@property
def visibility(self) -> SfgVisibility:
return self._vis
@property
def elements(self) -> list[SfgClassBodyElement]:
return self._elements
@elements.setter
def elements(self, elems: Iterable[SfgClassBodyElement]):
self._elements = list(elems)
class SfgNamespaceBlock:
"""A C++ namespace block.
Args:
namespace: Namespace associated with this block
label: Label printed at the opening brace of this block.
This may be the namespace name, or a compressed qualified
name containing one or more of its parent namespaces.
"""
__match_args__ = (
"namespace",
"elements",
"label",
)
def __init__(self, namespace: SfgNamespace, label: str | None = None) -> None:
self._namespace = namespace
self._label = label if label is not None else namespace.name
self._elements: list[SfgNamespaceElement] = []
@property
def namespace(self) -> SfgNamespace:
return self._namespace
@property
def label(self) -> str:
return self._label
@property
def elements(self) -> list[SfgNamespaceElement]:
"""Sequence of source elements that make up the body of this namespace"""
return self._elements
@elements.setter
def elements(self, elems: Iterable[SfgNamespaceElement]):
self._elements = list(elems)
class SfgClassBody:
"""Body of a class definition."""
__match_args__ = ("associated_class", "visibility_blocks")
def __init__(
self,
cls: SfgClass,
default_block: SfgVisibilityBlock,
vis_blocks: Iterable[SfgVisibilityBlock],
) -> None:
self._cls = cls
assert default_block.visibility == SfgVisibility.DEFAULT
self._default_block = default_block
self._blocks = [self._default_block] + list(vis_blocks)
@property
def associated_class(self) -> SfgClass:
return self._cls
@property
def default(self) -> SfgVisibilityBlock:
return self._default_block
def append_visibility_block(self, block: SfgVisibilityBlock):
if block.visibility == SfgVisibility.DEFAULT:
raise ValueError(
"Can't add another block with DEFAULT visibility to this class body."
)
self._blocks.append(block)
@property
def visibility_blocks(self) -> tuple[SfgVisibilityBlock, ...]:
return tuple(self._blocks)
SfgNamespaceElement = (
str | SfgNamespaceBlock | SfgClassBody | SfgEntityDecl | SfgEntityDef
)
"""Elements that may be placed inside a namespace, including the global namespace."""
class SfgSourceFileType(Enum):
HEADER = auto()
TRANSLATION_UNIT = auto()
class SfgSourceFile:
"""A C++ source file.
Args:
name: Name of the file (without parent directories), e.g. ``Algorithms.cpp``
file_type: Type of the source file (header or translation unit)
prelude: Optionally, text of the prelude comment printed at the top of the file
"""
def __init__(
self, name: str, file_type: SfgSourceFileType, prelude: str | None = None
) -> None:
self._name: str = name
self._file_type: SfgSourceFileType = file_type
self._prelude: str | None = prelude
self._includes: list[HeaderFile] = []
self._elements: list[SfgNamespaceElement] = []
@property
def name(self) -> str:
"""Name of this source file"""
return self._name
@property
def file_type(self) -> SfgSourceFileType:
"""File type of this source file"""
return self._file_type
@property
def prelude(self) -> str | None:
"""Text of the prelude comment"""
return self._prelude
@prelude.setter
def prelude(self, text: str | None):
self._prelude = text
@property
def includes(self) -> list[HeaderFile]:
"""Sequence of header files to be included at the top of this file"""
return self._includes
@includes.setter
def includes(self, incl: Iterable[HeaderFile]):
self._includes = list(incl)
@property
def elements(self) -> list[SfgNamespaceElement]:
"""Sequence of source elements comprising the body of this file"""
return self._elements
@elements.setter
def elements(self, elems: Iterable[SfgNamespaceElement]):
self._elements = list(elems)
from .headers import HeaderFile
from .expressions import (
SfgVar,
SfgKernelParamVar,
AugExpr,
VarLike,
_VarLike,
ExprLike,
_ExprLike,
asvar,
depends,
includes,
CppClass,
cppclass,
)
from .extractions import SupportsFieldExtraction, SupportsVectorExtraction
from .types import cpptype, void, Ref, strip_ptr_ref
__all__ = [
"HeaderFile",
"SfgVar",
"SfgKernelParamVar",
"AugExpr",
"VarLike",
"_VarLike",
"ExprLike",
"_ExprLike",
"asvar",
"depends",
"includes",
"cpptype",
"CppClass",
"cppclass",
"void",
"Ref",
"strip_ptr_ref",
"SupportsFieldExtraction",
"SupportsVectorExtraction",
]
from .std_mdspan import StdMdspan, mdspan_ref
from .std_vector import StdVector, std_vector_ref
from .std_tuple import StdTuple, std_tuple_ref
from .std_tuple import StdTuple
from .std_span import StdSpan, std_span_ref
__all__ = [
"StdMdspan",
......@@ -8,5 +9,6 @@ __all__ = [
"StdVector",
"std_vector_ref",
"StdTuple",
"std_tuple_ref",
"StdSpan",
"std_span_ref",
]
from .std_span import StdSpan
from .std_mdspan import StdMdspan
from .std_vector import StdVector
from .std_tuple import StdTuple
span = StdSpan
mdspan = StdMdspan
vector = StdVector
tuple = StdTuple
from typing import cast
from sympy import Symbol
from pystencils import Field, DynamicType
from pystencils.types import (
PsType,
PsUnsignedIntegerType,
UserTypeSpec,
create_type,
)
from pystencilssfg.lang.expressions import AugExpr
from ...lang import SupportsFieldExtraction, cpptype, HeaderFile, ExprLike
class StdMdspan(AugExpr, SupportsFieldExtraction):
"""Represents an `std::mdspan` instance.
The `std::mdspan <https://en.cppreference.com/w/cpp/container/mdspan>`_
provides non-owning views into contiguous or strided n-dimensional arrays.
It has been added to the C++ STL with the C++23 standard.
As such, it is a natural data structure to target with pystencils kernels.
**Concerning Headers and Namespaces**
Since ``std::mdspan`` is not yet widely adopted
(libc++ ships it as of LLVM 18, but GCC libstdc++ does not include it yet),
you might have to manually include an implementation in your project
(you can get a reference implementation at https://github.com/kokkos/mdspan).
However, when working with a non-standard mdspan implementation,
the path to its the header and the namespace it is defined in will likely be different.
To tell pystencils-sfg which headers to include and which namespace to use for ``mdspan``,
use `StdMdspan.configure`;
for instance, adding this call before creating any ``mdspan`` objects will
set their namespace to `std::experimental`, and require ``<experimental/mdspan>`` to be imported:
>>> from pystencilssfg.lang.cpp import std
>>> std.mdspan.configure("std::experimental", "<experimental/mdspan>")
**Creation from pystencils fields**
Using `from_field`, ``mdspan`` objects can be created directly from `Field <pystencils.Field>` instances.
The `extents`_ of the ``mdspan`` type will be inferred from the field;
each fixed entry in the field's shape will become a fixed entry of the ``mdspan``'s extents.
The ``mdspan``'s `layout_policy`_ defaults to `std::layout_stride`_,
which might not be the optimal choice depending on the memory layout of your fields.
You may therefore override this by specifying the name of the desired layout policy.
To map pystencils field layout identifiers to layout policies, consult the following table:
+------------------------+--------------------------+
| pystencils Layout Name | ``mdspan`` Layout Policy |
+========================+==========================+
| ``"fzyx"`` | `std::layout_left`_ |
| ``"soa"`` | |
| ``"f"`` | |
| ``"reverse_numpy"`` | |
+------------------------+--------------------------+
| ``"c"`` | `std::layout_right`_ |
| ``"numpy"`` | |
+------------------------+--------------------------+
| ``"zyxf"`` | `std::layout_stride`_ |
| ``"aos"`` | |
+------------------------+--------------------------+
The array-of-structures (``"aos"``, ``"zyxf"``) layout has no equivalent layout policy in the C++ standard,
so it can only be mapped onto ``layout_stride``.
.. _extents: https://en.cppreference.com/w/cpp/container/mdspan/extents
.. _layout_policy: https://en.cppreference.com/w/cpp/named_req/LayoutMappingPolicy
.. _std::layout_left: https://en.cppreference.com/w/cpp/container/mdspan/layout_left
.. _std::layout_right: https://en.cppreference.com/w/cpp/container/mdspan/layout_right
.. _std::layout_stride: https://en.cppreference.com/w/cpp/container/mdspan/layout_stride
Args:
T: Element type of the mdspan
"""
dynamic_extent = "std::dynamic_extent"
_namespace = "std"
_template = cpptype("std::mdspan< {T}, {extents}, {layout_policy} >", "<mdspan>")
@classmethod
def configure(cls, namespace: str = "std", header: str | HeaderFile = "<mdspan>"):
"""Configure the namespace and header ``std::mdspan`` is defined in."""
cls._namespace = namespace
cls._template = cpptype(
f"{namespace}::mdspan< {{T}}, {{extents}}, {{layout_policy}} >", header
)
def __init__(
self,
T: UserTypeSpec,
extents: tuple[int | str, ...],
index_type: UserTypeSpec = PsUnsignedIntegerType(64),
layout_policy: str | None = None,
ref: bool = False,
const: bool = False,
):
T = create_type(T)
extents_type_str = create_type(index_type).c_string()
extents_str = f"{self._namespace}::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
if layout_policy is None:
layout_policy = f"{self._namespace}::layout_stride"
elif layout_policy in ("layout_left", "layout_right", "layout_stride"):
layout_policy = f"{self._namespace}::{layout_policy}"
dtype = self._template(
T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref
)
super().__init__(dtype)
self._element_type = T
self._extents_type = extents_str
self._layout_type = layout_policy
self._dim = len(extents)
@property
def element_type(self) -> PsType:
return self._element_type
@property
def extents_type(self) -> str:
return self._extents_type
@property
def layout_type(self) -> str:
return self._layout_type
def extent(self, r: int | ExprLike) -> AugExpr:
return AugExpr.format("{}.extent({})", self, r)
def stride(self, r: int | ExprLike) -> AugExpr:
return AugExpr.format("{}.stride({})", self, r)
def data_handle(self) -> AugExpr:
return AugExpr.format("{}.data_handle()", self)
# SupportsFieldExtraction protocol
def _extract_ptr(self) -> AugExpr:
return self.data_handle()
def _extract_size(self, coordinate: int) -> AugExpr | None:
if coordinate > self._dim:
return None
else:
return self.extent(coordinate)
def _extract_stride(self, coordinate: int) -> AugExpr | None:
if coordinate > self._dim:
return None
else:
return self.stride(coordinate)
@staticmethod
def from_field(
field: Field,
extents_type: UserTypeSpec = PsUnsignedIntegerType(64),
layout_policy: str | None = None,
ref: bool = False,
const: bool = False,
):
"""Creates a `std::mdspan` instance for a given pystencils field."""
if isinstance(field.dtype, DynamicType):
raise ValueError("Cannot map dynamically typed field to std::mdspan")
extents: list[str | int] = []
for s in field.spatial_shape:
extents.append(
StdMdspan.dynamic_extent if isinstance(s, Symbol) else cast(int, s)
)
for s in field.index_shape:
extents.append(StdMdspan.dynamic_extent if isinstance(s, Symbol) else s)
return StdMdspan(
field.dtype,
tuple(extents),
index_type=extents_type,
layout_policy=layout_policy,
ref=ref,
const=const,
).var(field.name)
def mdspan_ref(field: Field, extents_type: PsType = PsUnsignedIntegerType(64)):
from warnings import warn
warn(
"`mdspan_ref` is deprecated and will be removed in version 0.1. Use `std.mdspan.from_field` instead.",
FutureWarning,
)
return StdMdspan.from_field(field, extents_type, ref=True)
from pystencils import Field, DynamicType
from pystencils.types import UserTypeSpec, create_type, PsType
from ...lang import SupportsFieldExtraction, AugExpr, cpptype
class StdSpan(AugExpr, SupportsFieldExtraction):
_template = cpptype("std::span< {T} >", "<span>")
def __init__(self, T: UserTypeSpec, ref=False, const=False):
T = create_type(T)
dtype = self._template(T=T, const=const, ref=ref)
self._element_type = T
super().__init__(dtype)
@property
def element_type(self) -> PsType:
return self._element_type
def _extract_ptr(self) -> AugExpr:
return AugExpr.format("{}.data()", self)
def _extract_size(self, coordinate: int) -> AugExpr | None:
if coordinate > 0:
return None
else:
return AugExpr.format("{}.size()", self)
def _extract_stride(self, coordinate: int) -> AugExpr | None:
if coordinate > 0:
return None
else:
return AugExpr.format("1")
@staticmethod
def from_field(field: Field, ref: bool = False, const: bool = False):
if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)):
raise ValueError(
"Only one-dimensional fields with trivial index dimensions can be mapped onto `std::span`"
)
if isinstance(field.dtype, DynamicType):
raise ValueError("Cannot map dynamically typed field to std::span")
return StdSpan(field.dtype, ref=ref, const=const).var(field.name)
def std_span_ref(field: Field):
from warnings import warn
warn(
"`std_span_ref` is deprecated and will be removed in version 0.1. Use `std.span.from_field` instead.",
FutureWarning,
)
return StdSpan.from_field(field, ref=True)
from pystencils.types import UserTypeSpec, create_type
from ...lang import SupportsVectorExtraction, AugExpr, cpptype
class StdTuple(AugExpr, SupportsVectorExtraction):
_template = cpptype("std::tuple< {ts} >", "<tuple>")
def __init__(
self,
*element_types: UserTypeSpec,
const: bool = False,
ref: bool = False,
):
self._element_types = tuple(create_type(t) for t in element_types)
self._length = len(element_types)
elt_type_strings = tuple(t.c_string() for t in self._element_types)
dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref)
super().__init__(dtype)
def get(self, idx: int | str) -> AugExpr:
return AugExpr.format("std::get< {} >({})", idx, self)
def _extract_component(self, coordinate: int) -> AugExpr:
if coordinate < 0 or coordinate >= self._length:
raise ValueError(
f"Index {coordinate} out-of-bounds for std::tuple with {self._length} entries."
)
return self.get(coordinate)
from pystencils import Field, DynamicType
from pystencils.types import UserTypeSpec, create_type, PsType
from ...lang import SupportsFieldExtraction, SupportsVectorExtraction, AugExpr, cpptype
class StdVector(AugExpr, SupportsFieldExtraction, SupportsVectorExtraction):
_template = cpptype("std::vector< {T} >", "<vector>")
def __init__(
self,
T: UserTypeSpec,
unsafe: bool = False,
ref: bool = False,
const: bool = False,
):
T = create_type(T)
dtype = self._template(T=T, const=const, ref=ref)
super().__init__(dtype)
self._element_type = T
self._unsafe = unsafe
@property
def element_type(self) -> PsType:
return self._element_type
def _extract_ptr(self) -> AugExpr:
return AugExpr.format("{}.data()", self)
def _extract_size(self, coordinate: int) -> AugExpr | None:
if coordinate > 0:
return None
else:
return AugExpr.format("{}.size()", self)
def _extract_stride(self, coordinate: int) -> AugExpr | None:
if coordinate > 0:
return None
else:
return AugExpr.format("1")
def _extract_component(self, coordinate: int) -> AugExpr:
if self._unsafe:
return AugExpr.format("{}[{}]", self, coordinate)
else:
return AugExpr.format("{}.at({})", self, coordinate)
@staticmethod
def from_field(field: Field, ref: bool = True, const: bool = False):
if field.spatial_dimensions > 1 or field.index_shape not in ((), (1,)):
raise ValueError(
f"Cannot create std::vector from more-than-one-dimensional field {field}."
)
if isinstance(field.dtype, DynamicType):
raise ValueError("Cannot map dynamically typed field to std::vector")
return StdVector(field.dtype, unsafe=False, ref=ref, const=const).var(
field.name
)
def std_vector_ref(field: Field):
from warnings import warn
warn(
"`std_vector_ref` is deprecated and will be removed in version 0.1. Use `std.vector.from_field` instead.",
FutureWarning,
)
return StdVector.from_field(field, ref=True)
from .sycl_accessor import SyclAccessor
accessor = SyclAccessor
from pystencils import Field, DynamicType
from pystencils.types import UserTypeSpec, create_type
from ...lang import AugExpr, cpptype, SupportsFieldExtraction
class SyclAccessor(AugExpr, SupportsFieldExtraction):
"""Represent a
`SYCL Accessor <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#subsec:accessors>`_.
.. note::
Sycl Accessor do not expose information about strides, so the linearization is done under
the assumption that the underlying memory is contiguous, as descibed
`here <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_multi_dimensional_objects_and_linearization>`_
""" # noqa: E501
_template = cpptype("sycl::accessor< {T}, {dims} >", "<sycl/sycl.hpp>")
def __init__(
self,
T: UserTypeSpec,
dimensions: int,
ref: bool = False,
const: bool = False,
):
T = create_type(T)
if dimensions > 3:
raise ValueError("sycl accessors can only have dims 1, 2 or 3")
dtype = self._template(T=T, dims=dimensions, const=const, ref=ref)
super().__init__(dtype)
self._dim = dimensions
self._inner_stride = 1
def _extract_ptr(self) -> AugExpr:
return AugExpr.format(
"{}.get_multi_ptr<sycl::access::decorated::no>().get()",
self,
)
def _extract_size(self, coordinate: int) -> AugExpr | None:
if coordinate > self._dim:
return None
else:
return AugExpr.format("{}.get_range().get({})", self, coordinate)
def _extract_stride(self, coordinate: int) -> AugExpr | None:
if coordinate > self._dim:
return None
elif coordinate == self._dim - 1:
return AugExpr.format("{}", self._inner_stride)
else:
exprs = []
args = []
for d in range(coordinate + 1, self._dim):
args.extend([self, d])
exprs.append("{}.get_range().get({})")
expr = " * ".join(exprs)
expr += " * {}"
return AugExpr.format(expr, *args, self._inner_stride)
@staticmethod
def from_field(field: Field, ref: bool = True):
"""Creates a `sycl::accessor &` for a given pystencils field."""
if isinstance(field.dtype, DynamicType):
raise ValueError("Cannot map dynamically typed field to sycl::accessor")
return SyclAccessor(
field.dtype,
field.spatial_dimensions + field.index_dimensions,
ref=ref,
).var(field.name)
from __future__ import annotations
from typing import Iterable, TypeAlias, Any, cast
from itertools import chain
import sympy as sp
from pystencils import TypedSymbol
from pystencils.codegen import Parameter
from pystencils.types import PsType, PsIntegerType, UserTypeSpec, create_type
from ..exceptions import SfgException
from .headers import HeaderFile
from .types import strip_ptr_ref, CppType, CppTypeFactory, cpptype
class SfgVar:
"""C++ Variable.
Args:
name: Name of the variable. Must be a valid C++ identifer.
dtype: Data type of the variable.
"""
__match_args__ = ("name", "dtype")
def __init__(
self,
name: str,
dtype: UserTypeSpec,
):
self._name = name
self._dtype = create_type(dtype)
@property
def name(self) -> str:
return self._name
@property
def dtype(self) -> PsType:
return self._dtype
def _args(self) -> tuple[Any, ...]:
return (self._name, self._dtype)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SfgVar):
return False
return self._args() == other._args()
def __hash__(self) -> int:
return hash(self._args())
def name_and_type(self) -> str:
return f"{self._name}: {self._dtype}"
def __str__(self) -> str:
return self._name
def __repr__(self) -> str:
return self.name_and_type()
class SfgKernelParamVar(SfgVar):
__match_args__ = ("wrapped",)
"""Cast pystencils- or SymPy-native symbol-like objects as a `SfgVar`."""
def __init__(self, param: Parameter):
self._param = param
super().__init__(param.name, param.dtype)
@property
def wrapped(self) -> Parameter:
return self._param
def _args(self):
return (self._param,)
class DependentExpression:
"""Wrapper around a C++ expression code string,
annotated with a set of variables and a set of header files this expression depends on.
Args:
expr: C++ Code string of the expression
depends: Iterable of variables and/or `AugExpr` from which variable and header dependencies are collected
includes: Iterable of header files which this expression additionally depends on
"""
__match_args__ = ("expr", "depends")
def __init__(
self,
expr: str,
depends: Iterable[SfgVar | AugExpr],
includes: Iterable[HeaderFile] | None = None,
):
self._expr: str = expr
deps: set[SfgVar] = set()
incls: set[HeaderFile] = set(includes) if includes is not None else set()
for obj in depends:
if isinstance(obj, AugExpr):
deps |= obj.depends
incls |= obj.includes
else:
deps.add(obj)
self._depends = frozenset(deps)
self._includes = frozenset(incls)
@property
def expr(self) -> str:
return self._expr
@property
def depends(self) -> frozenset[SfgVar]:
return self._depends
@property
def includes(self) -> frozenset[HeaderFile]:
return self._includes
def __hash_contents__(self):
return (self._expr, self._depends, self._includes)
def __eq__(self, other: object):
if not isinstance(other, DependentExpression):
return False
return self.__hash_contents__() == other.__hash_contents__()
def __hash__(self):
return hash(self.__hash_contents__())
def __str__(self) -> str:
return self.expr
def __add__(self, other: DependentExpression):
return DependentExpression(
self.expr + other.expr,
self.depends | other.depends,
self._includes | other._includes,
)
class VarExpr(DependentExpression):
def __init__(self, var: SfgVar):
self._var = var
base_type = strip_ptr_ref(var.dtype)
incls: Iterable[HeaderFile]
match base_type:
case CppType():
incls = base_type.class_includes
case _:
incls = (
HeaderFile.parse(header) for header in var.dtype.required_headers
)
super().__init__(var.name, (var,), incls)
@property
def variable(self) -> SfgVar:
return self._var
class AugExpr:
"""C++ expression augmented with variable dependencies and a type-dependent interface.
`AugExpr` is the primary class for modelling C++ expressions in *pystencils-sfg*.
It stores both an expression's code string,
the set of variables (`SfgVar`) the expression depends on,
as well as any headers that must be included for the expression to be evaluated.
This dependency information is used by the composer and postprocessing system
to infer function parameter lists and automatic header inclusions.
**Construction and Binding**
Constructing an `AugExpr` is a two-step process comprising *construction* and *binding*.
An `AugExpr` can be constructed with our without an associated data type.
After construction, the `AugExpr` object is still *unbound*;
it does not yet hold any syntax.
Syntax binding can happen in two ways:
- Calling `var <AugExpr.var>` on an unbound `AugExpr` turns it into a *variable* with the given name.
This variable expression takes its set of required header files from the
`required_headers <pystencils.types.PsType.required_headers>` field of the data type of the `AugExpr`.
- Using `bind <AugExpr.bind>`, an unbound `AugExpr` can be bound to an arbitrary string
of code. The `bind` method mirrors the interface of `str.format` to combine sub-expressions
and collect their dependencies.
The `format <AugExpr.format>` static method is a wrapper around `bind` for expressions
without a type.
An `AugExpr` can be bound only once.
**C++ API Mirroring**
Subclasses of `AugExpr` can mimic C++ APIs by defining factory methods that
build expressions for C++ method calls, etc., from a list of argument expressions.
Args:
dtype: Optional, data type of this expression interface
"""
__match_args__ = ("expr", "dtype")
def __init__(self, dtype: UserTypeSpec | None = None):
self._dtype = create_type(dtype) if dtype is not None else None
self._bound: DependentExpression | None = None
self._is_variable = False
def var(self, name: str):
"""Bind an unbound `AugExpr` instance as a new variable of given name."""
v = SfgVar(name, self.get_dtype())
expr = VarExpr(v)
return self._bind(expr)
@staticmethod
def make(
code: str,
depends: Iterable[SfgVar | AugExpr],
dtype: UserTypeSpec | None = None,
):
return AugExpr(dtype)._bind(DependentExpression(code, depends))
@staticmethod
def format(fmt: str, *deps, **kwdeps) -> AugExpr:
"""Create a new `AugExpr` by combining existing expressions."""
return AugExpr().bind(fmt, *deps, **kwdeps)
def bind(
self,
fmt: str | AugExpr,
*deps,
require_headers: Iterable[str | HeaderFile] = (),
**kwdeps,
):
"""Bind an unbound `AugExpr` instance to an expression."""
if isinstance(fmt, AugExpr):
if bool(deps) or bool(kwdeps):
raise ValueError(
"Binding to another AugExpr does not permit additional arguments"
)
if fmt._bound is None:
raise ValueError("Cannot rebind to unbound AugExpr.")
self._bind(fmt._bound)
else:
dependencies: set[SfgVar] = set()
incls: set[HeaderFile] = set(HeaderFile.parse(h) for h in require_headers)
from pystencils.sympyextensions import is_constant
for expr in chain(deps, kwdeps.values()):
if isinstance(expr, _ExprLike):
dependencies |= depends(expr)
incls |= includes(expr)
elif isinstance(expr, sp.Expr) and not is_constant(expr):
raise ValueError(
f"Cannot parse SymPy expression as C++ expression: {expr}\n"
" * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
"since they contain symbols without type information."
)
code = fmt.format(*deps, **kwdeps)
self._bind(DependentExpression(code, dependencies, incls))
return self
@property
def code(self) -> str:
if self._bound is None:
raise SfgException("No syntax bound to this AugExpr.")
return str(self._bound)
@property
def depends(self) -> frozenset[SfgVar]:
if self._bound is None:
raise SfgException("No syntax bound to this AugExpr.")
return self._bound.depends
@property
def includes(self) -> frozenset[HeaderFile]:
if self._bound is None:
raise SfgException("No syntax bound to this AugExpr.")
return self._bound.includes
@property
def dtype(self) -> PsType | None:
return self._dtype
def get_dtype(self) -> PsType:
if self._dtype is None:
raise SfgException("This AugExpr has no known data type.")
return self._dtype
@property
def is_variable(self) -> bool:
return isinstance(self._bound, VarExpr)
def as_variable(self) -> SfgVar:
if not isinstance(self._bound, VarExpr):
raise SfgException("This expression is not a variable")
return self._bound.variable
def __str__(self) -> str:
if self._bound is None:
return "/* [ERROR] unbound AugExpr */"
else:
return str(self._bound)
def __repr__(self) -> str:
return str(self)
def _bind(self, expr: DependentExpression):
if self._bound is not None:
raise SfgException("Attempting to bind an already-bound AugExpr.")
self._bound = expr
return self
def is_bound(self) -> bool:
return self._bound is not None
class CppClass(AugExpr):
"""Convenience base class for C++ API mirroring.
Example:
To reflect a C++ class (template) in pystencils-sfg, you may create a subclass
of `CppClass` like this:
>>> class MyClassTemplate(CppClass):
... template = lang.cpptype("mynamespace::MyClassTemplate< {T} >", "MyHeader.hpp")
Then use `AugExpr` initialization and binding to create variables or expressions with
this class:
>>> var = MyClassTemplate(T="float").var("myObj")
>>> var
myObj
>>> str(var.dtype).strip()
'mynamespace::MyClassTemplate< float >'
"""
template: CppTypeFactory
def __init__(self, *args, const: bool = False, ref: bool = False, **kwargs):
dtype = self.template(*args, **kwargs, const=const, ref=ref)
super().__init__(dtype)
def ctor_bind(self, *args):
fstr = self.get_dtype().c_string() + "{{" + ", ".join(["{}"] * len(args)) + "}}"
dtype = cast(CppType, self.get_dtype())
return self.bind(fstr, *args, require_headers=dtype.includes)
def cppclass(
template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
):
"""
Convience class decorator for CppClass.
It adds to the decorated class the variable ``template`` via `cpptype`
and sets `CppClass` as a base clase.
>>> @cppclass("MyClass", "MyClass.hpp")
... class MyClass:
... pass
"""
def wrapper(cls):
new_cls = type(cls.__name__, (cls, CppClass), {})
new_cls.template = cpptype(template_str, include)
return new_cls
return wrapper
_VarLike = (AugExpr, SfgVar, TypedSymbol)
VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol
"""Things that may act as a variable.
Variable-like objects are entities from pystencils and pystencils-sfg that define
a variable name and data type.
Any `VarLike` object can be transformed into a canonical representation (i.e. `SfgVar`)
using `asvar`.
"""
_ExprLike = (str, AugExpr, SfgVar, TypedSymbol)
ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol
"""Things that may act as a C++ expression.
This type combines all objects that *pystencils-sfg* can handle in the place of C++
expressions. These include all valid variable types (`VarLike`), plain strings, and
complex expressions with variable dependency information (`AugExpr`).
The set of variables an expression depends on can be determined using `depends`.
"""
def asvar(var: VarLike) -> SfgVar:
"""Cast a variable-like object to its canonical representation,
Args:
var: Variable-like object
Returns:
SfgVar: Variable cast as `SfgVar`.
Raises:
ValueError: If given a non-variable `AugExpr`,
a `TypedSymbol <pystencils.TypedSymbol>`
with a `DynamicType <pystencils.sympyextensions.typed_sympy.DynamicType>`,
or any non-variable-like object.
"""
match var:
case SfgVar():
return var
case AugExpr():
return var.as_variable()
case TypedSymbol():
from pystencils import DynamicType
if isinstance(var.dtype, DynamicType):
raise ValueError(
f"Unable to cast dynamically typed symbol {var} to a variable.\n"
f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel."
)
return SfgVar(var.name, var.dtype)
case _:
raise ValueError(f"Invalid variable: {var}")
def depends(expr: ExprLike) -> set[SfgVar]:
"""Determine the set of variables an expression depends on.
Args:
expr: Expression-like object to examine
Returns:
set[SfgVar]: Set of variables the expression depends on
Raises:
ValueError: If the argument was not a valid expression
"""
match expr:
case None | str():
return set()
case SfgVar():
return {expr}
case TypedSymbol():
return {asvar(expr)}
case AugExpr():
return set(expr.depends)
case _:
raise ValueError(f"Invalid expression: {expr}")
def includes(obj: ExprLike | PsType) -> set[HeaderFile]:
"""Determine the set of header files an expression depends on.
Args:
expr: Expression-like object to examine
Returns:
set[HeaderFile]: Set of headers the expression depends on
Raises:
ValueError: If the argument was not a valid variable or expression
"""
if isinstance(obj, PsType):
obj = strip_ptr_ref(obj)
match obj:
case CppType():
return set(obj.includes)
case PsType():
headers = set(HeaderFile.parse(h) for h in obj.required_headers)
if isinstance(obj, PsIntegerType):
headers.add(HeaderFile.parse("<cstdint>"))
return headers
case SfgVar(_, dtype):
return includes(dtype)
case TypedSymbol():
return includes(asvar(obj))
case str():
return set()
case AugExpr():
return set(obj.includes)
case _:
raise ValueError(f"Invalid expression: {obj}")
from __future__ import annotations
from typing import Protocol, runtime_checkable
from abc import abstractmethod
from .expressions import AugExpr
@runtime_checkable
class SupportsFieldExtraction(Protocol):
"""Protocol for field pointer and indexing extraction.
Objects adhering to this protocol are understood to provide expressions
for the base pointer, shape, and stride properties of a field.
They can therefore be passed to `sfg.map_field <SfgBasicComposer.map_field>`.
"""
# how-to-guide begin
@abstractmethod
def _extract_ptr(self) -> AugExpr:
"""Extract the field base pointer.
Return an expression which represents the base pointer
of this field data structure.
:meta public:
"""
@abstractmethod
def _extract_size(self, coordinate: int) -> AugExpr | None:
"""Extract field size in a given coordinate.
If ``coordinate`` is valid for this field (i.e. smaller than its dimensionality),
return an expression representing the logical size of this field
in the given dimension.
Otherwise, return `None`.
:meta public:
"""
@abstractmethod
def _extract_stride(self, coordinate: int) -> AugExpr | None:
"""Extract field stride in a given coordinate.
If ``coordinate`` is valid for this field (i.e. smaller than its dimensionality),
return an expression representing the memory linearization stride of this field
in the given dimension.
Otherwise, return `None`.
:meta public:
"""
# how-to-guide end
@runtime_checkable
class SupportsVectorExtraction(Protocol):
"""Protocol for component extraction from a vector.
Objects adhering to this protocol are understood to provide
access to the entries of a vector
and can therefore be passed to `sfg.map_vector <SfgBasicComposer.map_vector>`.
"""
@abstractmethod
def _extract_component(self, coordinate: int) -> AugExpr: ...
from __future__ import annotations
from typing import Protocol
from .expressions import CppClass, cpptype, AugExpr
class Dim3Interface(CppClass):
"""Interface definition for the ``dim3`` struct of Cuda and HIP."""
def ctor(self, dim0=1, dim1=1, dim2=1):
"""Constructor invocation of ``dim3``"""
return self.ctor_bind(dim0, dim1, dim2)
@property
def x(self) -> AugExpr:
"""The `x` coordinate member."""
return AugExpr.format("{}.x", self)
@property
def y(self) -> AugExpr:
"""The `y` coordinate member."""
return AugExpr.format("{}.y", self)
@property
def z(self) -> AugExpr:
"""The `z` coordinate member."""
return AugExpr.format("{}.z", self)
@property
def dims(self) -> tuple[AugExpr, AugExpr, AugExpr]:
"""`x`, `y`, and `z` as a tuple."""
return (self.x, self.y, self.z)
class ProvidesGpuRuntimeAPI(Protocol):
"""Protocol definition for a GPU runtime API provider."""
dim3: type[Dim3Interface]
"""The ``dim3`` struct type for this GPU runtime"""
stream_t: type[AugExpr]
"""The ``stream_t`` type for this GPU runtime"""
class CudaAPI(ProvidesGpuRuntimeAPI):
"""Reflection of the CUDA runtime API"""
class dim3(Dim3Interface):
"""Implements `Dim3Interface` for CUDA"""
template = cpptype("dim3", "<cuda_runtime.h>")
class stream_t(CppClass):
template = cpptype("cudaStream_t", "<cuda_runtime.h>")
cuda = CudaAPI
"""Alias for `CudaAPI`"""
class HipAPI(ProvidesGpuRuntimeAPI):
"""Reflection of the HIP runtime API"""
class dim3(Dim3Interface):
"""Implements `Dim3Interface` for HIP"""
template = cpptype("dim3", "<hip/hip_runtime.h>")
class stream_t(CppClass):
template = cpptype("hipStream_t", "<hip/hip_runtime.h>")
hip = HipAPI
"""Alias for `HipAPI`"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class HeaderFile:
"""Represents a C++ header file."""
filepath: str
"""(Relative) path of this header file"""
system_header: bool = False
"""Whether or not this is a system header."""
def __str__(self) -> str:
if self.system_header:
return f"<{self.filepath}>"
else:
return self.filepath
@staticmethod
def parse(header: str | HeaderFile):
if isinstance(header, HeaderFile):
return header
system_header = False
if header.startswith('"') and header.endswith('"'):
header = header[1:-1]
if header.startswith("<") and header.endswith(">"):
header = header[1:-1]
system_header = True
return HeaderFile(header, system_header=system_header)
from __future__ import annotations
from typing import Any, Iterable, Sequence, Mapping, TypeVar, Generic
from abc import ABC
from dataclasses import dataclass
from itertools import chain
import string
from pystencils.types import PsType, PsPointerType, PsCustomType
from .headers import HeaderFile
class VoidType(PsType):
"""C++ void type."""
def __init__(self, const: bool = False):
super().__init__(False)
def __args__(self) -> tuple[Any, ...]:
return ()
def c_string(self) -> str:
return "void"
def __repr__(self) -> str:
return "VoidType()"
void = VoidType()
class _TemplateArgFormatter(string.Formatter):
def format_field(self, arg, format_spec):
if isinstance(arg, PsType):
arg = arg.c_string()
return super().format_field(arg, format_spec)
def check_unused_args(
self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any]
) -> None:
max_args_len: int = (
max((k for k in used_args if isinstance(k, int)), default=-1) + 1
)
if len(args) > max_args_len:
raise ValueError(
f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}"
)
extra_keys = set(kwargs.keys()) - used_args # type: ignore
if extra_keys:
raise ValueError(f"Extraneous keyword arguments: {extra_keys}")
@dataclass(frozen=True)
class _TemplateArgs:
pargs: tuple[Any, ...]
kwargs: tuple[tuple[str, Any], ...]
class CppType(PsCustomType, ABC):
class_includes: frozenset[HeaderFile]
template_string: str
def __init__(self, *template_args, const: bool = False, **template_kwargs):
# Support for cloning CppTypes
if template_args and isinstance(template_args[0], _TemplateArgs):
assert not template_kwargs
targs = template_args[0]
pargs = targs.pargs
kwargs = dict(targs.kwargs)
else:
pargs = template_args
kwargs = template_kwargs
targs = _TemplateArgs(
pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0]))
)
formatter = _TemplateArgFormatter()
name = formatter.format(self.template_string, *pargs, **kwargs)
self._targs = targs
self._includes = self.class_includes
for arg in chain(pargs, kwargs.values()):
match arg:
case CppType():
self._includes |= arg.includes
case PsType():
self._includes |= {
HeaderFile.parse(h) for h in arg.required_headers
}
super().__init__(name, const=const)
def __args__(self) -> tuple[Any, ...]:
return (self._targs,)
@property
def includes(self) -> frozenset[HeaderFile]:
return self._includes
@property
def required_headers(self) -> set[str]:
return set(str(h) for h in self.class_includes)
TypeClass_T = TypeVar("TypeClass_T", bound=CppType)
"""Python type variable bound to `CppType`."""
class CppTypeFactory(Generic[TypeClass_T]):
"""Type Factory returned by `cpptype`."""
def __init__(self, tclass: type[TypeClass_T]) -> None:
self._type_class = tclass
@property
def includes(self) -> frozenset[HeaderFile]:
"""Set of headers required by this factory's type"""
return self._type_class.class_includes
@property
def template_string(self) -> str:
"""Template string of this factory's type"""
return self._type_class.template_string
def __str__(self) -> str:
return f"Factory for {self.template_string}` defined in {self.includes}"
def __repr__(self) -> str:
return f"CppTypeFactory({self.template_string}, includes={{ {', '.join(str(i) for i in self.includes)} }})"
def __call__(self, *args, ref: bool = False, **kwargs) -> TypeClass_T | Ref:
"""Create a type object of this factory's C++ type template.
Args:
args, kwargs: Positional and keyword arguments are forwarded to the template string formatter
ref: If ``True``, return a reference type
Returns:
An instantiated type object
"""
obj = self._type_class(*args, **kwargs)
if ref:
return Ref(obj)
else:
return obj
def cpptype(
template_str: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
) -> CppTypeFactory:
"""Describe a C++ type template, associated with a set of required header files.
This function allows users to define C++ type templates using
`Python format string syntax <https://docs.python.org/3/library/string.html#formatstrings>`_.
The types may furthermore be annotated with a set of header files that must be included
in order to use the type.
>>> opt_template = lang.cpptype("std::optional< {T} >", "<optional>")
>>> opt_template.template_string
'std::optional< {T} >'
This function returns a `CppTypeFactory` object, which in turn can be called to create
an instance of the C++ type template.
Therein, the ``template_str`` argument is treated as a Python format string:
The positional and keyword arguments passed to the returned type factory are passed
through machinery that is based on `str.format` to produce the actual type name.
>>> int_option = opt_template(T="int")
>>> int_option.c_string().strip()
'std::optional< int >'
The factory may also create reference types when the ``ref=True`` is specified.
>>> int_option_ref = opt_template(T="int", ref=True)
>>> int_option_ref.c_string().strip()
'std::optional< int >&'
Args:
template_str: Format string defining the type template
include: Either the name of a header file, or a sequence of names of header files
Returns:
CppTypeFactory: A factory used to instantiate the type template
"""
headers: list[str | HeaderFile]
if isinstance(include, (str, HeaderFile)):
headers = [
include,
]
else:
headers = list(include)
class TypeClass(CppType):
template_string = template_str
class_includes = frozenset(HeaderFile.parse(h) for h in headers)
return CppTypeFactory[TypeClass](TypeClass)
class Ref(PsType):
"""C++ reference type."""
__match_args__ = "base_type"
def __init__(self, base_type: PsType, const: bool = False):
super().__init__(False)
self._base_type = base_type
def __args__(self) -> tuple[Any, ...]:
return (self.base_type,)
@property
def base_type(self) -> PsType:
return self._base_type
def c_string(self) -> str:
base_str = self.base_type.c_string()
return base_str + "&"
def __repr__(self) -> str:
return f"Ref({repr(self.base_type)})"
def strip_ptr_ref(dtype: PsType):
match dtype:
case Ref():
return strip_ptr_ref(dtype.base_type)
case PsPointerType():
return strip_ptr_ref(dtype.base_type)
case _:
return dtype
from .source_objects import SrcObject, SrcField, SrcVector, TypedSymbolOrObject
__all__ = [
"SrcObject", "SrcField", "SrcVector", "TypedSymbolOrObject"
]
from typing import Union, cast
import numpy as np
from pystencils import Field
from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
from ...tree import SfgStatements
from ..source_objects import SrcField
from ...source_components import SfgHeaderInclude
from ...types import PsType, cpp_typename, SrcType
from ...exceptions import SfgException
class StdMdspan(SrcField):
dynamic_extent = "std::dynamic_extent"
def __init__(self, identifer: str,
T: PsType,
extents: tuple[int | str, ...],
extents_type: PsType = int,
reference: bool = False):
cpp_typestr = cpp_typename(T)
extents_type_str = cpp_typename(extents_type)
extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >"
typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}"
super().__init__(identifer, SrcType(typestring))
self._extents = extents
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("experimental/mdspan", system_header=True)}
def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
return SfgStatements(
f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data_handle();",
(ptr_symbol, ),
(self, )
)
def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements:
dim = len(self._extents)
if coordinate >= dim:
if isinstance(size, FieldShapeSymbol):
raise SfgException(f"Cannot extract size in coordinate {coordinate} from a {dim}-dimensional mdspan!")
elif size != 1:
raise SfgException(
f"Cannot map field with size {size} in coordinate {coordinate} to {dim}-dimensional mdspan!")
else:
# trivial trailing index dimensions are OK -> do nothing
return SfgStatements(f"// {self._identifier}.extents().extent({coordinate}) == 1", (), ())
if isinstance(size, FieldShapeSymbol):
return SfgStatements(
f"{size.dtype} {size.name} = {self._identifier}.extents().extent({coordinate});",
(size, ),
(self, )
)
else:
return SfgStatements(
f"assert( {self._identifier}.extents().extent({coordinate}) == {size} );",
(), (self, )
)
def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements:
if coordinate >= len(self._extents):
raise SfgException(
f"Cannot extract stride in coordinate {coordinate} from a {len(self._extents)}-dimensional mdspan")
if isinstance(stride, FieldStrideSymbol):
return SfgStatements(
f"{stride.dtype} {stride.name} = {self._identifier}.stride({coordinate});",
(stride, ),
(self, )
)
else:
return SfgStatements(
f"assert( {self._identifier}.stride({coordinate}) == {stride} );",
(), (self, )
)
def mdspan_ref(field: Field, extents_type: type = np.uint32):
"""Creates a `std::mdspan &` for a given pystencils field."""
from pystencils.field import layout_string_to_tuple
if field.layout != layout_string_to_tuple("soa", field.spatial_dimensions):
raise NotImplementedError("mdspan mapping is currently only available for structure-of-arrays fields")
extents: list[str | int] = []
for s in field.spatial_shape:
extents.append(StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else cast(int, s))
if field.index_shape != (1,):
for s in field.index_shape:
extents += StdMdspan.dynamic_extent if isinstance(s, FieldShapeSymbol) else s
return StdMdspan(field.name, field.dtype,
tuple(extents),
extents_type=extents_type,
reference=True)
from typing import Sequence
from pystencils.typing import BasicType, TypedSymbol
from ...tree import SfgStatements
from ..source_objects import SrcVector
from ..source_objects import TypedSymbolOrObject
from ...types import SrcType, cpp_typename
from ...source_components import SfgHeaderInclude
class StdTuple(SrcVector):
def __init__(
self,
identifier: str,
element_types: Sequence[BasicType],
const: bool = False,
ref: bool = False,
):
self._element_types = element_types
self._length = len(element_types)
elt_type_strings = tuple(cpp_typename(t) for t in self._element_types)
src_type = f"{'const' if const else ''} std::tuple< {', '.join(elt_type_strings)} > {'&' if ref else ''}"
super().__init__(identifier, SrcType(src_type))
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return {SfgHeaderInclude("tuple", system_header=True)}
def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
if coordinate < 0 or coordinate >= self._length:
raise ValueError(
f"Index {coordinate} out-of-bounds for std::tuple with {self._length} entries."
)
if destination.dtype != self._element_types[coordinate]:
raise ValueError(
f"Cannot extract type {destination.dtype} from std::tuple entry "
"of type {self._element_types[coordinate]}"
)
return SfgStatements(
f"{destination.dtype} {destination.name} = std::get< {coordinate} >({self.identifier});",
(destination,),
(self,),
)
def std_tuple_ref(
identifier: str, components: Sequence[TypedSymbol], const: bool = True
):
elt_types = tuple(c.dtype for c in components)
return StdTuple(identifier, elt_types, const=const, ref=True)
from typing import Union
from pystencils.field import Field, FieldType
from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
from ...tree import SfgStatements
from ..source_objects import SrcField, SrcVector
from ..source_objects import TypedSymbolOrObject
from ...types import SrcType, PsType, cpp_typename
from ...source_components import SfgHeaderInclude, SfgClass
from ...exceptions import SfgException
class StdVector(SrcVector, SrcField):
def __init__(
self,
identifer: str,
T: Union[SrcType, PsType],
unsafe: bool = False,
reference: bool = True,
):
typestring = f"std::vector< {cpp_typename(T)} > {'&' if reference else ''}"
super(StdVector, self).__init__(identifer, SrcType(typestring))
self._element_type = T
self._unsafe = unsafe
@property
def required_includes(self) -> set[SfgHeaderInclude]:
return {
SfgHeaderInclude("cassert", system_header=True),
SfgHeaderInclude("vector", system_header=True),
}
def extract_ptr(self, ptr_symbol: FieldPointerSymbol):
if ptr_symbol.dtype != self._element_type:
if self._unsafe:
mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = ({ptr_symbol.dtype}) {self._identifier}.data();"
else:
raise SfgException(
"Field type and std::vector element type do not match, and unsafe extraction was not enabled."
)
else:
mapping = (
f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data();"
)
return SfgStatements(mapping, (ptr_symbol,), (self,))
def extract_size(
self, coordinate: int, size: Union[int, FieldShapeSymbol]
) -> SfgStatements:
if coordinate > 0:
if isinstance(size, FieldShapeSymbol):
raise SfgException(
f"Cannot extract size in coordinate {coordinate} from std::vector!"
)
elif size != 1:
raise SfgException(
f"Cannot map field with size {size} in coordinate {coordinate} to std::vector!"
)
else:
# trivial trailing index dimensions are OK -> do nothing
return SfgStatements(
f"// {self._identifier}.size({coordinate}) == 1", (), ()
)
if isinstance(size, FieldShapeSymbol):
return SfgStatements(
f"{size.dtype} {size.name} = ({size.dtype}) {self._identifier}.size();",
(size,),
(self,),
)
else:
return SfgStatements(
f"assert( {self._identifier}.size() == {size} );", (), (self,)
)
def extract_stride(
self, coordinate: int, stride: Union[int, FieldStrideSymbol]
) -> SfgStatements:
if coordinate == 1:
if stride != 1:
raise SfgException(
"Can only map fields with trivial index stride onto std::vector!"
)
if coordinate > 1:
raise SfgException(
f"Cannot extract stride in coordinate {coordinate} from std::vector"
)
if isinstance(stride, FieldStrideSymbol):
return SfgStatements(f"{stride.dtype} {stride.name} = 1;", (stride,), ())
elif stride != 1:
raise SfgException(
"Can only map fields with trivial strides onto std::vector!"
)
else:
return SfgStatements(
f"// {self._identifier}.stride({coordinate}) == 1", (), ()
)
def extract_component(
self, destination: TypedSymbolOrObject, coordinate: int
) -> SfgStatements:
if self._unsafe:
mapping = f"{destination.dtype} {destination.name} = {self._identifier}[{coordinate}];"
else:
mapping = f"{destination.dtype} {destination.name} = {self._identifier}.at({coordinate});"
return SfgStatements(mapping, (destination,), (self,))
def std_vector_ref(field: Field, src_struct: SfgClass):
if field.field_type != FieldType.INDEXED:
raise ValueError("Can only create std::vector for index fields")
return StdVector(field.name, src_struct.src_type, unsafe=True, reference=True)