diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 848da232acfef0d54b9a31fad6dd085567c8747f..fa5c2997817b91383fe5c9b768235ab110fe8e03 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -73,6 +73,9 @@ def create_autograd_function(autodiff_obj, use_cuda): output_tensors = OrderedDict({f.name: field_to_tensor_dict.get(f, kwargs[f.name]) for f in autodiff_obj.forward_output_fields}) + for o in output_tensors.values(): + if isinstance(o, torch.Tensor): + o.requires_grad = True field_to_tensor_dict.update(kwargs) self.saved_for_backward = kwargs