From 6d56ff4333f4b3e18ee438f6dff74122b917ac99 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 7 Aug 2019 20:23:07 +0200
Subject: [PATCH] Also test DiffModes.TRANSPOSED

---
 src/pystencils_autodiff/autodiff.py |  4 ++--
 tests/test_autodiff.py              | 25 ++++++++++++++++---------
 2 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/src/pystencils_autodiff/autodiff.py b/src/pystencils_autodiff/autodiff.py
index e6c8623..a81929b 100644
--- a/src/pystencils_autodiff/autodiff.py
+++ b/src/pystencils_autodiff/autodiff.py
@@ -121,9 +121,9 @@ Backward:
         write_fields = {s.field for s in write_field_accesses}
 
         # for every field create a corresponding diff field
-        diff_read_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name)
+        diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
                             for f in read_fields}
-        diff_write_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name)
+        diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
                              for f in write_fields}
 
         # Translate field accesses from standard to diff fields
diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py
index 971aac8..3ddc3a7 100644
--- a/tests/test_autodiff.py
+++ b/tests/test_autodiff.py
@@ -1,8 +1,8 @@
-import pytest
 import sympy as sp
 
 import pystencils as ps
 import pystencils_autodiff
+from pystencils_autodiff.autodiff import DiffModes
 
 
 def test_simple_2d_check_assignment_collection():
@@ -20,18 +20,24 @@ def test_simple_2d_check_assignment_collection():
     print(repr(jac))
     assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])'
 
-    pystencils_autodiff.create_backward_assignments(
-        forward_assignments)
-    pystencils_autodiff.create_backward_assignments(
-        pystencils_autodiff.create_backward_assignments(forward_assignments))
+    for diff_mode in DiffModes:
+        pystencils_autodiff.create_backward_assignments(
+            forward_assignments, diff_mode=diff_mode)
+        pystencils_autodiff.create_backward_assignments(
+            pystencils_autodiff.create_backward_assignments(forward_assignments), diff_mode=diff_mode)
+
+    result1 = pystencils_autodiff.create_backward_assignments(
+        forward_assignments, diff_mode=DiffModes.TRANSPOSED)
+    result2 = pystencils_autodiff.create_backward_assignments(
+        forward_assignments, diff_mode=DiffModes.TF_MAD)
+    assert result1 == result2
 
 
 def test_simple_2d_check_raw_assignments():
     # use simply example
     z, x, y = ps.fields("z, y, x: [2d]")
 
-    forward_assignments = \
-        [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))]
+    forward_assignments = [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))]
 
     jac = pystencils_autodiff.get_jacobian_of_assignments(
         forward_assignments, [x[0, 0], y[0, 0]])
@@ -39,8 +45,9 @@ def test_simple_2d_check_raw_assignments():
     assert jac.shape == (1, 2)
     assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])'
 
-    pystencils_autodiff.create_backward_assignments(
-        forward_assignments)
+    for diff_mode in DiffModes:
+        pystencils_autodiff.create_backward_assignments(
+            forward_assignments, diff_mode=diff_mode)
 
 
 def main():
-- 
GitLab