Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
pystencils
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pycodegen
pystencils
Commits
6df2c640
Commit
6df2c640
authored
1 month ago
by
Frederik Hennig
Browse files
Options
Downloads
Patches
Plain Diff
insert casts in `add_subexpressions_for_field_reads`
parent
f8e5419f
No related branches found
No related tags found
1 merge request
!460
Fix data types in boundary handling. Fix deprecation checks.
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/pystencils/simp/simplifications.py
+32
-5
32 additions, 5 deletions
src/pystencils/simp/simplifications.py
with
32 additions
and
5 deletions
src/pystencils/simp/simplifications.py
+
32
−
5
View file @
6df2c640
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
Callable
,
List
,
Sequence
,
Union
from
typing
import
Callable
,
List
,
Sequence
,
Union
from
collections
import
defaultdict
from
collections
import
defaultdict
import
sympy
as
sp
import
sympy
as
sp
from
..types
import
UserTypeSpec
from
..assignment
import
Assignment
from
..assignment
import
Assignment
from
..sympyextensions
import
subs_additive
,
is_constant
,
recursive_collect
from
..sympyextensions
import
subs_additive
,
is_constant
,
recursive_collect
,
tcast
from
..sympyextensions.typed_sympy
import
TypedSymbol
from
..sympyextensions.typed_sympy
import
TypedSymbol
if
TYPE_CHECKING
:
from
.assignment_collection
import
AssignmentCollection
# TODO rewrite with SymPy AST
# TODO rewrite with SymPy AST
# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
...
@@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac):
...
@@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac):
return
ac
.
new_with_substitutions
(
substitutions
,
True
,
substitute_on_lhs
=
False
)
return
ac
.
new_with_substitutions
(
substitutions
,
True
,
substitute_on_lhs
=
False
)
def
add_subexpressions_for_field_reads
(
ac
,
subexpressions
=
True
,
main_assignments
=
True
,
data_type
=
None
):
def
add_subexpressions_for_field_reads
(
ac
:
AssignmentCollection
,
subexpressions
=
True
,
main_assignments
=
True
,
data_type
:
UserTypeSpec
|
None
=
None
):
r
"""
Substitutes field accesses on rhs of assignments with subexpressions
r
"""
Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation)
Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place.
then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
this data type
, and an explicit cast is inserted
. This is useful for mixed precision kernels
"""
"""
field_reads
=
set
()
field_reads
=
set
()
to_iterate
=
[]
to_iterate
=
[]
...
@@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
...
@@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
substitutions
.
update
({
fa
:
TypedSymbol
(
lhs
.
name
,
data_type
)})
substitutions
.
update
({
fa
:
TypedSymbol
(
lhs
.
name
,
data_type
)})
else
:
else
:
substitutions
.
update
({
fa
:
lhs
})
substitutions
.
update
({
fa
:
lhs
})
return
ac
.
new_with_substitutions
(
substitutions
,
add_substitutions_as_subexpressions
=
True
,
substitute_on_lhs
=
False
,
sort_topologically
=
False
)
ac
=
ac
.
new_with_substitutions
(
substitutions
,
add_substitutions_as_subexpressions
=
False
,
substitute_on_lhs
=
False
,
sort_topologically
=
False
)
loads
:
list
[
Assignment
]
=
[]
for
fa
in
field_reads
:
rhs
=
fa
if
data_type
is
None
else
tcast
(
fa
,
data_type
)
loads
.
append
(
Assignment
(
substitutions
[
fa
],
rhs
)
)
ac
.
subexpressions
=
loads
+
ac
.
subexpressions
return
ac
def
transform_rhs
(
assignment_list
,
transformation
,
*
args
,
**
kwargs
):
def
transform_rhs
(
assignment_list
,
transformation
,
*
args
,
**
kwargs
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment