From 3a4d360d94e647eb8c53e5517ff7f1261b48d8bf Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Wed, 8 Jan 2025 10:33:50 +0100 Subject: [PATCH] update CppType to hold its positional and kw arguments and use them for cloning --- src/pystencilssfg/lang/types.py | 97 ++++++++++++++++++++++++--------- tests/lang/test_types.py | 36 +++++++++++- 2 files changed, 104 insertions(+), 29 deletions(-) diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index a41246d..e437ae6 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,6 +1,10 @@ from __future__ import annotations -from typing import Any, Iterable +from typing import Any, Iterable, Sequence, Mapping from abc import ABC +from dataclasses import dataclass + +import string + from pystencils.types import PsType, PsPointerType, PsCustomType from .headers import HeaderFile @@ -24,8 +28,71 @@ class VoidType(PsType): void = VoidType() +class _TemplateArgFormatter(string.Formatter): + + def format_field(self, arg, format_spec): + if isinstance(arg, PsType): + arg = arg.c_string() + return super().format_field(arg, format_spec) + + def check_unused_args( + self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any] + ) -> None: + max_args_len: int = max((k for k in used_args if isinstance(k, int)), default=-1) + 1 + if len(args) > max_args_len: + raise ValueError( + f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}" + ) + + extra_keys = set(kwargs.keys()) - used_args # type: ignore + if extra_keys: + raise ValueError(f"Extraneous keyword arguments: {extra_keys}") + + +@dataclass(frozen=True) +class _TemplateArgs: + pargs: tuple[Any, ...] + kwargs: tuple[tuple[str, Any], ...] + + class CppType(PsCustomType, ABC): includes: frozenset[HeaderFile] + template_string: str + + def __new__( # type: ignore + cls, + *args, + ref: bool = False, + **kwargs, + ) -> CppType | Ref: + if ref: + obj = cls(*args, **kwargs) + return Ref(obj) + else: + return super().__new__(cls) + + def __init__(self, *template_args, const: bool = False, **template_kwargs): + # Support for cloning CppTypes + if template_args and isinstance(template_args[0], _TemplateArgs): + assert not template_kwargs + targs = template_args[0] + pargs = targs.pargs + kwargs = dict(targs.kwargs) + else: + pargs = template_args + kwargs = template_kwargs + targs = _TemplateArgs( + pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0])) + ) + + formatter = _TemplateArgFormatter() + name = formatter.format(self.template_string, *pargs, **kwargs) + + self._targs = targs + super().__init__(name, const=const) + + def __args__(self) -> tuple[Any, ...]: + return (self._targs,) @property def required_headers(self) -> set[str]: @@ -42,35 +109,11 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] else: headers = list(include) - def _fixarg(template_arg): - if isinstance(template_arg, PsType): - return template_arg.c_string() - else: - return str(template_arg) - class TypeClass(CppType): + template_string = typestr includes = frozenset(HeaderFile.parse(h) for h in headers) - class TypeClassFactory(TypeClass): - def __new__( # type: ignore - cls, - *template_args, - const: bool = False, - ref: bool = False, - **template_kwargs, - ) -> PsType: - template_args = tuple(_fixarg(arg) for arg in template_args) - template_kwargs = { - key: _fixarg(value) for key, value in template_kwargs.items() - } - - name = typestr.format(*template_args, **template_kwargs) - obj: PsType = TypeClass(name, const) - if ref: - obj = Ref(obj) - return obj - - return TypeClassFactory + return TypeClass class Ref(PsType): diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py index 79e9240..35713ea 100644 --- a/tests/lang/test_types.py +++ b/tests/lang/test_types.py @@ -1,12 +1,14 @@ +import pytest + from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref from pystencils import create_type from pystencils.types import constify, deconstify def test_cpptypes(): - tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>") + tclass = cpptype("std::vector< {}, {} >", "<vector>") - vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >") + vec_type = tclass(create_type("float32"), "std::allocator< float >") assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" assert ( tclass.includes @@ -14,8 +16,38 @@ def test_cpptypes(): == {HeaderFile("vector", system_header=True)} ) + # Cloning assert deconstify(constify(vec_type)) == vec_type + # Duplicate Equality + assert tclass(create_type("float32"), "std::allocator< float >") == vec_type + # Not equal with different argument even though it produces the same string + assert tclass("float", "std::allocator< float >") != vec_type + + # The same with keyword arguments + tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>") + + vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >") + assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >" + + assert deconstify(constify(vec_type)) == vec_type + + +def test_cpptype_invalid_construction(): + tclass = cpptype("std::vector< {}, {Allocator} >", "<vector>") + + with pytest.raises(IndexError): + _ = tclass(Allocator="SomeAlloc") + + with pytest.raises(KeyError): + _ = tclass("int") + + with pytest.raises(ValueError, match="Too many positional arguments"): + _ = tclass("int", "bogus", Allocator="SomeAlloc") + + with pytest.raises(ValueError, match="Extraneous keyword arguments"): + _ = tclass("int", Allocator="SomeAlloc", bogus=2) + def test_cpptype_const(): tclass = cpptype("std::vector< {T} >", "<vector>") -- GitLab