Skip to content
Snippets Groups Projects
Commit a91a2f9f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Remove AutodiffOp.backward_fields_map

parent ddba65e5
Branches
Tags
No related merge requests found
......@@ -26,6 +26,7 @@ def _create_backward_assignments_tf_mad(self, diff_fields_prefix):
"""
forward_assignments = self._forward_assignments
if hasattr(forward_assignments, 'new_without_subexpressions'):
forward_assignments = forward_assignments.new_without_subexpressions()
if hasattr(forward_assignments, 'main_assignments'):
......@@ -249,7 +250,6 @@ Backward:
self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x)))
self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x)))
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x)))
self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x)))
else:
......@@ -261,22 +261,16 @@ Backward:
self._backward_field_map = None
backward_assignments = _create_backward_assignments_tf_mad(self, diff_fields_prefix)
self._backward_assignments = backward_assignments
if self._backward_field_map:
self._backward_input_fields = [
self._backward_field_map[f] for f in self._forward_output_fields]
self._backward_output_fields = [
self._backward_field_map[f] for f in self._forward_input_fields]
else:
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
self._forward_write_accesses = None
self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x)))
self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x)))
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x)))
self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x)))
self._backward_field_map = None
self._forward_assignments = forward_assignments
self._forward_read_accesses = None
self._forward_write_accesses = None
self._forward_input_fields = list(sorted(forward_assignments.free_fields, key=lambda x: str(x)))
self._forward_output_fields = list(sorted(forward_assignments.bound_fields, key=lambda x: str(x)))
self._backward_assignments = backward_assignments
self._backward_field_map = None
self._backward_input_fields = list(sorted(backward_assignments.free_fields, key=lambda x: str(x)))
self._backward_output_fields = list(sorted(backward_assignments.bound_fields, key=lambda x: str(x)))
else:
raise NotImplementedError()
# else:
......@@ -528,10 +522,6 @@ Backward:
self._backward_kernel_gpu = self.backward_ast_gpu.compile()
return self._backward_kernel_gpu
@property
def backward_fields_map(self):
return self._backward_field_map
@property
def backward_input_fields(self):
return self._backward_input_fields
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment