Skip to content
Snippets Groups Projects
Commit e8c8ea8e authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add checks for constant shape and strides

parent e2face7c
No related branches found
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Pipeline #72920 passed
...@@ -183,10 +183,12 @@ class CpuJitKernelWrapper(KernelWrapper): ...@@ -183,10 +183,12 @@ class CpuJitKernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType): def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel) super().__init__(kernel)
self._module = jit_module self._module = jit_module
self._wrapper_func = getattr(jit_module, kernel.function_name) self._check_params = getattr(jit_module, "check_params")
self._invoke = getattr(jit_module, "invoke")
def __call__(self, **kwargs) -> None: def __call__(self, **kwargs) -> None:
return self._wrapper_func(**kwargs) self._check_params(**kwargs)
return self._invoke(**kwargs)
class ExtensionModuleBuilderBase(ABC): class ExtensionModuleBuilderBase(ABC):
......
...@@ -37,12 +37,14 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -37,12 +37,14 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
self._actual_field_types: dict[Field, PsType] self._actual_field_types: dict[Field, PsType]
self._param_binds: list[str] self._param_binds: list[str]
self._public_params: list[str] self._public_params: list[str]
self._param_check_lines: list[str]
self._extraction_lines: list[str] self._extraction_lines: list[str]
def __call__(self, kernel: Kernel, module_name: str) -> str: def __call__(self, kernel: Kernel, module_name: str) -> str:
self._actual_field_types = dict() self._actual_field_types = dict()
self._param_binds = [] self._param_binds = []
self._public_params = [] self._public_params = []
self._param_check_lines = []
self._extraction_lines = [] self._extraction_lines = []
self._handle_params(kernel.parameters) self._handle_params(kernel.parameters)
...@@ -61,6 +63,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -61,6 +63,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
kernel_name=kernel.function_name, kernel_name=kernel.function_name,
param_binds=", ".join(self._param_binds), param_binds=", ".join(self._param_binds),
public_params=", ".join(self._public_params), public_params=", ".join(self._public_params),
param_check_lines=indent("\n".join(self._param_check_lines), prefix=" "),
extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "), extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "),
kernel_args=", ".join(kernel_args), kernel_args=", ".join(kernel_args),
kernel_definition=kernel_def, kernel_definition=kernel_def,
...@@ -93,6 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -93,6 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}" kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
self._public_params.append(kernel_param) self._public_params.append(kernel_param)
for coord, size in enumerate(field.shape):
if isinstance(size, int):
self._param_check_lines.append(
f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});"
)
for coord, stride in enumerate(field.strides):
if isinstance(stride, int):
self._param_check_lines.append(
f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});"
)
def _add_scalar_param(self, sc_param: Parameter): def _add_scalar_param(self, sc_param: Parameter):
param_bind = f'py::arg("{sc_param.name}")' param_bind = f'py::arg("{sc_param.name}")'
if self._strict_scalar_types: if self._strict_scalar_types:
......
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include <array>
#include <string>
#include <sstream>
${includes} ${includes}
namespace py = pybind11; namespace py = pybind11;
...@@ -13,11 +17,50 @@ ${kernel_definition} ...@@ -13,11 +17,50 @@ ${kernel_definition}
} }
void callwrapper_${kernel_name} (${public_params}) { template< typename T >
void checkFieldShape(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){
std::stringstream err;
err << "Invalid shape of argument " << fieldName;
throw py::value_error{ err.str() };
};
if(arr.ndim() <= coord){
panic();
}
if(arr.shape(coord) != desired){
panic();
}
}
template< typename T >
void checkFieldStride(const std::string fieldName, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){
std::stringstream err;
err << "Invalid strides of argument " << fieldName;
throw py::value_error{ err.str() };
};
if(arr.ndim() <= coord){
panic();
}
if(arr.strides(coord) / sizeof(T) != desired){
panic();
}
}
void check_params_${kernel_name} (${public_params}) {
${param_check_lines}
}
void run_${kernel_name} (${public_params}) {
${extraction_lines} ${extraction_lines}
internal::${kernel_name}(${kernel_args}); internal::${kernel_name}(${kernel_args});
} }
PYBIND11_MODULE(${module_name}, m) { PYBIND11_MODULE(${module_name}, m) {
m.def("${kernel_name}", &callwrapper_${kernel_name}, py::kw_only(), ${param_binds}); m.def("check_params", &check_params_${kernel_name}, py::kw_only(), ${param_binds});
m.def("invoke", &run_${kernel_name}, py::kw_only(), ${param_binds});
} }
...@@ -2,13 +2,13 @@ import pytest ...@@ -2,13 +2,13 @@ import pytest
import sympy as sp import sympy as sp
import numpy as np import numpy as np
from pystencils import create_kernel, Assignment, fields from pystencils import create_kernel, Assignment, fields, Field
from pystencils.jit import CpuJit from pystencils.jit import CpuJit
@pytest.fixture @pytest.fixture
def cpu_jit(tmp_path) -> CpuJit: def cpu_jit(tmp_path) -> CpuJit:
return CpuJit.create(objcache=tmp_path) return CpuJit.create(objcache=".jit")
def test_basic_cpu_kernel(cpu_jit): def test_basic_cpu_kernel(cpu_jit):
...@@ -48,3 +48,23 @@ def test_argument_type_error(cpu_jit): ...@@ -48,3 +48,23 @@ def test_argument_type_error(cpu_jit):
# Wrong scalar types are OK, though # Wrong scalar types are OK, though
kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0)) kfunc(f=arr_fp64, g=arr_fp64, c=np.float16(1.0))
def test_fixed_shape(cpu_jit):
a = np.zeros((12, 23), dtype="float64")
b = np.zeros((13, 21), dtype="float64")
f = Field.create_from_numpy_array("f", a)
g = Field.create_from_numpy_array("g", a)
asm = Assignment(f.center(), 2.0 * g.center())
ker = create_kernel(asm)
kfunc = cpu_jit.compile(ker)
kfunc(f=a, g=a)
with pytest.raises(ValueError):
kfunc(f=b, g=a)
with pytest.raises(ValueError):
kfunc(f=a, g=b)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment