Skip to content
Snippets Groups Projects
Commit ba697180 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Rewire existing code extraction of fields to support reduction pointer extraction

parent f1c556e6
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -199,9 +199,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
"""
def __init__(self) -> None:
self._array_buffers: dict[Field, str] = dict()
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
self._array_buffers: dict[Any, str] = dict()
self._array_extractions: dict[Any, str] = dict()
self._array_frees: dict[Any, str] = dict()
self._array_assoc_var_extractions: dict[Parameter, str] = dict()
self._scalar_extractions: dict[Parameter, str] = dict()
......@@ -235,36 +235,37 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
else:
return None
def extract_field(self, field: Field) -> str:
def extract_buffer(self, buffer: Any, name: str, dtype: PsType) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer."""
if field not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
if buffer not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=name)
# Check array type
type_char = self._type_char(field.dtype)
type_char = self._type_char(dtype)
if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
dtype_cond = f"buffer_{name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=field.name,
expected=str(field.dtype),
name=name,
expected=str(dtype),
)
# Check item size
itemsize = field.dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
)
itemsize = dtype.itemsize
if itemsize is not None: # itemsize of pointer not known (TODO?)
item_size_cond = f"buffer_{name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=name, expected=itemsize
)
self._array_buffers[field] = f"buffer_{field.name}"
self._array_extractions[field] = extraction_code
self._array_buffers[buffer] = f"buffer_{name}"
self._array_extractions[buffer] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});"
self._array_frees[field] = release_code
release_code = f"PyBuffer_Release(&buffer_{name});"
self._array_frees[buffer] = release_code
return self._array_buffers[field]
return self._array_buffers[buffer]
def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions:
......@@ -280,14 +281,20 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_reduction_ptr(self, param: Parameter) -> str:
if param not in self._reduction_ptrs:
# TODO: implement
pass
ptr = param.reduction_pointer
buffer = self.extract_buffer(ptr, param.name, param.dtype)
code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;"
assert code is not None
self._array_assoc_var_extractions[param] = code
return param.name
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
buffer = self.extract_field(field)
buffer = self.extract_buffer(field, field.name, field.dtype)
code: str | None = None
for prop in param.properties:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment