diff --git a/tests/test_tfmad.py b/tests/test_tfmad.py
index 46279dda99ecac0472ba21bec1f68785064c13d1..5534819c6672552b86abfb933f9f427f82f9692a 100644
--- a/tests/test_tfmad.py
+++ b/tests/test_tfmad.py
@@ -239,8 +239,8 @@ def test_tfmad_gradient_check_torch_native(with_offsets, with_cuda):
         [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
 
 
-# @pytest.mark.parametrize('with_offsets, with_cuda', itertools.product((False, True), repeat=2))
-@pytest.mark.parametrize('with_offsets, with_cuda, gradient_check', ((False, False, True),))
+# @pytest.mark.parametrize('with_offsets, with_cuda, gradient_check', ((True, False, False),))
+@pytest.mark.parametrize('with_offsets, with_cuda', itertools.product((False, True), repeat=3))
 def test_tfmad_gradient_check_tensorflow_native(with_offsets, with_cuda, gradient_check):
     pytest.importorskip('tensorflow')
     import tensorflow as tf