Skip to content
Snippets Groups Projects
Commit d15f54b3 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Enable autodiff for gradients

parent 198df690
Branches
Tags
No related merge requests found
...@@ -9,9 +9,13 @@ import sympy as sp ...@@ -9,9 +9,13 @@ import sympy as sp
import pystencils as ps import pystencils as ps
import pystencils.cache import pystencils.cache
import pystencils_autodiff._layout_fixer import pystencils_autodiff._layout_fixer
from pystencils.interpolation_astnodes import InterpolatorAccess
from pystencils.math_optimizations import ReplaceOptim, optimize_assignments
from pystencils_autodiff.backends import AVAILABLE_BACKENDS from pystencils_autodiff.backends import AVAILABLE_BACKENDS
from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling
REMOVE_CASTS = ReplaceOptim(lambda x: isinstance(x, pystencils.data_types.cast_func), lambda x: x.args[0])
@pystencils.cache.disk_cache_no_fallback @pystencils.cache.disk_cache_no_fallback
def _create_backward_assignments_tf_mad(self, diff_fields_prefix): def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
...@@ -35,25 +39,28 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -35,25 +39,28 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
if not hasattr(forward_assignments, 'free_symbols'): if not hasattr(forward_assignments, 'free_symbols'):
forward_assignments = ps.AssignmentCollection( forward_assignments = ps.AssignmentCollection(
forward_assignments, []) forward_assignments, [])
forward_assignments = ps.AssignmentCollection(optimize_assignments(forward_assignments, [
REMOVE_CASTS
]))
read_field_accesses = sorted([ read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x)) a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)]
+ list(forward_assignments.atoms(InterpolatorAccess)), key=str)
write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x)) write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses} read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses} write_fields = {s.field for s in write_field_accesses}
self._forward_read_accesses = read_field_accesses self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses self._forward_write_accesses = write_field_accesses
self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x)) self._forward_input_fields = sorted(list(read_fields), key=str)
self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x)) self._forward_output_fields = sorted(list(write_fields), key=str)
read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)] write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
assert write_field_accesses, "No write accesses found" assert write_field_accesses, "No write accesses found"
# for every field create a corresponding diff field # for every field create a corresponding diff field
diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields} for f in read_fields if f not in self._constant_fields}
diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields} for f in write_fields}
...@@ -81,11 +88,23 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -81,11 +88,23 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
if ra.field != forward_read_field: if ra.field != forward_read_field:
continue # ignore constant fields in differentiation continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils if isinstance(ra, InterpolatorAccess):
inverted_offset = tuple(-v - w for v, out_symbols = sp.symbols(f'o:{len(ra.offsets)}')
w in zip(ra.offsets, forward_assignment.lhs.offsets)) in_symbols = pystencils.x_vector(len(ra.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \ maybe_solution = sp.solve(sp.Matrix(ra.offsets) - sp.Matrix(out_symbols), in_symbols)
diff_write_field[inverted_offset](*diff_write_index) assert maybe_solution, f'Could not solve for {in_symbols} when trying to derive for interpolator' # noqa
inverted_offset = tuple(maybe_solution.values())
inverted_offset = [foo.subs({o: s for o, s in zip(out_symbols, in_symbols)})
for foo in inverted_offset]
diff_read_field_sum += (sp.diff(forward_assignment.rhs, ra) *
diff_write_field.interpolated_access(inverted_offset))
else:
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index)
if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields: if self.time_constant_fields is not None and forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields # Accumulate in case of time_constant_fields
assignment = ps.Assignment( assignment = ps.Assignment(
...@@ -111,8 +130,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -111,8 +130,8 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
# TF-MAD requires flipped stencils # TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v, inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, write_field_accesses[0].offsets)) w in zip(ra.offsets, write_field_accesses[0].offsets))
diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \ diff_read_field_sum[ra.index[0]
diff_write_field[inverted_offset] ] += sp.diff(forward_assignment.rhs, ra) * diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]): for index in range(diff_read_field.index_shape[0]):
if forward_read_field in self._time_constant_fields: if forward_read_field in self._time_constant_fields:
...@@ -135,7 +154,6 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix): ...@@ -135,7 +154,6 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()] backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
try: try:
if self._do_common_subexpression_elimination: if self._do_common_subexpression_elimination:
backward_assignments = ps.simp.sympy_cse_on_assignment_list( backward_assignments = ps.simp.sympy_cse_on_assignment_list(
backward_assignments) backward_assignments)
...@@ -282,7 +300,7 @@ Backward: ...@@ -282,7 +300,7 @@ Backward:
# do_common_subexpression_elimination=do_common_subexpression_elimination) # do_common_subexpression_elimination=do_common_subexpression_elimination)
def __hash__(self): def __hash__(self):
return hash((str(self.forward_assignments), str(self.backward_assignments))) return hash((str(self.forward_assignments), str(self.backward_assignments), str(self.constant_fields)))
def __repr__(self): def __repr__(self):
return self._REPR_TEMPLATE.render(forward_assignments=str(self.forward_assignments), return self._REPR_TEMPLATE.render(forward_assignments=str(self.forward_assignments),
...@@ -774,3 +792,6 @@ def get_jacobian_of_assignments(assignments, diff_variables): ...@@ -774,3 +792,6 @@ def get_jacobian_of_assignments(assignments, diff_variables):
rhs = sp.Matrix([e.rhs for e in assignments]) rhs = sp.Matrix([e.rhs for e in assignments])
return rhs.jacobian(diff_variables) return rhs.jacobian(diff_variables)
rhs = sp.Matrix([e.rhs for e in assignments])
return rhs.jacobian(diff_variables)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment