Skip to content
Snippets Groups Projects
Commit 6df2c640 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

insert casts in `add_subexpressions_for_field_reads`

parent f8e5419f
No related branches found
No related tags found
1 merge request!460Fix data types in boundary handling. Fix deprecation checks.
from __future__ import annotations
from typing import TYPE_CHECKING
from itertools import chain from itertools import chain
from typing import Callable, List, Sequence, Union from typing import Callable, List, Sequence, Union
from collections import defaultdict from collections import defaultdict
import sympy as sp import sympy as sp
from ..types import UserTypeSpec
from ..assignment import Assignment from ..assignment import Assignment
from ..sympyextensions import subs_additive, is_constant, recursive_collect from ..sympyextensions import subs_additive, is_constant, recursive_collect, tcast
from ..sympyextensions.typed_sympy import TypedSymbol from ..sympyextensions.typed_sympy import TypedSymbol
if TYPE_CHECKING:
from .assignment_collection import AssignmentCollection
# TODO rewrite with SymPy AST # TODO rewrite with SymPy AST
# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: # def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
...@@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac): ...@@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac):
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None): def add_subexpressions_for_field_reads(
ac: AssignmentCollection,
subexpressions=True,
main_assignments=True,
data_type: UserTypeSpec | None = None
):
r"""Substitutes field accesses on rhs of assignments with subexpressions r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation) Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables, This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place. then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels this data type, and an explicit cast is inserted. This is useful for mixed precision kernels
""" """
field_reads = set() field_reads = set()
to_iterate = [] to_iterate = []
...@@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments ...@@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
substitutions.update({fa: TypedSymbol(lhs.name, data_type)}) substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else: else:
substitutions.update({fa: lhs}) substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False) ac = ac.new_with_substitutions(
substitutions,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=False,
sort_topologically=False
)
loads: list[Assignment] = []
for fa in field_reads:
rhs = fa if data_type is None else tcast(fa, data_type)
loads.append(
Assignment(substitutions[fa], rhs)
)
ac.subexpressions = loads + ac.subexpressions
return ac
def transform_rhs(assignment_list, transformation, *args, **kwargs): def transform_rhs(assignment_list, transformation, *args, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment