Skip to content
Snippets Groups Projects

Object-Oriented CPU JIT API and Prototype Implementation

Merged Frederik Hennig requested to merge fhennig/pybind11-jit into v2.0-dev
Viewing commit 4dcb81ac
Show latest version
3 files
+ 55
7
Preferences
Compare changes
Files
3
@@ -96,16 +96,18 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
self._public_params.append(kernel_param)
expect_shape = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.shape) + ")"
for coord, size in enumerate(field.shape):
if isinstance(size, int):
self._param_check_lines.append(
f"checkFieldShape(\"{field.name}\", {field.name}, {coord}, {size});"
f"checkFieldShape(\"{field.name}\", \"{expect_shape}\", {field.name}, {coord}, {size});"
)
expect_strides = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.strides) + ")"
for coord, stride in enumerate(field.strides):
if isinstance(stride, int):
self._param_check_lines.append(
f"checkFieldStride(\"{field.name}\", {field.name}, {coord}, {stride});"
f"checkFieldStride(\"{field.name}\", \"{expect_strides}\", {field.name}, {coord}, {stride});"
)
def _add_scalar_param(self, sc_param: Parameter):