diff --git a/tests/test_tfmad.py b/tests/test_tfmad.py index 74a9aa8a049b0a5cf426c37b0bbda99bd9ad6474..ab53db949742dc8f376d9f8839de311937947d61 100644 --- a/tests/test_tfmad.py +++ b/tests/test_tfmad.py @@ -231,7 +231,7 @@ def test_tfmad_gradient_check_torch_native(with_offsets, with_cuda): [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True) -@pytest.mark.parametrize('with_cuda', (False, 'with_cuda')) +@pytest.mark.parametrize('with_cuda', (False, True)) def test_tfmad_gradient_check_two_outputs(with_cuda): torch = pytest.importorskip('torch') import torch @@ -274,9 +274,9 @@ def test_tfmad_gradient_check_two_outputs(with_cuda): dict = { a: a_tensor, b: b_tensor, - out1_tensor: out1_tensor, - out2_tensor: out2_tensor, - out3_tensor: out3_tensor, + out1: out1_tensor, + out2: out2_tensor, + out3: out3_tensor, } torch.autograd.gradcheck(function.apply, tuple( [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)