Skip to content
Snippets Groups Projects

Fix data types in boundary handling. Fix deprecation checks.

Merged Frederik Hennig requested to merge fhennig/patches-for-lbmpy into v2.0-dev
1 file
+ 32
5
Compare changes
  • Side-by-side
  • Inline
from __future__ import annotations
from typing import TYPE_CHECKING
from itertools import chain
from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp
from ..types import UserTypeSpec
from ..assignment import Assignment
from ..sympyextensions import subs_additive, is_constant, recursive_collect
from ..sympyextensions import subs_additive, is_constant, recursive_collect, tcast
from ..sympyextensions.typed_sympy import TypedSymbol
if TYPE_CHECKING:
from .assignment_collection import AssignmentCollection
# TODO rewrite with SymPy AST
# def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
@@ -170,14 +177,19 @@ def add_subexpressions_for_sums(ac):
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
def add_subexpressions_for_field_reads(
ac: AssignmentCollection,
subexpressions=True,
main_assignments=True,
data_type: UserTypeSpec | None = None
):
r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
this data type, and an explicit cast is inserted. This is useful for mixed precision kernels
"""
field_reads = set()
to_iterate = []
@@ -201,8 +213,23 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False)
ac = ac.new_with_substitutions(
substitutions,
add_substitutions_as_subexpressions=False,
substitute_on_lhs=False,
sort_topologically=False
)
loads: list[Assignment] = []
for fa in field_reads:
rhs = fa if data_type is None else tcast(fa, data_type)
loads.append(
Assignment(substitutions[fa], rhs)
)
ac.subexpressions = loads + ac.subexpressions
return ac
def transform_rhs(assignment_list, transformation, *args, **kwargs):
Loading