Skip to content
Snippets Groups Projects
Commit 4dcb81ac authored by Frederik Hennig's avatar Frederik Hennig
Browse files

improve error strings for shape and stride checks. Add more test cases.

parent e8c8ea8e
No related branches found
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Pipeline #72926 passed
...@@ -96,16 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -96,16 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
self._public_params.append(kernel_param) self._public_params.append(kernel_param)
expect_shape = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.shape) + ")"
for coord, size in enumerate(field.shape): for coord, size in enumerate(field.shape):
if isinstance(size, int): if isinstance(size, int):
self._param_check_lines.append( self._param_check_lines.append(
f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});" f"checkFieldShape(\"{field.name}\", \"{expect_shape}\", {field.name}, {coord}, {size});"
) )
expect_strides = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.strides) + ")"
for coord, stride in enumerate(field.strides): for coord, stride in enumerate(field.strides):
if isinstance(stride, int): if isinstance(stride, int):
self._param_check_lines.append( self._param_check_lines.append(
f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});" f"checkFieldStride(\"{field.name}\", \"{expect_strides}\", {field.name}, {coord}, {stride});"
) )
def _add_scalar_param(self, sc_param: Parameter): def _add_scalar_param(self, sc_param: Parameter):
......
...@@ -17,11 +17,27 @@ ${kernel_definition} ...@@ -17,11 +17,27 @@ ${kernel_definition}
} }
std::string tuple_to_str(const ssize_t * data, const size_t N){
std::stringstream acc;
acc << "(";
for(size_t i = 0; i < N; ++i){
acc << data[i];
if(i + 1 < N){
acc << ", ";
}
}
acc << ")";
return acc.str();
}
template< typename T > template< typename T >
void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) { void checkFieldShape(const std::string& fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){ auto panic = [&](){
std::stringstream err; std::stringstream err;
err << "Invalid shape of argument " << fieldName; err << "Invalid shape of argument " << fieldName
<< ". Expected " << expected
<< ", but got " << tuple_to_str(arr.shape(), arr.ndim())
<< ".";
throw py::value_error{ err.str() }; throw py::value_error{ err.str() };
}; };
...@@ -35,10 +51,13 @@ void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, ...@@ -35,10 +51,13 @@ void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr,
} }
template< typename T > template< typename T >
void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) { void checkFieldStride(const std::string fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){ auto panic = [&](){
std::stringstream err; std::stringstream err;
err << "Invalid strides of argument " << fieldName; err << "Invalid strides of argument " << fieldName
<< ". Expected " << expected
<< ", but got " << tuple_to_str(arr.strides(), arr.ndim())
<< ".";
throw py::value_error{ err.str() }; throw py::value_error{ err.str() };
}; };
......
...@@ -8,7 +8,7 @@ from pystencils.jit import CpuJit ...@@ -8,7 +8,7 @@ from pystencils.jit import CpuJit
@pytest.fixture @pytest.fixture
def cpu_jit(tmp_path) -> CpuJit: def cpu_jit(tmp_path) -> CpuJit:
return CpuJit.create(objcache=".jit") return CpuJit.create(objcache=tmp_path)
def test_basic_cpu_kernel(cpu_jit): def test_basic_cpu_kernel(cpu_jit):
...@@ -68,3 +68,30 @@ def test_fixed_shape(cpu_jit): ...@@ -68,3 +68,30 @@ def test_fixed_shape(cpu_jit):
with pytest.raises(ValueError): with pytest.raises(ValueError):
kfunc(f=a, g=b) kfunc(f=a, g=b)
def test_fixed_index_shape(cpu_jit):
f, g = fields("f(3), g(2, 2): [2D]")
asm = Assignment(f.center(1), g.center(0, 0) + g.center(0, 1) + g.center(1, 0) + g.center(1, 1))
ker = create_kernel(asm)
kfunc = cpu_jit.compile(ker)
f_arr = np.zeros((12, 14, 3))
g_arr = np.zeros((12, 14, 2, 2))
kfunc(f=f_arr, g=g_arr)
with pytest.raises(ValueError):
f_arr = np.zeros((12, 14, 2))
g_arr = np.zeros((12, 14, 2, 2))
kfunc(f=f_arr, g=g_arr)
with pytest.raises(ValueError):
f_arr = np.zeros((12, 14, 3))
g_arr = np.zeros((12, 14, 4))
kfunc(f=f_arr, g=g_arr)
with pytest.raises(ValueError):
f_arr = np.zeros((12, 14, 3))
g_arr = np.zeros((12, 14, 1, 3))
kfunc(f=f_arr, g=g_arr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment