Skip to content
Snippets Groups Projects
Commit 603ad7ef authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Weirdly cache result

parent 0ef3d7a9
Branches
Tags
No related merge requests found
......@@ -7,11 +7,151 @@ import numpy as np
import sympy as sp
import pystencils as ps
import pystencils.cache
import pystencils_autodiff._layout_fixer
from pystencils_autodiff.backends import AVAILABLE_BACKENDS
from pystencils_autodiff.transformations import add_fixed_constant_boundary_handling
@pystencils.cache.disk_cache_no_fallback
def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
"""
Performs the automatic backward differentiation in a more fancy way with write accesses
like in the forward pass (only flipped).
It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD).
See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that
paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true
for more information
"""
forward_assignments = self._forward_assignments
if hasattr(forward_assignments, 'new_without_subexpressions'):
forward_assignments = forward_assignments.new_without_subexpressions()
if hasattr(forward_assignments, 'main_assignments'):
forward_assignments = forward_assignments.main_assignments
if not hasattr(forward_assignments, 'free_symbols'):
forward_assignments = ps.AssignmentCollection(
forward_assignments, [])
read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses}
self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses
self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
assert write_field_accesses, "No write accesses found"
# for every field create a corresponding diff field
diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields}
diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields}
assert all(isinstance(w, ps.Field.Access)
for w in write_field_accesses), \
"Please check if your assignments are a AssignmentCollection or main_assignments only"
backward_assignment_dict = collections.OrderedDict()
# for each output of forward operation
for _, forward_assignment in enumerate(forward_assignments.main_assignments):
# we have only one assignment
diff_write_field = diff_write_fields[forward_assignment.lhs.field]
diff_write_index = forward_assignment.lhs.index
# TODO: simplify implementation. use matrix notation like in 'transposed' mode
for forward_read_field in self._forward_input_fields:
if forward_read_field in self._constant_fields:
continue
diff_read_field = diff_read_fields[forward_read_field]
if diff_read_field.index_dimensions == 0:
diff_read_field_sum = 0
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index)
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field_sum)
# We can have contributions from multiple forward assignments
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
elif diff_read_field.index_dimensions == 1:
diff_read_field_sum = [0] * diff_read_field.index_shape[0]
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, write_field_accesses[0].offsets))
diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]):
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center_vector[index],
diff_read_field.center_vector[index] + diff_read_field_sum[index])
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center_vector[index], diff_read_field_sum[index])
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
else:
raise NotImplementedError()
backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
try:
if self._do_common_subexpression_elimination:
backward_assignments = ps.simp.sympy_cse_on_assignment_list(
backward_assignments)
except Exception:
pass
# print("Common subexpression elimination failed")
# print(err)
main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)]
subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)]
backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions)
assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!"
self._backward_field_map = {**diff_read_fields, **diff_write_fields}
return backward_assignments
class AutoDiffBoundaryHandling(str, Enum):
"""
Strategies for in-kernel boundary handling for AutoDiffOp
......@@ -112,12 +252,39 @@ Backward:
self._backward_input_fields = list(backward_assignments.free_fields)
self._backward_output_fields = list(backward_assignments.bound_fields)
else:
# if no_caching:
if diff_mode == 'transposed':
self._create_backward_assignments(diff_fields_prefix)
elif diff_mode == 'transposed-forward':
self._create_backward_assignments_tf_mad(diff_fields_prefix)
self._backward_assignments = None
self._backward_field_map = None
backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix)
self._backward_assignments = backward_assignments
if self._backward_field_map:
self._backward_input_fields = [
self._backward_field_map[f] for f in self._forward_output_fields]
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
else:
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
self._forward_write_accesses = None
self._forward_input_fields = list(forward_assignments.free_fields)
self._forward_output_fields = list(forward_assignments.bound_fields)
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(backward_assignments.free_fields)
self._backward_output_fields = list(backward_assignments.bound_fields)
self._backward_field_map = None
else:
raise NotImplementedError()
# else:
# # self.backward_assignments = create_backward_assignments(forward_assignments,
# # diff_fields_prefix,
# # time_constant_fields,
# # constant_fields,
# # diff_mode=diff_mode,
# do_common_subexpression_elimination=do_common_subexpression_elimination)
def __hash__(self):
return hash((str(self.forward_assignments), str(self.backward_assignments)))
......@@ -129,6 +296,23 @@ Backward:
def __str__(self):
return self.__repr__()
def __setstate__(self, state):
forward_assignments = state['forward_assignments']
backward_assignments = state['backward_assignments']
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
self._forward_write_accesses = None
self._forward_input_fields = list(forward_assignments.free_fields)
self._forward_output_fields = list(forward_assignments.bound_fields)
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(backward_assignments.free_fields)
self._backward_output_fields = list(backward_assignments.bound_fields)
self._backward_field_map = None
def __getstate__(self):
return {'forward_assignments': self._forward_assignments, 'backward_assignments': self.backward_assignments}
def _create_backward_assignments(self, diff_fields_prefix):
"""
Performs automatic differentiation in the traditional adjoint/tangent way.
......@@ -214,147 +398,6 @@ Backward:
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
"""
Performs the automatic backward differentiation in a more fancy way with write accesses
like in the forward pass (only flipped).
It is called "transposed-mode forward-mode algorithmic differentiation" (TF-MAD).
See this presentation https://autodiff-workshop.github.io/slides/Hueckelheim_nips_autodiff_CNN_PDE.pdf or that
paper https://www.tandfonline.com/doi/full/10.1080/10556788.2018.1435654?scroll=top&needAccess=true
for more information
"""
forward_assignments = self._forward_assignments
if hasattr(forward_assignments, 'new_without_subexpressions'):
forward_assignments = forward_assignments.new_without_subexpressions()
if hasattr(forward_assignments, 'main_assignments'):
forward_assignments = forward_assignments.main_assignments
if not hasattr(forward_assignments, 'free_symbols'):
forward_assignments = ps.AssignmentCollection(
forward_assignments, [])
read_field_accesses = sorted([
a for a in forward_assignments.free_symbols if isinstance(a, ps.Field.Access)], key=lambda x: str(x))
write_field_accesses = sorted([a.lhs for a in forward_assignments], key=lambda x: str(x))
read_fields = {s.field for s in read_field_accesses}
write_fields = {s.field for s in write_field_accesses}
self._forward_read_accesses = read_field_accesses
self._forward_write_accesses = write_field_accesses
self._forward_input_fields = sorted(list(read_fields), key=lambda x: str(x))
self._forward_output_fields = sorted(list(write_fields), key=lambda x: str(x))
read_field_accesses = [s for s in forward_assignments.free_symbols if isinstance(s, ps.Field.Access)]
write_field_accesses = [a.lhs for a in forward_assignments if isinstance(a.lhs, ps.Field.Access)]
assert write_field_accesses, "No write accesses found"
# for every field create a corresponding diff field
diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields}
diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields}
assert all(isinstance(w, ps.Field.Access)
for w in write_field_accesses), \
"Please check if your assignments are a AssignmentCollection or main_assignments only"
backward_assignment_dict = collections.OrderedDict()
# for each output of forward operation
for _, forward_assignment in enumerate(forward_assignments.main_assignments):
# we have only one assignment
diff_write_field = diff_write_fields[forward_assignment.lhs.field]
diff_write_index = forward_assignment.lhs.index
# TODO: simplify implementation. use matrix notation like in 'transposed' mode
for forward_read_field in self._forward_input_fields:
if forward_read_field in self._constant_fields:
continue
diff_read_field = diff_read_fields[forward_read_field]
if diff_read_field.index_dimensions == 0:
diff_read_field_sum = 0
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, forward_assignment.lhs.offsets))
diff_read_field_sum += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset](*diff_write_index)
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field.center() + diff_read_field_sum)
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center(), diff_read_field_sum)
# We can have contributions from multiple forward assignments
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
elif diff_read_field.index_dimensions == 1:
diff_read_field_sum = [0] * diff_read_field.index_shape[0]
for ra in read_field_accesses:
if ra.field != forward_read_field:
continue # ignore constant fields in differentiation
# TF-MAD requires flipped stencils
inverted_offset = tuple(-v - w for v,
w in zip(ra.offsets, write_field_accesses[0].offsets))
diff_read_field_sum[ra.index[0]] += sp.diff(forward_assignment.rhs, ra) * \
diff_write_field[inverted_offset]
for index in range(diff_read_field.index_shape[0]):
if forward_read_field in self._time_constant_fields:
# Accumulate in case of time_constant_fields
assignment = ps.Assignment(
diff_read_field.center_vector[index],
diff_read_field.center_vector[index] + diff_read_field_sum[index])
else:
# If time dependent, we just need to assign the sum for the current time step
assignment = ps.Assignment(
diff_read_field.center_vector[index], diff_read_field_sum[index])
if assignment.lhs in backward_assignment_dict:
backward_assignment_dict[assignment.lhs].append(assignment.rhs)
else:
backward_assignment_dict[assignment.lhs] = [assignment.rhs]
else:
raise NotImplementedError()
backward_assignments = [ps.Assignment(k, sp.Add(*v)) for k, v in backward_assignment_dict.items()]
try:
if self._do_common_subexpression_elimination:
backward_assignments = ps.simp.sympy_cse_on_assignment_list(
backward_assignments)
except Exception:
pass
# print("Common subexpression elimination failed")
# print(err)
main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)]
subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)]
backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions)
assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!"
self._backward_assignments = backward_assignments
self._backward_field_map = {**diff_read_fields, **diff_write_fields}
self._backward_input_fields = [
self._backward_field_map[f] for f in self._forward_output_fields]
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
@property
def forward_assignments(self):
return self._forward_assignments
......@@ -629,6 +672,7 @@ Backward:
return op
@pystencils.cache.disk_cache_no_fallback
def create_backward_assignments(forward_assignments,
diff_fields_prefix="diff",
time_constant_fields=[],
......@@ -640,7 +684,9 @@ def create_backward_assignments(forward_assignments,
time_constant_fields=time_constant_fields,
constant_fields=constant_fields,
diff_mode=diff_mode,
do_common_subexpression_elimination=do_common_sub_expression_elimination)
do_common_subexpression_elimination=do_common_sub_expression_elimination,
no_chaching=True
)
backward_assignments = auto_diff.backward_assignments
return backward_assignments
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment