Skip to content
Snippets Groups Projects
Commit da7adf35 authored by Martin Bauer's avatar Martin Bauer
Browse files

CPU backend now tests if the data type of array parameters is correct

parent 2e42e5ba
No related branches found
No related tags found
No related merge requests found
......@@ -248,7 +248,7 @@ template_extract_array = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE);
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
if (buffer_{name}_res == -1) {{ return NULL; }}
"""
......@@ -333,26 +333,38 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
post_call_code += template_release_buffer.format(name=field.name)
parameters.append("({dtype} *)buffer_{name}.buf".format(dtype=str(field.dtype), name=field.name))
if insert_checks and field.has_fixed_shape:
shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
for i, s in enumerate(field.spatial_shape)]
shape_cond = " && ".join(shape_cond)
pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
expected=str(field.shape))
item_size = field.dtype.numpy_dtype.itemsize
expected_strides = [e * item_size for e in field.spatial_strides]
stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
for i, s in enumerate(expected_strides)])
pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
expected=str(expected_strides))
if insert_checks and not field.has_fixed_shape:
if FieldType.is_generic(field):
variable_sized_normal_fields.add(field)
elif FieldType.is_indexed(field):
variable_sized_index_fields.add(field)
if insert_checks:
np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize
if np_dtype.isbuiltin and FieldType.is_generic(field):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
format=field.dtype.numpy_dtype.char)
pre_call_code += template_check_array.format(cond=dtype_cond, what="data type", name=field.name,
expected=str(field.dtype.numpy_dtype))
item_size_cond = "buffer_{name}.itemsize == {size}".format(name=field.name, size=item_size)
pre_call_code += template_check_array.format(cond=item_size_cond, what="itemsize", name=field.name,
expected=item_size)
if field.has_fixed_shape:
shape_cond = ["buffer_{name}.shape[{i}] == {s}".format(s=s, name=field.name, i=i)
for i, s in enumerate(field.spatial_shape)]
shape_cond = " && ".join(shape_cond)
pre_call_code += template_check_array.format(cond=shape_cond, what="shape", name=field.name,
expected=str(field.shape))
expected_strides = [e * item_size for e in field.spatial_strides]
stride_check_code = "(buffer_{name}.strides[{i}] == {s} || buffer_{name}.shape[{i}]<=1)"
strides_cond = " && ".join([stride_check_code.format(s=s, i=i, name=field.name)
for i, s in enumerate(expected_strides)])
pre_call_code += template_check_array.format(cond=strides_cond, what="strides", name=field.name,
expected=str(expected_strides))
else:
if FieldType.is_generic(field):
variable_sized_normal_fields.add(field)
elif FieldType.is_indexed(field):
variable_sized_index_fields.add(field)
elif param.is_field_stride:
field = param.fields[0]
item_size = field.dtype.numpy_dtype.itemsize
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment