Skip to content
Snippets Groups Projects
Commit b3b9c190 authored by Markus Holzer's avatar Markus Holzer Committed by Frederik Hennig
Browse files

[Fix] Printing of subtraction

parent 6c462e3d
Branches
Tags
No related merge requests found
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import sympy as sp import sympy as sp
from sympy import Piecewise from sympy import Piecewise
from sympy.core.numbers import NegativeOne
from sympy.core.relational import Relational from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
...@@ -228,6 +229,15 @@ class TypeAdder: ...@@ -228,6 +229,15 @@ class TypeAdder:
new_func = expr.func(*new_args) if new_args else expr new_func = expr.func(*new_args) if new_args else expr
return CastFunc(new_func, collated_type), collated_type return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)): elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
# Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case
# and resolve the typing entirely with the expression itself
if isinstance(expr, sp.Mul):
c, e = expr.as_coeff_Mul()
if c == NegativeOne():
args_types = self.figure_out_type(e)
new_args = [NegativeOne(), args_types[0]]
return expr.func(*new_args, evaluate=False), args_types[1]
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
if isinstance(collated_type, PointerType): if isinstance(collated_type, PointerType):
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import pytest
import pystencils
from pystencils.backends.cbackend import CBackend
class UnsupportedNode(pystencils.astnodes.Node):
def __init__(self):
super().__init__()
def test_print_unsupported_node():
with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'):
CBackend()(UnsupportedNode())
import pytest import pytest
import sympy as sp
import pystencils import pystencils
from sympy import oo from pystencils.backends.cbackend import CBackend
class UnsupportedNode(pystencils.astnodes.Node):
def __init__(self):
super().__init__()
@pytest.mark.parametrize('type', ('float32', 'float64', 'int64')) @pytest.mark.parametrize('type', ('float32', 'float64', 'int64'))
...@@ -12,9 +19,9 @@ def test_print_infinity(type, negative, target): ...@@ -12,9 +19,9 @@ def test_print_infinity(type, negative, target):
x = pystencils.fields(f'x: {type}[1d]') x = pystencils.fields(f'x: {type}[1d]')
if negative: if negative:
assignment = pystencils.Assignment(x.center, -oo) assignment = pystencils.Assignment(x.center, -sp.oo)
else: else:
assignment = pystencils.Assignment(x.center, oo) assignment = pystencils.Assignment(x.center, sp.oo)
ast = pystencils.create_kernel(assignment, data_type=type, target=target) ast = pystencils.create_kernel(assignment, data_type=type, target=target)
if target == pystencils.Target.GPU: if target == pystencils.Target.GPU:
...@@ -23,3 +30,24 @@ def test_print_infinity(type, negative, target): ...@@ -23,3 +30,24 @@ def test_print_infinity(type, negative, target):
ast.compile() ast.compile()
print(ast.compile().code) print(ast.compile().code)
def test_print_unsupported_node():
with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'):
CBackend()(UnsupportedNode())
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('target', (pystencils.Target.CPU, pystencils.Target.GPU))
def test_print_subtraction(dtype, target):
a, b, c = sp.symbols("a b c")
x = pystencils.fields(f'x: {dtype}[3d]')
y = pystencils.fields(f'y: {dtype}[3d]')
config = pystencils.CreateKernelConfig(target=target, data_type=dtype)
update = pystencils.Assignment(x.center, y.center - a * b ** 8 + b * -1 / 42.0 - 2 * c ** 4)
ast = pystencils.create_kernel(update, config=config)
code = pystencils.get_code_str(ast)
assert "-1.0" not in code
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment