Skip to content
Snippets Groups Projects
Commit ac3b5e34 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Introduce Type Uniquing Mechanism

 - Introduce metaclass PsTypeMeta
 - Refactor __args__ protocol to exclude const
 - Move PsType and PsTypeMeta to types/meta.py
 - Rename basic_types.py to types.py
 - Adapt test cases to check for identity
parent d6eb671a
Branches
Tags
2 merge requests!379Type System Refactor,!374Uniqueness of Data Type Instances
Pipeline #64886 passed
Showing with 227 additions and 171 deletions
......@@ -21,7 +21,7 @@ from .astnode import PsAstNode, PsLeafMixIn
class PsExpression(PsAstNode, ABC):
"""Base class for all expressions.
**Types:** Each expression should be annotated with its type.
Upon construction, the `dtype` property of most expression nodes is unset;
only constant expressions, symbol expressions, and array accesses immediately inherit their type from
......@@ -271,7 +271,7 @@ class PsVectorArrayAccess(PsArrayAccess):
@property
def alignment(self) -> int:
return self._alignment
def get_vector_type(self) -> PsVectorType:
return cast(PsVectorType, self._dtype)
......
......@@ -7,10 +7,10 @@ from .exceptions import PsInternalCompilerError
class PsConstant:
"""Type-safe representation of typed numerical constants.
This class models constants in the backend representation of kernels.
A constant may be *untyped*, in which case its ``value`` may be any Python object.
If the constant is *typed* (i.e. its ``dtype`` is not ``None``), its data type is used
to check the validity of its ``value`` and to convert it into the type's internal representation.
......@@ -36,19 +36,19 @@ class PsConstant:
def interpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Interprets this *untyped* constant with the given data type.
If this constant is already typed, raises an error.
"""
if self._dtype is not None:
raise PsInternalCompilerError(
f"Cannot interpret already typed constant {self} with type {dtype}"
)
return PsConstant(self._value, dtype)
def reinterpret_as(self, dtype: PsNumericType) -> PsConstant:
"""Reinterprets this constant with the given data type.
Other than `interpret_as`, this method also works on typed constants.
"""
return PsConstant(self._value, dtype)
......@@ -60,7 +60,7 @@ class PsConstant:
@property
def dtype(self) -> PsNumericType | None:
"""This constant's data type, or ``None`` if it is untyped.
The data type of a constant always has ``const == True``.
"""
return self._dtype
......
......@@ -56,7 +56,7 @@ NodeT = TypeVar("NodeT", bound=PsAstNode)
class TypeContext:
"""Typing context, with support for type inference and checking.
Instances of this class are used to propagate and check data types across expression subtrees
of the AST. Each type context has:
......@@ -185,7 +185,7 @@ class TypeContext:
def _compatible(self, dtype: PsType):
"""Checks whether the given data type is compatible with the context's target type.
If the target type is ``const``, they must be equal up to const qualification;
if the target type is not ``const``, `dtype` must match it exactly.
"""
......@@ -248,7 +248,7 @@ class Typifier:
Some expressions (`PsSymbolExpr`, `PsArrayAccess`) encapsulate symbols and inherit their data types, but
not necessarily their const-qualification.
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
A symbol with non-``const`` type may occur in a `PsSymbolExpr` with ``const`` type,
and an array base pointer with non-``const`` base type may be nested in a ``const`` `PsArrayAccess`,
but not vice versa.
"""
......@@ -321,7 +321,7 @@ class Typifier:
def visit_expr(self, expr: PsExpression, tc: TypeContext) -> None:
"""Recursive processing of expression nodes.
This method opens, expands, and closes typing contexts according to the respective expression's
typing rules. It may add or check restrictions only when opening or closing a type context.
......@@ -394,7 +394,7 @@ class Typifier:
f"Unable to determine type of argument to AddressOf: {arg}"
)
ptr_type = PsPointerType(arg_tc.target_type, True)
ptr_type = PsPointerType(arg_tc.target_type, const=True)
tc.apply_dtype(ptr_type, expr)
case PsLookup(aggr, member_name):
......
from pystencils.backend.functions import CFunction, PsMathFunction
from pystencils.types.basic_types import PsType
from pystencils.types.types import PsType
from .platform import Platform
from ..kernelcreation.iteration_space import (
......@@ -56,8 +56,10 @@ class GenericGpu(Platform):
]
return indices[:dim]
def select_function(self, math_function: PsMathFunction, dtype: PsType) -> CFunction:
def select_function(
self, math_function: PsMathFunction, dtype: PsType
) -> CFunction:
raise NotImplementedError()
# Internals
......
......@@ -26,6 +26,6 @@ class AddressOf(sp.Function):
@property
def dtype(self):
if hasattr(self.args[0], 'dtype'):
return PsPointerType(self.args[0].dtype, const=True, restrict=True)
return PsPointerType(self.args[0].dtype, restrict=True, const=True)
else:
raise ValueError(f'pystencils supports only non void pointers. Current address_of type: {self.args[0]}')
......@@ -172,7 +172,7 @@ class FieldPointerSymbol(TypedSymbol):
def __new_stage2__(cls, field_name, field_dtype: PsType, const: bool):
name = f"_data_{field_name}"
dtype = PsPointerType(field_dtype, const=const, restrict=True)
dtype = PsPointerType(field_dtype, restrict=True, const=const)
obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
obj.field_name = field_name
return obj
......
......@@ -8,8 +8,9 @@ For more user-friendly and less verbose access to the type modelling system, ref
the `pystencils.types.quick` submodule.
"""
from .basic_types import (
PsType,
from .meta import PsType, constify, deconstify
from .types import (
PsCustomType,
PsStructType,
PsNumericType,
......@@ -23,8 +24,6 @@ from .basic_types import (
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
constify,
deconstify,
)
from .quick import UserTypeSpec, create_type, create_numeric_type
......
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Any, cast
import numpy as np
class PsTypeMeta(ABCMeta):
_instances: dict[Any, PsType] = dict()
def __call__(cls, *args: Any, const: bool = False, **kwargs: Any) -> Any:
obj = super(PsTypeMeta, cls).__call__(*args, const=const, **kwargs)
canonical_args = obj.__args__()
key = (cls, canonical_args, const)
if key in cls._instances:
obj = cls._instances[key]
else:
cls._instances[key] = obj
return obj
class PsType(metaclass=PsTypeMeta):
"""Base class for all pystencils types.
Args:
const: Const-qualification of this type
**Implementation details for subclasses:**
`PsType` and its metaclass ``PsTypeMeta`` together implement a uniquing mechanism to ensure that of each type,
only one instance ever exists in the public.
For this to work, subclasses have to adhere to several rules:
- All instances of `PsType` must be immutable.
- The `const` argument must be the last keyword argument to ``__init__`` and must be passed to the superclass
``__init__``.
- The `__args__` method must return a tuple of positional arguments excluding the `const` property,
which, when passed to the class's constructor, create an identically-behaving instance.
"""
def __init__(self, const: bool = False):
self._const = const
self._requalified: PsType | None = None
@property
def const(self) -> bool:
return self._const
# -------------------------------------------------------------------------------------------
# Optional Info
# -------------------------------------------------------------------------------------------
@property
def required_headers(self) -> set[str]:
"""The set of header files required when this type occurs in generated code."""
return set()
@property
def itemsize(self) -> int | None:
"""If this type has a valid in-memory size, return that size."""
return None
@property
def numpy_dtype(self) -> np.dtype | None:
"""A np.dtype object representing this data type.
Available both for backward compatibility and for interaction with the numpy-based runtime system.
"""
return None
# -------------------------------------------------------------------------------------------
# Internal operations
# -------------------------------------------------------------------------------------------
@abstractmethod
def __args__(self) -> tuple[Any, ...]:
"""Arguments to this type, excluding the const-qualifier.
The tuple returned by this method is used to serialize, deserialize, and check equality of types.
For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
```
t = MyType(< arguments >)
assert MyType(*t.__args__()) == t
```
"""
pass
def _const_string(self) -> str:
return "const " if self._const else ""
@abstractmethod
def c_string(self) -> str:
pass
# -------------------------------------------------------------------------------------------
# Dunder Methods
# -------------------------------------------------------------------------------------------
def __eq__(self, other: object) -> bool:
if self is other:
return True
if type(self) is not type(other):
return False
other = cast(PsType, other)
return self._const == other._const and self.__args__() == other.__args__()
def __str__(self) -> str:
return self.c_string()
def __hash__(self) -> int:
return hash((type(self), self.__args__()))
T = TypeVar("T", bound=PsType)
def constify(t: T) -> T:
"""Adds the const qualifier to a given type."""
if not t.const:
if t._requalified is None:
t._requalified = type(t)(*t.__args__(), const=True) # type: ignore
return cast(T, t._requalified)
else:
return t
def deconstify(t: T) -> T:
"""Removes the const qualifier from a given type."""
if t.const:
if t._requalified is None:
t._requalified = type(t)(*t.__args__(), const=False) # type: ignore
return cast(T, t._requalified)
else:
return t
import numpy as np
from .basic_types import (
from .types import (
PsType,
PsPointerType,
PsStructType,
......@@ -76,13 +76,13 @@ def parse_type_string(s: str) -> PsType:
base_type = parse_type_string(base)
match suffix.split():
case []:
return PsPointerType(base_type, const=False, restrict=False)
return PsPointerType(base_type, restrict=False, const=False)
case ["const"]:
return PsPointerType(base_type, const=True, restrict=False)
return PsPointerType(base_type, restrict=False, const=True)
case ["restrict"]:
return PsPointerType(base_type, const=False, restrict=True)
return PsPointerType(base_type, restrict=True, const=False)
case ["const", "restrict"] | ["restrict", "const"]:
return PsPointerType(base_type, const=True, restrict=True)
return PsPointerType(base_type, restrict=True, const=True)
case _:
raise ValueError(f"Could not parse token '{s}' as C type.")
......
......@@ -4,7 +4,7 @@ from __future__ import annotations
import numpy as np
from .basic_types import (
from .types import (
PsType,
PsCustomType,
PsNumericType,
......
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import final, TypeVar, Any, Sequence, cast
from typing import final, Any, Sequence
from dataclasses import dataclass
from copy import copy
import numpy as np
from .exception import PsTypeError
class PsType(ABC):
"""Base class for all pystencils types.
Args:
const: Const-qualification of this type
"""
def __init__(self, const: bool = False):
self._const = const
@property
def const(self) -> bool:
return self._const
# -------------------------------------------------------------------------------------------
# Optional Info
# -------------------------------------------------------------------------------------------
@property
def required_headers(self) -> set[str]:
"""The set of header files required when this type occurs in generated code."""
return set()
@property
def itemsize(self) -> int | None:
"""If this type has a valid in-memory size, return that size."""
return None
@property
def numpy_dtype(self) -> np.dtype | None:
"""A np.dtype object representing this data type.
Available both for backward compatibility and for interaction with the numpy-based runtime system.
"""
return None
# -------------------------------------------------------------------------------------------
# Internal operations
# -------------------------------------------------------------------------------------------
@abstractmethod
def __args__(self) -> tuple[Any, ...]:
"""Arguments to this type.
The tuple returned by this method is used to serialize, deserialize, and check equality of types.
For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
```
t = MyType(< arguments >)
assert MyType(*t.__args__()) == t
```
"""
pass
def _const_string(self) -> str:
return "const " if self._const else ""
@abstractmethod
def c_string(self) -> str:
pass
# -------------------------------------------------------------------------------------------
# Dunder Methods
# -------------------------------------------------------------------------------------------
def __eq__(self, other: object) -> bool:
if self is other:
return True
if type(self) is not type(other):
return False
other = cast(PsType, other)
return self.__args__() == other.__args__()
def __str__(self) -> str:
return self.c_string()
def __hash__(self) -> int:
return hash((type(self), self.__args__()))
from .meta import PsType, constify, deconstify
class PsCustomType(PsType):
......@@ -154,17 +72,17 @@ class PsPointerType(PsDereferencableType):
__match_args__ = ("base_type",)
def __init__(self, base_type: PsType, const: bool = False, restrict: bool = True):
def __init__(self, base_type: PsType, restrict: bool = True, const: bool = False):
super().__init__(base_type, const)
self._restrict = restrict
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsPointerType(PsBoolType(), const=True)
>>> t = PsPointerType(PsBoolType())
>>> t == PsPointerType(*t.__args__())
True
"""
return (self._base_type, self._const, self._restrict)
return (self._base_type, self._restrict)
@property
def restrict(self) -> bool:
......@@ -190,11 +108,11 @@ class PsArrayType(PsDereferencableType):
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsArrayType(PsBoolType(), 13, const=True)
>>> t = PsArrayType(PsBoolType(), 13)
>>> t == PsArrayType(*t.__args__())
True
"""
return (self._base_type, self._length, self._const)
return (self._base_type, self._length)
@property
def length(self) -> int | None:
......@@ -246,7 +164,7 @@ class PsStructType(PsType):
>>> t == PsStructType(*t.__args__())
True
"""
return (self._members, self._name, self._const)
return (self._members, self._name)
@property
def members(self) -> tuple[PsStructType.Member, ...]:
......@@ -394,11 +312,11 @@ class PsVectorType(PsNumericType):
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsVectorType(PsBoolType(), 8, True)
>>> t = PsVectorType(PsBoolType(), 8)
>>> t == PsVectorType(*t.__args__())
True
"""
return (self._scalar_type, self._vector_entries, self._const)
return (self._scalar_type, self._vector_entries)
@property
def scalar_type(self) -> PsScalarType:
......@@ -474,11 +392,11 @@ class PsBoolType(PsScalarType):
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsBoolType(True)
>>> t = PsBoolType()
>>> t == PsBoolType(*t.__args__())
True
"""
return (self._const,)
return ()
@property
def width(self) -> int:
......@@ -494,7 +412,9 @@ class PsBoolType(PsScalarType):
def create_literal(self, value: Any) -> str:
if not isinstance(value, self.NUMPY_TYPE):
raise PsTypeError(f"Given value {value} is not of required type {self.NUMPY_TYPE}")
raise PsTypeError(
f"Given value {value} is not of required type {self.NUMPY_TYPE}"
)
if value == np.True_:
return "true"
......@@ -513,7 +433,7 @@ class PsBoolType(PsScalarType):
def c_string(self) -> str:
return "bool"
class PsIntegerType(PsScalarType, ABC):
"""Signed and unsigned integer types.
......@@ -561,18 +481,20 @@ class PsIntegerType(PsScalarType, ABC):
unsigned_suffix = "" if self.signed else "u"
# TODO: cast literal to correct type?
return str(value) + unsigned_suffix
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, (int, np.integer)):
iinfo = np.iinfo(np_type) # type: ignore
if value < iinfo.min or value > iinfo.max:
raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
raise PsTypeError(
f"Could not interpret {value} as {self}: Value is out of bounds."
)
return np_type(value)
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def c_string(self) -> str:
prefix = "" if self._signed else "u"
return f"{self._const_string()}{prefix}int{self._width}_t"
......@@ -599,11 +521,11 @@ class PsSignedIntegerType(PsIntegerType):
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsSignedIntegerType(32, True)
>>> t = PsSignedIntegerType(32)
>>> t == PsSignedIntegerType(*t.__args__())
True
"""
return (self._width, self._const)
return (self._width,)
@final
......@@ -624,11 +546,11 @@ class PsUnsignedIntegerType(PsIntegerType):
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsUnsignedIntegerType(32, True)
>>> t = PsUnsignedIntegerType(32)
>>> t == PsUnsignedIntegerType(*t.__args__())
True
"""
return (self._width, self._const)
return (self._width,)
@final
......@@ -656,11 +578,11 @@ class PsIeeeFloatType(PsScalarType):
def __args__(self) -> tuple[Any, ...]:
"""
>>> t = PsIeeeFloatType(32, True)
>>> t = PsIeeeFloatType(32)
>>> t == PsIeeeFloatType(*t.__args__())
True
"""
return (self._width, self._const)
return (self._width,)
@property
def width(self) -> int:
......@@ -702,7 +624,9 @@ class PsIeeeFloatType(PsScalarType):
if isinstance(value, (int, float, np.floating)):
finfo = np.finfo(np_type) # type: ignore
if value < finfo.min or value > finfo.max:
raise PsTypeError(f"Could not interpret {value} as {self}: Value is out of bounds.")
raise PsTypeError(
f"Could not interpret {value} as {self}: Value is out of bounds."
)
return np_type(value)
raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
......@@ -720,26 +644,3 @@ class PsIeeeFloatType(PsScalarType):
def __repr__(self) -> str:
return f"PsIeeeFloatType( width={self.width}, const={self.const} )"
T = TypeVar("T", bound=PsType)
def constify(t: T) -> T:
"""Adds the const qualifier to a given type."""
if not t.const:
t_copy = copy(t)
t_copy._const = True
return t_copy
else:
return t
def deconstify(t: T) -> T:
"""Removes the const qualifier from a given type."""
if t.const:
t_copy = copy(t)
t_copy._const = False
return t_copy
else:
return t
......@@ -289,3 +289,6 @@ def test_typify_constant_clones():
assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
test_lhs_constness()
\ No newline at end of file
......@@ -19,13 +19,13 @@ def test_widths(Type):
def test_parsing_positive():
assert create_type("const uint32_t * restrict") == Ptr(
assert create_type("const uint32_t * restrict") is Ptr(
UInt(32, const=True), restrict=True
)
assert create_type("float * * const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
assert create_type("float * * restrict const") == Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
assert create_type("uint16 * const") == Ptr(UInt(16), const=True, restrict=False)
assert create_type("uint64 const * const") == Ptr(UInt(64, const=True), const=True, restrict=False)
assert create_type("float * * const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=False)
assert create_type("float * * restrict const") is Ptr(Ptr(Fp(32), restrict=False), const=True, restrict=True)
assert create_type("uint16 * const") is Ptr(UInt(16), const=True, restrict=False)
assert create_type("uint64 const * const") is Ptr(UInt(64, const=True), const=True, restrict=False)
def test_parsing_negative():
......@@ -45,14 +45,14 @@ def test_parsing_negative():
def test_numpy():
import numpy as np
assert create_type(np.single) == create_type(np.float32) == PsIeeeFloatType(32)
assert create_type(np.single) is create_type(np.float32) is PsIeeeFloatType(32)
assert (
create_type(float)
== create_type(np.double)
== create_type(np.float64)
== PsIeeeFloatType(64)
is create_type(np.double)
is create_type(np.float64)
is PsIeeeFloatType(64)
)
assert create_type(int) == create_type(np.int64) == PsSignedIntegerType(64)
assert create_type(int) is create_type(np.int64) is PsSignedIntegerType(64)
@pytest.mark.parametrize(
......@@ -102,10 +102,21 @@ def test_numpy_translation(numpy_type):
def test_constify():
t = PsCustomType("std::shared_ptr< Custom >")
assert deconstify(t) == t
assert deconstify(constify(t)) == t
assert deconstify(t) is t
assert deconstify(constify(t)) is t
s = PsCustomType("Field", const=True)
assert constify(s) == s
assert constify(s) is s
i32 = create_type(np.int32)
i32_2 = PsSignedIntegerType(32)
assert i32 is i32_2
assert constify(i32) is constify(i32_2)
i32_const = PsSignedIntegerType(32, const=True)
assert i32_const is not i32
assert i32_const is constify(i32)
def test_struct_types():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment