From 602559e7a5fe8d3a515d530a34cfcd989d246f9a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 9 Nov 2023 17:52:44 +0900
Subject: [PATCH] cleaned up cpp typenames

---
 pystencilssfg/__init__.py                     |  6 ++-
 pystencilssfg/source_concepts/__init__.py     |  5 +--
 .../source_concepts/cpp/std_mdspan.py         |  7 ++-
 .../source_concepts/cpp/std_vector.py         | 13 +++---
 .../source_concepts/source_objects.py         | 27 ++----------
 pystencilssfg/types.py                        | 44 +++++++++++++++----
 6 files changed, 54 insertions(+), 48 deletions(-)

diff --git a/pystencilssfg/__init__.py b/pystencilssfg/__init__.py
index 04329f1..a56913f 100644
--- a/pystencilssfg/__init__.py
+++ b/pystencilssfg/__init__.py
@@ -1,7 +1,9 @@
 from .context import SourceFileGenerator, SfgContext
 from .kernel_namespace import SfgKernelNamespace, SfgKernelHandle
 
+from .types import PsType, SrcType
+
 __all__ = [
-    'SourceFileGenerator', 'SfgContext',
-    'SfgKernelNamespace', 'SfgKernelHandle'
+    SourceFileGenerator, SfgContext, SfgKernelNamespace, SfgKernelHandle,
+    PsType, SrcType
 ]
diff --git a/pystencilssfg/source_concepts/__init__.py b/pystencilssfg/source_concepts/__init__.py
index 10ba677..0531fd2 100644
--- a/pystencilssfg/source_concepts/__init__.py
+++ b/pystencilssfg/source_concepts/__init__.py
@@ -1,6 +1,5 @@
-from .source_objects import SrcObject, SrcField, SrcVector, PsType, SrcType, TypedSymbolOrObject
+from .source_objects import SrcObject, SrcField, SrcVector, TypedSymbolOrObject
 
 __all__ = [
-    SrcObject, SrcField, SrcVector,
-    PsType, SrcType, TypedSymbolOrObject
+    SrcObject, SrcField, SrcVector, TypedSymbolOrObject
 ]
\ No newline at end of file
diff --git a/pystencilssfg/source_concepts/cpp/std_mdspan.py b/pystencilssfg/source_concepts/cpp/std_mdspan.py
index 2b99b86..dcb5265 100644
--- a/pystencilssfg/source_concepts/cpp/std_mdspan.py
+++ b/pystencilssfg/source_concepts/cpp/std_mdspan.py
@@ -5,16 +5,15 @@ from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeS
 from ...tree import SfgStatements
 from ..source_objects import SrcField
 from ...source_components.header_include import SfgHeaderInclude
-from ..source_objects import PsType
+from ...types import PsType, cpp_typename
 from ...exceptions import SfgException
 
 class std_mdspan(SrcField):
     dynamic_extent = "std::dynamic_extent"
 
     def __init__(self, identifer: str, T: PsType, extents: Tuple[int, str], extents_type: PsType = int, reference: bool = False):
-        from pystencils.typing import create_type
-        T = create_type(T)
-        extents_type = create_type(extents_type)
+        T = cpp_typename(T)
+        extents_type = cpp_typename(extents_type)
 
         typestring = f"std::mdspan< {T}, std::extents< {extents_type}, {', '.join(str(e) for e in extents)} > > {'&' if reference else ''}"
         super().__init__(typestring, identifer)
diff --git a/pystencilssfg/source_concepts/cpp/std_vector.py b/pystencilssfg/source_concepts/cpp/std_vector.py
index 5f2e8f0..d01e881 100644
--- a/pystencilssfg/source_concepts/cpp/std_vector.py
+++ b/pystencilssfg/source_concepts/cpp/std_vector.py
@@ -1,16 +1,17 @@
 from typing import Set, Union, Tuple
 
-from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol, create_type
+from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
 
 from ...tree import SfgStatements
 from ..source_objects import SrcField, SrcVector
-from ..source_objects import SrcObject, SrcType, TypedSymbolOrObject
+from ..source_objects import SrcObject, TypedSymbolOrObject
+from ...types import SrcType, PsType, cpp_typename
 from ...source_components.header_include import SfgHeaderInclude
 from ...exceptions import SfgException
 
 class std_vector(SrcVector, SrcField):
-    def __init__(self, identifer: str, T: SrcType, unsafe: bool = False):
-        typestring = f"std::vector< {T} >"
+    def __init__(self, identifer: str, T: Union[SrcType, PsType], unsafe: bool = False):
+        typestring = f"std::vector< {cpp_typename(T)} >"
         super(SrcObject, self).__init__(identifer, typestring)
 
         self._element_type = T
@@ -56,7 +57,6 @@ class std_vector(SrcVector, SrcField):
         else:
             return SfgStatements(f"assert( 1 == {stride} );", (), ())
 
-
     def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
         if self._unsafe:
             mapping = f"{destination.dtype} {destination.name} = {self._identifier}[{coordinate}];"
@@ -66,8 +66,7 @@ class std_vector(SrcVector, SrcField):
         return SfgStatements(mapping, (destination,), (self,))
 
 
-
 class std_vector_ref(std_vector):
-    def __init__(self, identifer: str, T: SrcType):
+    def __init__(self, identifer: str, T: Union[SrcType, PsType]):
         typestring = f"std::vector< {T} > &"
         super(SrcObject, self).__init__(identifer, typestring)
diff --git a/pystencilssfg/source_concepts/source_objects.py b/pystencilssfg/source_concepts/source_objects.py
index 244ff75..ef08675 100644
--- a/pystencilssfg/source_concepts/source_objects.py
+++ b/pystencilssfg/source_concepts/source_objects.py
@@ -6,33 +6,12 @@ if TYPE_CHECKING:
     from ..source_components import SfgHeaderInclude
     from ..tree import SfgStatements, SfgSequence
 
-from numpy import dtype
-
 from abc import ABC, abstractmethod
 
 from pystencils import TypedSymbol, Field
-from pystencils.typing import AbstractType, FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
-
-PsType: TypeAlias = Union[type, dtype, AbstractType]
-"""Types used in interacting with pystencils.
-
-PsType represents various ways of specifying types within pystencils.
-In particular, it encompasses most ways to construct an instance of `AbstractType`,
-for example via `create_type`.
-
-(Note that, while `create_type` does accept strings, they are excluded here for
-reasons of safety. It is discouraged to use strings for type specifications when working
-with pystencils!)
-"""
-
-SrcType = NewType('SrcType', str)
-"""Nonprimitive C/C++-Types occuring during source file generation.
-
-Nonprimitive C/C++ types are represented by their names.
-When necessary, the SFG package checks equality of types by these name strings; it does
-not care about typedefs, aliases, namespaces, etc!
-"""
+from pystencils.typing import FieldPointerSymbol, FieldStrideSymbol, FieldShapeSymbol
 
+from ..types import SrcType
 
 class SrcObject:
     """C/C++ object of nonprimitive type.
@@ -100,7 +79,7 @@ class SrcField(SrcObject, ABC):
         )
 
 
-class SrcVector(SrcObject):
+class SrcVector(SrcObject, ABC):
     @abstractmethod
     def extract_component(self, destination: TypedSymbolOrObject, coordinate: int):
         pass
diff --git a/pystencilssfg/types.py b/pystencilssfg/types.py
index 9e309e8..fe12909 100644
--- a/pystencilssfg/types.py
+++ b/pystencilssfg/types.py
@@ -1,12 +1,40 @@
-from pystencils.typing import AbstractType, BasicType, StructType, PointerType
+from typing import Union, TypeAlias, NewType
+import numpy as np
 
+from pystencils.typing import AbstractType, numpy_name_to_c
 
-class SrcType:
-    """Valid C/C++-Type occuring during source file generation.
 
-    Nonprimitive C/C++ types are represented by their names.
-    When necessary, the SFG package checks equality of types by these name strings; it does
-    not care about typedefs, aliases, namespaces, etc!
-    """
-    
+PsType: TypeAlias = Union[type, np.dtype, AbstractType]
+"""Types used in interacting with pystencils.
+
+PsType represents various ways of specifying types within pystencils.
+In particular, it encompasses most ways to construct an instance of `AbstractType`,
+for example via `create_type`.
+
+(Note that, while `create_type` does accept strings, they are excluded here for
+reasons of safety. It is discouraged to use strings for type specifications when working
+with pystencils!)
+"""
+
+SrcType = NewType('SrcType', str)
+"""Nonprimitive C/C++-Types occuring during source file generation.
+
+Nonprimitive C/C++ types are represented by their names.
+When necessary, the SFG package checks equality of types by these name strings; it does
+not care about typedefs, aliases, namespaces, etc!
+"""
+
+
+def cpp_typename(type_obj: Union[str, SrcType, PsType]):
+    """Accepts type specifications in various ways and returns a valid typename to be used in code."""
+    # if isinstance(type_obj, str):
+    #     return type_obj
+    if isinstance(type_obj, str):
+        return type_obj
+    elif isinstance(type_obj, AbstractType):
+        return str(type_obj)
+    elif isinstance(type_obj, np.dtype) or isinstance(type_obj, type):
+        return numpy_name_to_c(np.dtype(type_obj).name)
+    else:
+        raise ValueError(f"Don't know how to interpret type object {type_obj}.")
 
-- 
GitLab