Skip to content
Snippets Groups Projects
Commit 4c508c11 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Print shape of tensor in tfmad_checkgradient

parent bb88ec51
No related branches found
No related tags found
No related merge requests found
...@@ -59,6 +59,7 @@ def test_tfmad_two_stencils(): ...@@ -59,6 +59,7 @@ def test_tfmad_two_stencils():
@pytest.mark.skipif("NO_TENSORFLOW_TEST" in os.environ, reason="Requires Tensorflow") @pytest.mark.skipif("NO_TENSORFLOW_TEST" in os.environ, reason="Requires Tensorflow")
def test_tfmad_gradient_check(): def test_tfmad_gradient_check():
a, b, out = ps.fields("a, b, out: double[21,13]") a, b, out = ps.fields("a, b, out: double[21,13]")
print(a.shape)
cont = ps.fd.Diff(a, 0) - ps.fd.Diff(a, 1) - ps.fd.Diff(b, 0) + ps.fd.Diff(b, 1) cont = ps.fd.Diff(a, 0) - ps.fd.Diff(a, 1) - ps.fd.Diff(b, 0) + ps.fd.Diff(b, 1)
discretize = ps.fd.Discretization2ndOrder(dx=1) discretize = ps.fd.Discretization2ndOrder(dx=1)
...@@ -248,12 +249,12 @@ def test_tfmad_two_outputs(): ...@@ -248,12 +249,12 @@ def test_tfmad_two_outputs():
def main(): def main():
test_tfmad_stencil() # test_tfmad_stencil()
test_tfmad_two_stencils() # test_tfmad_two_stencils()
test_tfmad_gradient_check_torch() # test_tfmad_gradient_check_torch()
test_tfmad_gradient_check() test_tfmad_gradient_check()
test_tfmad_vector_input_data() # test_tfmad_vector_input_data()
test_tfmad_two_outputs() # test_tfmad_two_outputs()
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment