Skip to content
Snippets Groups Projects
Commit 7cbdd430 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test_reproducability

parent 1a1d6d59
Branches
Tags
No related merge requests found
Pipeline #18103 passed
...@@ -205,3 +205,36 @@ def test_execute_torch(target): ...@@ -205,3 +205,36 @@ def test_execute_torch(target):
z_torch = np.copy(z) z_torch = np.copy(z)
assert np.allclose(z_torch[1:-1, 1:-1], z_pystencils[1:-1, 1:-1], atol=1e-6) assert np.allclose(z_torch[1:-1, 1:-1], z_pystencils[1:-1, 1:-1], atol=1e-6)
def test_reproducability():
from sympy.core.cache import clear_cache
output_0 = None
for i in range(10):
module_name = "Ololol"
target = 'gpu'
z, y, x = pystencils.fields("z, y, x: [20,40]")
a = sympy.Symbol('a')
forward_assignments = pystencils.AssignmentCollection({
z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0])
})
backward_assignments = create_backward_assignments(forward_assignments)
forward_ast = pystencils.create_kernel(forward_assignments, target)
forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward'
new_output = str(TorchModule(module_name, [forward_ast, backward_ast]))
TorchModule(module_name, [forward_ast, backward_ast]).compile()
clear_cache()
if not output_0:
output_0 = new_output
assert output_0 == new_output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment