diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py index 0e17fa6a14c66d104eea1e3a40366e8664040f59..3d269d837f5cab2a12cfeb8b3b0be141763f9490 100644 --- a/tests/jit/test_cpujit.py +++ b/tests/jit/test_cpujit.py @@ -1,16 +1,21 @@ +import pytest + import sympy as sp import numpy as np from pystencils import create_kernel, Assignment, fields from pystencils.jit import CpuJit -def test_basic_cpu_kernel(tmp_path): - jit = CpuJit.create(objcache=tmp_path) +@pytest.fixture +def cpu_jit(tmp_path) -> CpuJit: + return CpuJit.create(objcache=tmp_path) + +def test_basic_cpu_kernel(cpu_jit): f, g = fields("f, g: [2D]") asm = Assignment(f.center(), 2.0 * g.center()) ker = create_kernel(asm) - kfunc = jit.compile(ker) + kfunc = cpu_jit.compile(ker) rng = np.random.default_rng() f_arr = rng.random(size=(34, 26), dtype="float64") @@ -19,3 +24,27 @@ def test_basic_cpu_kernel(tmp_path): kfunc(f=f_arr, g=g_arr) np.testing.assert_almost_equal(g_arr, 2.0 * f_arr) + + +def test_argument_type_error(cpu_jit): + f, g = fields("f, g: [2D]") + c = sp.Symbol("c") + asm = Assignment(f.center(), c * g.center()) + ker = create_kernel(asm) + kfunc = cpu_jit.compile(ker) + + arr_fp16 = np.zeros((23, 12), dtype="float16") + arr_fp32 = np.zeros((23, 12), dtype="float32") + arr_fp64 = np.zeros((23, 12), dtype="float64") + + with pytest.raises(TypeError): + kfunc(f=arr_fp32, g=arr_fp64, c=2.0) + + with pytest.raises(TypeError): + kfunc(f=arr_fp64, g=arr_fp32, c=2.0) + + with pytest.raises(TypeError): + kfunc(f=arr_fp16, g=arr_fp16, c=2.0) + + # Wrong scalar types are OK, though + kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0))