Skip to content
Snippets Groups Projects
Commit b3b9c190 authored by Markus Holzer's avatar Markus Holzer Committed by Frederik Hennig
Browse files

[Fix] Printing of subtraction

parent 6c462e3d
No related merge requests found
......@@ -6,6 +6,7 @@ import numpy as np
import sympy as sp
from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
......@@ -228,6 +229,15 @@ class TypeAdder:
new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType):
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import pytest
import pystencils
from pystencils.backends.cbackend import CBackend
class UnsupportedNode(pystencils.astnodes.Node):
def __init__(self):
super().__init__()
def test_print_unsupported_node():
with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'):
CBackend()(UnsupportedNode())
import pytest
import sympy as sp
import pystencils
from sympy import oo
from pystencils.backends.cbackend import CBackend
class UnsupportedNode(pystencils.astnodes.Node):
def __init__(self):
super().__init__()
@pytest.mark.parametrize('type', ('float32', 'float64', 'int64'))
......@@ -12,9 +19,9 @@ def test_print_infinity(type, negative, target):
x = pystencils.fields(f'x: {type}[1d]')
if negative:
assignment = pystencils.Assignment(x.center, -oo)
assignment = pystencils.Assignment(x.center, -sp.oo)
else:
assignment = pystencils.Assignment(x.center, oo)
assignment = pystencils.Assignment(x.center, sp.oo)
ast = pystencils.create_kernel(assignment, data_type=type, target=target)
if target == pystencils.Target.GPU:
......@@ -23,3 +30,24 @@ def test_print_infinity(type, negative, target):
ast.compile()
print(ast.compile().code)
def test_print_unsupported_node():
with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'):
CBackend()(UnsupportedNode())
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('target', (pystencils.Target.CPU, pystencils.Target.GPU))
def test_print_subtraction(dtype, target):
a, b, c = sp.symbols("a b c")
x = pystencils.fields(f'x: {dtype}[3d]')
y = pystencils.fields(f'y: {dtype}[3d]')
config = pystencils.CreateKernelConfig(target=target, data_type=dtype)
update = pystencils.Assignment(x.center, y.center - a * b ** 8 + b * -1 / 42.0 - 2 * c ** 4)
ast = pystencils.create_kernel(update, config=config)
code = pystencils.get_code_str(ast)
assert "-1.0" not in code
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment