diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 936d5c5a6d5aa6988bcf1f2f615c872ef87f162c..76e2907361d054f7468279e2500469b408f0d477 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import Sequence, TypeAlias from abc import ABC, abstractmethod -import numpy as np import sympy as sp from functools import reduce from warnings import warn @@ -33,10 +32,6 @@ from ..ir.source_components import ( SfgFunction, SfgKernelNamespace, SfgKernelHandle, - SfgClass, - SfgConstructor, - SfgMemberVariable, - SfgClassKeyword, SfgEntityDecl, SfgEntityDef, SfgNamespaceBlock, @@ -86,7 +81,7 @@ SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder class KernelsAdder: def __init__(self, ctx: SfgContext, loc: SfgNamespaceBlock): self._ctx = ctx - self._loc = SfgNamespaceBlock + self._loc = loc assert isinstance(loc.namespace, SfgKernelNamespace) self._kernel_namespace = loc.namespace @@ -110,6 +105,8 @@ class KernelsAdder: khandle = SfgKernelHandle(kernel_name, self._kernel_namespace, kernel) self._kernel_namespace.add_kernel(khandle) + self._loc.elements.append(SfgEntityDef(khandle)) + for header in kernel.required_headers: assert self._ctx.impl_file is not None self._ctx.impl_file.includes.append(HeaderFile.parse(header)) @@ -242,7 +239,7 @@ class SfgBasicComposer(SfgIComposer): self._cursor.write_impl(kns_block) return KernelsAdder(self._ctx, kns_block) - def include(self, header_file: str | HeaderFile, private: bool = False): + def include(self, header: str | HeaderFile, private: bool = False): """Include a header file. Args: @@ -262,7 +259,7 @@ class SfgBasicComposer(SfgIComposer): #include <vector> #include "custom.h" """ - header_file = HeaderFile.parse(header_file) + header_file = HeaderFile.parse(header) if private: if self._ctx.impl_file is None: @@ -273,21 +270,6 @@ class SfgBasicComposer(SfgIComposer): else: self._ctx.header_file.includes.append(header_file) - def numpy_struct( - self, name: str, dtype: np.dtype, add_constructor: bool = True - ) -> SfgClass: - """Add a numpy structured data type as a C++ struct - - Returns: - The created class object - """ - cls = self._struct_from_numpy_dtype( - name, dtype, add_constructor=add_constructor - ) - self._cursor.add_entity(cls) - self._cursor.write_header(SfgEntityDecl(cls)) - return cls - def kernel_function(self, name: str, kernel: Kernel | SfgKernelHandle): """Create a function comprising just a single kernel call. @@ -295,9 +277,11 @@ class SfgBasicComposer(SfgIComposer): ast_or_kernel_handle: Either a pystencils AST, or a kernel handle for an already registered AST. """ if isinstance(kernel, Kernel): - kernel = self.kernels.add(kernel, name) + khandle = self.kernels.add(kernel, name) + else: + khandle = kernel - self.function(name)(self.call(kernel)) + self.function(name)(self.call(khandle)) def function( self, @@ -536,41 +520,6 @@ class SfgBasicComposer(SfgIComposer): ] return SfgDeferredVectorMapping(components, rhs) - def _struct_from_numpy_dtype( - self, struct_name: str, dtype: np.dtype, add_constructor: bool = True - ): - cls = SfgClass( - struct_name, - self._cursor.current_namespace, - class_keyword=SfgClassKeyword.STRUCT, - ) - - fields = dtype.fields - if fields is None: - raise SfgException(f"Numpy dtype {dtype} is not a structured type.") - - constr_params = [] - constr_inits = [] - - for member_name, type_info in fields.items(): - member_type = create_type(type_info[0]) - - member = SfgMemberVariable(member_name, member_type) - - arg = SfgVar(f"{member_name}_", member_type) - - cls.default.append_member(member) - - constr_params.append(arg) - constr_inits.append(f"{member}({arg})") - - if add_constructor: - cls.default.append_member( - SfgEntityDef(SfgConstructor(constr_params, constr_inits)) - ) - - return cls - def make_statements(arg: ExprLike) -> SfgStatements: return SfgStatements(str(arg), (), depends(arg), includes(arg)) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 1e3e6a3085b3bc1a716930ccbc8d74381b33e05d..fa7d6f2583fd0de6539a67af13b67d05ae1029a1 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Sequence from itertools import takewhile, dropwhile +import numpy as np from pystencils.types import PsCustomType, UserTypeSpec, create_type @@ -196,6 +197,16 @@ class SfgClassComposer(SfgComposerMixIn): """ return self._class(class_name, SfgClassKeyword.STRUCT, bases) + def numpy_struct( + self, name: str, dtype: np.dtype, add_constructor: bool = True + ): + """Add a numpy structured data type as a C++ struct + + Returns: + The created class object + """ + return self._struct_from_numpy_dtype(name, dtype, add_constructor) + @property def public(self) -> SfgClassComposer.VisibilityBlockSequencer: """Create a `public` visibility block in a class body""" @@ -241,6 +252,8 @@ class SfgClassComposer(SfgComposerMixIn): # INTERNALS def _class(self, class_name: str, keyword: SfgClassKeyword, bases: Sequence[str]): + # TODO: Return a `CppClass` instance representing the generated class + if self._cursor.get_entity(class_name) is not None: raise ValueError( f"Another entity with name {class_name} already exists in the current namespace." @@ -288,3 +301,30 @@ class SfgClassComposer(SfgComposerMixIn): self._cursor.write_header(SfgEntityDef(cls)) return sequencer + + def _struct_from_numpy_dtype( + self, struct_name: str, dtype: np.dtype, add_constructor: bool = True + ): + fields = dtype.fields + if fields is None: + raise SfgException(f"Numpy dtype {dtype} is not a structured type.") + + members: list[SfgClassComposer.ConstructorBuilder | SfgVar] = [] + if add_constructor: + ctor = self.constructor() + members.append(ctor) + + for member_name, type_info in fields.items(): + member_type = create_type(type_info[0]) + + member = SfgVar(member_name, member_type) + members.append(member) + + if add_constructor: + arg = SfgVar(f"{member_name}_", member_type) + ctor.add_param(arg) + ctor.init(member)(arg) + + return self.struct( + struct_name, + )(*members) diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index 8eee39cfc2e5fa6e37f677dd780aa35bb9a796b6..f1760b7c9b9ecdaa4821adf29fbc9f2129e0bd46 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -15,8 +15,6 @@ from .call_tree import ( ) from .source_components import ( - SfgHeaderInclude, - SfgEmptyLines, SfgKernelNamespace, SfgKernelHandle, SfgKernelParamVar, @@ -25,7 +23,6 @@ from .source_components import ( SfgClassKeyword, SfgClassMember, SfgVisibilityBlock, - SfgInClassDefinition, SfgMemberVariable, SfgMethod, SfgConstructor, @@ -47,8 +44,6 @@ __all__ = [ "SfgBranch", "SfgSwitchCase", "SfgSwitch", - "SfgHeaderInclude", - "SfgEmptyLines", "SfgKernelNamespace", "SfgKernelHandle", "SfgKernelParamVar", @@ -57,7 +52,6 @@ __all__ = [ "SfgClassKeyword", "SfgClassMember", "SfgVisibilityBlock", - "SfgInClassDefinition", "SfgMemberVariable", "SfgMethod", "SfgConstructor", diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 15b27fb428b3a66827185d550b93d0bab78475f8..b4d8aa77cbf769fb465460a54ace5755168a6d3e 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -404,18 +404,7 @@ class SfgConstructor(SfgClassMember): class SfgClass(SfgCodeEntity): - """Models a C++ class. - - ### Adding members to classes - - Members are never added directly to a class. Instead, they are added to - an [SfgVisibilityBlock][pystencilssfg.source_components.SfgVisibilityBlock] - which defines their syntactic position and visibility modifier in the code. - At the top of every class, there is a default visibility block - accessible through the `default` property. - To add members with custom visibility, create a new SfgVisibilityBlock, - add members to the block, and add the block using `append_visibility_block`. - """ + """A C++ class.""" __match_args__ = ("class_name",) @@ -524,12 +513,6 @@ class SfgClass(SfgCodeEntity): self._member_vars[variable.name] = variable -SourceEntity_T = TypeVar( - "SourceEntity_T", bound=SfgFunction | SfgClassMember | SfgClass, covariant=True -) -"""Source entities that may have declarations and definitions.""" - - # ========================================================================================================= # # SYNTACTICAL ELEMENTS @@ -540,6 +523,12 @@ SourceEntity_T = TypeVar( # ========================================================================================================= +SourceEntity_T = TypeVar( + "SourceEntity_T", bound=SfgKernelHandle | SfgFunction | SfgClassMember | SfgClass, covariant=True +) +"""Source entities that may have declarations and definitions.""" + + class SfgEntityDecl(Generic[SourceEntity_T]): """Declaration of a function, class, method, or constructor"""