diff --git a/tests/test_tfmad.py b/tests/test_tfmad.py index 46279dda99ecac0472ba21bec1f68785064c13d1..5534819c6672552b86abfb933f9f427f82f9692a 100644 --- a/tests/test_tfmad.py +++ b/tests/test_tfmad.py @@ -239,8 +239,8 @@ def test_tfmad_gradient_check_torch_native(with_offsets, with_cuda): [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True) -# @pytest.mark.parametrize('with_offsets, with_cuda', itertools.product((False, True), repeat=2)) -@pytest.mark.parametrize('with_offsets, with_cuda, gradient_check', ((False, False, True),)) +# @pytest.mark.parametrize('with_offsets, with_cuda, gradient_check', ((True, False, False),)) +@pytest.mark.parametrize('with_offsets, with_cuda', itertools.product((False, True), repeat=3)) def test_tfmad_gradient_check_tensorflow_native(with_offsets, with_cuda, gradient_check): pytest.importorskip('tensorflow') import tensorflow as tf