From 6c88aa678973ef1ba94d139cc1d93e7db4651e73 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Thu, 10 Oct 2024 12:58:58 +0200 Subject: [PATCH] Fix add_subexpressions_for_field_reads --- src/pystencils/simp/simplifications.py | 2 +- tests/test_simplifications.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/pystencils/simp/simplifications.py b/src/pystencils/simp/simplifications.py index e1f42dc60..e8986aa4b 100644 --- a/src/pystencils/simp/simplifications.py +++ b/src/pystencils/simp/simplifications.py @@ -190,7 +190,7 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments field_reads.update(assignment.rhs.atoms(Field.Access)) if not field_reads: - return + return ac substitutions = dict() for fa in field_reads: diff --git a/tests/test_simplifications.py b/tests/test_simplifications.py index 2814ac14f..b7f72651b 100644 --- a/tests/test_simplifications.py +++ b/tests/test_simplifications.py @@ -146,6 +146,16 @@ def test_add_subexpressions_for_field_reads(): assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) assert ac3.subexpressions[0].lhs.dtype == BasicType("float32") + # added check for early out of add_subexpressions_for_field_reads is no fields appear on the rhs (See #92) + main = [Assignment(s[0, 0](0), 3.0), + Assignment(s[0, 0](1), 4.0)] + + ac4 = AssignmentCollection(main, subexpressions) + assert len(ac4.subexpressions) == 0 + ac5 = add_subexpressions_for_field_reads(ac4) + assert ac5 is not None + assert ac4 is ac5 + @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) @pytest.mark.parametrize('dtype', ('float32', 'float64')) -- GitLab