Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
d6eb671a
Commit
d6eb671a
authored
1 year ago
by
Frederik Hennig
Browse files
Options
Downloads
Patches
Plain Diff
refactor types hashing and equality
parent
191cc207
No related branches found
No related tags found
2 merge requests
!379
Type System Refactor
,
!374
Uniqueness of Data Type Instances
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/pystencils/types/basic_types.py
+97
-96
97 additions, 96 deletions
src/pystencils/types/basic_types.py
tests/nbackend/types/test_types.py
+4
-3
4 additions, 3 deletions
tests/nbackend/types/test_types.py
with
101 additions
and
99 deletions
src/pystencils/types/basic_types.py
+
97
−
96
View file @
d6eb671a
from
__future__
import
annotations
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
final
,
TypeVar
,
Any
,
Sequence
from
typing
import
final
,
TypeVar
,
Any
,
Sequence
,
cast
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
copy
import
copy
from
copy
import
copy
...
@@ -46,11 +46,22 @@ class PsType(ABC):
...
@@ -46,11 +46,22 @@ class PsType(ABC):
return
None
return
None
# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------
# Internal
virtual
operations
# Internal operations
# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------
def
_base_equal
(
self
,
other
:
PsType
)
->
bool
:
@abstractmethod
return
type
(
self
)
is
type
(
other
)
and
self
.
_const
==
other
.
_const
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
Arguments to this type.
The tuple returned by this method is used to serialize, deserialize, and check equality of types.
For each instantiable subclass ``MyType`` of ``PsType``, the following must hold:
```
t = MyType(< arguments >)
assert MyType(*t.__args__()) == t
```
"""
pass
def
_const_string
(
self
)
->
str
:
def
_const_string
(
self
)
->
str
:
return
"
const
"
if
self
.
_const
else
""
return
"
const
"
if
self
.
_const
else
""
...
@@ -63,16 +74,21 @@ class PsType(ABC):
...
@@ -63,16 +74,21 @@ class PsType(ABC):
# Dunder Methods
# Dunder Methods
# -------------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------
@abstractmethod
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
pass
if
self
is
other
:
return
True
if
type
(
self
)
is
not
type
(
other
):
return
False
other
=
cast
(
PsType
,
other
)
return
self
.
__args__
()
==
other
.
__args__
()
def
__str__
(
self
)
->
str
:
def
__str__
(
self
)
->
str
:
return
self
.
c_string
()
return
self
.
c_string
()
@abstractmethod
def
__hash__
(
self
)
->
int
:
def
__hash__
(
self
)
->
int
:
pass
return
hash
((
type
(
self
),
self
.
__args__
()))
class
PsCustomType
(
PsType
):
class
PsCustomType
(
PsType
):
...
@@ -92,13 +108,13 @@ class PsCustomType(PsType):
...
@@ -92,13 +108,13 @@ class PsCustomType(PsType):
def
name
(
self
)
->
str
:
def
name
(
self
)
->
str
:
return
self
.
_name
return
self
.
_name
def
__
eq
__
(
self
,
other
:
object
)
->
bool
:
def
__
args
__
(
self
)
->
tuple
[
Any
,
...]
:
if
not
isinstance
(
other
,
PsCustomType
):
"""
return
False
>>>
t
=
PsCustomType
(
"
std::vector< int >
"
)
return
self
.
_base_equal
(
other
)
and
self
.
_name
==
other
.
_name
>>>
t
==
PsCustomType
(
*
t
.
__args__
())
True
def
__hash__
(
self
)
->
int
:
"""
return
hash
((
"
PsCustomType
"
,
self
.
_name
,
self
.
_const
)
)
return
(
self
.
_name
,
)
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
return
f
"
{
self
.
_const_string
()
}
{
self
.
_name
}
"
return
f
"
{
self
.
_const_string
()
}
{
self
.
_name
}
"
...
@@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType):
...
@@ -142,18 +158,18 @@ class PsPointerType(PsDereferencableType):
super
().
__init__
(
base_type
,
const
)
super
().
__init__
(
base_type
,
const
)
self
.
_restrict
=
restrict
self
.
_restrict
=
restrict
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsPointerType
(
PsBoolType
(),
const
=
True
)
>>>
t
==
PsPointerType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_base_type
,
self
.
_const
,
self
.
_restrict
)
@property
@property
def
restrict
(
self
)
->
bool
:
def
restrict
(
self
)
->
bool
:
return
self
.
_restrict
return
self
.
_restrict
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsPointerType
):
return
False
return
self
.
_base_equal
(
other
)
and
self
.
_base_type
==
other
.
_base_type
def
__hash__
(
self
)
->
int
:
return
hash
((
"
PsPointerType
"
,
self
.
_base_type
,
self
.
_restrict
,
self
.
_const
))
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
base_str
=
self
.
_base_type
.
c_string
()
base_str
=
self
.
_base_type
.
c_string
()
restrict_str
=
"
RESTRICT
"
if
self
.
_restrict
else
""
restrict_str
=
"
RESTRICT
"
if
self
.
_restrict
else
""
...
@@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType):
...
@@ -172,6 +188,14 @@ class PsArrayType(PsDereferencableType):
self
.
_length
=
length
self
.
_length
=
length
super
().
__init__
(
base_type
,
const
)
super
().
__init__
(
base_type
,
const
)
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsArrayType
(
PsBoolType
(),
13
,
const
=
True
)
>>>
t
==
PsArrayType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_base_type
,
self
.
_length
,
self
.
_const
)
@property
@property
def
length
(
self
)
->
int
|
None
:
def
length
(
self
)
->
int
|
None
:
return
self
.
_length
return
self
.
_length
...
@@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType):
...
@@ -179,19 +203,6 @@ class PsArrayType(PsDereferencableType):
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
return
f
"
{
self
.
_base_type
.
c_string
()
}
[
{
str
(
self
.
_length
)
if
self
.
_length
is
not
None
else
''
}
]
"
return
f
"
{
self
.
_base_type
.
c_string
()
}
[
{
str
(
self
.
_length
)
if
self
.
_length
is
not
None
else
''
}
]
"
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsArrayType
):
return
False
return
(
self
.
_base_equal
(
other
)
and
self
.
_base_type
==
other
.
_base_type
and
self
.
_length
==
other
.
_length
)
def
__hash__
(
self
)
->
int
:
return
hash
((
"
PsArrayType
"
,
self
.
_base_type
,
self
.
_length
,
self
.
_const
))
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
f
"
PsArrayType(element_type=
{
repr
(
self
.
_base_type
)
}
, size=
{
self
.
_length
}
, const=
{
self
.
_const
}
)
"
return
f
"
PsArrayType(element_type=
{
repr
(
self
.
_base_type
)
}
, size=
{
self
.
_length
}
, const=
{
self
.
_const
}
)
"
...
@@ -229,6 +240,14 @@ class PsStructType(PsType):
...
@@ -229,6 +240,14 @@ class PsStructType(PsType):
raise
ValueError
(
f
"
Duplicate struct member name:
{
member
.
name
}
"
)
raise
ValueError
(
f
"
Duplicate struct member name:
{
member
.
name
}
"
)
names
.
add
(
member
.
name
)
names
.
add
(
member
.
name
)
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsStructType
([(
"
idx
"
,
PsSignedIntegerType
(
32
)),
(
"
val
"
,
PsBoolType
())],
"
sname
"
)
>>>
t
==
PsStructType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_members
,
self
.
_name
,
self
.
_const
)
@property
@property
def
members
(
self
)
->
tuple
[
PsStructType
.
Member
,
...]:
def
members
(
self
)
->
tuple
[
PsStructType
.
Member
,
...]:
return
self
.
_members
return
self
.
_members
...
@@ -276,19 +295,6 @@ class PsStructType(PsType):
...
@@ -276,19 +295,6 @@ class PsStructType(PsType):
else
:
else
:
return
self
.
_name
return
self
.
_name
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsStructType
):
return
False
return
(
self
.
_base_equal
(
other
)
and
self
.
_name
==
other
.
_name
and
self
.
_members
==
other
.
_members
)
def
__hash__
(
self
)
->
int
:
return
hash
((
"
PsStructTupe
"
,
self
.
_name
,
self
.
_members
,
self
.
_const
))
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
members
=
"
,
"
.
join
(
f
"
{
m
.
dtype
}
{
m
.
name
}
"
for
m
in
self
.
_members
)
members
=
"
,
"
.
join
(
f
"
{
m
.
dtype
}
{
m
.
name
}
"
for
m
in
self
.
_members
)
name
=
"
<anonymous>
"
if
self
.
anonymous
else
f
"
name=
{
self
.
_name
}
"
name
=
"
<anonymous>
"
if
self
.
anonymous
else
f
"
name=
{
self
.
_name
}
"
...
@@ -386,6 +392,14 @@ class PsVectorType(PsNumericType):
...
@@ -386,6 +392,14 @@ class PsVectorType(PsNumericType):
self
.
_vector_entries
=
vector_entries
self
.
_vector_entries
=
vector_entries
self
.
_scalar_type
=
constify
(
scalar_type
)
if
const
else
deconstify
(
scalar_type
)
self
.
_scalar_type
=
constify
(
scalar_type
)
if
const
else
deconstify
(
scalar_type
)
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsVectorType
(
PsBoolType
(),
8
,
True
)
>>>
t
==
PsVectorType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_scalar_type
,
self
.
_vector_entries
,
self
.
_const
)
@property
@property
def
scalar_type
(
self
)
->
PsScalarType
:
def
scalar_type
(
self
)
->
PsScalarType
:
return
self
.
_scalar_type
return
self
.
_scalar_type
...
@@ -437,21 +451,6 @@ class PsVectorType(PsNumericType):
...
@@ -437,21 +451,6 @@ class PsVectorType(PsNumericType):
[
element
]
*
self
.
_vector_entries
,
dtype
=
self
.
scalar_type
.
numpy_dtype
[
element
]
*
self
.
_vector_entries
,
dtype
=
self
.
scalar_type
.
numpy_dtype
)
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsVectorType
):
return
False
return
(
self
.
_base_equal
(
other
)
and
self
.
_scalar_type
==
other
.
_scalar_type
and
self
.
_vector_entries
==
other
.
_vector_entries
)
def
__hash__
(
self
)
->
int
:
return
hash
(
(
"
PsVectorType
"
,
self
.
_scalar_type
,
self
.
_vector_entries
,
self
.
_const
)
)
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
raise
PsTypeError
(
"
Cannot retrieve C type string for generic vector types.
"
)
raise
PsTypeError
(
"
Cannot retrieve C type string for generic vector types.
"
)
...
@@ -473,6 +472,14 @@ class PsBoolType(PsScalarType):
...
@@ -473,6 +472,14 @@ class PsBoolType(PsScalarType):
def
__init__
(
self
,
const
:
bool
=
False
):
def
__init__
(
self
,
const
:
bool
=
False
):
super
().
__init__
(
const
)
super
().
__init__
(
const
)
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsBoolType
(
True
)
>>>
t
==
PsBoolType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_const
,)
@property
@property
def
width
(
self
)
->
int
:
def
width
(
self
)
->
int
:
return
8
return
8
...
@@ -506,16 +513,7 @@ class PsBoolType(PsScalarType):
...
@@ -506,16 +513,7 @@ class PsBoolType(PsScalarType):
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
return
"
bool
"
return
"
bool
"
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsBoolType
):
return
False
return
self
.
_base_equal
(
other
)
def
__hash__
(
self
)
->
int
:
return
hash
((
"
PsBoolType
"
,
self
.
_const
))
class
PsIntegerType
(
PsScalarType
,
ABC
):
class
PsIntegerType
(
PsScalarType
,
ABC
):
"""
Signed and unsigned integer types.
"""
Signed and unsigned integer types.
...
@@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC):
...
@@ -574,20 +572,7 @@ class PsIntegerType(PsScalarType, ABC):
return
np_type
(
value
)
return
np_type
(
value
)
raise
PsTypeError
(
f
"
Could not interpret
{
value
}
as
{
repr
(
self
)
}
"
)
raise
PsTypeError
(
f
"
Could not interpret
{
value
}
as
{
repr
(
self
)
}
"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsIntegerType
):
return
False
return
(
self
.
_base_equal
(
other
)
and
self
.
_width
==
other
.
_width
and
self
.
_signed
==
other
.
_signed
)
def
__hash__
(
self
)
->
int
:
return
hash
((
"
PsIntegerType
"
,
self
.
_width
,
self
.
_signed
,
self
.
_const
))
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
prefix
=
""
if
self
.
_signed
else
"
u
"
prefix
=
""
if
self
.
_signed
else
"
u
"
return
f
"
{
self
.
_const_string
()
}{
prefix
}
int
{
self
.
_width
}
_t
"
return
f
"
{
self
.
_const_string
()
}{
prefix
}
int
{
self
.
_width
}
_t
"
...
@@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType):
...
@@ -612,6 +597,14 @@ class PsSignedIntegerType(PsIntegerType):
def
__init__
(
self
,
width
:
int
,
const
:
bool
=
False
):
def
__init__
(
self
,
width
:
int
,
const
:
bool
=
False
):
super
().
__init__
(
width
,
True
,
const
)
super
().
__init__
(
width
,
True
,
const
)
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsSignedIntegerType
(
32
,
True
)
>>>
t
==
PsSignedIntegerType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_width
,
self
.
_const
)
@final
@final
class
PsUnsignedIntegerType
(
PsIntegerType
):
class
PsUnsignedIntegerType
(
PsIntegerType
):
...
@@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType):
...
@@ -629,6 +622,14 @@ class PsUnsignedIntegerType(PsIntegerType):
def
__init__
(
self
,
width
:
int
,
const
:
bool
=
False
):
def
__init__
(
self
,
width
:
int
,
const
:
bool
=
False
):
super
().
__init__
(
width
,
False
,
const
)
super
().
__init__
(
width
,
False
,
const
)
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsUnsignedIntegerType
(
32
,
True
)
>>>
t
==
PsUnsignedIntegerType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_width
,
self
.
_const
)
@final
@final
class
PsIeeeFloatType
(
PsScalarType
):
class
PsIeeeFloatType
(
PsScalarType
):
...
@@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType):
...
@@ -653,6 +654,14 @@ class PsIeeeFloatType(PsScalarType):
super
().
__init__
(
const
)
super
().
__init__
(
const
)
self
.
_width
=
width
self
.
_width
=
width
def
__args__
(
self
)
->
tuple
[
Any
,
...]:
"""
>>>
t
=
PsIeeeFloatType
(
32
,
True
)
>>>
t
==
PsIeeeFloatType
(
*
t
.
__args__
())
True
"""
return
(
self
.
_width
,
self
.
_const
)
@property
@property
def
width
(
self
)
->
int
:
def
width
(
self
)
->
int
:
return
self
.
_width
return
self
.
_width
...
@@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType):
...
@@ -698,14 +707,6 @@ class PsIeeeFloatType(PsScalarType):
raise
PsTypeError
(
f
"
Could not interpret
{
value
}
as
{
repr
(
self
)
}
"
)
raise
PsTypeError
(
f
"
Could not interpret
{
value
}
as
{
repr
(
self
)
}
"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
PsIeeeFloatType
):
return
False
return
self
.
_base_equal
(
other
)
and
self
.
_width
==
other
.
_width
def
__hash__
(
self
)
->
int
:
return
hash
((
"
PsIeeeFloatType
"
,
self
.
_width
,
self
.
_const
))
def
c_string
(
self
)
->
str
:
def
c_string
(
self
)
->
str
:
match
self
.
_width
:
match
self
.
_width
:
case
16
:
case
16
:
...
...
This diff is collapsed.
Click to expand it.
tests/nbackend/types/test_types.py
+
4
−
3
View file @
d6eb671a
...
@@ -22,9 +22,10 @@ def test_parsing_positive():
...
@@ -22,9 +22,10 @@ def test_parsing_positive():
assert
create_type
(
"
const uint32_t * restrict
"
)
==
Ptr
(
assert
create_type
(
"
const uint32_t * restrict
"
)
==
Ptr
(
UInt
(
32
,
const
=
True
),
restrict
=
True
UInt
(
32
,
const
=
True
),
restrict
=
True
)
)
assert
create_type
(
"
float * * const
"
)
==
Ptr
(
Ptr
(
Fp
(
32
)),
const
=
True
)
assert
create_type
(
"
float * * const
"
)
==
Ptr
(
Ptr
(
Fp
(
32
),
restrict
=
False
),
const
=
True
,
restrict
=
False
)
assert
create_type
(
"
uint16 * const
"
)
==
Ptr
(
UInt
(
16
),
const
=
True
)
assert
create_type
(
"
float * * restrict const
"
)
==
Ptr
(
Ptr
(
Fp
(
32
),
restrict
=
False
),
const
=
True
,
restrict
=
True
)
assert
create_type
(
"
uint64 const * const
"
)
==
Ptr
(
UInt
(
64
,
const
=
True
),
const
=
True
)
assert
create_type
(
"
uint16 * const
"
)
==
Ptr
(
UInt
(
16
),
const
=
True
,
restrict
=
False
)
assert
create_type
(
"
uint64 const * const
"
)
==
Ptr
(
UInt
(
64
,
const
=
True
),
const
=
True
,
restrict
=
False
)
def
test_parsing_negative
():
def
test_parsing_negative
():
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment