Skip to content
Snippets Groups Projects
Commit 3a4d360d authored by Frederik Hennig's avatar Frederik Hennig
Browse files

update CppType to hold its positional and kw arguments and use them for cloning

parent 1a13ccaf
No related branches found
No related tags found
1 merge request!12Improve versatility and robustness of `cpptype`, and document it in the user guide
Pipeline #71532 passed
from __future__ import annotations
from typing import Any, Iterable
from typing import Any, Iterable, Sequence, Mapping
from abc import ABC
from dataclasses import dataclass
import string
from pystencils.types import PsType, PsPointerType, PsCustomType
from .headers import HeaderFile
......@@ -24,8 +28,71 @@ class VoidType(PsType):
void = VoidType()
class _TemplateArgFormatter(string.Formatter):
def format_field(self, arg, format_spec):
if isinstance(arg, PsType):
arg = arg.c_string()
return super().format_field(arg, format_spec)
def check_unused_args(
self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any]
) -> None:
max_args_len: int = max((k for k in used_args if isinstance(k, int)), default=-1) + 1
if len(args) > max_args_len:
raise ValueError(
f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}"
)
extra_keys = set(kwargs.keys()) - used_args # type: ignore
if extra_keys:
raise ValueError(f"Extraneous keyword arguments: {extra_keys}")
@dataclass(frozen=True)
class _TemplateArgs:
pargs: tuple[Any, ...]
kwargs: tuple[tuple[str, Any], ...]
class CppType(PsCustomType, ABC):
includes: frozenset[HeaderFile]
template_string: str
def __new__( # type: ignore
cls,
*args,
ref: bool = False,
**kwargs,
) -> CppType | Ref:
if ref:
obj = cls(*args, **kwargs)
return Ref(obj)
else:
return super().__new__(cls)
def __init__(self, *template_args, const: bool = False, **template_kwargs):
# Support for cloning CppTypes
if template_args and isinstance(template_args[0], _TemplateArgs):
assert not template_kwargs
targs = template_args[0]
pargs = targs.pargs
kwargs = dict(targs.kwargs)
else:
pargs = template_args
kwargs = template_kwargs
targs = _TemplateArgs(
pargs, tuple(sorted(kwargs.items(), key=lambda t: t[0]))
)
formatter = _TemplateArgFormatter()
name = formatter.format(self.template_string, *pargs, **kwargs)
self._targs = targs
super().__init__(name, const=const)
def __args__(self) -> tuple[Any, ...]:
return (self._targs,)
@property
def required_headers(self) -> set[str]:
......@@ -42,35 +109,11 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile]
else:
headers = list(include)
def _fixarg(template_arg):
if isinstance(template_arg, PsType):
return template_arg.c_string()
else:
return str(template_arg)
class TypeClass(CppType):
template_string = typestr
includes = frozenset(HeaderFile.parse(h) for h in headers)
class TypeClassFactory(TypeClass):
def __new__( # type: ignore
cls,
*template_args,
const: bool = False,
ref: bool = False,
**template_kwargs,
) -> PsType:
template_args = tuple(_fixarg(arg) for arg in template_args)
template_kwargs = {
key: _fixarg(value) for key, value in template_kwargs.items()
}
name = typestr.format(*template_args, **template_kwargs)
obj: PsType = TypeClass(name, const)
if ref:
obj = Ref(obj)
return obj
return TypeClassFactory
return TypeClass
class Ref(PsType):
......
import pytest
from pystencilssfg.lang import cpptype, HeaderFile, Ref, strip_ptr_ref
from pystencils import create_type
from pystencils.types import constify, deconstify
def test_cpptypes():
tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
tclass = cpptype("std::vector< {}, {} >", "<vector>")
vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >")
vec_type = tclass(create_type("float32"), "std::allocator< float >")
assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
assert (
tclass.includes
......@@ -14,8 +16,38 @@ def test_cpptypes():
== {HeaderFile("vector", system_header=True)}
)
# Cloning
assert deconstify(constify(vec_type)) == vec_type
# Duplicate Equality
assert tclass(create_type("float32"), "std::allocator< float >") == vec_type
# Not equal with different argument even though it produces the same string
assert tclass("float", "std::allocator< float >") != vec_type
# The same with keyword arguments
tclass = cpptype("std::vector< {T}, {Allocator} >", "<vector>")
vec_type = tclass(T=create_type("float32"), Allocator="std::allocator< float >")
assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
assert deconstify(constify(vec_type)) == vec_type
def test_cpptype_invalid_construction():
tclass = cpptype("std::vector< {}, {Allocator} >", "<vector>")
with pytest.raises(IndexError):
_ = tclass(Allocator="SomeAlloc")
with pytest.raises(KeyError):
_ = tclass("int")
with pytest.raises(ValueError, match="Too many positional arguments"):
_ = tclass("int", "bogus", Allocator="SomeAlloc")
with pytest.raises(ValueError, match="Extraneous keyword arguments"):
_ = tclass("int", Allocator="SomeAlloc", bogus=2)
def test_cpptype_const():
tclass = cpptype("std::vector< {T} >", "<vector>")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment