diff --git a/pystencils_tests/test_print_infinity.py b/pystencils_tests/test_print_infinity.py deleted file mode 100644 index 62c83e68a8e87ab01a702608b3d7e53884511255..0000000000000000000000000000000000000000 --- a/pystencils_tests/test_print_infinity.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest - -import pystencils -from sympy import oo - - -@pytest.mark.parametrize('type', ('float32', 'float64', 'int64')) -@pytest.mark.parametrize('negative', (False, 'Negative')) -@pytest.mark.parametrize('target', (pystencils.Target.CPU, pystencils.Target.GPU)) -def test_print_infinity(type, negative, target): - - x = pystencils.fields(f'x: {type}[1d]') - - if negative: - assignment = pystencils.Assignment(x.center, -oo) - else: - assignment = pystencils.Assignment(x.center, oo) - ast = pystencils.create_kernel(assignment, data_type=type, target=target) - - if target == pystencils.Target.GPU: - pytest.importorskip('cupy') - - ast.compile() - - print(ast.compile().code) diff --git a/pystencils_tests/test_printing.py b/pystencils_tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..c105cbb8c0ea1f6cbfa705423c6c23b2d5484290 --- /dev/null +++ b/pystencils_tests/test_printing.py @@ -0,0 +1,52 @@ +import pytest +import sympy as sp + +import pystencils +from pystencils.backends.cbackend import CBackend + + +class UnsupportedNode(pystencils.astnodes.Node): + + def __init__(self): + super().__init__() + + +@pytest.mark.parametrize('type', ('float32', 'float64', 'int64')) +@pytest.mark.parametrize('negative', (False, 'Negative')) +@pytest.mark.parametrize('target', (pystencils.Target.CPU, pystencils.Target.GPU)) +def test_print_infinity(type, negative, target): + + x = pystencils.fields(f'x: {type}[1d]') + + if negative: + assignment = pystencils.Assignment(x.center, -sp.oo) + else: + assignment = pystencils.Assignment(x.center, sp.oo) + ast = pystencils.create_kernel(assignment, data_type=type, target=target) + + if target == pystencils.Target.GPU: + pytest.importorskip('cupy') + + ast.compile() + + print(ast.compile().code) + + +def test_print_unsupported_node(): + with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'): + CBackend()(UnsupportedNode()) + + +def test_print_subtraction(): + a, b = sp.symbols("a b") + + x = pystencils.fields(f'x: double[3d]') + y = pystencils.fields(f'y: double[3d]') + + config = pystencils.CreateKernelConfig(target=pystencils.Target.CPU) + update = pystencils.Assignment(x.center, y.center - b) + + ast = pystencils.create_kernel(update, config=config) + + code = pystencils.get_code_str(ast) + assert "-1.0" not in code