diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py index b68ed9c29ef7899f8a731a95ebfb6d5cf5580e66..eff3a061f2f4fef4d03a56183553ebedd33d0c6d 100644 --- a/src/pystencils/jit/cpu/cpujit_pybind11.py +++ b/src/pystencils/jit/cpu/cpujit_pybind11.py @@ -96,16 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" self._public_params.append(kernel_param) + expect_shape = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.shape) + ")" for coord, size in enumerate(field.shape): if isinstance(size, int): self._param_check_lines.append( - f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});" + f"checkFieldShape(\"{field.name}\", \"{expect_shape}\", {field.name}, {coord}, {size});" ) + expect_strides = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.strides) + ")" for coord, stride in enumerate(field.strides): if isinstance(stride, int): self._param_check_lines.append( - f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});" + f"checkFieldStride(\"{field.name}\", \"{expect_strides}\", {field.name}, {coord}, {stride});" ) def _add_scalar_param(self, sc_param: Parameter): diff --git a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp index acc9d05dea9f7df70f97e8ce3ce8c95fa745470b..ef945586f460298395642776ced79991ed3e626b 100644 --- a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp +++ b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp @@ -17,11 +17,27 @@ ${kernel_definition} } +std::string tuple_to_str(const ssize_t * data, const size_t N){ + std::stringstream acc; + acc << "("; + for(size_t i = 0; i < N; ++i){ + acc << data[i]; + if(i + 1 < N){ + acc << ", "; + } + } + acc << ")"; + return acc.str(); +} + template< typename T > -void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) { +void checkFieldShape(const std::string& fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) { auto panic = [&](){ std::stringstream err; - err << "Invalid shape of argument " << fieldName; + err << "Invalid shape of argument " << fieldName + << ". Expected " << expected + << ", but got " << tuple_to_str(arr.shape(), arr.ndim()) + << "."; throw py::value_error{ err.str() }; }; @@ -35,10 +51,13 @@ void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, } template< typename T > -void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) { +void checkFieldStride(const std::string fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) { auto panic = [&](){ std::stringstream err; - err << "Invalid strides of argument " << fieldName; + err << "Invalid strides of argument " << fieldName + << ". Expected " << expected + << ", but got " << tuple_to_str(arr.strides(), arr.ndim()) + << "."; throw py::value_error{ err.str() }; }; diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py index c8e75f8625129891e12ea592d7efd8bbac02636b..bfa4c98975cc3865429e734cb6c2997ec9074622 100644 --- a/tests/jit/test_cpujit.py +++ b/tests/jit/test_cpujit.py @@ -8,7 +8,7 @@ from pystencils.jit import CpuJit @pytest.fixture def cpu_jit(tmp_path) -> CpuJit: - return CpuJit.create(objcache=".jit") + return CpuJit.create(objcache=tmp_path) def test_basic_cpu_kernel(cpu_jit): @@ -68,3 +68,30 @@ def test_fixed_shape(cpu_jit): with pytest.raises(ValueError): kfunc(f=a, g=b) + + +def test_fixed_index_shape(cpu_jit): + f, g = fields("f(3), g(2, 2): [2D]") + + asm = Assignment(f.center(1), g.center(0, 0) + g.center(0, 1) + g.center(1, 0) + g.center(1, 1)) + ker = create_kernel(asm) + kfunc = cpu_jit.compile(ker) + + f_arr = np.zeros((12, 14, 3)) + g_arr = np.zeros((12, 14, 2, 2)) + kfunc(f=f_arr, g=g_arr) + + with pytest.raises(ValueError): + f_arr = np.zeros((12, 14, 2)) + g_arr = np.zeros((12, 14, 2, 2)) + kfunc(f=f_arr, g=g_arr) + + with pytest.raises(ValueError): + f_arr = np.zeros((12, 14, 3)) + g_arr = np.zeros((12, 14, 4)) + kfunc(f=f_arr, g=g_arr) + + with pytest.raises(ValueError): + f_arr = np.zeros((12, 14, 3)) + g_arr = np.zeros((12, 14, 1, 3)) + kfunc(f=f_arr, g=g_arr)