diff --git a/src/pystencils_autodiff/autodiff.py b/src/pystencils_autodiff/autodiff.py index e6c8623c8d740c0c87a8523213675aef849cf7fd..a81929b9d906acdb359e34b48ab8eeefe0d3c90f 100644 --- a/src/pystencils_autodiff/autodiff.py +++ b/src/pystencils_autodiff/autodiff.py @@ -121,9 +121,9 @@ Backward: write_fields = {s.field for s in write_field_accesses} # for every field create a corresponding diff field - diff_read_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name) + diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in read_fields} - diff_write_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name) + diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix) for f in write_fields} # Translate field accesses from standard to diff fields diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 971aac8f72449460930091fba44c70f11d45d9b6..3ddc3a791aab283a7d3fe582f888fbecf2bbf263 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -1,8 +1,8 @@ -import pytest import sympy as sp import pystencils as ps import pystencils_autodiff +from pystencils_autodiff.autodiff import DiffModes def test_simple_2d_check_assignment_collection(): @@ -20,18 +20,24 @@ def test_simple_2d_check_assignment_collection(): print(repr(jac)) assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])' - pystencils_autodiff.create_backward_assignments( - forward_assignments) - pystencils_autodiff.create_backward_assignments( - pystencils_autodiff.create_backward_assignments(forward_assignments)) + for diff_mode in DiffModes: + pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=diff_mode) + pystencils_autodiff.create_backward_assignments( + pystencils_autodiff.create_backward_assignments(forward_assignments), diff_mode=diff_mode) + + result1 = pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=DiffModes.TRANSPOSED) + result2 = pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=DiffModes.TF_MAD) + assert result1 == result2 def test_simple_2d_check_raw_assignments(): # use simply example z, x, y = ps.fields("z, y, x: [2d]") - forward_assignments = \ - [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))] + forward_assignments = [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))] jac = pystencils_autodiff.get_jacobian_of_assignments( forward_assignments, [x[0, 0], y[0, 0]]) @@ -39,8 +45,9 @@ def test_simple_2d_check_raw_assignments(): assert jac.shape == (1, 2) assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])' - pystencils_autodiff.create_backward_assignments( - forward_assignments) + for diff_mode in DiffModes: + pystencils_autodiff.create_backward_assignments( + forward_assignments, diff_mode=diff_mode) def main():