From a97975445120e6aa1213fec405de33d01ab8bb01 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Mon, 16 Dec 2019 12:02:49 +0100
Subject: [PATCH] Make test_tfmad_gradient_check_torch_native also pass with
 CUDA with_cuda needed to be True not 'with_cuda'

---
 tests/test_tfmad.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tests/test_tfmad.py b/tests/test_tfmad.py
index 74a9aa8..ab53db9 100644
--- a/tests/test_tfmad.py
+++ b/tests/test_tfmad.py
@@ -231,7 +231,7 @@ def test_tfmad_gradient_check_torch_native(with_offsets, with_cuda):
         [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
 
 
-@pytest.mark.parametrize('with_cuda', (False, 'with_cuda'))
+@pytest.mark.parametrize('with_cuda', (False, True))
 def test_tfmad_gradient_check_two_outputs(with_cuda):
     torch = pytest.importorskip('torch')
     import torch
@@ -274,9 +274,9 @@ def test_tfmad_gradient_check_two_outputs(with_cuda):
     dict = {
         a: a_tensor,
         b: b_tensor,
-        out1_tensor: out1_tensor,
-        out2_tensor: out2_tensor,
-        out3_tensor: out3_tensor,
+        out1: out1_tensor,
+        out2: out2_tensor,
+        out3: out3_tensor,
     }
     torch.autograd.gradcheck(function.apply, tuple(
         [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
-- 
GitLab