Skip to content
Snippets Groups Projects
Commit 61d1bae6 authored by Martin Bauer's avatar Martin Bauer
Browse files

Tests and documentation for derivative module

parent beec6d3e
Branches
Tags
No related merge requests found
......@@ -130,4 +130,4 @@ pages:
tags:
- docker
only:
- master@software/pystencils
- master@pycodegen/pystencils
This diff is collapsed.
This diff is collapsed.
......@@ -15,6 +15,6 @@ It is a good idea to download them and run them directly to be able to play arou
/notebooks/05_tutorial_phasefield_spinodal_decomposition.ipynb
/notebooks/06_tutorial_phasefield_dentritic_growth.ipynb
/notebooks/demo_assignment_collection.ipynb
/notebooks/demo_derivatives.ipynb
/notebooks/demo_benchmark.ipynb
/notebooks/demo_wave_equation.ipynb
import sympy as sp
from collections import namedtuple, defaultdict
from pystencils import Field
from pystencils.sympyextensions import normalize_product, prod
......@@ -214,6 +213,11 @@ def diff_terms(expr):
This function yields different results than 'expr.atoms(Diff)' when nested derivatives are in the expression,
since this function only returns the outer derivatives
Example:
>>> x, y = sp.symbols("x, y")
>>> diff_terms( diff(x, 0, 0) )
{Diff(Diff(x, 0, -1), 0, -1)}
"""
result = set()
......
......@@ -11,6 +11,7 @@ from .derivation import FiniteDifferenceStencilDerivation
def fd_stencils_standard(indices, dx, fa):
order = len(indices)
assert all(i >= 0 for i in indices), "Can only discretize objects with (integer) subscripts"
if order == 1:
idx = indices[0]
return (fa.neighbor(idx, 1) - fa.neighbor(idx, -1)) / (2 * dx)
......@@ -122,7 +123,6 @@ def discretize_spatial(expr, dx, stencil=fd_stencils_standard):
def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
def staggered_visitor(e, coordinate, sign):
if isinstance(e, Diff):
arg, *indices = diff_args(e)
......
import sympy as sp
import pystencils as ps
from sympy.printing.latex import LatexPrinter
from pystencils.fd import *
from sympy.abc import a, b, t, x, y, z
def test_derivative_basic():
x, y, z, t = sp.symbols("x y z t")
d = diff
op1, op2, op3 = DiffOperator(), DiffOperator(target=x), DiffOperator(target=x, superscript=1)
......@@ -18,4 +19,31 @@ def test_derivative_basic():
assert diff_term == dx**2 + 2 * dx * dy + dy**2 + 1
assert DiffOperator.apply(diff_term, t) == d(t, x, x) + 2 * d(t, x, y) + d(t, y, y) + t
assert ps.fd.Diff(0) == 0
expr = ps.fd.diff(ps.fd.diff(x, 0, 0), 1, 1)
assert expr.get_arg_recursive() == x
assert expr.change_arg_recursive(y).get_arg_recursive() == y
def test_derivative_expand_collect():
original = Diff(x*y*z)
result = combine_diff_products(combine_diff_products(expand_diff_products(original))).expand()
assert original == result
original = -3 * y * z * Diff(x) + 2 * x * z * Diff(y)
result = expand_diff_products(combine_diff_products(original)).expand()
assert original == result
original = a + b * Diff(x ** 2 * y * z)
expanded = expand_diff_products(original)
collect_res = combine_diff_products(combine_diff_products(combine_diff_products(expanded)))
assert collect_res == original
def test_diff_expand_using_linearity():
eps = sp.symbols("epsilon")
funcs = [a, b]
test = Diff(eps * Diff(a+b))
result = expand_diff_linear(test, functions=funcs)
assert result == eps * Diff(Diff(a)) + eps * Diff(Diff(b))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment