Skip to content
Snippets Groups Projects
Commit ffd1b9d7 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

fix requires_grad

parent e91b33d6
No related branches found
No related tags found
No related merge requests found
......@@ -73,6 +73,9 @@ def create_autograd_function(autodiff_obj, use_cuda):
output_tensors = OrderedDict({f.name:
field_to_tensor_dict.get(f, kwargs[f.name])
for f in autodiff_obj.forward_output_fields})
for o in output_tensors.values():
if isinstance(o, torch.Tensor):
o.requires_grad = True
field_to_tensor_dict.update(kwargs)
self.saved_for_backward = kwargs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment