diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index 2e4ea6c160e4e135b3273bbb0200d0348bb83b14..a15afbd01a194df7f5e2597b485b5ac205b0cd1c 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -7,11 +7,151 @@ import numpy as np import sympy as sp import pystencils as ps +import pystencils.cache import pystencils_autodiff._layout_fixer from pystencils_autodiff.backends import AVAILABLE_BACKENDS from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling +@pystencils.cache.disk_cache_no_fallback +def _create_backward_assignments_tf_mad(self, diff_fields_prefix): + """ + Performs the automatic backward differentiation in a more fancy way with write accesses + like in the forward pass (only flipped). + It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD). + + See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that + paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true + for more information + """ + + forward_assignments = self._forward_assignments + if hasattr(forward_assignments, 'new_without_subexpressions'): + forward_assignments = forward_assignments.new_without_subexpressions() + if hasattr(forward_assignments, 'main_assignments'): + forward_assignments = forward_assignments.main_assignments + + if not hasattr(forward_assignments, 'free_symbols'): + forward_assignments = ps.AssignmentCollection( + forward_assignments, []) + + read_field_accesses = sorted([ + a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x)) + write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x)) + read_fields = {s.field for s in read_field_accesses} + write_fields = {s.field for s in write_field_accesses} + + self._forward_read_accesses = read_field_accesses + self._forward_write_accesses = write_field_accesses + self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x)) + self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x)) + + read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)] + write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)] + assert write_field_accesses, "No write accesses found" + + # for every field create a corresponding diff field + diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) + for f in read_fields} + diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) + for f in write_fields} + + assert all(isinstance(w, ps.Field.Access) + for w in write_field_accesses), \ + "Please check if your assignments are a AssignmentCollection or main_assignments only" + + backward_assignment_dict = collections.OrderedDict() + # for each output of forward operation + for _, forward_assignment in enumerate(forward_assignments.main_assignments): + # we have only one assignment + diff_write_field = diff_write_fields[forward_assignment.lhs.field] + diff_write_index = forward_assignment.lhs.index + + # TODO: simplify implementation. use matrix notation like in 'transposed' mode + for forward_read_field in self._forward_input_fields: + if forward_read_field in self._constant_fields: + continue + diff_read_field = diff_read_fields[forward_read_field] + + if diff_read_field.index_dimensions == 0: + + diff_read_field_sum = 0 + for ra in read_field_accesses: + if ra.field != forward_read_field: + continue # ignore constant fields in differentiation + + # TF-MAD requires flipped stencils + inverted_offset = tuple(-v - w for v, + w in zip(ra.offsets, forward_assignment.lhs.offsets)) + diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ + diff_write_field[inverted_offset](*diff_write_index) + if forward_read_field in self._time_constant_fields: + # Accumulate in case of time_constant_fields + assignment = ps.Assignment( + diff_read_field.center(), diff_read_field.center() + diff_read_field_sum) + else: + # If time dependent, we just need to assign the sum for the current time step + assignment = ps.Assignment( + diff_read_field.center(), diff_read_field_sum) + + # We can have contributions from multiple forward assignments + if assignment.lhs in backward_assignment_dict: + backward_assignment_dict[assignment.lhs].append(assignment.rhs) + else: + backward_assignment_dict[assignment.lhs] = [assignment.rhs] + + elif diff_read_field.index_dimensions == 1: + + diff_read_field_sum = [0] * diff_read_field.index_shape[0] + for ra in read_field_accesses: + if ra.field != forward_read_field: + continue # ignore constant fields in differentiation + + # TF-MAD requires flipped stencils + inverted_offset = tuple(-v - w for v, + w in zip(ra.offsets, write_field_accesses[0].offsets)) + diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \ + diff_write_field[inverted_offset] + + for index in range(diff_read_field.index_shape[0]): + if forward_read_field in self._time_constant_fields: + # Accumulate in case of time_constant_fields + assignment = ps.Assignment( + diff_read_field.center_vector[index], + diff_read_field.center_vector[index] + diff_read_field_sum[index]) + else: + # If time dependent, we just need to assign the sum for the current time step + assignment = ps.Assignment( + diff_read_field.center_vector[index], diff_read_field_sum[index]) + + if assignment.lhs in backward_assignment_dict: + backward_assignment_dict[assignment.lhs].append(assignment.rhs) + else: + backward_assignment_dict[assignment.lhs] = [assignment.rhs] + else: + raise NotImplementedError() + + backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()] + + try: + + if self._do_common_subexpression_elimination: + backward_assignments = ps.simp.sympy_cse_on_assignment_list( + backward_assignments) + except Exception: + pass + # print("Common subexpression elimination failed") + # print(err) + main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)] + subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)] + backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions) + + assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!" + self._backward_field_map = {**diff_read_fields, **diff_write_fields} + + return backward_assignments + + class AutoDiffBoundaryHandling(str, Enum): """ Strategies for in-kernel boundary handling for AutoDiffOp @@ -112,12 +252,39 @@ Backward: self._backward_input_fields = list(backward_assignments.free_fields) self._backward_output_fields = list(backward_assignments.bound_fields) else: + # if no_caching: if diff_mode == 'transposed': self._create_backward_assignments(diff_fields_prefix) elif diff_mode == 'transposed-forward': - self._create_backward_assignments_tf_mad(diff_fields_prefix) + self._backward_assignments = None + self._backward_field_map = None + backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix) + self._backward_assignments = backward_assignments + if self._backward_field_map: + self._backward_input_fields = [ + self._backward_field_map[f] for f in self._forward_output_fields] + self._backward_output_fields = [ + self._backward_field_map[f] for f in self._forward_input_fields] + else: + self._forward_assignments = forward_assignments + self._forward_read_accesses = None + self._forward_write_accesses = None + self._forward_input_fields = list(forward_assignments.free_fields) + self._forward_output_fields = list(forward_assignments.bound_fields) + self._backward_assignments = backward_assignments + self._backward_field_map = None + self._backward_input_fields = list(backward_assignments.free_fields) + self._backward_output_fields = list(backward_assignments.bound_fields) + self._backward_field_map = None else: raise NotImplementedError() + # else: + # # self.backward_assignments = create_backward_assignments(forward_assignments, + # # diff_fields_prefix, + # # time_constant_fields, + # # constant_fields, + # # diff_mode=diff_mode, + # do_common_subexpression_elimination=do_common_subexpression_elimination) def __hash__(self): return hash((str(self.forward_assignments), str(self.backward_assignments))) @@ -129,6 +296,23 @@ Backward: def __str__(self): return self.__repr__() + def __setstate__(self, state): + forward_assignments = state['forward_assignments'] + backward_assignments = state['backward_assignments'] + self._forward_assignments = forward_assignments + self._forward_read_accesses = None + self._forward_write_accesses = None + self._forward_input_fields = list(forward_assignments.free_fields) + self._forward_output_fields = list(forward_assignments.bound_fields) + self._backward_assignments = backward_assignments + self._backward_field_map = None + self._backward_input_fields = list(backward_assignments.free_fields) + self._backward_output_fields = list(backward_assignments.bound_fields) + self._backward_field_map = None + + def __getstate__(self): + return {'forward_assignments': self._forward_assignments, 'backward_assignments': self.backward_assignments} + def _create_backward_assignments(self, diff_fields_prefix): """ Performs automatic differentiation in the traditional adjoint/tangent way. @@ -214,147 +398,6 @@ Backward: self._backward_output_fields = [ self._backward_field_map[f] for f in self._forward_input_fields] - def _create_backward_assignments_tf_mad(self, diff_fields_prefix): - """ - Performs the automatic backward differentiation in a more fancy way with write accesses - like in the forward pass (only flipped). - It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD). - - See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that - paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true - for more information - """ - - forward_assignments = self._forward_assignments - if hasattr(forward_assignments, 'new_without_subexpressions'): - forward_assignments = forward_assignments.new_without_subexpressions() - if hasattr(forward_assignments, 'main_assignments'): - forward_assignments = forward_assignments.main_assignments - - if not hasattr(forward_assignments, 'free_symbols'): - forward_assignments = ps.AssignmentCollection( - forward_assignments, []) - - read_field_accesses = sorted([ - a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x)) - write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x)) - read_fields = {s.field for s in read_field_accesses} - write_fields = {s.field for s in write_field_accesses} - - self._forward_read_accesses = read_field_accesses - self._forward_write_accesses = write_field_accesses - self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x)) - self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x)) - - read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)] - write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)] - assert write_field_accesses, "No write accesses found" - - # for every field create a corresponding diff field - diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) - for f in read_fields} - diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) - for f in write_fields} - - assert all(isinstance(w, ps.Field.Access) - for w in write_field_accesses), \ - "Please check if your assignments are a AssignmentCollection or main_assignments only" - - backward_assignment_dict = collections.OrderedDict() - # for each output of forward operation - for _, forward_assignment in enumerate(forward_assignments.main_assignments): - # we have only one assignment - diff_write_field = diff_write_fields[forward_assignment.lhs.field] - diff_write_index = forward_assignment.lhs.index - - # TODO: simplify implementation. use matrix notation like in 'transposed' mode - for forward_read_field in self._forward_input_fields: - if forward_read_field in self._constant_fields: - continue - diff_read_field = diff_read_fields[forward_read_field] - - if diff_read_field.index_dimensions == 0: - - diff_read_field_sum = 0 - for ra in read_field_accesses: - if ra.field != forward_read_field: - continue # ignore constant fields in differentiation - - # TF-MAD requires flipped stencils - inverted_offset = tuple(-v - w for v, - w in zip(ra.offsets, forward_assignment.lhs.offsets)) - diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ - diff_write_field[inverted_offset](*diff_write_index) - if forward_read_field in self._time_constant_fields: - # Accumulate in case of time_constant_fields - assignment = ps.Assignment( - diff_read_field.center(), diff_read_field.center() + diff_read_field_sum) - else: - # If time dependent, we just need to assign the sum for the current time step - assignment = ps.Assignment( - diff_read_field.center(), diff_read_field_sum) - - # We can have contributions from multiple forward assignments - if assignment.lhs in backward_assignment_dict: - backward_assignment_dict[assignment.lhs].append(assignment.rhs) - else: - backward_assignment_dict[assignment.lhs] = [assignment.rhs] - - elif diff_read_field.index_dimensions == 1: - - diff_read_field_sum = [0] * diff_read_field.index_shape[0] - for ra in read_field_accesses: - if ra.field != forward_read_field: - continue # ignore constant fields in differentiation - - # TF-MAD requires flipped stencils - inverted_offset = tuple(-v - w for v, - w in zip(ra.offsets, write_field_accesses[0].offsets)) - diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \ - diff_write_field[inverted_offset] - - for index in range(diff_read_field.index_shape[0]): - if forward_read_field in self._time_constant_fields: - # Accumulate in case of time_constant_fields - assignment = ps.Assignment( - diff_read_field.center_vector[index], - diff_read_field.center_vector[index] + diff_read_field_sum[index]) - else: - # If time dependent, we just need to assign the sum for the current time step - assignment = ps.Assignment( - diff_read_field.center_vector[index], diff_read_field_sum[index]) - - if assignment.lhs in backward_assignment_dict: - backward_assignment_dict[assignment.lhs].append(assignment.rhs) - else: - backward_assignment_dict[assignment.lhs] = [assignment.rhs] - else: - raise NotImplementedError() - - backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()] - - try: - - if self._do_common_subexpression_elimination: - backward_assignments = ps.simp.sympy_cse_on_assignment_list( - backward_assignments) - except Exception: - pass - # print("Common subexpression elimination failed") - # print(err) - main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)] - subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)] - backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions) - - assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!" - - self._backward_assignments = backward_assignments - self._backward_field_map = {**diff_read_fields, **diff_write_fields} - self._backward_input_fields = [ - self._backward_field_map[f] for f in self._forward_output_fields] - self._backward_output_fields = [ - self._backward_field_map[f] for f in self._forward_input_fields] - @property def forward_assignments(self): return self._forward_assignments @@ -629,6 +672,7 @@ Backward: return op +@pystencils.cache.disk_cache_no_fallback def create_backward_assignments(forward_assignments, diff_fields_prefix="diff", time_constant_fields=[], @@ -640,7 +684,9 @@ def create_backward_assignments(forward_assignments, time_constant_fields=time_constant_fields, constant_fields=constant_fields, diff_mode=diff_mode, - do_common_subexpression_elimination=do_common_sub_expression_elimination) + do_common_subexpression_elimination=do_common_sub_expression_elimination, + no_chaching=True + ) backward_assignments = auto_diff.backward_assignments return backward_assignments