diff --git a/integration/test_classes.py b/integration/test_classes.py index 39421001743bbd84564a5cbbfb5ef50953b597bf..d298a6c775edcbe4886a24ae8ec41eb9a474cf13 100644 --- a/integration/test_classes.py +++ b/integration/test_classes.py @@ -70,7 +70,7 @@ with SourceFileGenerator(sfg_config) as ctx: cls.add_constructor( SfgConstructor( cls, - [SrcObject("std::vector< int > &", "stuff")], + [SrcObject("stuff", "std::vector< int > &")], ["stuff_(stuff)"], visibility=SfgVisibility.PUBLIC ) diff --git a/src/pystencilssfg/composer.py b/src/pystencilssfg/composer.py index 0596eb6d0add3b909b202cbd7f0c796fa8bc674f..277ce8ddcb85c5e25530e0aff5a5ccb7b827f25f 100644 --- a/src/pystencilssfg/composer.py +++ b/src/pystencilssfg/composer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod +import numpy as np from pystencils import Field from pystencils.astnodes import KernelFunction @@ -10,6 +11,7 @@ from .tree import ( SfgKernelCallNode, SfgStatements, SfgFunctionParams, + SfgRequireIncludes, SfgSequence, SfgBlock, ) @@ -20,8 +22,15 @@ from .source_components import ( SfgHeaderInclude, SfgKernelNamespace, SfgKernelHandle, + SfgClass, + SfgConstructor, + SfgMemberVariable, + SfgClassKeyword, + SfgVisibility, ) -from .source_concepts import SrcField, TypedSymbolOrObject, SrcVector +from .source_concepts import SrcObject, SrcField, TypedSymbolOrObject, SrcVector +from .types import cpp_typename, SrcType +from .exceptions import SfgException if TYPE_CHECKING: from .context import SfgContext @@ -70,14 +79,22 @@ class SfgComposer: return kns def include(self, header_file: str): - system_header = False - if header_file.startswith("<") and header_file.endswith(">"): - header_file = header_file[1:-1] - system_header = True + self._ctx.add_include(parse_include(header_file)) - self._ctx.add_include( - SfgHeaderInclude(header_file, system_header=system_header) - ) + def numpy_struct( + self, name: str, dtype: np.dtype, add_constructor: bool = True + ) -> SfgClass: + """Add a numpy structured data type as a C++ struct + + Returns: + The created class object + """ + if self._ctx.get_class(name) is not None: + raise SfgException(f"Class with name {name} already exists.") + + cls = struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) + self._ctx.add_class(cls) + return cls def kernel_function( self, name: str, ast_or_kernel_handle: KernelFunction | SfgKernelHandle @@ -141,6 +158,9 @@ class SfgComposer: """Use inside a function body to add parameters to the function.""" return SfgFunctionParams(args) + def require(self, *includes: str | SfgHeaderInclude) -> SfgRequireIncludes: + return SfgRequireIncludes(list(parse_include(incl) for incl in includes)) + @property def branch(self) -> SfgBranchBuilder: """Use inside a function body to create an if/else conditonal branch. @@ -296,3 +316,51 @@ class SfgBranchBuilder(SfgNodeBuilder): def resolve(self) -> SfgCallTreeNode: assert self._cond is not None return SfgBranch(self._cond, self._branch_true, self._branch_false) + + +def parse_include(incl: str | SfgHeaderInclude): + if isinstance(incl, SfgHeaderInclude): + return incl + + system_header = False + if incl.startswith("<") and incl.endswith(">"): + incl = incl[1:-1] + system_header = True + + return SfgHeaderInclude(incl, system_header=system_header) + + +def struct_from_numpy_dtype( + struct_name: str, dtype: np.dtype, add_constructor: bool = True +): + cls = SfgClass(struct_name, class_keyword=SfgClassKeyword.STRUCT) + + fields = dtype.fields + if fields is None: + raise SfgException(f"Numpy dtype {dtype} is not a structured type.") + + constr_params = [] + constr_inits = [] + + for member_name, type_info in fields.items(): + member_type = SrcType(cpp_typename(type_info[0])) + + member = SfgMemberVariable( + member_name, member_type, cls, visibility=SfgVisibility.DEFAULT + ) + + arg = SrcObject(f"{member_name}_", member_type) + + cls.add_member_variable(member) + + constr_params.append(arg) + constr_inits.append(f"{member}({arg})") + + if add_constructor: + cls.add_constructor( + SfgConstructor( + cls, constr_params, constr_inits, visibility=SfgVisibility.DEFAULT + ) + ) + + return cls diff --git a/src/pystencilssfg/source_components.py b/src/pystencilssfg/source_components.py index 1bf7242ce6ca8099d04706d058e3c97ac30b3ace..1346059a08cf0744d1ce5b17c0925a57fdb7a292 100644 --- a/src/pystencilssfg/source_components.py +++ b/src/pystencilssfg/source_components.py @@ -172,7 +172,9 @@ class SfgKernelHandle: class SfgFunction: - def __init__(self, name: str, tree: SfgCallTreeNode, return_type: SrcType = SrcType("void")): + def __init__( + self, name: str, tree: SfgCallTreeNode, return_type: SrcType = SrcType("void") + ): self._name = name self._tree = tree self._return_type = return_type @@ -248,7 +250,7 @@ class SfgMemberVariable(SrcObject, SfgClassMember): cls: SfgClass, visibility: SfgVisibility = SfgVisibility.PRIVATE, ): - SrcObject.__init__(self, type, name) + SrcObject.__init__(self, name, type) SfgClassMember.__init__(self, cls, visibility) @@ -261,7 +263,7 @@ class SfgMethod(SfgFunction, SfgClassMember): visibility: SfgVisibility = SfgVisibility.PUBLIC, return_type: SrcType = SrcType("void"), inline: bool = False, - const: bool = False + const: bool = False, ): SfgFunction.__init__(self, name, tree, return_type=return_type) SfgClassMember.__init__(self, cls, visibility) @@ -324,6 +326,10 @@ class SfgClass: def class_name(self) -> str: return self._class_name + @property + def src_type(self) -> SrcType: + return SrcType(self._class_name) + @property def base_classes(self) -> tuple[str, ...]: return self._bases_classes diff --git a/src/pystencilssfg/source_concepts/cpp/__init__.py b/src/pystencilssfg/source_concepts/cpp/__init__.py index ed0b13ee7213ba30c58f0e24864c69ea947c4016..b6e587860a4528f10108a85cef2cebd3511e3582 100644 --- a/src/pystencilssfg/source_concepts/cpp/__init__.py +++ b/src/pystencilssfg/source_concepts/cpp/__init__.py @@ -1,7 +1,7 @@ from .std_mdspan import StdMdspan, mdspan_ref -from .std_vector import std_vector, std_vector_ref +from .std_vector import StdVector, std_vector_ref __all__ = [ - "StdMdspan", "std_vector", "std_vector_ref", + "StdMdspan", "StdVector", "std_vector_ref", "mdspan_ref" ] diff --git a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py index 0bf1b823d38c2457dace50126cfc45d0b3977162..2e57c5261a6474fac309a39205d50ad29e6fd211 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_mdspan.py +++ b/src/pystencilssfg/source_concepts/cpp/std_mdspan.py @@ -25,7 +25,7 @@ class StdMdspan(SrcField): extents_str = f"std::extents< {extents_type_str}, {', '.join(str(e) for e in extents)} >" typestring = f"std::mdspan< {cpp_typestr}, {extents_str} > {'&' if reference else ''}" - super().__init__(SrcType(typestring), identifer) + super().__init__(identifer, SrcType(typestring)) self._extents = extents diff --git a/src/pystencilssfg/source_concepts/cpp/std_vector.py b/src/pystencilssfg/source_concepts/cpp/std_vector.py index 9e80f620434305256ae30fbbe82a74a412ebce49..a63220607c895f7577035015a512a7f67d6ea7b0 100644 --- a/src/pystencilssfg/source_concepts/cpp/std_vector.py +++ b/src/pystencilssfg/source_concepts/cpp/std_vector.py @@ -1,26 +1,36 @@ from typing import Union +from pystencils.field import Field, FieldType from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol from ...tree import SfgStatements from ..source_objects import SrcField, SrcVector from ..source_objects import TypedSymbolOrObject from ...types import SrcType, PsType, cpp_typename -from ...source_components import SfgHeaderInclude +from ...source_components import SfgHeaderInclude, SfgClass from ...exceptions import SfgException -class std_vector(SrcVector, SrcField): - def __init__(self, identifer: str, T: Union[SrcType, PsType], unsafe: bool = False): - typestring = f"std::vector< {cpp_typename(T)} >" - super(std_vector, self).__init__(SrcType(typestring), identifer) +class StdVector(SrcVector, SrcField): + def __init__( + self, + identifer: str, + T: Union[SrcType, PsType], + unsafe: bool = False, + reference: bool = True, + ): + typestring = f"std::vector< {cpp_typename(T)} > {'&' if reference else ''}" + super(StdVector, self).__init__(identifer, SrcType(typestring)) self._element_type = T self._unsafe = unsafe @property def required_includes(self) -> set[SfgHeaderInclude]: - return {SfgHeaderInclude("vector", system_header=True)} + return { + SfgHeaderInclude("cassert", system_header=True), + SfgHeaderInclude("vector", system_header=True), + } def extract_ptr(self, ptr_symbol: FieldPointerSymbol): if ptr_symbol.dtype != self._element_type: @@ -28,38 +38,72 @@ class std_vector(SrcVector, SrcField): mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = ({ptr_symbol.dtype}) {self._identifier}.data();" else: raise SfgException( - "Field type and std::vector element type do not match, and unsafe extraction was not enabled.") + "Field type and std::vector element type do not match, and unsafe extraction was not enabled." + ) else: - mapping = f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data();" + mapping = ( + f"{ptr_symbol.dtype} {ptr_symbol.name} = {self._identifier}.data();" + ) return SfgStatements(mapping, (ptr_symbol,), (self,)) - def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: + def extract_size( + self, coordinate: int, size: Union[int, FieldShapeSymbol] + ) -> SfgStatements: if coordinate > 0: - raise SfgException(f"Cannot extract size in coordinate {coordinate} from std::vector") + if isinstance(size, FieldShapeSymbol): + raise SfgException( + f"Cannot extract size in coordinate {coordinate} from std::vector!" + ) + elif size != 1: + raise SfgException( + f"Cannot map field with size {size} in coordinate {coordinate} to std::vector!" + ) + else: + # trivial trailing index dimensions are OK -> do nothing + return SfgStatements( + f"// {self._identifier}.size({coordinate}) == 1", (), () + ) if isinstance(size, FieldShapeSymbol): return SfgStatements( - f"{size.dtype} {size.name} = {self._identifier}.size();", - (size, ), - (self, ) + f"{size.dtype} {size.name} = ({size.dtype}) {self._identifier}.size();", + (size,), + (self,), ) else: return SfgStatements( - f"assert( {self._identifier}.size() == {size} );", - (), (self, ) + f"assert( {self._identifier}.size() == {size} );", (), (self,) ) - def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: - if coordinate > 0: - raise SfgException(f"Cannot extract stride in coordinate {coordinate} from std::vector") + def extract_stride( + self, coordinate: int, stride: Union[int, FieldStrideSymbol] + ) -> SfgStatements: + if coordinate == 1: + if stride != 1: + raise SfgException( + "Can only map fields with trivial index stride onto std::vector!" + ) + + if coordinate > 1: + raise SfgException( + f"Cannot extract stride in coordinate {coordinate} from std::vector" + ) if isinstance(stride, FieldStrideSymbol): - return SfgStatements(f"{stride.dtype} {stride.name} = 1;", (stride, ), ()) + return SfgStatements(f"{stride.dtype} {stride.name} = 1;", (stride,), ()) + elif stride != 1: + raise SfgException( + "Can only map fields with trivial strides onto std::vector!" + ) else: - return SfgStatements(f"assert( 1 == {stride} );", (), ()) + return SfgStatements( + f"// {self._identifier}.stride({coordinate}) == 1", (), () + ) - def extract_component(self, destination: TypedSymbolOrObject, coordinate: int) -> SfgStatements: + def extract_component( + self, destination: TypedSymbolOrObject, coordinate: int + ) -> SfgStatements: if self._unsafe: mapping = f"{destination.dtype} {destination.name} = {self._identifier}[{coordinate}];" else: @@ -68,7 +112,8 @@ class std_vector(SrcVector, SrcField): return SfgStatements(mapping, (destination,), (self,)) -class std_vector_ref(std_vector): - def __init__(self, identifer: str, T: Union[SrcType, PsType]): - typestring = f"std::vector< {T} > &" - super(std_vector_ref, self).__init__(identifer, SrcType(typestring)) +def std_vector_ref(field: Field, src_struct: SfgClass): + if field.field_type != FieldType.INDEXED: + raise ValueError("Can only create std::vector for index fields") + + return StdVector(field.name, src_struct.src_type, unsafe=True, reference=True) diff --git a/src/pystencilssfg/source_concepts/source_objects.py b/src/pystencilssfg/source_concepts/source_objects.py index 2f5a392f8cd6a45922543547f000acb12d37d78b..f86227e176de88d8c8f538050ea281c270e8b4f3 100644 --- a/src/pystencilssfg/source_concepts/source_objects.py +++ b/src/pystencilssfg/source_concepts/source_objects.py @@ -19,9 +19,9 @@ class SrcObject: Two objects are identical if they have the same identifier and type string.""" - def __init__(self, src_type: SrcType, identifier: str): - self._src_type = src_type + def __init__(self, identifier: str, src_type: SrcType): self._identifier = identifier + self._src_type = src_type @property def identifier(self): @@ -44,28 +44,37 @@ class SrcObject: return hash((self._identifier, self._src_type)) def __eq__(self, other: object) -> bool: - return (isinstance(other, SrcObject) - and self._identifier == other._identifier - and self._src_type == other._src_type) + return ( + isinstance(other, SrcObject) + and self._identifier == other._identifier + and self._src_type == other._src_type + ) + + def __str__(self) -> str: + return self.name TypedSymbolOrObject: TypeAlias = Union[TypedSymbol, SrcObject] class SrcField(SrcObject, ABC): - def __init__(self, src_type: SrcType, identifier: str): - super().__init__(src_type, identifier) + def __init__(self, identifier: str, src_type: SrcType): + super().__init__(identifier, src_type) @abstractmethod def extract_ptr(self, ptr_symbol: FieldPointerSymbol) -> SfgStatements: pass @abstractmethod - def extract_size(self, coordinate: int, size: Union[int, FieldShapeSymbol]) -> SfgStatements: + def extract_size( + self, coordinate: int, size: Union[int, FieldShapeSymbol] + ) -> SfgStatements: pass @abstractmethod - def extract_stride(self, coordinate: int, stride: Union[int, FieldStrideSymbol]) -> SfgStatements: + def extract_stride( + self, coordinate: int, stride: Union[int, FieldStrideSymbol] + ) -> SfgStatements: pass def extract_parameters(self, field: Field) -> SfgSequence: @@ -76,11 +85,13 @@ class SrcField(SrcObject, ABC): return make_sequence( self.extract_ptr(ptr), *(self.extract_size(c, s) for c, s in enumerate(field.shape)), - *(self.extract_stride(c, s) for c, s in enumerate(field.strides)) + *(self.extract_stride(c, s) for c, s in enumerate(field.strides)), ) class SrcVector(SrcObject, ABC): @abstractmethod - def extract_component(self, destination: TypedSymbolOrObject, coordinate: int) -> SfgStatements: + def extract_component( + self, destination: TypedSymbolOrObject, coordinate: int + ) -> SfgStatements: pass diff --git a/src/pystencilssfg/tree/__init__.py b/src/pystencilssfg/tree/__init__.py index db5b8b7a3fd92808ec9e0a168d3a59dc4b482bbc..d25a5ebbd9d69f1b13b4c5937cfbf688f9edd9d2 100644 --- a/src/pystencilssfg/tree/__init__.py +++ b/src/pystencilssfg/tree/__init__.py @@ -5,6 +5,7 @@ from .basic_nodes import ( SfgSequence, SfgStatements, SfgFunctionParams, + SfgRequireIncludes ) from .conditional import SfgBranch, SfgCondition, IntEven, IntOdd @@ -15,6 +16,7 @@ __all__ = [ "SfgBlock", "SfgStatements", "SfgFunctionParams", + "SfgRequireIncludes", "SfgCondition", "SfgBranch", "IntEven", diff --git a/src/pystencilssfg/tree/basic_nodes.py b/src/pystencilssfg/tree/basic_nodes.py index ee4860d2afb8a8414012e5ec5dd7d70c95734ffe..437afc959a94aeb51efe7193e3ae6a4d4e28e1ab 100644 --- a/src/pystencilssfg/tree/basic_nodes.py +++ b/src/pystencilssfg/tree/basic_nodes.py @@ -4,12 +4,11 @@ from typing import TYPE_CHECKING, Sequence from abc import ABC, abstractmethod from itertools import chain -from ..source_components import SfgKernelHandle +from ..source_components import SfgHeaderInclude, SfgKernelHandle from ..source_concepts.source_objects import SrcObject, TypedSymbolOrObject if TYPE_CHECKING: from ..context import SfgContext - from ..source_components import SfgHeaderInclude class SfgCallTreeNode(ABC): @@ -57,6 +56,14 @@ class SfgCallTreeLeaf(SfgCallTreeNode, ABC): ... +class SfgEmptyNode(SfgCallTreeLeaf): + def __init__(self): + super().__init__() + + def get_code(self, ctx: SfgContext) -> str: + return "" + + class SfgStatements(SfgCallTreeLeaf): """Represents (a sequence of) statements in the source language. @@ -108,7 +115,7 @@ class SfgStatements(SfgCallTreeLeaf): return self._code_string -class SfgFunctionParams(SfgCallTreeLeaf): +class SfgFunctionParams(SfgEmptyNode): def __init__(self, parameters: Sequence[TypedSymbolOrObject]): super().__init__() self._params = set(parameters) @@ -126,8 +133,19 @@ class SfgFunctionParams(SfgCallTreeLeaf): def required_includes(self) -> set[SfgHeaderInclude]: return self._required_includes - def get_code(self, ctx: SfgContext) -> str: - return "" + +class SfgRequireIncludes(SfgEmptyNode): + def __init__(self, includes: Sequence[SfgHeaderInclude]): + super().__init__() + self._required_includes = set(includes) + + @property + def required_parameters(self) -> set[TypedSymbolOrObject]: + return set() + + @property + def required_includes(self) -> set[SfgHeaderInclude]: + return self._required_includes class SfgSequence(SfgCallTreeNode): diff --git a/src/pystencilssfg/types.py b/src/pystencilssfg/types.py index a0ea1b1ab92e2472281ecd9000056dd61ace4b18..250b0cc73eb68e5a733c0847c467b027d7290668 100644 --- a/src/pystencilssfg/types.py +++ b/src/pystencilssfg/types.py @@ -14,14 +14,19 @@ for example via `create_type`. (Note that, while `create_type` does accept strings, they are excluded here for reasons of safety. It is discouraged to use strings for type specifications when working with pystencils!) + +PsType is a temporary solution and will be removed in the future +in favor of the consolidated pystencils backend typing system. """ SrcType = NewType('SrcType', str) -"""Nonprimitive C/C++-Types occuring during source file generation. +"""C/C++-Types occuring during source file generation. -Nonprimitive C/C++ types are represented by their names. -When necessary, the SFG package checks equality of types by these name strings; it does +When necessary, the SFG package checks equality of types by their name strings; it does not care about typedefs, aliases, namespaces, etc! + +SrcType is a temporary solution and will be removed in the future +in favor of the consolidated pystencils backend typing system. """ diff --git a/src/pystencilssfg/visitors/dispatcher.py b/src/pystencilssfg/visitors/dispatcher.py index 48bfda997b61c9450d3562c83c04d39714306de7..0b47534d683e824ca28f3474e337d12b2820f924 100644 --- a/src/pystencilssfg/visitors/dispatcher.py +++ b/src/pystencilssfg/visitors/dispatcher.py @@ -4,8 +4,6 @@ from types import MethodType from functools import wraps -from ..tree.basic_nodes import SfgCallTreeNode - V = TypeVar("V") R = TypeVar("R") P = ParamSpec("P") @@ -27,7 +25,7 @@ class VisitorDispatcher(Generic[V, R]): return decorate - def __call__(self, instance: V, node: SfgCallTreeNode, *args, **kwargs) -> R: + def __call__(self, instance: V, node: object, *args, **kwargs) -> R: for cls in node.__class__.mro(): if cls in self._dispatch_dict: return self._dispatch_dict[cls](instance, node, *args, **kwargs)