diff --git a/pystencils/nbackend/arrays.py b/pystencils/nbackend/arrays.py index 23f0ff7788e9464c8ed3860e78993b3ff4527d0f..47394d2ac633897ddfd62985a0e1b17f1c49b05b 100644 --- a/pystencils/nbackend/arrays.py +++ b/pystencils/nbackend/arrays.py @@ -118,33 +118,67 @@ class PsArrayAssocVar(PsTypedVariable, ABC): to a particular array. """ + init_arg_names: tuple[str, ...] = ("name", "dtype", "array") + __match_args__ = ("name", "dtype", "array") + def __init__(self, name: str, dtype: PsAbstractType, array: PsLinearizedArray): super().__init__(name, dtype) self._array = array + def __getinitargs__(self): + return self.name, self.dtype, self.array + @property def array(self) -> PsLinearizedArray: return self._array class PsArrayBasePointer(PsArrayAssocVar): + init_arg_names: tuple[str, ...] = ("name", "array") + __match_args__ = ("name", "array") + def __init__(self, name: str, array: PsLinearizedArray): dtype = PsPointerType(array.element_type) super().__init__(name, dtype, array) self._array = array + def __getinitargs__(self): + return self.name, self.array + class PsArrayShapeVar(PsArrayAssocVar): - def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType): - name = f"{array}_size{dimension}" + init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") + __match_args__ = ("array", "coordinate", "dtype") + + def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): + name = f"{array}_size{coordinate}" super().__init__(name, dtype, array) + self._coordinate = coordinate + + @property + def coordinate(self) -> int: + return self._coordinate + + def __getinitargs__(self): + return self.array, self.coordinate, self.dtype class PsArrayStrideVar(PsArrayAssocVar): - def __init__(self, array: PsLinearizedArray, dimension: int, dtype: PsIntegerType): - name = f"{array}_size{dimension}" + init_arg_names: tuple[str, ...] = ("array", "coordinate", "dtype") + __match_args__ = ("array", "coordinate", "dtype") + + def __init__(self, array: PsLinearizedArray, coordinate: int, dtype: PsIntegerType): + name = f"{array}_size{coordinate}" super().__init__(name, dtype, array) + self._coordinate = coordinate + + @property + def coordinate(self) -> int: + return self._coordinate + + def __getinitargs__(self): + return self.array, self.coordinate, self.dtype class PsArrayAccess(pb.Subscript): diff --git a/pystencils/nbackend/ast/constraints.py b/pystencils/nbackend/ast/constraints.py index 68cbe347a3acc7c5716da86277f952323e26f8ae..d11fe1195a0df3c1f12cfb159ce1763a3dbf7c8f 100644 --- a/pystencils/nbackend/ast/constraints.py +++ b/pystencils/nbackend/ast/constraints.py @@ -2,6 +2,9 @@ from dataclasses import dataclass import pymbolic.primitives as pb from pymbolic.mapper.c_code import CCodeMapper +from pymbolic.mapper.dependency import DependencyMapper + +from ..typed_expressions import PsTypedVariable @dataclass @@ -9,5 +12,11 @@ class PsParamConstraint: condition: pb.Comparison message: str = "" - def print(self): + def print_c_condition(self): return CCodeMapper()(self.condition) + + def get_variables(self) -> set[PsTypedVariable]: + return DependencyMapper(False, False, False, False)(self.condition) + + def __str__(self) -> str: + return f"{self.message} [{self.condition}]" diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py index fccd4f12a18292a53566ad3a7d93eef6fef637f4..e2f0ac8e678e3d794b01536afc76dac03245f9ab 100644 --- a/pystencils/nbackend/ast/kernelfunction.py +++ b/pystencils/nbackend/ast/kernelfunction.py @@ -126,3 +126,6 @@ class PsKernelFunction(PsAstNode): arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer)) return PsKernelParametersSpec(tuple(params_list), tuple(arrays), tuple(self._constraints)) + + def get_required_headers(self) -> set[str]: + raise NotImplementedError() diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/emission.py similarity index 94% rename from pystencils/nbackend/c_printer.py rename to pystencils/nbackend/emission.py index 58a61d579991f6fa2c51ee2300b7941bf584ed0e..b315b172e4173aeaa2d2308928b94c52535e7556 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/emission.py @@ -6,6 +6,12 @@ from .ast import ast_visitor, PsAstNode, PsBlock, PsExpression, PsDeclaration, P from .ast.kernelfunction import PsKernelFunction +def emit_code(kernel: PsKernelFunction): + # TODO: Specialize for different targets + printer = CPrinter() + return printer.print(kernel) + + class CPrinter: def __init__(self, indent_width=3): self._indent_width = indent_width diff --git a/pystencils/nbackend/jit/cpu_extension_module.py b/pystencils/nbackend/jit/cpu_extension_module.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d021c5ddbb3e952a5fa2c6877f934262e590aa --- /dev/null +++ b/pystencils/nbackend/jit/cpu_extension_module.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +from typing import Any + +from os import path +import hashlib + +from itertools import chain + +from ..exceptions import PsInternalCompilerError +from ..ast import PsKernelFunction +from ..ast.constraints import PsParamConstraint +from ..typed_expressions import PsTypedVariable +from ..arrays import ( + PsLinearizedArray, + PsArrayAssocVar, + PsArrayBasePointer, + PsArrayShapeVar, + PsArrayStrideVar, +) +from ..types import PsAbstractType +from ..types.quick import Fp, SInt, UInt +from ..emission import emit_code + + +class PsKernelExtensioNModule: + """Replacement for `pystencils.cpu.cpujit.ExtensionModuleCode`. + Conforms to its interface for plug-in to `compile_and_load`. + """ + + def __init__( + self, module_name: str = "generated", custom_backend: Any = None + ) -> None: + self._module_name = module_name + + if custom_backend is not None: + raise PsInternalCompilerError( + "The `custom_backend` parameter exists only for interface compatibility and cannot be set." + ) + + self._kernels: dict[str, PsKernelFunction] = dict() + self._code_string: str | None = None + self._code_hash: str | None = None + + @property + def module_name(self) -> str: + return self._module_name + + def add_function(self, kernel_function: PsKernelFunction, name: str | None = None): + if name is None: + name = kernel_function.name + + self._kernels[name] = kernel_function + + def create_code_string(self, restrict_qualifier: str, function_prefix: str): + code = "" + + # Collect headers + headers = {"<math.h>", "<stdint.h>"} + for kernel in self._kernels.values(): + headers |= kernel.get_required_headers() + + header_list = sorted(headers) + header_list.insert(0, '"Python.h"') + + from pystencils.include import get_pystencils_include_path + + ps_incl_path = get_pystencils_include_path() + + ps_headers = [] + for header in header_list: + header = header[1:-1] + header_path = path.join(ps_incl_path, header) + if path.exists(header_path): + ps_headers.append(header_path) + + header_hash = b"".join( + [hashlib.sha256(open(h, "rb").read()).digest() for h in ps_headers] + ) + + # Prelude: Includes and definitions + + includes = "\n".join(f"#include {header}" for header in header_list) + + code += includes + code += "\n" + code += f"#define RESTRICT {restrict_qualifier}\n" + code += f"#define FUNC_PREFIX {function_prefix}\n" + code += "\n" + + # Kernels and call wrappers + + for name, kernel in self._kernels.items(): + old_name = kernel.name + kernel.name = f"kernel_{name}" + + code += emit_code(kernel) + code += "\n" + code += emit_call_wrapper(name, kernel) + code += "\n" + + kernel.name = old_name + + self._code_hash = ( + "mod_" + hashlib.sha256(code.encode() + header_hash).hexdigest() + ) + + from ...cpu.cpujit import create_module_boilerplate_code + + code += create_module_boilerplate_code(self._code_hash, self._kernels.keys()) + + def get_hash_of_code(self): + assert self._code_string is not None, "The code must be generated first" + return self._code_hash + + def write_to_file(self, file): + assert self._code_string is not None, "The code must be generated first" + print(self._code_string, file=file) + + +def emit_call_wrapper(function_name: str, kernel: PsKernelFunction) -> str: + builder = CallWrapperBuilder() + params_spec = kernel.get_parameters() + + for p in params_spec.params: + builder.extract_parameter(p) + + for c in params_spec.constraints: + builder.check_constraint(c) + + builder.call(kernel, params_spec.params) + + return builder.resolve(function_name) + + +class CallWrapperBuilder: + TMPL_EXTRACT_SCALAR = """ +PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); +if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; +{target_type} {name} = ({target_type}) {extract_function}( obj_{name} ); +if( PyErr_Occurred() ) {{ return NULL; }} +""" + + TMPL_EXTRACT_ARRAY = """ +PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}"); +if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }}; +Py_buffer buffer_{name}; +int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT); +if (buffer_{name}_res == -1) {{ return NULL; }} +""" + + KWCHECK = """ +if( !kwargs || !PyDict_Check(kwargs) ) {{ + PyErr_SetString(PyExc_TypeError, "No keyword arguments passed"); + return NULL; + }} +""" + + def __init__(self) -> None: + self._array_buffers: dict[PsLinearizedArray, str] = dict() + self._array_extractions: dict[PsLinearizedArray, str] = dict() + self._array_frees: dict[PsLinearizedArray, str] = dict() + + self._array_assoc_var_extractions: dict[PsArrayAssocVar, str] = dict() + self._scalar_extractions: dict[PsTypedVariable, str] = dict() + + self._constraint_checks: list[str] = [] + + self._call: str | None = None + + def _scalar_extractor(self, dtype: PsAbstractType) -> str: + match dtype: + case Fp(32) | Fp(64): + return "PyFloat_AsDouble" + case SInt(): + return "PyLong_AsLong" + case UInt(): + return "PyLong_AsUnsignedLong" + + case _: + raise PsInternalCompilerError( + f"Don't know how to cast Python objects to {dtype}" + ) + + def extract_array(self, arr: PsLinearizedArray) -> str: + """Adds an array, and returns the name of the underlying Py_Buffer.""" + if arr not in self._array_extractions: + extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=arr.name) + self._array_buffers[arr] = f"buffer_{arr.name}" + self._array_extractions[arr] = extraction_code + + release_code = f"PyBuffer_Release(&buffer_{arr.name});" + self._array_frees[arr] = release_code + + return self._array_buffers[arr] + + def extract_scalar(self, variable: PsTypedVariable) -> str: + if variable not in self._scalar_extractions: + self.TMPL_EXTRACT_SCALAR.format() + + extract_func = self._scalar_extractor(variable.dtype) + code = self.TMPL_EXTRACT_SCALAR.format( + name=variable.name, + target_type=str(variable.dtype), + extract_function=extract_func, + ) + self._scalar_extractions[variable] = code + + return variable.name + + def extract_array_assoc_var(self, variable: PsArrayAssocVar) -> str: + if variable not in self._array_assoc_var_extractions: + arr = variable.array + buffer = self.extract_array(arr) + match variable: + case PsArrayBasePointer(): + code = f"{variable.dtype} {variable.name} = ({variable.dtype}) {buffer}.buf;" + case PsArrayShapeVar(): + coord = variable.coordinate + code = ( + f"{variable.dtype} {variable.name} = " + f"{buffer}.shape[{coord}] / {arr.element_type.itemsize};" + ) + case PsArrayStrideVar(): + coord = variable.coordinate + code = ( + f"{variable.dtype} {variable.name} = " + f"{buffer}.strides[{coord}] / {arr.element_type.itemsize};" + ) + case _: + assert False, "unreachable code" + + self._array_assoc_var_extractions[variable] = code + + return variable.name + + def extract_parameter(self, variable: PsTypedVariable): + match variable: + case PsArrayAssocVar(): + self.extract_array_assoc_var(variable) + case PsTypedVariable(): + self.extract_scalar(variable) + case _: + assert False, "Invalid variable encountered." + + def check_constraint(self, constraint: PsParamConstraint): + variables = constraint.get_variables() + + for var in variables: + self.extract_parameter(var) + + cond = constraint.print_c_condition() + + code = f""" +if(!({cond})) +{{ + PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}"); + return NULL; +}} +""" + + self._constraint_checks.append(code) + + def call(self, kernel: PsKernelFunction, params: tuple[PsTypedVariable, ...]): + param_list = ", ".join(p.name for p in params) + self._call = f"{kernel.name} ({param_list});" + + def resolve(self, function_name) -> str: + assert self._call is not None + + body = "\n".join( + chain( + [self.KWCHECK], + self._scalar_extractions.values(), + self._array_extractions.values(), + self._array_assoc_var_extractions.values(), + [self._call], + self._array_frees.values(), + ["Py_RETURN_NONE;"], + ) + ) + + code = f"static PyObject * {function_name}(PyObject * self, PyObject * args, PyObject * kwargs)\n" + code += "{\n" + body + "\n}\n" + + return code diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index e878fbfcee0b03f4444382672a624fad10f5679e..657eaf4f0845f16a784418769597b07382d6c0c8 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -13,10 +13,18 @@ from .types import ( class PsTypedVariable(pb.Variable): + + init_arg_names: tuple[str, ...] = ("name", "dtype") + + __match_args__ = ("name", "dtype") + def __init__(self, name: str, dtype: PsAbstractType): super(PsTypedVariable, self).__init__(name) self._dtype = dtype + def __getinitargs__(self): + return self.name, self._dtype + @property def dtype(self) -> PsAbstractType: return self._dtype diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py index 412189968ce7dcdc90db507a2e2fad5b91c3514a..45deb88a001b754338b21b46d44cfce131aeed59 100644 --- a/pystencils/nbackend/types/basic_types.py +++ b/pystencils/nbackend/types/basic_types.py @@ -185,6 +185,11 @@ class PsScalarType(PsNumericType, ABC): def is_float(self) -> bool: return isinstance(self, PsIeeeFloatType) + + @property + @abstractmethod + def itemsize(self) -> int: + """Size of this type's elements in bytes.""" class PsIntegerType(PsScalarType, ABC): @@ -216,6 +221,10 @@ class PsIntegerType(PsScalarType, ABC): @property def signed(self) -> bool: return self._signed + + @property + def itemsize(self) -> int: + return self.width // 8 def __eq__(self, other: object) -> bool: if not isinstance(other, PsIntegerType): @@ -320,6 +329,10 @@ class PsIeeeFloatType(PsScalarType): @property def width(self) -> int: return self._width + + @property + def itemsize(self) -> int: + return self.width // 8 def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] diff --git a/pystencils_tests/nbackend/test_basic_printing.py b/pystencils_tests/nbackend/test_basic_printing.py index 8679211146633fdd48bd8aadf0128cfe1fb9d012..4031c12feaf542d562dcd6e18e93de3e0b6dd6c3 100644 --- a/pystencils_tests/nbackend/test_basic_printing.py +++ b/pystencils_tests/nbackend/test_basic_printing.py @@ -5,7 +5,7 @@ from pystencils import Target from pystencils.nbackend.ast import * from pystencils.nbackend.typed_expressions import * from pystencils.nbackend.types.quick import * -from pystencils.nbackend.c_printer import CPrinter +from pystencils.nbackend.emission import CPrinter def test_basic_kernel():