Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Viewing commit ba697180
Show latest version
1 file
+ 30
23
Preferences
Compare changes
@@ -199,9 +199,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
"""
def __init__(self) -> None:
self._array_buffers: dict[Field, str] = dict()
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
self._array_buffers: dict[Any, str] = dict()
self._array_extractions: dict[Any, str] = dict()
self._array_frees: dict[Any, str] = dict()
self._array_assoc_var_extractions: dict[Parameter, str] = dict()
self._scalar_extractions: dict[Parameter, str] = dict()
@@ -235,36 +235,37 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
else:
return None
def extract_field(self, field: Field) -> str:
def extract_buffer(self, buffer: Any, name: str, dtype: PsType) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer."""
if field not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
if buffer not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=name)
# Check array type
type_char = self._type_char(field.dtype)
type_char = self._type_char(dtype)
if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
dtype_cond = f"buffer_{name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=field.name,
expected=str(field.dtype),
name=name,
expected=str(dtype),
)
# Check item size
itemsize = field.dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
)
itemsize = dtype.itemsize
if itemsize is not None: # itemsize of pointer not known (TODO?)
item_size_cond = f"buffer_{name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=name, expected=itemsize
)
self._array_buffers[field] = f"buffer_{field.name}"
self._array_extractions[field] = extraction_code
self._array_buffers[buffer] = f"buffer_{name}"
self._array_extractions[buffer] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});"
self._array_frees[field] = release_code
release_code = f"PyBuffer_Release(&buffer_{name});"
self._array_frees[buffer] = release_code
return self._array_buffers[field]
return self._array_buffers[buffer]
def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions:
@@ -280,14 +281,20 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_reduction_ptr(self, param: Parameter) -> str:
if param not in self._reduction_ptrs:
# TODO: implement
pass
ptr = param.reduction_pointer
buffer = self.extract_buffer(ptr, param.name, param.dtype)
code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;"
assert code is not None
self._array_assoc_var_extractions[param] = code
return param.name
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
buffer = self.extract_field(field)
buffer = self.extract_buffer(field, field.name, field.dtype)
code: str | None = None
for prop in param.properties: