From 3a4d360d94e647eb8c53e5517ff7f1261b48d8bf Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 8 Jan 2025 10:33:50 +0100
Subject: [PATCH] update CppType to hold its positional and kw arguments and
 use them for cloning

---
 src/pystencilssfg/lang/types.py | 97 ++++++++++++++++++++++++---------
 tests/lang/test_types.py        | 36 +++++++++++-
 2 files changed, 104 insertions(+), 29 deletions(-)

diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py
index a41246d..e437ae6 100644
--- a/src/pystencilssfg/lang/types.py
+++ b/src/pystencilssfg/lang/types.py
@@ -1,6 +1,10 @@
 from __future__ import annotations
-from typing import Any, Iterable
+from typing import Any, Iterable, Sequence, Mapping
 from abc import ABC
+from dataclasses import dataclass
+
+import string
+
 from pystencils.types import PsType, PsPointerType, PsCustomType
 from .headers import HeaderFile
 
@@ -24,8 +28,71 @@ class VoidType(PsType):
 void = VoidType()
 
 
+class _TemplateArgFormatter(string.Formatter):
+
+    def format_field(self, arg, format_spec):
+        if isinstance(arg, PsType):
+            arg = arg.c_string()
+        return super().format_field(arg, format_spec)
+
+    def check_unused_args(
+        self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any]
+    ) -> None:
+        max_args_len: int = max((k for k in used_args if isinstance(k, int)), default=-1) + 1
+        if len(args) > max_args_len:
+            raise ValueError(
+                f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}"
+            )
+
+        extra_keys = set(kwargs.keys()) - used_args  # type: ignore
+        if extra_keys:
+            raise ValueError(f"Extraneous keyword arguments: {extra_keys}")
+
+
+@dataclass(frozen=True)
+class _TemplateArgs:
+    pargs: tuple[Any, ...]
+    kwargs: tuple[tuple[str, Any], ...]
+
+
 class CppType(PsCustomType, ABC):
     includes: frozenset[HeaderFile]
+    template_string: str
+
+    def __new__(  # type: ignore
+        cls,
+        *args,
+        ref: bool = False,
+        **kwargs,
+    ) -> CppType | Ref:
+        if ref:
+            obj = cls(*args, **kwargs)
+            return Ref(obj)
+        else:
+            return super().__new__(cls)
+
+    def __init__(self, *template_args, const: bool = False, **template_kwargs):
+        #   Support for cloning CppTypes
+        if template_args and isinstance(template_args[0], _TemplateArgs):
+            assert not template_kwargs
+            targs = template_args[0]
+            pargs = targs.pargs
+            kwargs = dict(targs.kwargs)
+        else:
+            pargs = template_args
+            kwargs = template_kwargs
+            targs = _TemplateArgs(
+                pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0]))
+            )
+
+        formatter = _TemplateArgFormatter()
+        name = formatter.format(self.template_string, *pargs, **kwargs)
+
+        self._targs = targs
+        super().__init__(name, const=const)
+
+    def __args__(self) -> tuple[Any, ...]:
+        return (self._targs,)
 
     @property
     def required_headers(self) -> set[str]:
@@ -42,35 +109,11 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile]
     else:
         headers = list(include)
 
-    def _fixarg(template_arg):
-        if isinstance(template_arg, PsType):
-            return template_arg.c_string()
-        else:
-            return str(template_arg)
-
     class TypeClass(CppType):
+        template_string = typestr
         includes = frozenset(HeaderFile.parse(h) for h in headers)
 
-    class TypeClassFactory(TypeClass):
-        def __new__(  # type: ignore
-            cls,
-            *template_args,
-            const: bool = False,
-            ref: bool = False,
-            **template_kwargs,
-        ) -> PsType:
-            template_args = tuple(_fixarg(arg) for arg in template_args)
-            template_kwargs = {
-                key: _fixarg(value) for key, value in template_kwargs.items()
-            }
-
-            name = typestr.format(*template_args, **template_kwargs)
-            obj: PsType = TypeClass(name, const)
-            if ref:
-                obj = Ref(obj)
-            return obj
-
-    return TypeClassFactory
+    return TypeClass
 
 
 class Ref(PsType):
diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py
index 79e9240..35713ea 100644
--- a/tests/lang/test_types.py
+++ b/tests/lang/test_types.py
@@ -1,12 +1,14 @@
+import pytest
+
 from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref
 from pystencils import create_type
 from pystencils.types import constify, deconstify
 
 
 def test_cpptypes():
-    tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
+    tclass = cpptype("std::vector< {}, {} >", "<vector>")
 
-    vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >")
+    vec_type = tclass(create_type("float32"), "std::allocator< float >")
     assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
     assert (
         tclass.includes
@@ -14,8 +16,38 @@ def test_cpptypes():
         == {HeaderFile("vector", system_header=True)}
     )
 
+    #   Cloning
     assert deconstify(constify(vec_type)) == vec_type
 
+    #   Duplicate Equality
+    assert tclass(create_type("float32"), "std::allocator< float >") == vec_type
+    #   Not equal with different argument even though it produces the same string
+    assert tclass("float", "std::allocator< float >") != vec_type
+
+    #   The same with keyword arguments
+    tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
+
+    vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >")
+    assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
+
+    assert deconstify(constify(vec_type)) == vec_type
+
+
+def test_cpptype_invalid_construction():
+    tclass = cpptype("std::vector< {}, {Allocator} >", "<vector>")
+
+    with pytest.raises(IndexError):
+        _ = tclass(Allocator="SomeAlloc")
+
+    with pytest.raises(KeyError):
+        _ = tclass("int")
+
+    with pytest.raises(ValueError, match="Too many positional arguments"):
+        _ = tclass("int", "bogus", Allocator="SomeAlloc")
+
+    with pytest.raises(ValueError, match="Extraneous keyword arguments"):
+        _ = tclass("int", Allocator="SomeAlloc", bogus=2)
+
 
 def test_cpptype_const():
     tclass = cpptype("std::vector< {T} >", "<vector>")
-- 
GitLab