Skip to content
Snippets Groups Projects
Commit 4f6f5580 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Rewire existing code extraction of fields to support reduction pointer extraction

parent 4e748308
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -199,9 +199,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ ...@@ -199,9 +199,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._array_buffers: dict[Field, str] = dict() self._array_buffers: dict[Any, str] = dict()
self._array_extractions: dict[Field, str] = dict() self._array_extractions: dict[Any, str] = dict()
self._array_frees: dict[Field, str] = dict() self._array_frees: dict[Any, str] = dict()
self._array_assoc_var_extractions: dict[Parameter, str] = dict() self._array_assoc_var_extractions: dict[Parameter, str] = dict()
self._scalar_extractions: dict[Parameter, str] = dict() self._scalar_extractions: dict[Parameter, str] = dict()
...@@ -235,36 +235,37 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ ...@@ -235,36 +235,37 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
else: else:
return None return None
def extract_field(self, field: Field) -> str: def extract_buffer(self, buffer: Any, name: str, dtype: PsType) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer.""" """Adds an array, and returns the name of the underlying Py_Buffer."""
if field not in self._array_extractions: if buffer not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name) extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=name)
# Check array type # Check array type
type_char = self._type_char(field.dtype) type_char = self._type_char(dtype)
if type_char is not None: if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'" dtype_cond = f"buffer_{name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond, cond=dtype_cond,
what="data type", what="data type",
name=field.name, name=name,
expected=str(field.dtype), expected=str(dtype),
) )
# Check item size # Check item size
itemsize = field.dtype.itemsize itemsize = dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}" if itemsize is not None: # itemsize of pointer not known (TODO?)
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format( item_size_cond = f"buffer_{name}.itemsize == {itemsize}"
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
) cond=item_size_cond, what="itemsize", name=name, expected=itemsize
)
self._array_buffers[field] = f"buffer_{field.name}" self._array_buffers[buffer] = f"buffer_{name}"
self._array_extractions[field] = extraction_code self._array_extractions[buffer] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});" release_code = f"PyBuffer_Release(&buffer_{name});"
self._array_frees[field] = release_code self._array_frees[buffer] = release_code
return self._array_buffers[field] return self._array_buffers[buffer]
def extract_scalar(self, param: Parameter) -> str: def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions: if param not in self._scalar_extractions:
...@@ -280,14 +281,20 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ ...@@ -280,14 +281,20 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
def extract_reduction_ptr(self, param: Parameter) -> str: def extract_reduction_ptr(self, param: Parameter) -> str:
if param not in self._reduction_ptrs: if param not in self._reduction_ptrs:
# TODO: implement ptr = param.reduction_pointer
pass buffer = self.extract_buffer(ptr, param.name, param.dtype)
code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;"
assert code is not None
self._array_assoc_var_extractions[param] = code
return param.name return param.name
def extract_array_assoc_var(self, param: Parameter) -> str: def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions: if param not in self._array_assoc_var_extractions:
field = param.fields[0] field = param.fields[0]
buffer = self.extract_field(field) buffer = self.extract_buffer(field, field.name, field.dtype)
code: str | None = None code: str | None = None
for prop in param.properties: for prop in param.properties:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment