Skip to content
Snippets Groups Projects

[Fix] Printing of subtraction

Merged Markus Holzer requested to merge holzer/pystencils:FixSubtraction into master
3 files
+ 63
24
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -6,6 +6,7 @@ import numpy as np
@@ -6,6 +6,7 @@ import numpy as np
import sympy as sp
import sympy as sp
from sympy import Piecewise
from sympy import Piecewise
 
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
@@ -228,6 +229,15 @@ class TypeAdder:
@@ -228,6 +229,15 @@ class TypeAdder:
new_func = expr.func(*new_args) if new_args else expr
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
 
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
 
# and resolve the typing entirely with the expression itself
 
if isinstance(expr, sp.Mul):
 
c, e = expr.as_coeff_Mul()
 
if c == NegativeOne():
 
args_types = self.figure_out_type(e)
 
new_args = [NegativeOne(), args_types[0]]
 
return expr.func(*new_args, evaluate=False), args_types[1]
 
args_types = [self.figure_out_type(arg) for arg in expr.args]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
if isinstance(collated_type, PointerType):
Loading