Skip to content
Snippets Groups Projects
Commit 2a95e87e authored by Markus Holzer's avatar Markus Holzer
Browse files

Implement test case

parent 6c462e3d
No related branches found
No related tags found
1 merge request!352[Fix] Printing of subtraction
import pytest import pytest
import sympy as sp
import pystencils import pystencils
from sympy import oo from pystencils.backends.cbackend import CBackend
class UnsupportedNode(pystencils.astnodes.Node):
def __init__(self):
super().__init__()
@pytest.mark.parametrize('type', ('float32', 'float64', 'int64')) @pytest.mark.parametrize('type', ('float32', 'float64', 'int64'))
...@@ -12,9 +19,9 @@ def test_print_infinity(type, negative, target): ...@@ -12,9 +19,9 @@ def test_print_infinity(type, negative, target):
x = pystencils.fields(f'x: {type}[1d]') x = pystencils.fields(f'x: {type}[1d]')
if negative: if negative:
assignment = pystencils.Assignment(x.center, -oo) assignment = pystencils.Assignment(x.center, -sp.oo)
else: else:
assignment = pystencils.Assignment(x.center, oo) assignment = pystencils.Assignment(x.center, sp.oo)
ast = pystencils.create_kernel(assignment, data_type=type, target=target) ast = pystencils.create_kernel(assignment, data_type=type, target=target)
if target == pystencils.Target.GPU: if target == pystencils.Target.GPU:
...@@ -23,3 +30,23 @@ def test_print_infinity(type, negative, target): ...@@ -23,3 +30,23 @@ def test_print_infinity(type, negative, target):
ast.compile() ast.compile()
print(ast.compile().code) print(ast.compile().code)
def test_print_unsupported_node():
with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'):
CBackend()(UnsupportedNode())
def test_print_subtraction():
a, b = sp.symbols("a b")
x = pystencils.fields(f'x: double[3d]')
y = pystencils.fields(f'y: double[3d]')
config = pystencils.CreateKernelConfig(target=pystencils.Target.CPU)
update = pystencils.Assignment(x.center, y.center - b)
ast = pystencils.create_kernel(update, config=config)
code = pystencils.get_code_str(ast)
assert "-1.0" not in code
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment