diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py index f9c39529ae2519f2fbab45fc60d4c522b7b09ec2..b3a9e48aaf678189dbcbb514b5e61e3c115fbaca 100644 --- a/src/pystencils/jit/cpu/cpujit.py +++ b/src/pystencils/jit/cpu/cpujit.py @@ -183,10 +183,12 @@ class CpuJitKernelWrapper(KernelWrapper): def __init__(self, kernel: Kernel, jit_module: ModuleType): super().__init__(kernel) self._module = jit_module - self._wrapper_func = getattr(jit_module, kernel.function_name) + self._check_params = getattr(jit_module, "check_params") + self._invoke = getattr(jit_module, "invoke") def __call__(self, **kwargs) -> None: - return self._wrapper_func(**kwargs) + self._check_params(**kwargs) + return self._invoke(**kwargs) class ExtensionModuleBuilderBase(ABC): diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py index aee2e9f996d3147fd06435aaac60bbcb728083a2..b68ed9c29ef7899f8a731a95ebfb6d5cf5580e66 100644 --- a/src/pystencils/jit/cpu/cpujit_pybind11.py +++ b/src/pystencils/jit/cpu/cpujit_pybind11.py @@ -37,12 +37,14 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): self._actual_field_types: dict[Field, PsType] self._param_binds: list[str] self._public_params: list[str] + self._param_check_lines: list[str] self._extraction_lines: list[str] def __call__(self, kernel: Kernel, module_name: str) -> str: self._actual_field_types = dict() self._param_binds = [] self._public_params = [] + self._param_check_lines = [] self._extraction_lines = [] self._handle_params(kernel.parameters) @@ -61,6 +63,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): kernel_name=kernel.function_name, param_binds=", ".join(self._param_binds), public_params=", ".join(self._public_params), + param_check_lines=indent("\n".join(self._param_check_lines), prefix=" "), extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "), kernel_args=", ".join(kernel_args), kernel_definition=kernel_def, @@ -93,6 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" self._public_params.append(kernel_param) + for coord, size in enumerate(field.shape): + if isinstance(size, int): + self._param_check_lines.append( + f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});" + ) + + for coord, stride in enumerate(field.strides): + if isinstance(stride, int): + self._param_check_lines.append( + f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});" + ) + def _add_scalar_param(self, sc_param: Parameter): param_bind = f'py::arg("{sc_param.name}")' if self._strict_scalar_types: diff --git a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp index 3ee5c6973a761a6acbabf7864df947de38e4289c..acc9d05dea9f7df70f97e8ce3ce8c95fa745470b 100644 --- a/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp +++ b/src/pystencils/jit/cpu/pybind11_kernel_module.tmpl.cpp @@ -1,6 +1,10 @@ #include "pybind11/pybind11.h" #include "pybind11/numpy.h" +#include <array> +#include <string> +#include <sstream> + ${includes} namespace py = pybind11; @@ -13,11 +17,50 @@ ${kernel_definition} } -void callwrapper_${kernel_name} (${public_params}) { +template< typename T > +void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) { + auto panic = [&](){ + std::stringstream err; + err << "Invalid shape of argument " << fieldName; + throw py::value_error{ err.str() }; + }; + + if(arr.ndim() <= coord){ + panic(); + } + + if(arr.shape(coord) != desired){ + panic(); + } +} + +template< typename T > +void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) { + auto panic = [&](){ + std::stringstream err; + err << "Invalid strides of argument " << fieldName; + throw py::value_error{ err.str() }; + }; + + if(arr.ndim() <= coord){ + panic(); + } + + if(arr.strides(coord) / sizeof(T) != desired){ + panic(); + } +} + +void check_params_${kernel_name} (${public_params}) { +${param_check_lines} +} + +void run_${kernel_name} (${public_params}) { ${extraction_lines} internal::${kernel_name}(${kernel_args}); } PYBIND11_MODULE(${module_name}, m) { - m.def("${kernel_name}", &callwrapper_${kernel_name}, py::kw_only(), ${param_binds}); + m.def("check_params", &check_params_${kernel_name}, py::kw_only(), ${param_binds}); + m.def("invoke", &run_${kernel_name}, py::kw_only(), ${param_binds}); } diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py index 3d269d837f5cab2a12cfeb8b3b0be141763f9490..c8e75f8625129891e12ea592d7efd8bbac02636b 100644 --- a/tests/jit/test_cpujit.py +++ b/tests/jit/test_cpujit.py @@ -2,13 +2,13 @@ import pytest import sympy as sp import numpy as np -from pystencils import create_kernel, Assignment, fields +from pystencils import create_kernel, Assignment, fields, Field from pystencils.jit import CpuJit @pytest.fixture def cpu_jit(tmp_path) -> CpuJit: - return CpuJit.create(objcache=tmp_path) + return CpuJit.create(objcache=".jit") def test_basic_cpu_kernel(cpu_jit): @@ -48,3 +48,23 @@ def test_argument_type_error(cpu_jit): # Wrong scalar types are OK, though kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0)) + + +def test_fixed_shape(cpu_jit): + a = np.zeros((12, 23), dtype="float64") + b = np.zeros((13, 21), dtype="float64") + + f = Field.create_from_numpy_array("f", a) + g = Field.create_from_numpy_array("g", a) + + asm = Assignment(f.center(), 2.0 * g.center()) + ker = create_kernel(asm) + kfunc = cpu_jit.compile(ker) + + kfunc(f=a, g=a) + + with pytest.raises(ValueError): + kfunc(f=b, g=a) + + with pytest.raises(ValueError): + kfunc(f=a, g=b)