Skip to content
Snippets Groups Projects
Commit 303171de authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use cuda

parent 990ee0b0
Branches
Tags
No related merge requests found
......@@ -62,7 +62,8 @@ def create_autograd_function(autodiff_obj, use_cuda):
kwargs[field.name] = torch.zeros(
field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=next(chain(args, kwargs.values())).device)
device='cuda' if use_cuda else 'cpu') # use device of tensor
output_tensors = OrderedDict({f.name:
field_to_tensor_dict.get(f, kwargs[f.name])
for f in autodiff_obj.forward_output_fields})
......
......@@ -55,7 +55,7 @@ class Header(JinjaCppFile):
class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data_ptr<{dtype}>()",
FieldPointerSymbol: "data<{dtype}>()",
FieldShapeSymbol: "size({dim})",
FieldStrideSymbol: "strides()[{dim}]"
}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment