Skip to content
Snippets Groups Projects
Commit d8e498fa authored by Martin Bauer's avatar Martin Bauer
Browse files

Workaround for sympy bug in placeholder_function

parent 27a131fb
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,7 @@ import sympy as sp ...@@ -2,6 +2,7 @@ import sympy as sp
from typing import List from typing import List
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.sympyextensions import is_constant
from pystencils.transformations import generic_visit from pystencils.transformations import generic_visit
...@@ -37,11 +38,11 @@ def to_placeholder_function(expr, name): ...@@ -37,11 +38,11 @@ def to_placeholder_function(expr, name):
assignments = [Assignment(sp.Symbol(name), expr)] assignments = [Assignment(sp.Symbol(name), expr)]
assignments += [Assignment(symbol, derivative) assignments += [Assignment(symbol, derivative)
for symbol, derivative in zip(derivative_symbols, derivatives) for symbol, derivative in zip(derivative_symbols, derivatives)
if not derivative.is_constant()] if not is_constant(derivative)]
def fdiff(_, index): def fdiff(_, index):
result = derivatives[index - 1] result = derivatives[index - 1]
return result if result.is_constant() else derivative_symbols[index - 1] return result if is_constant(result) else derivative_symbols[index - 1]
func = type(name, (sp.Function, PlaceholderFunction), func = type(name, (sp.Function, PlaceholderFunction),
{'fdiff': fdiff, {'fdiff': fdiff,
......
...@@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -172,6 +172,13 @@ def fast_subs(expression: T, substitutions: Dict,
return visit(expression) return visit(expression)
def is_constant(expr):
"""Simple version of checking if a sympy expression is constant.
Works also for piecewise defined functions - sympy's is_constant() has a problem there, see:
https://github.com/sympy/sympy/issues/16662
"""
return len(expr.free_symbols) == 0
def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
required_match_replacement: Optional[Union[int, float]] = 0.5, required_match_replacement: Optional[Union[int, float]] = 0.5,
required_match_original: Optional[Union[int, float]] = None) -> sp.Expr: required_match_original: Optional[Union[int, float]] = None) -> sp.Expr:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment