From 6df2c640b0f88cf1882604e5917ea21f40687ba9 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 11 Apr 2025 12:02:14 +0200 Subject: [PATCH] insert casts in `add_subexpressions_for_field_reads` --- src/pystencils/simp/simplifications.py | 37 ++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/src/pystencils/simp/simplifications.py b/src/pystencils/simp/simplifications.py index 9368c8f51..baecf6cb4 100644 --- a/src/pystencils/simp/simplifications.py +++ b/src/pystencils/simp/simplifications.py @@ -1,13 +1,20 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + from itertools import chain from typing import Callable, List, Sequence, Union from collections import defaultdict import sympy as sp +from ..types import UserTypeSpec from ..assignment import Assignment -from ..sympyextensions import subs_additive, is_constant, recursive_collect +from ..sympyextensions import subs_additive, is_constant, recursive_collect, tcast from ..sympyextensions.typed_sympy import TypedSymbol +if TYPE_CHECKING: + from .assignment_collection import AssignmentCollection + # TODO rewrite with SymPy AST # def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: @@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac): return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) -def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None): +def add_subexpressions_for_field_reads( + ac: AssignmentCollection, + subexpressions=True, + main_assignments=True, + data_type: UserTypeSpec | None = None +): r"""Substitutes field accesses on rhs of assignments with subexpressions Can change semantics of the update rule (which is the goal of this transformation) This is useful if a field should be update in place - all values are loaded before into subexpression variables, then the new values are computed and written to the same field in-place. Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have - this data type. This is useful for mixed precision kernels + this data type, and an explicit cast is inserted. This is useful for mixed precision kernels """ field_reads = set() to_iterate = [] @@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments substitutions.update({fa: TypedSymbol(lhs.name, data_type)}) else: substitutions.update({fa: lhs}) - return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, - substitute_on_lhs=False, sort_topologically=False) + + ac = ac.new_with_substitutions( + substitutions, + add_substitutions_as_subexpressions=False, + substitute_on_lhs=False, + sort_topologically=False + ) + + loads: list[Assignment] = [] + for fa in field_reads: + rhs = fa if data_type is None else tcast(fa, data_type) + loads.append( + Assignment(substitutions[fa], rhs) + ) + + ac.subexpressions = loads + ac.subexpressions + return ac def transform_rhs(assignment_list, transformation, *args, **kwargs): -- GitLab