Skip to content
Snippets Groups Projects

Extension to field read extraction

Merged Markus Holzer requested to merge holzer/pystencils:ExtensionSimp into master
2 files
+ 27
10
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -8,6 +8,7 @@ from pystencils.assignment import Assignment
from pystencils.astnodes import Node
from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
@@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac):
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True):
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
"""
field_reads = set()
to_iterate = []
@@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
if not field_reads:
return
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
Loading