Skip to content
Snippets Groups Projects
Commit 59dce1c1 authored by Markus Holzer's avatar Markus Holzer Committed by Christoph Alt
Browse files

Extension to field read extraction

parent 6da822aa
1 merge request!346Extension to field read extraction
...@@ -8,6 +8,7 @@ from pystencils.assignment import Assignment ...@@ -8,6 +8,7 @@ from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.field import Field from pystencils.field import Field
from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
...@@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac): ...@@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac):
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation) Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables, This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place. then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
""" """
field_reads = set() field_reads = set()
to_iterate = [] to_iterate = []
...@@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments ...@@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
for assignment in to_iterate: for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access)) field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
if not field_reads:
return
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False) substitute_on_lhs=False, sort_topologically=False)
......
...@@ -4,14 +4,14 @@ import pytest ...@@ -4,14 +4,14 @@ import pytest
import pystencils.config import pystencils.config
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
import numpy as np
from pystencils import Assignment, AssignmentCollection, fields
from pystencils.simp import subexpression_substitution_in_main_assignments from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions from pystencils.simp import add_subexpressions_for_divisions
from pystencils.simp import add_subexpressions_for_sums from pystencils.simp import add_subexpressions_for_sums
from pystencils.simp import add_subexpressions_for_field_reads from pystencils.simp import add_subexpressions_for_field_reads
from pystencils.simp.simplifications import add_subexpressions_for_constants from pystencils.simp.simplifications import add_subexpressions_for_constants
from pystencils import Assignment, AssignmentCollection, fields from pystencils.typing import BasicType, TypedSymbol
a, b, c, d, x, y, z = sp.symbols("a b c d x y z") a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4") s0, s1, s2, s3 = sp.symbols("s_:4")
...@@ -133,14 +133,18 @@ def test_add_subexpressions_for_sums(): ...@@ -133,14 +133,18 @@ def test_add_subexpressions_for_sums():
def test_add_subexpressions_for_field_reads(): def test_add_subexpressions_for_field_reads():
s, v = fields("s(5), v(5): double[2D]") s, v = fields("s(5), v(5): double[2D]")
subexpressions = [] subexpressions = []
main = [
Assignment(s[0, 0](0), 3 * v[0, 0](0)), main = [Assignment(s[0, 0](0), 3 * v[0, 0](0)),
Assignment(s[0, 0](1), 10 * v[0, 0](1)) Assignment(s[0, 0](1), 10 * v[0, 0](1))]
]
ac = AssignmentCollection(main, subexpressions) ac = AssignmentCollection(main, subexpressions)
assert len(ac.subexpressions) == 0 assert len(ac.subexpressions) == 0
ac = add_subexpressions_for_field_reads(ac) ac2 = add_subexpressions_for_field_reads(ac)
assert len(ac.subexpressions) == 2 assert len(ac2.subexpressions) == 2
ac3 = add_subexpressions_for_field_reads(ac, data_type="float32")
assert len(ac3.subexpressions) == 2
assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol)
assert ac3.subexpressions[0].lhs.dtype == BasicType("float32")
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment