Skip to content
Snippets Groups Projects
Commit 603ad7ef authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Weirdly cache result

parent 0ef3d7a9
No related branches found
No related tags found
No related merge requests found
......@@ -7,11 +7,151 @@ import numpy as np
import sympy as sp
import pystencils as ps
import pystencils.cache
import pystencils_autodiff._layout_fixer
from pystencils_autodiff.backends import AVAILABLE_BACKENDS
from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling
@pystencils.cache.disk_cache_no_fallback
def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
"""
Performs the automatic backward differentiation in a more fancy way with write accesses
like in the forward pass (only flipped).
It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD).
See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that
paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true
for more information
"""
forward_assignments = self._forward_assignments
if hasattr(forward_assignments, 'new_without_subexpressions'):
forward_assignments = forward_assignments.new_without_subexpressions()
if hasattr(forward_assignments, 'main_assignments'):
forward_assignments = forward_assignments.main_assignments
if not hasattr(forward_assignments, 'free_symbols'):
forward_assignments = ps.AssignmentCollection(
forward_assignments, [])
read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses}
self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses
self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
assert write_field_accesses, "No write accesses found"
# for every field create a corresponding diff field
diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields}
diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields}
assert all(isinstance(w, ps.Field.Access)
for w in write_field_accesses), \
"Please check if your assignments are a AssignmentCollection or main_assignments only"
backward_assignment_dict = collections.OrderedDict()
# for each output of forward operation
for _, forward_assignment in enumerate(forward_assignments.main_assignments):
# we have only one assignment
diff_write_field = diff_write_fields[forward_assignment.lhs.field]
diff_write_index = forward_assignment.lhs.index
# TODO: simplify implementation. use matrix notation like in 'transposed' mode
for forward_read_field in self._forward_input_fields:
if forward_read_field in self._constant_fields:
continue
diff_read_field = diff_read_fields[forward_read_field]
if diff_read_field.index_dimensions == 0:
diff_read_field_sum = 0
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index)
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field_sum)
# We can have contributions from multiple forward assignments
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
elif diff_read_field.index_dimensions == 1:
diff_read_field_sum = [0] * diff_read_field.index_shape[0]
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, write_field_accesses[0].offsets))
diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]):
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center_vector[index],
diff_read_field.center_vector[index] + diff_read_field_sum[index])
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center_vector[index], diff_read_field_sum[index])
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
else:
raise NotImplementedError()
backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
try:
if self._do_common_subexpression_elimination:
backward_assignments = ps.simp.sympy_cse_on_assignment_list(
backward_assignments)
except Exception:
pass
# print("Common subexpression elimination failed")
# print(err)
main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)]
subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)]
backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions)
assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!"
self._backward_field_map = {**diff_read_fields, **diff_write_fields}
return backward_assignments
class AutoDiffBoundaryHandling(str, Enum):
"""
Strategies for in-kernel boundary handling for AutoDiffOp
......@@ -112,12 +252,39 @@ Backward:
self._backward_input_fields = list(backward_assignments.free_fields)
self._backward_output_fields = list(backward_assignments.bound_fields)
else:
# if no_caching:
if diff_mode == 'transposed':
self._create_backward_assignments(diff_fields_prefix)
elif diff_mode == 'transposed-forward':
self._create_backward_assignments_tf_mad(diff_fields_prefix)
self._backward_assignments = None
self._backward_field_map = None
backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix)
self._backward_assignments = backward_assignments
if self._backward_field_map:
self._backward_input_fields = [
self._backward_field_map[f] for f in self._forward_output_fields]
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
else:
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
self._forward_write_accesses = None
self._forward_input_fields = list(forward_assignments.free_fields)
self._forward_output_fields = list(forward_assignments.bound_fields)
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(backward_assignments.free_fields)
self._backward_output_fields = list(backward_assignments.bound_fields)
self._backward_field_map = None
else:
raise NotImplementedError()
# else:
# # self.backward_assignments = create_backward_assignments(forward_assignments,
# # diff_fields_prefix,
# # time_constant_fields,
# # constant_fields,
# # diff_mode=diff_mode,
# do_common_subexpression_elimination=do_common_subexpression_elimination)
def __hash__(self):
return hash((str(self.forward_assignments), str(self.backward_assignments)))
......@@ -129,6 +296,23 @@ Backward:
def __str__(self):
return self.__repr__()
def __setstate__(self, state):
forward_assignments = state['forward_assignments']
backward_assignments = state['backward_assignments']
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
self._forward_write_accesses = None
self._forward_input_fields = list(forward_assignments.free_fields)
self._forward_output_fields = list(forward_assignments.bound_fields)
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(backward_assignments.free_fields)
self._backward_output_fields = list(backward_assignments.bound_fields)
self._backward_field_map = None
def __getstate__(self):
return {'forward_assignments': self._forward_assignments, 'backward_assignments': self.backward_assignments}
def _create_backward_assignments(self, diff_fields_prefix):
"""
Performs automatic differentiation in the traditional adjoint/tangent way.
......@@ -214,147 +398,6 @@ Backward:
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
"""
Performs the automatic backward differentiation in a more fancy way with write accesses
like in the forward pass (only flipped).
It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD).
See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that
paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true
for more information
"""
forward_assignments = self._forward_assignments
if hasattr(forward_assignments, 'new_without_subexpressions'):
forward_assignments = forward_assignments.new_without_subexpressions()
if hasattr(forward_assignments, 'main_assignments'):
forward_assignments = forward_assignments.main_assignments
if not hasattr(forward_assignments, 'free_symbols'):
forward_assignments = ps.AssignmentCollection(
forward_assignments, [])
read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses}
self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses
self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
assert write_field_accesses, "No write accesses found"
# for every field create a corresponding diff field
diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields}
diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields}
assert all(isinstance(w, ps.Field.Access)
for w in write_field_accesses), \
"Please check if your assignments are a AssignmentCollection or main_assignments only"
backward_assignment_dict = collections.OrderedDict()
# for each output of forward operation
for _, forward_assignment in enumerate(forward_assignments.main_assignments):
# we have only one assignment
diff_write_field = diff_write_fields[forward_assignment.lhs.field]
diff_write_index = forward_assignment.lhs.index
# TODO: simplify implementation. use matrix notation like in 'transposed' mode
for forward_read_field in self._forward_input_fields:
if forward_read_field in self._constant_fields:
continue
diff_read_field = diff_read_fields[forward_read_field]
if diff_read_field.index_dimensions == 0:
diff_read_field_sum = 0
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index)
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field_sum)
# We can have contributions from multiple forward assignments
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
elif diff_read_field.index_dimensions == 1:
diff_read_field_sum = [0] * diff_read_field.index_shape[0]
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, write_field_accesses[0].offsets))
diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]):
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center_vector[index],
diff_read_field.center_vector[index] + diff_read_field_sum[index])
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center_vector[index], diff_read_field_sum[index])
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
else:
raise NotImplementedError()
backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
try:
if self._do_common_subexpression_elimination:
backward_assignments = ps.simp.sympy_cse_on_assignment_list(
backward_assignments)
except Exception:
pass
# print("Common subexpression elimination failed")
# print(err)
main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)]
subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)]
backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions)
assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!"
self._backward_assignments = backward_assignments
self._backward_field_map = {**diff_read_fields, **diff_write_fields}
self._backward_input_fields = [
self._backward_field_map[f] for f in self._forward_output_fields]
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
@property
def forward_assignments(self):
return self._forward_assignments
......@@ -629,6 +672,7 @@ Backward:
return op
@pystencils.cache.disk_cache_no_fallback
def create_backward_assignments(forward_assignments,
diff_fields_prefix="diff",
time_constant_fields=[],
......@@ -640,7 +684,9 @@ def create_backward_assignments(forward_assignments,
time_constant_fields=time_constant_fields,
constant_fields=constant_fields,
diff_mode=diff_mode,
do_common_subexpression_elimination=do_common_sub_expression_elimination)
do_common_subexpression_elimination=do_common_sub_expression_elimination,
no_chaching=True
)
backward_assignments = auto_diff.backward_assignments
return backward_assignments
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment