-
Frederik Hennig authoredFrederik Hennig authored
cpu_extension_module.py 11.23 KiB
from __future__ import annotations
from typing import Any
from os import path
import hashlib
from itertools import chain
from textwrap import indent
import numpy as np
from ..exceptions import PsInternalCompilerError
from ..kernelfunction import (
KernelFunction,
KernelParameter,
)
from ..properties import FieldBasePtr, FieldShape, FieldStride
from ..constraints import KernelParamsConstraint
from ...types import (
PsType,
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
)
from ...types.quick import Fp, SInt, UInt
from ...field import Field
from ..emission import emit_code
class PsKernelExtensioNModule:
"""Replacement for `pystencils.cpu.cpujit.ExtensionModuleCode`.
Conforms to its interface for plug-in to `compile_and_load`.
"""
def __init__(
self, module_name: str = "generated", custom_backend: Any = None
) -> None:
self._module_name = module_name
if custom_backend is not None:
raise PsInternalCompilerError(
"The `custom_backend` parameter exists only for interface compatibility and cannot be set."
)
self._kernels: dict[str, KernelFunction] = dict()
self._code_string: str | None = None
self._code_hash: str | None = None
@property
def module_name(self) -> str:
return self._module_name
def add_function(self, kernel_function: KernelFunction, name: str | None = None):
if name is None:
name = kernel_function.name
self._kernels[name] = kernel_function
def create_code_string(self, restrict_qualifier: str, function_prefix: str):
code = ""
# Collect headers
headers = {"<stdint.h>"}
for kernel in self._kernels.values():
headers |= kernel.required_headers
header_list = sorted(headers)
header_list.insert(0, '"Python.h"')
from pystencils.include import get_pystencils_include_path
ps_incl_path = get_pystencils_include_path()
ps_headers = []
for header in header_list:
header = header[1:-1]
header_path = path.join(ps_incl_path, header)
if path.exists(header_path):
ps_headers.append(header_path)
header_hash = b"".join(
[hashlib.sha256(open(h, "rb").read()).digest() for h in ps_headers]
)
# Prelude: Includes and definitions
includes = "\n".join(f"#include {header}" for header in header_list)
code += includes
code += "\n"
code += f"#define RESTRICT {restrict_qualifier}\n"
code += f"#define FUNC_PREFIX {function_prefix}\n"
code += "\n"
# Kernels and call wrappers
for name, kernel in self._kernels.items():
old_name = kernel.name
kernel.name = f"kernel_{name}"
code += emit_code(kernel)
code += "\n"
code += emit_call_wrapper(name, kernel)
code += "\n"
kernel.name = old_name
self._code_hash = (
"mod_" + hashlib.sha256(code.encode() + header_hash).hexdigest()
)
code += create_module_boilerplate_code(self._code_hash, self._kernels.keys())
self._code_string = code
def get_hash_of_code(self):
assert self._code_string is not None, "The code must be generated first"
return self._code_hash
def write_to_file(self, file):
assert self._code_string is not None, "The code must be generated first"
print(self._code_string, file=file)
def emit_call_wrapper(function_name: str, kernel: KernelFunction) -> str:
builder = CallWrapperBuilder()
for p in kernel.parameters:
builder.extract_parameter(p)
for c in kernel.constraints:
builder.check_constraint(c)
builder.call(kernel, kernel.parameters)
return builder.resolve(function_name)
template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
{method_definitions}
{{NULL, NULL, 0, NULL}}
}};
static struct PyModuleDef module_definition = {{
PyModuleDef_HEAD_INIT,
"{module_name}", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_definitions
}};
PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
return PyModule_Create(&module_definition);
}}
"""
def create_module_boilerplate_code(module_name, names):
method_definition = (
'{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
)
method_definitions = "\n".join(
[method_definition.format(name=name) for name in names]
)
return template_module_boilerplate.format(
module_name=module_name, method_definitions=method_definitions
)
class CallWrapperBuilder:
TMPL_EXTRACT_SCALAR = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
TMPL_EXTRACT_ARRAY = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
if (buffer_{name}_res == -1) {{ return NULL; }}
"""
TMPL_CHECK_ARRAY_TYPE = """
if(!({cond})) {{
PyErr_SetString(PyExc_TypeError, "Wrong {what} of array {name}. Expected {expected}");
return NULL;
}}
"""
KWCHECK = """
if( !kwargs || !PyDict_Check(kwargs) ) {{
PyErr_SetString(PyExc_TypeError, "No keyword arguments passed");
return NULL;
}}
"""
def __init__(self) -> None:
self._array_buffers: dict[Field, str] = dict()
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
self._array_assoc_var_extractions: dict[KernelParameter, str] = dict()
self._scalar_extractions: dict[KernelParameter, str] = dict()
self._constraint_checks: list[str] = []
self._call: str | None = None
def _scalar_extractor(self, dtype: PsType) -> str:
match dtype:
case Fp(32) | Fp(64):
return "PyFloat_AsDouble"
case SInt():
return "PyLong_AsLong"
case UInt():
return "PyLong_AsUnsignedLong"
case _:
raise PsInternalCompilerError(
f"Don't know how to cast Python objects to {dtype}"
)
def _type_char(self, dtype: PsType) -> str | None:
if isinstance(
dtype, (PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType)
):
np_dtype = dtype.NUMPY_TYPES[dtype.width]
return np.dtype(np_dtype).char
else:
return None
def extract_field(self, field: Field) -> str:
"""Adds an array, and returns the name of the underlying Py_Buffer."""
if field not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
# Check array type
type_char = self._type_char(field.dtype)
if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=field.name,
expected=str(field.dtype),
)
# Check item size
itemsize = field.dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
)
self._array_buffers[field] = f"buffer_{field.name}"
self._array_extractions[field] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});"
self._array_frees[field] = release_code
return self._array_buffers[field]
def extract_scalar(self, param: KernelParameter) -> str:
if param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
code = self.TMPL_EXTRACT_SCALAR.format(
name=param.name,
target_type=str(param.dtype),
extract_function=extract_func,
)
self._scalar_extractions[param] = code
return param.name
def extract_array_assoc_var(self, param: KernelParameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
buffer = self.extract_field(field)
code: str | None = None
for prop in param.properties:
match prop:
case FieldBasePtr():
code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;"
break
case FieldShape(_, coord):
code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];"
break
case FieldStride(_, coord):
code = (
f"{param.dtype} {param.name} = "
f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
)
break
assert code is not None
self._array_assoc_var_extractions[param] = code
return param.name
def extract_parameter(self, param: KernelParameter):
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
def check_constraint(self, constraint: KernelParamsConstraint):
variables = constraint.get_parameters()
for var in variables:
self.extract_parameter(var)
cond = constraint.to_code()
code = f"""
if(!({cond}))
{{
PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
return NULL;
}}
"""
self._constraint_checks.append(code)
def call(self, kernel: KernelFunction, params: tuple[KernelParameter, ...]):
param_list = ", ".join(p.name for p in params)
self._call = f"{kernel.name} ({param_list});"
def resolve(self, function_name) -> str:
assert self._call is not None
body = "\n\n".join(
chain(
[self.KWCHECK],
self._scalar_extractions.values(),
self._array_extractions.values(),
self._array_assoc_var_extractions.values(),
self._constraint_checks,
[self._call],
self._array_frees.values(),
["Py_RETURN_NONE;"],
)
)
code = f"static PyObject * {function_name}(PyObject * self, PyObject * args, PyObject * kwargs)\n"
code += "{\n" + indent(body, prefix=" ") + "\n}\n"
return code