From 25ebe5829217b650262410473c166226aa6f8dc6 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Mon, 16 Dec 2019 12:19:02 +0100 Subject: [PATCH] Allow use_cuda to be truthy (instead of only True/False) --- src/pystencils_autodiff/backends/_torch_native.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 75f79bf..9fae28c 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -74,9 +74,12 @@ def create_autograd_function(autodiff_obj, use_cuda): return tuple(output_tensors.values()) def backward(self, *grad_outputs): + nonlocal use_cuda if use_cuda: + use_cuda = True grad_outputs = [a.contiguous().cuda() for a in grad_outputs] else: + use_cuda = False grad_outputs = [a.contiguous().cpu() for a in grad_outputs] grad_fields = [f for f in autodiff_obj.backward_input_fields if f not in autodiff_obj.forward_input_fields] @@ -84,7 +87,6 @@ def create_autograd_function(autodiff_obj, use_cuda): assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(grad_fields)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) for i, f in enumerate(grad_fields)) - assert use_cuda in (True, False), "use_cuda needs to be True or False" assert all(a.is_cuda == use_cuda for a in grad_outputs), ( "Some of the tensors where on the wrong device. " f"Op was compiled for CUDA: {str(use_cuda)}") -- GitLab