diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index 64b30a737eec344ca4cc3aab08eeeebc650dbf02..c3916ff497d2e5ad7fce815bc38bcb357496796d 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -205,3 +205,36 @@ def test_execute_torch(target): z_torch = np.copy(z) assert np.allclose(z_torch[1:-1, 1:-1], z_pystencils[1:-1, 1:-1], atol=1e-6) + + +def test_reproducability(): + from sympy.core.cache import clear_cache + + output_0 = None + for i in range(10): + module_name = "Ololol" + + target = 'gpu' + + z, y, x = pystencils.fields("z, y, x: [20,40]") + a = sympy.Symbol('a') + + forward_assignments = pystencils.AssignmentCollection({ + z[0, 0]: x[0, 0] * sympy.log(a * x[0, 0] * y[0, 0]) + }) + + backward_assignments = create_backward_assignments(forward_assignments) + + forward_ast = pystencils.create_kernel(forward_assignments, target) + forward_ast.function_name = 'forward' + backward_ast = pystencils.create_kernel(backward_assignments, target) + backward_ast.function_name = 'backward' + new_output = str(TorchModule(module_name, [forward_ast, backward_ast])) + TorchModule(module_name, [forward_ast, backward_ast]).compile() + + clear_cache() + + if not output_0: + output_0 = new_output + + assert output_0 == new_output