Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
c83b06a5
Commit
c83b06a5
authored
10 months ago
by
Frederik Hennig
Browse files
Options
Downloads
Patches
Plain Diff
Introduce DereferencableTo type hint and implement inference hooks for (nested) arrays
parent
fbcf566f
No related branches found
No related tags found
1 merge request
!418
Nesting of Type Contexts, Type Hints, and Improved Array Typing
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/pystencils/backend/kernelcreation/typification.py
+40
-7
40 additions, 7 deletions
src/pystencils/backend/kernelcreation/typification.py
tests/nbackend/kernelcreation/test_typification.py
+55
-1
55 additions, 1 deletion
tests/nbackend/kernelcreation/test_typification.py
with
95 additions
and
8 deletions
src/pystencils/backend/kernelcreation/typification.py
+
40
−
7
View file @
c83b06a5
...
...
@@ -73,9 +73,16 @@ class TypeHint:
@dataclass
(
frozen
=
True
)
class
ToDefault
(
TypeHint
):
"""
Indicates to fall back to a default type.
"""
default_dtype
:
PsType
@dataclass
(
frozen
=
True
)
class
DereferencableTo
(
TypeHint
):
"""
Indicates that the type has to be dereferencable to the given base type.
"""
base_type
:
PsType
|
TypeHint
InferenceHook
=
Callable
[[
PsType
|
TypeHint
],
PsType
|
None
]
"""
An inference hook is a callback that is attached to a type context,
to be called once type information about that context is known.
...
...
@@ -83,8 +90,8 @@ The inference hook will then try to use that information to resolve nested type
and potentially the context it is attached to as well.
When called with a `PsType`, that type is the target type of the context to which the hook is attached.
The hook has to use this type to resolve any nested type contexts and return
`None`;
i
f it cannot resolve its nested contexts, it must raise a TypificationError.
The hook has to use this type to resolve any nested type contexts and
must
return
either `None` or the same data type.
I
f it cannot resolve its nested contexts, it must raise a TypificationError.
When called with a `TypeHint`, the inference hook has to attempt to resolve its nested contexts.
If it succeeds, it has to return the data type that must be applied to the outer context.
...
...
@@ -200,8 +207,9 @@ class TypeContext:
self
.
apply_dtype
(
default_dtype
)
case
_
:
for
i
,
hook
in
enumerate
(
self
.
_inference_hooks
):
self
.
_target_type
=
hook
(
hint
)
if
self
.
_target_type
is
not
None
:
target_type
=
hook
(
hint
)
if
target_type
is
not
None
:
self
.
_target_type
=
self
.
_fix_constness
(
target_type
)
# That hook was successful; remove it so it is not called a second time
del
self
.
_inference_hooks
[
i
]
...
...
@@ -523,12 +531,25 @@ class Typifier:
arr_tc
=
TypeContext
()
self
.
visit_expr
(
arr
,
arr_tc
)
if
not
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
):
if
arr_tc
.
target_type
is
None
:
def
subscript_hook
(
type_or_hint
:
PsType
|
TypeHint
)
->
PsType
|
None
:
# Whatever type the enclosing context is to be,
# the type of `arr` has to be dereferencable to it
arr_tc
.
apply_hint
(
DereferencableTo
(
type_or_hint
))
# Now we know the type of the array
# -> pass its dereferenced version to the outer context
assert
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
)
return
arr_tc
.
target_type
.
base_type
tc
.
hook
(
subscript_hook
)
elif
not
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
):
raise
TypificationError
(
"
Type of subscript base is not subscriptable.
"
)
tc
.
apply_dtype
(
arr_tc
.
target_type
.
base_type
,
expr
)
else
:
tc
.
apply_dtype
(
arr_tc
.
target_type
.
base_type
,
expr
)
index_tc
=
TypeContext
()
self
.
visit_expr
(
idx
,
index_tc
)
...
...
@@ -683,13 +704,25 @@ class Typifier:
)
items_tc
.
apply_dtype
(
deconstify
(
elem_type
))
tc
.
infer_dtype
(
expr
)
case
DereferencableTo
(
elem_type_or_hint
):
if
isinstance
(
elem_type_or_hint
,
PsType
):
items_tc
.
apply_dtype
(
deconstify
(
elem_type_or_hint
))
else
:
items_tc
.
apply_hint
(
elem_type_or_hint
)
tc
.
infer_dtype
(
expr
)
return
PsArrayType
(
deconstify
(
items_tc
.
get_target_type
()),
len
(
items
))
case
ToDefault
():
items_tc
.
apply_hint
(
type_or_hint
)
tc
.
infer_dtype
(
expr
)
return
PsArrayType
(
deconstify
(
items_tc
.
get_target_type
()),
len
(
items
))
case
TypeHint
():
# Can't deal with any other type hints
return
None
case
other_type
:
raise
TypificationError
(
f
"
Cannot apply type
{
other_type
}
to array initializer
{
expr
}
.
"
...
...
This diff is collapsed.
Click to expand it.
tests/nbackend/kernelcreation/test_typification.py
+
55
−
1
View file @
c83b06a5
...
...
@@ -29,6 +29,7 @@ from pystencils.backend.ast.expressions import (
PsLt
,
PsCall
,
PsTernary
,
PsArrayInitList
,
)
from
pystencils.backend.constants
import
PsConstant
from
pystencils.backend.functions
import
CFunction
...
...
@@ -220,7 +221,7 @@ def test_constant_decls():
typify
=
Typifier
(
ctx
)
x
,
y
=
sp
.
symbols
(
"
x, y
"
)
decl
=
freeze
(
Assignment
(
x
,
3.0
))
decl
=
typify
(
decl
)
assert
ctx
.
get_symbol
(
"
x
"
).
dtype
==
Fp
(
16
)
...
...
@@ -250,6 +251,59 @@ def test_constant_array_decls():
assert
ctx
.
get_symbol
(
"
y
"
).
dtype
==
Arr
(
Arr
(
Fp
(
16
),
4
),
2
)
def
test_inline_arrays_1d
():
ctx
=
KernelCreationContext
(
default_dtype
=
Fp
(
16
))
freeze
=
FreezeExpressions
(
ctx
)
typify
=
Typifier
(
ctx
)
x
,
y
=
sp
.
symbols
(
"
x, y
"
)
idx
=
TypedSymbol
(
"
idx
"
,
Int
(
32
))
arr
:
PsArrayInitList
=
cast
(
PsArrayInitList
,
freeze
(
sp
.
Tuple
(
1
,
2
,
3
,
4
)))
decl
=
PsDeclaration
(
freeze
(
x
),
freeze
(
y
)
+
PsSubscript
(
arr
,
freeze
(
idx
)))
# The array elements should learn their type from the context, which gets it from `y`
decl
=
typify
(
decl
)
assert
decl
.
lhs
.
dtype
==
Fp
(
16
,
const
=
True
)
assert
decl
.
rhs
.
dtype
==
Fp
(
16
,
const
=
True
)
assert
arr
.
dtype
==
Arr
(
Fp
(
16
),
4
,
const
=
True
)
for
item
in
arr
.
items
:
assert
item
.
dtype
==
Fp
(
16
,
const
=
True
)
def
test_inline_arrays_3d
():
ctx
=
KernelCreationContext
(
default_dtype
=
Fp
(
16
))
freeze
=
FreezeExpressions
(
ctx
)
typify
=
Typifier
(
ctx
)
x
,
y
=
sp
.
symbols
(
"
x, y
"
)
idx
=
[
TypedSymbol
(
f
"
idx_
{
i
}
"
,
Int
(
32
))
for
i
in
range
(
3
)]
arr
=
freeze
(
sp
.
Tuple
(((
1
,
2
),
(
3
,
4
),
(
5
,
6
)),
((
5
,
6
),
(
7
,
8
),
(
9
,
10
))))
decl
=
PsDeclaration
(
freeze
(
x
),
freeze
(
y
)
+
PsSubscript
(
PsSubscript
(
PsSubscript
(
arr
,
freeze
(
idx
[
0
])),
freeze
(
idx
[
1
])),
freeze
(
idx
[
2
]),
),
)
# The array elements should learn their type from the context, which gets it from `y`
decl
=
typify
(
decl
)
assert
decl
.
lhs
.
dtype
==
Fp
(
16
,
const
=
True
)
assert
decl
.
rhs
.
dtype
==
Fp
(
16
,
const
=
True
)
assert
arr
.
dtype
==
Arr
(
Arr
(
Arr
(
Fp
(
16
),
2
),
3
),
2
,
const
=
True
)
for
item
in
arr
.
items
:
assert
item
.
dtype
==
Arr
(
Arr
(
Fp
(
16
),
2
),
3
,
const
=
True
)
for
iitem
in
item
.
items
:
assert
iitem
.
dtype
==
Arr
(
Fp
(
16
),
2
,
const
=
True
)
for
iiitem
in
iitem
.
items
:
assert
iiitem
.
dtype
==
Fp
(
16
,
const
=
True
)
def
test_lhs_inference
():
ctx
=
KernelCreationContext
(
default_dtype
=
create_numeric_type
(
np
.
float64
))
freeze
=
FreezeExpressions
(
ctx
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment