diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index 716178571e3e161427a2795798830de64a2f4c5b..29677281fe9d9ff0a642445edd35244056fb5ba0 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -9,9 +9,13 @@ import sympy as sp import pystencils as ps import pystencils.cache import pystencils_autodiff._layout_fixer +from pystencils.interpolation_astnodes import InterpolatorAccess +from pystencils.math_optimizations import ReplaceOptim, optimize_assignments from pystencils_autodiff.backends import AVAILABLE_BACKENDS from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling +REMOVE_CASTS = ReplaceOptim(lambda x: isinstance(x, pystencils.data_types.cast_func), lambda x: x.args[0]) + @pystencils.cache.disk_cache_no_fallback def _create_backward_assignments_tf_mad(self, diff_fields_prefix): @@ -35,25 +39,28 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): if not hasattr(forward_assignments, 'free_symbols'): forward_assignments = ps.AssignmentCollection( forward_assignments, []) + forward_assignments = ps.AssignmentCollection(optimize_assignments(forward_assignments, [ + REMOVE_CASTS + ])) read_field_accesses = sorted([ - a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x)) + a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)] + + list(forward_assignments.atoms(InterpolatorAccess)), key=str) write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x)) read_fields = {s.field for s in read_field_accesses} write_fields = {s.field for s in write_field_accesses} self._forward_read_accesses = read_field_accesses self._forward_write_accesses = write_field_accesses - self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x)) - self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x)) + self._forward_input_fields = sorted(list(read_fields), key=str) + self._forward_output_fields = sorted(list(write_fields), key=str) - read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)] write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)] assert write_field_accesses, "No write accesses found" # for every field create a corresponding diff field diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) - for f in read_fields} + for f in read_fields if f not in self._constant_fields} diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in write_fields} @@ -81,11 +88,23 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): if ra.field != forward_read_field: continue # ignore constant fields in differentiation - # TF-MAD requires flipped stencils - inverted_offset = tuple(-v - w for v, - w in zip(ra.offsets, forward_assignment.lhs.offsets)) - diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ - diff_write_field[inverted_offset](*diff_write_index) + if isinstance(ra, InterpolatorAccess): + out_symbols = sp.symbols(f'o:{len(ra.offsets)}') + in_symbols = pystencils.x_vector(len(ra.offsets)) + maybe_solution = sp.solve(sp.Matrix(ra.offsets) - sp.Matrix(out_symbols), in_symbols) + assert maybe_solution, f'Could not solve for {in_symbols} when trying to derive for interpolator' # noqa + inverted_offset = tuple(maybe_solution.values()) + + inverted_offset = [foo.subs({o: s for o, s in zip(out_symbols, in_symbols)}) + for foo in inverted_offset] + diff_read_field_sum += (sp.diff(forward_assignment.rhs, ra) * + diff_write_field.interpolated_access(inverted_offset)) + else: + # TF-MAD requires flipped stencils + inverted_offset = tuple(-v - w for v, + w in zip(ra.offsets, forward_assignment.lhs.offsets)) + diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ + diff_write_field[inverted_offset](*diff_write_index) if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields: # Accumulate in case of time_constant_fields assignment = ps.Assignment( @@ -111,8 +130,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): # TF-MAD requires flipped stencils inverted_offset = tuple(-v - w for v, w in zip(ra.offsets, write_field_accesses[0].offsets)) - diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \ - diff_write_field[inverted_offset] + diff_read_field_sum[ra.index[0] + ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset] for index in range(diff_read_field.index_shape[0]): if forward_read_field in self._time_constant_fields: @@ -135,7 +154,6 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()] try: - if self._do_common_subexpression_elimination: backward_assignments = ps.simp.sympy_cse_on_assignment_list( backward_assignments) @@ -282,7 +300,7 @@ Backward: # do_common_subexpression_elimination=do_common_subexpression_elimination) def __hash__(self): - return hash((str(self.forward_assignments), str(self.backward_assignments))) + return hash((str(self.forward_assignments), str(self.backward_assignments), str(self.constant_fields))) def __repr__(self): return self._REPR_TEMPLATE.render(forward_assignments=str(self.forward_assignments), @@ -774,3 +792,6 @@ def get_jacobian_of_assignments(assignments, diff_variables): rhs = sp.Matrix([e.rhs for e in assignments]) return rhs.jacobian(diff_variables) + + rhs = sp.Matrix([e.rhs for e in assignments]) + return rhs.jacobian(diff_variables)