diff --git a/src/pystencils/simp/simplifications.py b/src/pystencils/simp/simplifications.py index e1f42dc60467c2cd8a7fa874afd79323e77111a7..e8986aa4b951a3735e79167cb4d2754754be267b 100644 --- a/src/pystencils/simp/simplifications.py +++ b/src/pystencils/simp/simplifications.py @@ -190,7 +190,7 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments field_reads.update(assignment.rhs.atoms(Field.Access)) if not field_reads: - return + return ac substitutions = dict() for fa in field_reads: diff --git a/tests/test_simplifications.py b/tests/test_simplifications.py index 2814ac14f5fb5734e0d23a92d8e0aaf975691b36..b7f72651bf84fda273c2867cfef65c90c3b6780a 100644 --- a/tests/test_simplifications.py +++ b/tests/test_simplifications.py @@ -146,6 +146,16 @@ def test_add_subexpressions_for_field_reads(): assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) assert ac3.subexpressions[0].lhs.dtype == BasicType("float32") + # added check for early out of add_subexpressions_for_field_reads is no fields appear on the rhs (See #92) + main = [Assignment(s[0, 0](0), 3.0), + Assignment(s[0, 0](1), 4.0)] + + ac4 = AssignmentCollection(main, subexpressions) + assert len(ac4.subexpressions) == 0 + ac5 = add_subexpressions_for_field_reads(ac4) + assert ac5 is not None + assert ac4 is ac5 + @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) @pytest.mark.parametrize('dtype', ('float32', 'float64'))