diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 6c30a6abfe3189970055a02e1a3d70cd037c2da2..9e7065b0a5960febbcc954f70040b9097031041e 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -6,6 +6,7 @@ import numpy as np import sympy as sp from sympy import Piecewise +from sympy.core.numbers import NegativeOne from sympy.core.relational import Relational from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction @@ -228,6 +229,15 @@ class TypeAdder: new_func = expr.func(*new_args) if new_args else expr return CastFunc(new_func, collated_type), collated_type elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)): + # Subtraction is realised a multiplication with -1 in SymPy. Thus we exclude the coefficient in this case + # and resolve the typing entirely with the expression itself + if isinstance(expr, sp.Mul): + c, e = expr.as_coeff_Mul() + if c == NegativeOne(): + args_types = self.figure_out_type(e) + new_args = [NegativeOne(), args_types[0]] + return expr.func(*new_args, evaluate=False), args_types[1] + args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) if isinstance(collated_type, PointerType): diff --git a/pystencils_tests/test_print_unsupported_node.py b/pystencils_tests/test_print_unsupported_node.py deleted file mode 100644 index 674a7b390512a4d7d12d1a733fbd759a460499d8..0000000000000000000000000000000000000000 --- a/pystencils_tests/test_print_unsupported_node.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> -# -# Distributed under terms of the GPLv3 license. - -""" - -""" -import pytest - -import pystencils -from pystencils.backends.cbackend import CBackend - - -class UnsupportedNode(pystencils.astnodes.Node): - - def __init__(self): - super().__init__() - - -def test_print_unsupported_node(): - with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'): - CBackend()(UnsupportedNode()) diff --git a/pystencils_tests/test_printing.py b/pystencils_tests/test_printing.py index c105cbb8c0ea1f6cbfa705423c6c23b2d5484290..6490f59e43c0c42eb64077b3c27dc99d250f04eb 100644 --- a/pystencils_tests/test_printing.py +++ b/pystencils_tests/test_printing.py @@ -37,15 +37,16 @@ def test_print_unsupported_node(): CBackend()(UnsupportedNode()) -def test_print_subtraction(): - a, b = sp.symbols("a b") - - x = pystencils.fields(f'x: double[3d]') - y = pystencils.fields(f'y: double[3d]') +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('target', (pystencils.Target.CPU, pystencils.Target.GPU)) +def test_print_subtraction(dtype, target): + a, b, c = sp.symbols("a b c") - config = pystencils.CreateKernelConfig(target=pystencils.Target.CPU) - update = pystencils.Assignment(x.center, y.center - b) + x = pystencils.fields(f'x: {dtype}[3d]') + y = pystencils.fields(f'y: {dtype}[3d]') + config = pystencils.CreateKernelConfig(target=target, data_type=dtype) + update = pystencils.Assignment(x.center, y.center - a * b ** 8 + b * -1 / 42.0 - 2 * c ** 4) ast = pystencils.create_kernel(update, config=config) code = pystencils.get_code_str(ast)