Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
f907b330
Commit
f907b330
authored
8 months ago
by
Frederik Hennig
Browse files
Options
Downloads
Patches
Plain Diff
fix minor type conflicts with predefined array types
parent
582784a2
No related branches found
No related tags found
1 merge request
!418
Nesting of Type Contexts, Type Hints, and Improved Array Typing
Pipeline
#69406
passed
8 months ago
Stage: Code Quality
Stage: Unit Tests
Stage: legacy_test
Stage: docs
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/pystencils/backend/kernelcreation/typification.py
+46
-21
46 additions, 21 deletions
src/pystencils/backend/kernelcreation/typification.py
tests/nbackend/kernelcreation/test_typification.py
+29
-0
29 additions, 0 deletions
tests/nbackend/kernelcreation/test_typification.py
with
75 additions
and
21 deletions
src/pystencils/backend/kernelcreation/typification.py
+
46
−
21
View file @
f907b330
...
@@ -73,12 +73,14 @@ class TypeHint:
...
@@ -73,12 +73,14 @@ class TypeHint:
@dataclass
(
frozen
=
True
)
@dataclass
(
frozen
=
True
)
class
ToDefault
(
TypeHint
):
class
ToDefault
(
TypeHint
):
"""
Indicates to fall back to a default type.
"""
"""
Indicates to fall back to a default type.
"""
default_dtype
:
PsType
default_dtype
:
PsType
@dataclass
(
frozen
=
True
)
@dataclass
(
frozen
=
True
)
class
DereferencableTo
(
TypeHint
):
class
DereferencableTo
(
TypeHint
):
"""
Indicates that the type has to be dereferencable to the given base type.
"""
"""
Indicates that the type has to be dereferencable to the given base type.
"""
base_type
:
PsType
|
TypeHint
base_type
:
PsType
|
TypeHint
...
@@ -111,7 +113,7 @@ class TypeContext:
...
@@ -111,7 +113,7 @@ class TypeContext:
- Additional restrictions may be added in the future.
- Additional restrictions may be added in the future.
**Target type**
**Target type**
Each typing context needs to be assigned its target type at some point.
Each typing context needs to be assigned its target type at some point.
The target type may be
The target type may be
...
@@ -120,8 +122,8 @@ class TypeContext:
...
@@ -120,8 +122,8 @@ class TypeContext:
- inferred from the type of the enclosing context via an inference hook, or as a last resort
- inferred from the type of the enclosing context via an inference hook, or as a last resort
- determined from a type hint applied to the enclosing context via an inference hook.
- determined from a type hint applied to the enclosing context via an inference hook.
**Expansion**
**Expansion**
Expression nodes are added to a type context using either `apply_dtype` or `infer_dtype`.
Expression nodes are added to a type context using either `apply_dtype` or `infer_dtype`.
In both cases, the context
'
s target type will be applied to the node,
In both cases, the context
'
s target type will be applied to the node,
unless it already has a conflicting type.
unless it already has a conflicting type.
...
@@ -160,7 +162,7 @@ class TypeContext:
...
@@ -160,7 +162,7 @@ class TypeContext:
@property
@property
def
target_type
(
self
)
->
PsType
|
None
:
def
target_type
(
self
)
->
PsType
|
None
:
return
self
.
_target_type
return
self
.
_target_type
def
get_target_type
(
self
)
->
PsType
:
def
get_target_type
(
self
)
->
PsType
:
assert
self
.
_target_type
is
not
None
assert
self
.
_target_type
is
not
None
return
self
.
_target_type
return
self
.
_target_type
...
@@ -209,7 +211,7 @@ class TypeContext:
...
@@ -209,7 +211,7 @@ class TypeContext:
def
apply_hint
(
self
,
hint
:
TypeHint
):
def
apply_hint
(
self
,
hint
:
TypeHint
):
"""
Attempt to resolve this type context from the given type hint.
"""
Attempt to resolve this type context from the given type hint.
If the hint is not sufficient to resolve the context, a `TypificationError` is raised.
If the hint is not sufficient to resolve the context, a `TypificationError` is raised.
"""
"""
assert
self
.
_target_type
is
None
assert
self
.
_target_type
is
None
...
@@ -230,7 +232,9 @@ class TypeContext:
...
@@ -230,7 +232,9 @@ class TypeContext:
# Now we have the target type
# Now we have the target type
self
.
_propagate_target_type
()
self
.
_propagate_target_type
()
else
:
else
:
raise
TypificationError
(
f
"
Unable to infer context type from hint
{
hint
}
"
)
raise
TypificationError
(
f
"
Unable to infer context type from hint
{
hint
}
"
)
def
infer_dtype
(
self
,
expr
:
PsExpression
):
def
infer_dtype
(
self
,
expr
:
PsExpression
):
"""
Infer the data type for the given expression.
"""
Infer the data type for the given expression.
...
@@ -250,7 +254,7 @@ class TypeContext:
...
@@ -250,7 +254,7 @@ class TypeContext:
def
_propagate_target_type
(
self
):
def
_propagate_target_type
(
self
):
"""
Propagates the target type to any registered inference hooks and applies it to any deferred nodes.
"""
Propagates the target type to any registered inference hooks and applies it to any deferred nodes.
Call after the target type of this context has been set.
Call after the target type of this context has been set.
"""
"""
assert
self
.
_target_type
is
not
None
assert
self
.
_target_type
is
not
None
...
@@ -551,7 +555,10 @@ class Typifier:
...
@@ -551,7 +555,10 @@ class Typifier:
self
.
visit_expr
(
arr
,
arr_tc
)
self
.
visit_expr
(
arr
,
arr_tc
)
if
arr_tc
.
target_type
is
None
:
if
arr_tc
.
target_type
is
None
:
def
subscript_hook
(
type_or_hint
:
PsType
|
TypeHint
)
->
PsType
|
None
:
def
subscript_hook
(
type_or_hint
:
PsType
|
TypeHint
,
)
->
PsType
|
None
:
# Whatever type the enclosing context is to be,
# Whatever type the enclosing context is to be,
# the type of `arr` has to be dereferencable to it
# the type of `arr` has to be dereferencable to it
arr_tc
.
apply_hint
(
DereferencableTo
(
type_or_hint
))
arr_tc
.
apply_hint
(
DereferencableTo
(
type_or_hint
))
...
@@ -560,7 +567,7 @@ class Typifier:
...
@@ -560,7 +567,7 @@ class Typifier:
# -> pass its dereferenced version to the outer context
# -> pass its dereferenced version to the outer context
assert
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
)
assert
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
)
return
arr_tc
.
target_type
.
base_type
return
arr_tc
.
target_type
.
base_type
tc
.
hook
(
subscript_hook
)
tc
.
hook
(
subscript_hook
)
elif
not
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
):
elif
not
isinstance
(
arr_tc
.
target_type
,
PsDereferencableType
):
...
@@ -667,7 +674,7 @@ class Typifier:
...
@@ -667,7 +674,7 @@ class Typifier:
if
args_tc
.
target_type
is
None
:
if
args_tc
.
target_type
is
None
:
args_tc
.
apply_hint
(
ToDefault
(
self
.
_ctx
.
default_dtype
))
args_tc
.
apply_hint
(
ToDefault
(
self
.
_ctx
.
default_dtype
))
if
not
isinstance
(
args_tc
.
target_type
,
PsNumericType
):
if
not
isinstance
(
args_tc
.
target_type
,
PsNumericType
):
raise
TypificationError
(
raise
TypificationError
(
f
"
Invalid type in arguments to relation
\n
"
f
"
Invalid type in arguments to relation
\n
"
...
@@ -707,6 +714,23 @@ class Typifier:
...
@@ -707,6 +714,23 @@ class Typifier:
case
PsArrayInitList
(
items
):
case
PsArrayInitList
(
items
):
items_tc
=
TypeContext
()
items_tc
=
TypeContext
()
def
propagate_elem_type
(
elem_type
:
PsType
,
length
:
int
|
None
):
if
length
is
not
None
and
length
!=
len
(
items
):
raise
TypificationError
(
"
Array size mismatch: Cannot typify initializer list with
"
f
"
{
len
(
items
)
}
items as
{
tc
.
target_type
}
"
)
items_tc
.
apply_dtype
(
deconstify
(
elem_type
))
# If the enclosing context already prescribes an array type,
# eagerly propagate it to the items-context
if
isinstance
(
tc
.
target_type
,
PsArrayType
):
inherit_arr_type
=
True
propagate_elem_type
(
tc
.
target_type
.
base_type
,
tc
.
target_type
.
length
)
else
:
inherit_arr_type
=
False
for
item
in
items
:
for
item
in
items
:
self
.
visit_expr
(
item
,
items_tc
)
self
.
visit_expr
(
item
,
items_tc
)
...
@@ -715,12 +739,7 @@ class Typifier:
...
@@ -715,12 +739,7 @@ class Typifier:
def
hook
(
type_or_hint
:
PsType
|
TypeHint
)
->
PsType
|
None
:
def
hook
(
type_or_hint
:
PsType
|
TypeHint
)
->
PsType
|
None
:
match
type_or_hint
:
match
type_or_hint
:
case
PsArrayType
(
elem_type
,
length
):
case
PsArrayType
(
elem_type
,
length
):
if
length
is
not
None
and
length
!=
len
(
items
):
propagate_elem_type
(
elem_type
,
length
)
raise
TypificationError
(
"
Array size mismatch: Cannot typify initializer list with
"
f
"
{
len
(
items
)
}
items as
{
tc
.
target_type
}
"
)
items_tc
.
apply_dtype
(
deconstify
(
elem_type
))
tc
.
infer_dtype
(
expr
)
tc
.
infer_dtype
(
expr
)
return
None
return
None
...
@@ -731,25 +750,31 @@ class Typifier:
...
@@ -731,25 +750,31 @@ class Typifier:
items_tc
.
apply_hint
(
elem_type_or_hint
)
items_tc
.
apply_hint
(
elem_type_or_hint
)
tc
.
infer_dtype
(
expr
)
tc
.
infer_dtype
(
expr
)
return
PsArrayType
(
deconstify
(
items_tc
.
get_target_type
()),
len
(
items
))
return
PsArrayType
(
deconstify
(
items_tc
.
get_target_type
()),
len
(
items
)
)
case
ToDefault
():
case
ToDefault
():
items_tc
.
apply_hint
(
type_or_hint
)
items_tc
.
apply_hint
(
type_or_hint
)
tc
.
infer_dtype
(
expr
)
tc
.
infer_dtype
(
expr
)
return
PsArrayType
(
deconstify
(
items_tc
.
get_target_type
()),
len
(
items
))
return
PsArrayType
(
deconstify
(
items_tc
.
get_target_type
()),
len
(
items
)
)
case
TypeHint
():
case
TypeHint
():
# Can't deal with any other type hints
# Can't deal with any other type hints
return
None
return
None
case
other_type
:
case
other_type
:
raise
TypificationError
(
raise
TypificationError
(
f
"
Cannot apply type
{
other_type
}
to array initializer
{
expr
}
.
"
f
"
Cannot apply type
{
other_type
}
to array initializer
{
expr
}
.
"
)
)
tc
.
hook
(
hook
)
tc
.
hook
(
hook
)
elif
inherit_arr_type
:
tc
.
infer_dtype
(
expr
)
else
:
else
:
arr_type
=
PsArrayType
(
items_tc
.
target_type
,
len
(
items
))
arr_type
=
PsArrayType
(
deconstify
(
items_tc
.
target_type
)
,
len
(
items
))
tc
.
apply_dtype
(
arr_type
,
expr
)
tc
.
apply_dtype
(
arr_type
,
expr
)
case
PsCast
(
dtype
,
arg
):
case
PsCast
(
dtype
,
arg
):
...
...
This diff is collapsed.
Click to expand it.
tests/nbackend/kernelcreation/test_typification.py
+
29
−
0
View file @
f907b330
...
@@ -284,6 +284,35 @@ def test_constant_array_decls():
...
@@ -284,6 +284,35 @@ def test_constant_array_decls():
assert
ctx
.
get_symbol
(
"
y
"
).
dtype
==
Arr
(
Arr
(
Fp
(
16
),
4
),
2
)
assert
ctx
.
get_symbol
(
"
y
"
).
dtype
==
Arr
(
Arr
(
Fp
(
16
),
4
),
2
)
def
test_array_decl_lhs_type_propagation
():
ctx
=
KernelCreationContext
()
freeze
=
FreezeExpressions
(
ctx
)
typify
=
Typifier
(
ctx
)
# Type of array initializer is figured out to be half [4],
# but LHS symbol has type `half []` without shape information
# Expected behavior: LHS type overrides inferred type
arr
=
TypedSymbol
(
"
arr
"
,
Arr
(
Fp
(
16
)))
decl
=
freeze
(
Assignment
(
arr
,
(
5
,
78
,
1
,
TypedSymbol
(
"
x
"
,
Fp
(
16
)))))
decl
=
typify
(
decl
)
assert
decl
.
rhs
.
dtype
==
constify
(
arr
.
dtype
)
def
test_array_decl_constness_conflict
():
ctx
=
KernelCreationContext
()
freeze
=
FreezeExpressions
(
ctx
)
typify
=
Typifier
(
ctx
)
# Type of array initializer is figured out to be half [4],
# but LHS symbol has fixed type `const half []`.
# This is still a valid declaration.
arr
=
TypedSymbol
(
"
arr
"
,
Arr
(
Fp
(
16
,
const
=
True
)))
decl
=
freeze
(
Assignment
(
arr
,
(
5
,
78
,
1
,
TypedSymbol
(
"
x
"
,
Fp
(
16
)))))
decl
=
typify
(
decl
)
assert
decl
.
rhs
.
dtype
==
constify
(
Arr
(
Fp
(
16
,
const
=
True
)))
def
test_inline_arrays_1d
():
def
test_inline_arrays_1d
():
ctx
=
KernelCreationContext
(
default_dtype
=
Fp
(
16
))
ctx
=
KernelCreationContext
(
default_dtype
=
Fp
(
16
))
freeze
=
FreezeExpressions
(
ctx
)
freeze
=
FreezeExpressions
(
ctx
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment