From e2face7c9c9dbfd6981b99740c18526d38f86782 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 28 Jan 2025 13:13:33 +0100 Subject: [PATCH] add test for field types --- tests/jit/test_cpujit.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py index 0e17fa6a1..3d269d837 100644 --- a/tests/jit/test_cpujit.py +++ b/tests/jit/test_cpujit.py @@ -1,16 +1,21 @@ +import pytest + import sympy as sp import numpy as np from pystencils import create_kernel, Assignment, fields from pystencils.jit import CpuJit -def test_basic_cpu_kernel(tmp_path): - jit = CpuJit.create(objcache=tmp_path) +@pytest.fixture +def cpu_jit(tmp_path) -> CpuJit: + return CpuJit.create(objcache=tmp_path) + +def test_basic_cpu_kernel(cpu_jit): f, g = fields("f, g: [2D]") asm = Assignment(f.center(), 2.0 * g.center()) ker = create_kernel(asm) - kfunc = jit.compile(ker) + kfunc = cpu_jit.compile(ker) rng = np.random.default_rng() f_arr = rng.random(size=(34, 26), dtype="float64") @@ -19,3 +24,27 @@ def test_basic_cpu_kernel(tmp_path): kfunc(f=f_arr, g=g_arr) np.testing.assert_almost_equal(g_arr, 2.0 * f_arr) + + +def test_argument_type_error(cpu_jit): + f, g = fields("f, g: [2D]") + c = sp.Symbol("c") + asm = Assignment(f.center(), c * g.center()) + ker = create_kernel(asm) + kfunc = cpu_jit.compile(ker) + + arr_fp16 = np.zeros((23, 12), dtype="float16") + arr_fp32 = np.zeros((23, 12), dtype="float32") + arr_fp64 = np.zeros((23, 12), dtype="float64") + + with pytest.raises(TypeError): + kfunc(f=arr_fp32, g=arr_fp64, c=2.0) + + with pytest.raises(TypeError): + kfunc(f=arr_fp64, g=arr_fp32, c=2.0) + + with pytest.raises(TypeError): + kfunc(f=arr_fp16, g=arr_fp16, c=2.0) + + # Wrong scalar types are OK, though + kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0)) -- GitLab