From 6d56ff4333f4b3e18ee438f6dff74122b917ac99 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Wed, 7 Aug 2019 20:23:07 +0200 Subject: [PATCH] Also test DiffModes.TRANSPOSED --- src/pystencils_autodiff/autodiff.py | 4 ++-- tests/test_autodiff.py | 25 ++++++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/pystencils_autodiff/autodiff.py b/src/pystencils_autodiff/autodiff.py index e6c8623..a81929b 100644 --- a/src/pystencils_autodiff/autodiff.py +++ b/src/pystencils_autodiff/autodiff.py @@ -121,9 +121,9 @@ Backward: write_fields = {s.field for s in write_field_accesses} # for every field create a corresponding diff field - diff_read_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name) + diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in read_fields} - diff_write_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name) + diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in write_fields} # Translate field accesses from standard to diff fields diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 971aac8..3ddc3a7 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -1,8 +1,8 @@ -import pytest import sympy as sp import pystencils as ps import pystencils_autodiff +from pystencils_autodiff.autodiff import DiffModes def test_simple_2d_check_assignment_collection(): @@ -20,18 +20,24 @@ def test_simple_2d_check_assignment_collection(): print(repr(jac)) assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])' - pystencils_autodiff.create_backward_assignments( - forward_assignments) - pystencils_autodiff.create_backward_assignments( - pystencils_autodiff.create_backward_assignments(forward_assignments)) + for diff_mode in DiffModes: + pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=diff_mode) + pystencils_autodiff.create_backward_assignments( + pystencils_autodiff.create_backward_assignments(forward_assignments), diff_mode=diff_mode) + + result1 = pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=DiffModes.TRANSPOSED) + result2 = pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=DiffModes.TF_MAD) + assert result1 == result2 def test_simple_2d_check_raw_assignments(): # use simply example z, x, y = ps.fields("z, y, x: [2d]") - forward_assignments = \ - [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))] + forward_assignments = [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))] jac = pystencils_autodiff.get_jacobian_of_assignments( forward_assignments, [x[0, 0], y[0, 0]]) @@ -39,8 +45,9 @@ def test_simple_2d_check_raw_assignments(): assert jac.shape == (1, 2) assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])' - pystencils_autodiff.create_backward_assignments( - forward_assignments) + for diff_mode in DiffModes: + pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=diff_mode) def main(): -- GitLab