diff --git a/tests/frontend/test_simplifications.py b/tests/frontend/test_simplifications.py index 45cde724108fe7578d8ff2dc9b8a2509a9add728..771f82159630f3d96ce05298ed3e70ddea440b1b 100644 --- a/tests/frontend/test_simplifications.py +++ b/tests/frontend/test_simplifications.py @@ -147,6 +147,8 @@ def test_add_subexpressions_for_field_reads(): assert len(ac3.subexpressions) == 2 assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) assert ac3.subexpressions[0].lhs.dtype == create_type("float32") + assert isinstance(ac3.subexpressions[0].rhs, ps.tcast) + assert ac3.subexpressions[0].rhs.dtype == create_type("float32") # TODO: What does this test mean to accomplish? diff --git a/tests/runtime/test_boundary.py b/tests/runtime/test_boundary.py index 226510b83d8832a5a189552df5c8760235f0d598..422553bcafb0ca1278f70f63a725d6f1cba8f496 100644 --- a/tests/runtime/test_boundary.py +++ b/tests/runtime/test_boundary.py @@ -222,15 +222,17 @@ def test_boundary_data_setter(): assert np.all(data_setter.link_positions(1) == 6.) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize('with_indices', ('with_indices', False)) -def test_dirichlet(with_indices): +def test_dirichlet(dtype, with_indices): value = (1, 20, 3) if with_indices else 1 dh = SerialDataHandling(domain_size=(7, 7)) - src = dh.add_array('src', values_per_cell=3 if with_indices else 1) - dh.cpu_arrays.src[...] = np.random.rand(*src.shape) + src = dh.add_array('src', values_per_cell=3 if with_indices else 1, dtype=dtype) + rng = np.random.default_rng() + dh.cpu_arrays.src[...] = rng.random(src.shape, dtype=dtype) boundary_stencil = [(1, 0), (-1, 0), (0, 1), (0, -1)] - boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil) + boundary_handling = BoundaryHandling(dh, src.name, boundary_stencil, default_dtype=dtype) dirichlet = Dirichlet(value) assert dirichlet.name == 'Dirichlet' dirichlet.name = "wall"