Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
bddd3a37
Commit
bddd3a37
authored
1 year ago
by
Frederik Hennig
Browse files
Options
Downloads
Patches
Plain Diff
refactor field and array handling in context
parent
0dbf7137
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/pystencils/backend/kernelcreation/context.py
+86
-46
86 additions, 46 deletions
src/pystencils/backend/kernelcreation/context.py
tests/nbackend/kernelcreation/test_options.py
+1
-1
1 addition, 1 deletion
tests/nbackend/kernelcreation/test_options.py
with
87 additions
and
47 deletions
src/pystencils/backend/kernelcreation/context.py
+
86
−
46
View file @
bddd3a37
from
__future__
import
annotations
from
types
import
EllipsisType
from
...field
import
Field
,
FieldType
from
...sympyextensions.typed_sympy
import
TypedSymbol
,
BasicType
,
StructType
from
..arrays
import
PsLinearizedArray
...
...
@@ -43,15 +45,18 @@ class KernelCreationContext:
or full iteration space.
"""
def
__init__
(
self
,
default_dtype
:
PsNumericType
=
PbDefaults
.
numeric_dtype
,
index_dtype
:
PsIntegerType
=
PbDefaults
.
index_dtype
):
def
__init__
(
self
,
default_dtype
:
PsNumericType
=
PbDefaults
.
numeric_dtype
,
index_dtype
:
PsIntegerType
=
PbDefaults
.
index_dtype
,
):
self
.
_default_dtype
=
default_dtype
self
.
_index_dtype
=
index_dtype
self
.
_arrays
:
dict
[
Field
,
PsLinearizedArray
]
=
dict
()
self
.
_constraints
:
list
[
PsKernelConstraint
]
=
[]
self
.
_field_arrays
:
dict
[
Field
,
PsLinearizedArray
]
=
dict
()
self
.
_fields_collection
=
FieldsInKernel
()
self
.
_ispace
:
IterationSpace
|
None
=
None
@property
...
...
@@ -76,7 +81,22 @@ class KernelCreationContext:
return
self
.
_fields_collection
def
add_field
(
self
,
field
:
Field
):
"""
Add the given field to the context
'
s fields collection
"""
"""
Add the given field to the context
'
s fields collection.
This method adds the passed ``field`` to the context
'
s field collection, which is
accesible through the `fields` member, and creates an array representation of the field,
which is retrievable through `get_array`.
Before adding the field to the collection, various sanity and constraint checks are applied.
"""
if
field
in
self
.
_field_arrays
:
# Field was already added
return
arr_shape
:
list
[
EllipsisType
|
int
]
|
None
=
None
arr_strides
:
list
[
EllipsisType
|
int
]
|
None
=
None
# Check field constraints and add to collection
match
field
.
field_type
:
case
FieldType
.
GENERIC
|
FieldType
.
STAGGERED
|
FieldType
.
STAGGERED_FLUX
:
self
.
_fields_collection
.
domain_fields
.
add
(
field
)
...
...
@@ -87,6 +107,23 @@ class KernelCreationContext:
f
"
Invalid spatial shape of buffer field
{
field
.
name
}
:
{
field
.
spatial_dimensions
}
.
"
"
Buffer fields must be one-dimensional.
"
)
if
field
.
index_dimensions
>
1
:
raise
KernelConstraintsError
(
f
"
Invalid index shape of buffer field
{
field
.
name
}
:
{
field
.
spatial_dimensions
}
.
"
"
Buffer fields can have at most one index dimension.
"
)
num_entries
=
field
.
index_shape
[
0
]
if
field
.
index_shape
else
1
if
not
isinstance
(
num_entries
,
int
):
raise
KernelConstraintsError
(
f
"
Invalid index shape of buffer field
{
field
.
name
}
:
{
field
.
spatial_dimensions
}
.
"
"
Buffer fields cannot have variable index shape.
"
)
arr_shape
=
[...,
num_entries
]
arr_strides
=
[
num_entries
,
1
]
self
.
_fields_collection
.
buffer_fields
.
add
(
field
)
case
FieldType
.
INDEXED
:
...
...
@@ -103,48 +140,51 @@ class KernelCreationContext:
case
_
:
assert
False
,
"
unreachable code
"
def
get_array
(
self
,
field
:
Field
)
->
PsLinearizedArray
:
if
field
not
in
self
.
_arrays
:
if
field
.
field_type
==
FieldType
.
BUFFER
:
# Buffers are always contiguous
assert
field
.
spatial_dimensions
==
1
assert
field
.
index_dimensions
<=
1
num_entries
=
field
.
index_shape
[
0
]
if
field
.
index_shape
else
1
# For non-buffer fields, determine shape and strides
arr_shape
=
[...,
num_entries
]
arr_strides
=
[
num_entries
,
1
]
else
:
arr_shape
=
[
(
Ellipsis
if
isinstance
(
s
,
TypedSymbol
)
else
s
)
# TODO: Field should also use ellipsis
for
s
in
field
.
shape
]
arr_strides
=
[
(
Ellipsis
if
isinstance
(
s
,
TypedSymbol
)
else
s
)
# TODO: Field should also use ellipsis
for
s
in
field
.
strides
]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if
not
field
.
index_shape
:
arr_shape
+=
[
1
]
arr_strides
+=
[
1
]
assert
isinstance
(
field
.
dtype
,
(
BasicType
,
StructType
))
element_type
=
make_type
(
field
.
dtype
.
numpy_dtype
)
arr
=
PsLinearizedArray
(
field
.
name
,
element_type
,
arr_shape
,
arr_strides
,
self
.
index_dtype
)
self
.
_arrays
[
field
]
=
arr
return
self
.
_arrays
[
field
]
if
arr_shape
is
None
:
arr_shape
=
[
(
Ellipsis
if
isinstance
(
s
,
TypedSymbol
)
else
s
)
# TODO: Field should also use ellipsis
for
s
in
field
.
shape
]
arr_strides
=
[
(
Ellipsis
if
isinstance
(
s
,
TypedSymbol
)
else
s
)
# TODO: Field should also use ellipsis
for
s
in
field
.
strides
]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if
not
field
.
index_shape
:
arr_shape
+=
[
1
]
arr_strides
+=
[
1
]
# Add array
assert
arr_strides
is
not
None
assert
isinstance
(
field
.
dtype
,
(
BasicType
,
StructType
))
element_type
=
make_type
(
field
.
dtype
.
numpy_dtype
)
arr
=
PsLinearizedArray
(
field
.
name
,
element_type
,
arr_shape
,
arr_strides
,
self
.
index_dtype
)
self
.
_field_arrays
[
field
]
=
arr
def
get_array
(
self
,
field
:
Field
)
->
PsLinearizedArray
:
"""
Retrieve the underlying array for a given field.
If the given field was not previously registered using `add_field`,
this method internally calls `add_field` to check the field for consistency.
"""
if
field
not
in
self
.
_field_arrays
:
self
.
add_field
(
field
)
return
self
.
_field_arrays
[
field
]
# Iteration Space
...
...
This diff is collapsed.
Click to expand it.
tests/nbackend/kernelcreation/test_options.py
+
1
−
1
View file @
bddd3a37
...
...
@@ -2,7 +2,7 @@ import pytest
from
pystencils.field
import
Field
,
FieldType
from
pystencils.backend.types.quick
import
*
from
pystencils.
kernelcreation
import
(
from
pystencils.
config
import
(
CreateKernelConfig
,
PsOptionsError
,
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment