From 59dce1c1c70b55a7c4a27b15751c744cf0e8f88d Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Thu, 24 Aug 2023 09:47:30 +0200 Subject: [PATCH] Extension to field read extraction --- pystencils/simp/simplifications.py | 17 +++++++++++++++-- pystencils_tests/test_simplifications.py | 20 ++++++++++++-------- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 5ed8c4ea9..e1f42dc60 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -8,6 +8,7 @@ from pystencils.assignment import Assignment from pystencils.astnodes import Node from pystencils.field import Field from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect +from pystencils.typing import TypedSymbol def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: @@ -168,12 +169,14 @@ def add_subexpressions_for_sums(ac): return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) -def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): +def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None): r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) This is useful if a field should be update in place - all values are loaded before into subexpression variables, then the new values are computed and written to the same field in-place. + Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have + this data type. This is useful for mixed precision kernels """ field_reads = set() to_iterate = [] @@ -185,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments for assignment in to_iterate: if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): field_reads.update(assignment.rhs.atoms(Field.Access)) - substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads} + + if not field_reads: + return + + substitutions = dict() + for fa in field_reads: + lhs = next(ac.subexpression_symbol_generator) + if data_type is not None: + substitutions.update({fa: TypedSymbol(lhs.name, data_type)}) + else: + substitutions.update({fa: lhs}) return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, substitute_on_lhs=False, sort_topologically=False) diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py index 61d009d03..2814ac14f 100644 --- a/pystencils_tests/test_simplifications.py +++ b/pystencils_tests/test_simplifications.py @@ -4,14 +4,14 @@ import pytest import pystencils.config import sympy as sp import pystencils as ps -import numpy as np +from pystencils import Assignment, AssignmentCollection, fields from pystencils.simp import subexpression_substitution_in_main_assignments from pystencils.simp import add_subexpressions_for_divisions from pystencils.simp import add_subexpressions_for_sums from pystencils.simp import add_subexpressions_for_field_reads from pystencils.simp.simplifications import add_subexpressions_for_constants -from pystencils import Assignment, AssignmentCollection, fields +from pystencils.typing import BasicType, TypedSymbol a, b, c, d, x, y, z = sp.symbols("a b c d x y z") s0, s1, s2, s3 = sp.symbols("s_:4") @@ -133,14 +133,18 @@ def test_add_subexpressions_for_sums(): def test_add_subexpressions_for_field_reads(): s, v = fields("s(5), v(5): double[2D]") subexpressions = [] - main = [ - Assignment(s[0, 0](0), 3 * v[0, 0](0)), - Assignment(s[0, 0](1), 10 * v[0, 0](1)) - ] + + main = [Assignment(s[0, 0](0), 3 * v[0, 0](0)), + Assignment(s[0, 0](1), 10 * v[0, 0](1))] + ac = AssignmentCollection(main, subexpressions) assert len(ac.subexpressions) == 0 - ac = add_subexpressions_for_field_reads(ac) - assert len(ac.subexpressions) == 2 + ac2 = add_subexpressions_for_field_reads(ac) + assert len(ac2.subexpressions) == 2 + ac3 = add_subexpressions_for_field_reads(ac, data_type="float32") + assert len(ac3.subexpressions) == 2 + assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) + assert ac3.subexpressions[0].lhs.dtype == BasicType("float32") @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) -- GitLab