Skip to content
Snippets Groups Projects
Commit efa0baf1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix _torch_native

parent cb6afb11
Branches
Tags
No related merge requests found
...@@ -64,7 +64,7 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -64,7 +64,7 @@ def create_autograd_function(autodiff_obj, use_cuda):
kwargs[field.name] = torch.zeros( kwargs[field.name] = torch.zeros(
field.shape, field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=chain(args[0], kwargs.values()).device) device=next(chain(args, kwargs.values())).device)
output_tensors = OrderedDict({f.name: output_tensors = OrderedDict({f.name:
field_to_tensor_dict.get(f, kwargs[f.name]) field_to_tensor_dict.get(f, kwargs[f.name])
for f in autodiff_obj.forward_output_fields}) for f in autodiff_obj.forward_output_fields})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment