Skip to content
Snippets Groups Projects
Commit e2face7c authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add test for field types

parent fb3243dc
No related branches found
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Pipeline #72918 passed
import pytest
import sympy as sp import sympy as sp
import numpy as np import numpy as np
from pystencils import create_kernel, Assignment, fields from pystencils import create_kernel, Assignment, fields
from pystencils.jit import CpuJit from pystencils.jit import CpuJit
def test_basic_cpu_kernel(tmp_path): @pytest.fixture
jit = CpuJit.create(objcache=tmp_path) def cpu_jit(tmp_path) -> CpuJit:
return CpuJit.create(objcache=tmp_path)
def test_basic_cpu_kernel(cpu_jit):
f, g = fields("f, g: [2D]") f, g = fields("f, g: [2D]")
asm = Assignment(f.center(), 2.0 * g.center()) asm = Assignment(f.center(), 2.0 * g.center())
ker = create_kernel(asm) ker = create_kernel(asm)
kfunc = jit.compile(ker) kfunc = cpu_jit.compile(ker)
rng = np.random.default_rng() rng = np.random.default_rng()
f_arr = rng.random(size=(34, 26), dtype="float64") f_arr = rng.random(size=(34, 26), dtype="float64")
...@@ -19,3 +24,27 @@ def test_basic_cpu_kernel(tmp_path): ...@@ -19,3 +24,27 @@ def test_basic_cpu_kernel(tmp_path):
kfunc(f=f_arr, g=g_arr) kfunc(f=f_arr, g=g_arr)
np.testing.assert_almost_equal(g_arr, 2.0 * f_arr) np.testing.assert_almost_equal(g_arr, 2.0 * f_arr)
def test_argument_type_error(cpu_jit):
f, g = fields("f, g: [2D]")
c = sp.Symbol("c")
asm = Assignment(f.center(), c * g.center())
ker = create_kernel(asm)
kfunc = cpu_jit.compile(ker)
arr_fp16 = np.zeros((23, 12), dtype="float16")
arr_fp32 = np.zeros((23, 12), dtype="float32")
arr_fp64 = np.zeros((23, 12), dtype="float64")
with pytest.raises(TypeError):
kfunc(f=arr_fp32, g=arr_fp64, c=2.0)
with pytest.raises(TypeError):
kfunc(f=arr_fp64, g=arr_fp32, c=2.0)
with pytest.raises(TypeError):
kfunc(f=arr_fp16, g=arr_fp16, c=2.0)
# Wrong scalar types are OK, though
kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment