Skip to content
Snippets Groups Projects
Commit 6d56ff43 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Also test DiffModes.TRANSPOSED

parent 80b6ae21
No related branches found
No related tags found
No related merge requests found
Pipeline #17023 failed
...@@ -121,9 +121,9 @@ Backward: ...@@ -121,9 +121,9 @@ Backward:
write_fields = {s.field for s in write_field_accesses} write_fields = {s.field for s in write_field_accesses}
# for every field create a corresponding diff field # for every field create a corresponding diff field
diff_read_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name) diff_read_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in read_fields} for f in read_fields}
diff_write_fields = {f: f.new_field_with_different_name(diff_fields_prefix + f.name) diff_write_fields = {f: pystencils_autodiff.AdjointField(f, diff_fields_prefix)
for f in write_fields} for f in write_fields}
# Translate field accesses from standard to diff fields # Translate field accesses from standard to diff fields
......
import pytest
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
import pystencils_autodiff import pystencils_autodiff
from pystencils_autodiff.autodiff import DiffModes
def test_simple_2d_check_assignment_collection(): def test_simple_2d_check_assignment_collection():
...@@ -20,18 +20,24 @@ def test_simple_2d_check_assignment_collection(): ...@@ -20,18 +20,24 @@ def test_simple_2d_check_assignment_collection():
print(repr(jac)) print(repr(jac))
assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])' assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])'
pystencils_autodiff.create_backward_assignments( for diff_mode in DiffModes:
forward_assignments) pystencils_autodiff.create_backward_assignments(
pystencils_autodiff.create_backward_assignments( forward_assignments, diff_mode=diff_mode)
pystencils_autodiff.create_backward_assignments(forward_assignments)) pystencils_autodiff.create_backward_assignments(
pystencils_autodiff.create_backward_assignments(forward_assignments), diff_mode=diff_mode)
result1 = pystencils_autodiff.create_backward_assignments(
forward_assignments, diff_mode=DiffModes.TRANSPOSED)
result2 = pystencils_autodiff.create_backward_assignments(
forward_assignments, diff_mode=DiffModes.TF_MAD)
assert result1 == result2
def test_simple_2d_check_raw_assignments(): def test_simple_2d_check_raw_assignments():
# use simply example # use simply example
z, x, y = ps.fields("z, y, x: [2d]") z, x, y = ps.fields("z, y, x: [2d]")
forward_assignments = \ forward_assignments = [ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))]
[ps.Assignment(z[0, 0], x[0, 0]*sp.log(x[0, 0]*y[0, 0]))]
jac = pystencils_autodiff.get_jacobian_of_assignments( jac = pystencils_autodiff.get_jacobian_of_assignments(
forward_assignments, [x[0, 0], y[0, 0]]) forward_assignments, [x[0, 0], y[0, 0]])
...@@ -39,8 +45,9 @@ def test_simple_2d_check_raw_assignments(): ...@@ -39,8 +45,9 @@ def test_simple_2d_check_raw_assignments():
assert jac.shape == (1, 2) assert jac.shape == (1, 2)
assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])' assert repr(jac) == 'Matrix([[log(x_C*y_C) + 1, y_C/x_C]])'
pystencils_autodiff.create_backward_assignments( for diff_mode in DiffModes:
forward_assignments) pystencils_autodiff.create_backward_assignments(
forward_assignments, diff_mode=diff_mode)
def main(): def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment