Skip to content
Snippets Groups Projects
Commit 4c508c11 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Print shape of tensor in tfmad_checkgradient

parent bb88ec51
No related branches found
No related tags found
No related merge requests found
......@@ -59,6 +59,7 @@ def test_tfmad_two_stencils():
@pytest.mark.skipif("NO_TENSORFLOW_TEST" in os.environ, reason="Requires Tensorflow")
def test_tfmad_gradient_check():
a, b, out = ps.fields("a, b, out: double[21,13]")
print(a.shape)
cont = ps.fd.Diff(a, 0) - ps.fd.Diff(a, 1) - ps.fd.Diff(b, 0) + ps.fd.Diff(b, 1)
discretize = ps.fd.Discretization2ndOrder(dx=1)
......@@ -248,12 +249,12 @@ def test_tfmad_two_outputs():
def main():
test_tfmad_stencil()
test_tfmad_two_stencils()
test_tfmad_gradient_check_torch()
# test_tfmad_stencil()
# test_tfmad_two_stencils()
# test_tfmad_gradient_check_torch()
test_tfmad_gradient_check()
test_tfmad_vector_input_data()
test_tfmad_two_outputs()
# test_tfmad_vector_input_data()
# test_tfmad_two_outputs()
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment