Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • const_fix
  • fhennig/v2.0-deprecations
  • fma
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • master
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikeliml/124-rework-tutorials
  • zikeliml/Task-96-dotExporterForAST
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
  • release/1.3.4
  • release/1.3.5
  • release/1.3.6
  • release/1.3.7
  • release/2.0.dev0
56 results

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • const_fix
  • fhennig/v2.0-deprecations
  • fma
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • master
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikeliml/124-rework-tutorials
  • zikeliml/Task-96-dotExporterForAST
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
  • release/1.3.4
  • release/1.3.5
  • release/1.3.6
  • release/1.3.7
  • release/2.0.dev0
56 results
Show changes
Showing
with 2038 additions and 503 deletions
......@@ -3,7 +3,7 @@ from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from pystencils.enums import Target, Backend
from ..codegen import Target
from pystencils.field import Field, FieldType
......@@ -18,7 +18,6 @@ class DataHandling(ABC):
"""
_GPU_LIKE_TARGETS = [Target.GPU]
_GPU_LIKE_BACKENDS = [Backend.CUDA]
# ---------------------------- Adding and accessing data -----------------------------------------------------------
@property
......@@ -83,7 +82,7 @@ class DataHandling(ABC):
>>> dh = create_data_handling((20, 30))
>>> x, y =dh.add_arrays('x, y(9)')
>>> print(dh.fields)
{'x': x: double[22,32], 'y': y(9): double[22,32]}
{'x': x: float64[22,32], 'y': y(9): float64[22,32]}
>>> assert x == dh.fields['x']
>>> assert dh.fields['x'].shape == (22, 32)
>>> assert dh.fields['y'].index_shape == (9,)
......
......@@ -7,10 +7,9 @@ import waLBerla as wlb
from pystencils.datahandling.blockiteration import block_iteration, sliced_block_iteration
from pystencils.datahandling.datahandling_interface import DataHandling
from pystencils.enums import Backend
from pystencils.field import Field, FieldType
from pystencils.typing.typed_sympy import FieldPointerSymbol
from pystencils.utils import DotDict
from pystencils.codegen.properties import FieldBasePtr
from pystencils import Target
......@@ -253,15 +252,15 @@ class ParallelDataHandling(DataHandling):
kernel_function(**arg_dict)
def get_kernel_kwargs(self, kernel_function, **kwargs):
if kernel_function.ast.backend == Backend.CUDA:
if kernel_function.ast.target.is_gpu():
name_map = self._field_name_to_gpu_data_name
to_array = wlb.gpu.toGpuArray
else:
name_map = self._field_name_to_cpu_data_name
to_array = wlb.field.toArray
data_used_in_kernel = [(name_map[p.symbol.field_name], self.fields[p.symbol.field_name])
data_used_in_kernel = [(name_map[p.field_name], self.fields[p.field_name])
for p in kernel_function.parameters if
isinstance(p.symbol, FieldPointerSymbol) and p.symbol.field_name not in kwargs]
p.get_properties(FieldBasePtr) and p.field_name not in kwargs]
result = []
for block in self.blocks:
......
......@@ -6,7 +6,7 @@ import numpy as np
from pystencils.datahandling.blockiteration import SerialBlock
from pystencils.datahandling.datahandling_interface import DataHandling
from pystencils.enums import Target
from ..codegen import Target
from pystencils.field import (Field, FieldType, create_numpy_array_with_layout,
layout_string_to_tuple, spatial_layout_string_to_tuple)
from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler
......@@ -254,12 +254,12 @@ class SerialDataHandling(DataHandling):
self.to_gpu(name)
def run_kernel(self, kernel_function, **kwargs):
arrays = self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays
arrays = self.gpu_arrays if kernel_function.target.is_gpu() else self.cpu_arrays
kernel_function(**{**arrays, **kwargs})
def get_kernel_kwargs(self, kernel_function, **kwargs):
result = {}
result.update(self.gpu_arrays if kernel_function.ast.backend in self._GPU_LIKE_BACKENDS else self.cpu_arrays)
result.update(self.gpu_arrays if kernel_function.target.is_gpu() else self.cpu_arrays)
result.update(kwargs)
return [result]
......@@ -291,7 +291,10 @@ class SerialDataHandling(DataHandling):
def synchronization_function(self, names, stencil=None, target=None, functor=None, **_):
if target is None:
target = self.default_target
assert target in (Target.CPU, Target.GPU)
if not (target.is_cpu() or target == Target.CUDA):
raise ValueError(f"Unsupported target: {target}")
if not hasattr(names, '__len__') or type(names) is str:
names = [names]
......@@ -325,7 +328,7 @@ class SerialDataHandling(DataHandling):
values_per_cell = values_per_cell[0]
if len(filtered_stencil) > 0:
if target == Target.CPU:
if target.is_cpu():
if functor is None:
from pystencils.slicing import get_periodic_boundary_functor
functor = get_periodic_boundary_functor
......@@ -344,11 +347,11 @@ class SerialDataHandling(DataHandling):
if target == Target.CPU:
def result_functor():
for arr_name, func in zip(names, result):
func(pdfs=self.cpu_arrays[arr_name])
func(self.cpu_arrays[arr_name])
else:
def result_functor():
for arr_name, func in zip(names, result):
func(pdfs=self.gpu_arrays[arr_name])
func(self.gpu_arrays[arr_name])
return result_functor
......
from .types import (
PsIeeeFloatType,
PsIntegerType,
PsSignedIntegerType,
PsStructType,
UserTypeSpec,
create_type,
)
from pystencils.sympyextensions.typed_sympy import TypedSymbol, DynamicType
class SympyDefaults:
def __init__(self):
self.numeric_dtype = PsIeeeFloatType(64)
"""Default data type for numerical computations"""
self.index_dtype: PsIntegerType = PsSignedIntegerType(64)
"""Default data type for indices."""
self.spatial_counter_names = ("ctr_0", "ctr_1", "ctr_2")
"""Names of the default spatial counters"""
self.spatial_counters = (
TypedSymbol("ctr_0", DynamicType.INDEX_TYPE),
TypedSymbol("ctr_1", DynamicType.INDEX_TYPE),
TypedSymbol("ctr_2", DynamicType.INDEX_TYPE),
)
"""Default spatial counters"""
self.index_struct_coordinate_names = ("x", "y", "z")
"""Default names of spatial coordinate members in index list structures"""
self.sparse_counter_name = "sparse_idx"
"""Name of the default sparse iteration counter"""
self.sparse_counter = TypedSymbol(
self.sparse_counter_name, DynamicType.INDEX_TYPE
)
"""Default sparse iteration counter."""
def field_shape_name(self, field_name: str, coord: int):
return f"_size_{field_name}_{coord}"
def field_stride_name(self, field_name: str, coord: int):
return f"_stride_{field_name}_{coord}"
def field_pointer_name(self, field_name: str):
return f"_data_{field_name}"
def index_struct(self, index_dtype: UserTypeSpec, dim: int) -> PsStructType:
idx_type = create_type(index_dtype)
return PsStructType(
[(name, idx_type) for name in self.index_struct_coordinate_names[:dim]]
)
DEFAULTS = SympyDefaults()
"""Default names and symbols used throughout code generation"""
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
import sympy as sp
from pystencils.astnodes import KernelFunction
from pystencils.enums import Backend
from pystencils.kernel_wrapper import KernelWrapper
from .codegen import Kernel
from .jit import KernelWrapper
def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
"""Show a sympy or pystencils AST as dot graph"""
from pystencils.astnodes import Node
try:
import graphviz
except ImportError:
......@@ -18,12 +16,15 @@ def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=Tr
graph_style = {} if graph_style is None else graph_style
if isinstance(expr, Node):
from pystencils.backends.dot import print_dot
return graphviz.Source(print_dot(expr, short=short, graph_attr=graph_style))
else:
# if isinstance(expr, Node):
# from pystencils.backends.dot import print_dot
# return graphviz.Source(print_dot(expr, short=short, graph_attr=graph_style))
if isinstance(expr, sp.Basic):
from sympy.printing.dot import dotprint
return graphviz.Source(dotprint(expr, graph_attr=graph_style))
else:
# TODO Implement dot / graphviz exporter for new backend AST
raise NotImplementedError("Printing of AST nodes for the new backend is not implemented yet")
def highlight_cpp(code: str):
......@@ -41,33 +42,27 @@ def highlight_cpp(code: str):
return HTML(highlight(code, CppLexer(), HtmlFormatter()))
def get_code_obj(ast: Union[KernelFunction, KernelWrapper], custom_backend=None):
def get_code_obj(ast: KernelWrapper | Kernel, custom_backend=None):
"""Returns an object to display generated code (C/C++ or CUDA)
Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
"""
from pystencils.backends.cbackend import generate_c
if isinstance(ast, KernelWrapper):
ast = ast.ast
if ast.backend not in {Backend.C, Backend.CUDA}:
raise NotImplementedError(f'get_code_obj is not implemented for backend {ast.backend}')
dialect = ast.backend
func = ast.kernel_function
else:
func = ast
class CodeDisplay:
def __init__(self, ast_input):
self.ast = ast_input
def _repr_html_(self):
return highlight_cpp(generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)).__html__()
return highlight_cpp(func.get_c_code()).__html__()
def __str__(self):
return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)
return func.get_c_code()
def __repr__(self):
return generate_c(self.ast, dialect=dialect, custom_backend=custom_backend)
return CodeDisplay(ast)
return func.get_c_code()
return CodeDisplay()
def get_code_str(ast, custom_backend=None):
......@@ -87,7 +82,7 @@ def _isnotebook():
return False
def show_code(ast: Union[KernelFunction, KernelWrapper], custom_backend=None):
def show_code(ast: KernelWrapper | Kernel, custom_backend=None):
code = get_code_obj(ast, custom_backend)
if _isnotebook():
......
from enum import Enum, auto
from .codegen import Target as _Target
from warnings import warn
class Target(Enum):
"""
The Target enumeration represents all possible targets that can be used for the code generation.
"""
CPU = auto()
"""
Target CPU architecture.
"""
GPU = auto()
"""
Target GPU architecture.
"""
warn(
"Importing anything from `pystencils.enums` is deprecated and the module will be removed in pystencils 2.1. "
"Import from `pystencils` instead.",
FutureWarning
)
class Backend(Enum):
"""
The Backend enumeration represents all possible backends that can be used for the code generation.
Backends and targets must be combined with care. For example CPU as a target and CUDA as a backend makes no sense.
"""
C = auto()
"""
Use the C Backend of pystencils.
"""
CUDA = auto()
"""
Use the CUDA backend to generate code for NVIDIA GPUs.
"""
Target = _Target
......@@ -7,7 +7,7 @@ from pystencils.fd import Diff
from pystencils.fd.derivative import diff_args
from pystencils.fd.spatial import fd_stencils_standard
from pystencils.field import Field
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.simp import AssignmentCollection
from pystencils.sympyextensions import fast_subs
FieldOrFieldAccess = Union[Field, Field.Access]
......
......@@ -3,10 +3,10 @@ from typing import Tuple
import sympy as sp
from pystencils.astnodes import LoopOverCoordinate
from pystencils.fd import Diff
from pystencils.field import Field
from pystencils.transformations import generic_visit
from pystencils.sympyextensions.astnodes import generic_visit
from pystencils.sympyextensions.typed_sympy import is_loop_counter_symbol
from .derivation import FiniteDifferenceStencilDerivation
from .derivative import diff_args
......@@ -112,7 +112,7 @@ def discretize_spatial_staggered(expr, dx, stencil=fd_stencils_standard):
elif isinstance(e, Field.Access):
return (e.neighbor(coordinate, sign) + e) / 2
elif isinstance(e, sp.Symbol):
loop_idx = LoopOverCoordinate.is_loop_counter_symbol(e)
loop_idx = is_loop_counter_symbol(e)
return e + sign / 2 if loop_idx == coordinate else e
else:
new_args = [staggered_visitor(a, coordinate, sign) for a in e.args]
......
from __future__ import annotations
import functools
import hashlib
import operator
......@@ -5,21 +7,28 @@ import pickle
import re
from enum import Enum
from itertools import chain
from typing import List, Optional, Sequence, Set, Tuple, Union
from typing import List, Optional, Sequence, Set, Tuple
from warnings import warn
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
import pystencils
from pystencils.alignedarray import aligned_empty
from pystencils.typing import StructType, TypedSymbol, BasicType, create_type
from pystencils.typing.typed_sympy import FieldShapeSymbol, FieldStrideSymbol
from pystencils.stencil import (
direction_string_to_offset, inverse_direction, offset_to_direction_string)
from pystencils.sympyextensions import is_integer_sequence
from .defaults import DEFAULTS
from .alignedarray import aligned_empty
from .spatial_coordinates import x_staggered_vector, x_vector
from .stencil import (
direction_string_to_offset,
inverse_direction,
offset_to_direction_string,
)
from .types import PsType, PsStructType, create_type
from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import is_integer_sequence
from .types import UserTypeSpec
__all__ = ['Field', 'fields', 'FieldType', 'Field']
__all__ = ["Field", "fields", "FieldType", "Field"]
class FieldType(Enum):
......@@ -61,7 +70,10 @@ class FieldType(Enum):
@staticmethod
def is_staggered(field):
assert isinstance(field, Field)
return field.field_type == FieldType.STAGGERED or field.field_type == FieldType.STAGGERED_FLUX
return (
field.field_type == FieldType.STAGGERED
or field.field_type == FieldType.STAGGERED_FLUX
)
@staticmethod
def is_staggered_flux(field):
......@@ -75,7 +87,7 @@ class Field:
This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array.
Creating Fields:
The preferred method to create fields is the `fields` function.
The preferred method to create fields is the `fields <pystencils.field.fields>` function.
Alternatively one can use one of the static functions `Field.create_generic`, `Field.create_from_numpy_array`
and `Field.create_fixed_size`. Don't instantiate the Field directly!
Fields can be created with known or unknown shapes:
......@@ -119,17 +131,36 @@ class Field:
>>> stencil = np.array([[0,0], [0,1], [0,-1]])
>>> src, dst = fields("src(3), dst(3) : double[2D]")
>>> assignments = [Assignment(dst[0,0](i), src[-offset](i)) for i, offset in enumerate(stencil)];
Args:
field_name: The field's name
field_type: The kind of the field
dtype: Data type of the field's entries
layout: Linearization order of the field's spatial dimensions
shape: Total shape (spatial and index) of the field
strides: Linearization strides of the field
"""
@staticmethod
def create_generic(field_name, spatial_dimensions, dtype=np.float64, index_dimensions=0, layout='numpy',
index_shape=None, field_type=FieldType.GENERIC) -> 'Field':
def create_generic(
field_name,
spatial_dimensions,
dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
index_dimensions=0,
layout="numpy",
index_shape=None,
field_type=FieldType.GENERIC,
) -> "Field":
"""
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes
Creates a generic field where the field size is not fixed i.e. can be called with arrays of different sizes.
**Field Element Type** By default, the data type of the field entries is left undetermined until
code generation, at which point it is set to the default numerical type of the kernel.
You can specify a concrete type using the `dtype` parameter.
Args:
field_name: symbolic name for the field
dtype: numpy data type of the array the kernel is called with later
dtype: Data type of the field entries
spatial_dimensions: see documentation of Field
index_dimensions: see documentation of Field
layout: tuple specifying the loop ordering of the spatial dimensions e.g. (2, 1, 0 ) means that
......@@ -149,27 +180,60 @@ class Field:
layout = spatial_layout_string_to_tuple(layout, dim=spatial_dimensions)
total_dimensions = spatial_dimensions + index_dimensions
shape: tuple[TypedSymbol | int, ...]
if index_shape is None or len(index_shape) == 0:
shape = tuple([FieldShapeSymbol([field_name], i) for i in range(total_dimensions)])
shape = tuple(
[
TypedSymbol(
DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
)
for i in range(total_dimensions)
]
)
else:
shape = tuple([FieldShapeSymbol([field_name], i) for i in range(spatial_dimensions)] + list(index_shape))
strides = tuple([FieldStrideSymbol(field_name, i) for i in range(total_dimensions)])
np_data_type = np.dtype(dtype)
if np_data_type.fields is not None:
shape = tuple(
[
TypedSymbol(
DEFAULTS.field_shape_name(field_name, i), DynamicType.INDEX_TYPE
)
for i in range(spatial_dimensions)
]
+ list(index_shape)
)
strides: tuple[TypedSymbol | int, ...] = tuple(
[
TypedSymbol(
DEFAULTS.field_stride_name(field_name, i), DynamicType.INDEX_TYPE
)
for i in range(total_dimensions)
]
)
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
if isinstance(dtype, PsStructType):
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
raise ValueError(
"Structured arrays/fields are not allowed to have an index dimension"
)
shape += (1,)
strides += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
return Field(field_name, field_type, dtype, layout, shape, strides)
@staticmethod
def create_from_numpy_array(field_name: str, array: np.ndarray, index_dimensions: int = 0,
field_type=FieldType.GENERIC) -> 'Field':
def create_from_numpy_array(
field_name: str,
array: np.ndarray,
index_dimensions: int = 0,
field_type=FieldType.GENERIC,
) -> Field:
"""Creates a field based on the layout, data type, and shape of a given numpy array.
Kernels created for these kind of fields can only be called with arrays of the same layout, shape and type.
......@@ -182,7 +246,9 @@ class Field:
"""
spatial_dimensions = len(array.shape) - index_dimensions
if spatial_dimensions < 1:
raise ValueError("Too many index dimensions. At least one spatial dimension required")
raise ValueError(
"Too many index dimensions. At least one spatial dimension required"
)
full_layout = get_layout_of_array(array)
spatial_layout = tuple([i for i in full_layout if i < spatial_dimensions])
......@@ -194,20 +260,31 @@ class Field:
numpy_dtype = np.dtype(array.dtype)
if numpy_dtype.fields is not None:
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
raise ValueError(
"Structured arrays/fields are not allowed to have an index dimension"
)
shape += (1,)
strides += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
return Field(field_name, field_type, array.dtype, spatial_layout, shape, strides)
return Field(
field_name, field_type, array.dtype, spatial_layout, shape, strides
)
@staticmethod
def create_fixed_size(field_name: str, shape: Tuple[int, ...], index_dimensions: int = 0,
dtype=np.float64, layout: str = 'numpy', strides: Optional[Sequence[int]] = None,
field_type=FieldType.GENERIC) -> 'Field':
def create_fixed_size(
field_name: str,
shape: tuple[int, ...],
index_dimensions: int = 0,
dtype: UserTypeSpec | DynamicType = DynamicType.NUMERIC_TYPE,
layout: str | tuple[int, ...] = "numpy",
memory_strides: None | Sequence[int] = None,
strides: Optional[Sequence[int]] = None,
field_type=FieldType.GENERIC,
) -> Field:
"""
Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout
Creates a field with fixed sizes i.e. can be called only with arrays of the same size and layout.
Args:
field_name: symbolic name for the field
......@@ -215,43 +292,90 @@ class Field:
index_dimensions: how many of the trailing dimensions are interpreted as index (as opposed to spatial)
dtype: numpy data type of the array the kernel is called with later
layout: full layout of array, not only spatial dimensions
strides: strides in bytes or None to automatically compute them from shape (assuming no padding)
memory_strides: Linearization strides for each dimension;
i.e. the number of elements to skip to get from one index to the next in the respective dimension.
field_type: kind of field
"""
if strides is not None:
warn(
"The `strides` parameter to `Field.create_fixed_size` is deprecated "
"and will be removed in pystencils 2.1. "
"Use `memory_strides` instead; "
"beware that `memory_strides` takes the number of *elements* to skip, "
"instead of the number of bytes.",
FutureWarning
)
if memory_strides is not None:
raise ValueError("Cannot specify `memory_strides` and deprecated parameter `strides` at the same time.")
if isinstance(dtype, DynamicType):
raise ValueError("Cannot specify the deprecated parameter `strides` together with a `DynamicType`. "
"Set `memory_strides` instead.")
np_type = create_type(dtype).numpy_dtype
assert np_type is not None
memory_strides = tuple([s // np_type.itemsize for s in strides])
spatial_dimensions = len(shape) - index_dimensions
assert spatial_dimensions >= 1
if isinstance(layout, str):
layout = layout_string_to_tuple(layout, spatial_dimensions + index_dimensions)
layout = layout_string_to_tuple(
layout, spatial_dimensions + index_dimensions
)
if not isinstance(dtype, DynamicType):
dtype = create_type(dtype)
shape_tuple = tuple(int(s) for s in shape)
strides_tuple: tuple[int, ...]
shape = tuple(int(s) for s in shape)
if strides is None:
strides = compute_strides(shape, layout)
strides_tuple = compute_strides(shape_tuple, layout)
else:
assert len(strides) == len(shape)
strides = tuple([s // np.dtype(dtype).itemsize for s in strides])
assert len(strides) == len(shape_tuple)
strides_tuple = tuple(strides)
numpy_dtype = np.dtype(dtype)
if numpy_dtype.fields is not None:
if isinstance(dtype, PsStructType):
if index_dimensions != 0:
raise ValueError("Structured arrays/fields are not allowed to have an index dimension")
shape += (1,)
strides += (1,)
raise ValueError(
"Structured arrays/fields are not allowed to have an index dimension"
)
shape_tuple += (1,)
strides_tuple += (1,)
if field_type == FieldType.STAGGERED and index_dimensions == 0:
raise ValueError("A staggered field needs at least one index dimension")
spatial_layout = list(layout)
for i in range(spatial_dimensions, len(layout)):
spatial_layout.remove(i)
return Field(field_name, field_type, dtype, tuple(spatial_layout), shape, strides)
def __init__(self, field_name, field_type, dtype, layout, shape, strides):
return Field(
field_name,
field_type,
dtype,
tuple(spatial_layout),
shape_tuple,
strides_tuple,
)
def __init__(
self,
field_name: str,
field_type: FieldType,
dtype: UserTypeSpec | DynamicType,
layout: tuple[int, ...],
shape,
strides,
):
"""Do not use directly. Use static create* methods"""
self._field_name = field_name
assert isinstance(field_type, FieldType)
assert len(shape) == len(strides)
self.field_type = field_type
self._dtype = create_type(dtype)
self._dtype: PsType | DynamicType = (
create_type(dtype) if not isinstance(dtype, DynamicType) else dtype
)
self._layout = normalize_layout(layout)
self.shape = shape
self.strides = strides
......@@ -263,9 +387,23 @@ class Field:
def new_field_with_different_name(self, new_name):
if self.has_fixed_shape:
return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
return Field(
new_name,
self.field_type,
self._dtype,
self._layout,
self.shape,
self.strides,
)
else:
return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides)
return Field(
new_name,
self.field_type,
self.dtype,
self.layout,
self.shape,
self.strides,
)
@property
def spatial_dimensions(self) -> int:
......@@ -292,7 +430,7 @@ class Field:
@property
def spatial_shape(self) -> Tuple[int, ...]:
return self.shape[:self.spatial_dimensions]
return self.shape[: self.spatial_dimensions]
@property
def has_fixed_shape(self):
......@@ -308,31 +446,34 @@ class Field:
@property
def spatial_strides(self):
return self.strides[:self.spatial_dimensions]
return self.strides[: self.spatial_dimensions]
@property
def index_strides(self):
return self.strides[self.spatial_dimensions:]
@property
def dtype(self):
def dtype(self) -> PsType | DynamicType:
return self._dtype
@property
def itemsize(self):
return self.dtype.numpy_dtype.itemsize
def itemsize(self) -> int | None:
if isinstance(self.dtype, PsType):
return self.dtype.itemsize
else:
return None
def __repr__(self):
if any(isinstance(s, sp.Symbol) for s in self.spatial_shape):
spatial_shape_str = f'{self.spatial_dimensions}d'
spatial_shape_str = f"{self.spatial_dimensions}d"
else:
spatial_shape_str = ','.join(str(i) for i in self.spatial_shape)
index_shape_str = ','.join(str(i) for i in self.index_shape)
spatial_shape_str = ",".join(str(i) for i in self.spatial_shape)
index_shape_str = ",".join(str(i) for i in self.index_shape)
if self.index_shape:
return f'{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]'
return f"{self._field_name}({index_shape_str}): {self.dtype}[{spatial_shape_str}]"
else:
return f'{self._field_name}: {self.dtype}[{spatial_shape_str}]'
return f"{self._field_name}: {self.dtype}[{spatial_shape_str}]"
def __str__(self):
return self.name
......@@ -353,12 +494,26 @@ class Field:
elif len(index_shape) == 1:
return sp.Matrix([self(i) for i in range(index_shape[0])])
elif len(index_shape) == 2:
return sp.Matrix([[self(i, j) for j in range(index_shape[1])] for i in range(index_shape[0])])
return sp.Matrix(
[
[self(i, j) for j in range(index_shape[1])]
for i in range(index_shape[0])
]
)
elif len(index_shape) == 3:
return sp.Array([[[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])] for i in range(index_shape[0])])
return sp.Array(
[
[
[self(i, j, k) for k in range(index_shape[2])]
for j in range(index_shape[1])
]
for i in range(index_shape[0])
]
)
else:
raise NotImplementedError("center_vector is not implemented for more than 3 index dimensions")
raise NotImplementedError(
"center_vector is not implemented for more than 3 index dimensions"
)
@property
def center(self):
......@@ -374,12 +529,20 @@ class Field:
if self.index_dimensions == 0:
return sp.Matrix([self.__getitem__(offset)])
elif self.index_dimensions == 1:
return sp.Matrix([self.__getitem__(offset)(i) for i in range(self.index_shape[0])])
return sp.Matrix(
[self.__getitem__(offset)(i) for i in range(self.index_shape[0])]
)
elif self.index_dimensions == 2:
return sp.Matrix([[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
for i in range(self.index_shape[0])])
return sp.Matrix(
[
[self.__getitem__(offset)(i, k) for k in range(self.index_shape[1])]
for i in range(self.index_shape[0])
]
)
else:
raise NotImplementedError("neighbor_vector is not implemented for more than 2 index dimensions")
raise NotImplementedError(
"neighbor_vector is not implemented for more than 2 index dimensions"
)
def __getitem__(self, offset):
if type(offset) is np.ndarray:
......@@ -389,7 +552,9 @@ class Field:
if type(offset) is not tuple:
offset = (offset,)
if len(offset) != self.spatial_dimensions:
raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
raise ValueError(
f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
)
return Field.Access(self, offset)
def absolute_access(self, offset, index):
......@@ -412,7 +577,9 @@ class Field:
offset = tuple(direction_string_to_offset(offset, self.spatial_dimensions))
offset = tuple([o * sp.Rational(1, 2) for o in offset])
if len(offset) != self.spatial_dimensions:
raise ValueError(f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}")
raise ValueError(
f"Wrong number of spatial indices: Got {len(offset)}, expected {self.spatial_dimensions}"
)
prefactor = 1
neighbor_vec = [0] * len(offset)
......@@ -426,25 +593,33 @@ class Field:
if FieldType.is_staggered_flux(self):
prefactor = -1
if neighbor not in self.staggered_stencil:
raise ValueError(f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil")
raise ValueError(
f"{offset_orig} is not a valid neighbor for the {self.staggered_stencil_name} stencil"
)
offset = tuple(sp.Matrix(offset) - sp.Rational(1, 2) * sp.Matrix(neighbor_vec))
idx = self.staggered_stencil.index(neighbor)
if self.index_dimensions == 1: # this field stores a scalar value at each staggered position
if (
self.index_dimensions == 1
): # this field stores a scalar value at each staggered position
if index is not None:
raise ValueError("Cannot specify an index for a scalar staggered field")
return prefactor * Field.Access(self, offset, (idx,))
else: # this field stores a vector or tensor at each staggered position
if index is None:
raise ValueError(f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}")
raise ValueError(
f"Wrong number of indices: Got 0, expected {self.index_dimensions - 1}"
)
if type(index) is np.ndarray:
index = tuple(index)
if type(index) is not tuple:
index = (index,)
if self.index_dimensions != len(index) + 1:
raise ValueError(f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}")
raise ValueError(
f"Wrong number of indices: Got {len(index)}, expected {self.index_dimensions - 1}"
)
return prefactor * Field.Access(self, offset, (idx, *index))
......@@ -455,30 +630,54 @@ class Field:
if self.index_dimensions == 1:
return sp.Matrix([self.staggered_access(offset)])
elif self.index_dimensions == 2:
return sp.Matrix([self.staggered_access(offset, i) for i in range(self.index_shape[1])])
return sp.Matrix(
[self.staggered_access(offset, i) for i in range(self.index_shape[1])]
)
elif self.index_dimensions == 3:
return sp.Matrix([[self.staggered_access(offset, (i, k)) for k in range(self.index_shape[2])]
for i in range(self.index_shape[1])])
return sp.Matrix(
[
[
self.staggered_access(offset, (i, k))
for k in range(self.index_shape[2])
]
for i in range(self.index_shape[1])
]
)
else:
raise NotImplementedError("staggered_vector_access is not implemented for more than 3 index dimensions")
raise NotImplementedError(
"staggered_vector_access is not implemented for more than 3 index dimensions"
)
@property
def staggered_stencil(self):
assert FieldType.is_staggered(self)
stencils = {
2: {
2: ["W", "S"], # D2Q5
4: ["W", "S", "SW", "NW"] # D2Q9
},
2: {2: ["W", "S"], 4: ["W", "S", "SW", "NW"]}, # D2Q5 # D2Q9
3: {
3: ["W", "S", "B"], # D3Q7
7: ["W", "S", "B", "BSW", "TSW", "BNW", "TNW"], # D3Q15
9: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS"], # D3Q19
13: ["W", "S", "B", "SW", "NW", "BW", "TW", "BS", "TS", "BSW", "TSW", "BNW", "TNW"] # D3Q27
}
13: [
"W",
"S",
"B",
"SW",
"NW",
"BW",
"TW",
"BS",
"TS",
"BSW",
"TSW",
"BNW",
"TNW",
], # D3Q27
},
}
if not self.index_shape[0] in stencils[self.spatial_dimensions]:
raise ValueError(f"No known stencil has {self.index_shape[0]} staggered points")
raise ValueError(
f"No known stencil has {self.index_shape[0]} staggered points"
)
return stencils[self.spatial_dimensions][self.index_shape[0]]
@property
......@@ -491,13 +690,15 @@ class Field:
return Field.Access(self, center)(*args, **kwargs)
def hashable_contents(self):
return (self._layout,
self.shape,
self.strides,
self.field_type,
self._field_name,
self.latex_name,
self._dtype)
return (
self._layout,
self.shape,
self.strides,
self.field_type,
self._field_name,
self.latex_name,
self._dtype,
)
def __hash__(self):
return hash(self.hashable_contents())
......@@ -509,36 +710,53 @@ class Field:
@property
def physical_coordinates(self):
if hasattr(self.coordinate_transform, '__call__'):
return self.coordinate_transform(self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions))
if hasattr(self.coordinate_transform, "__call__"):
return self.coordinate_transform(
self.coordinate_origin + x_vector(self.spatial_dimensions)
)
else:
return self.coordinate_transform @ (self.coordinate_origin + pystencils.x_vector(self.spatial_dimensions))
return self.coordinate_transform @ (
self.coordinate_origin + x_vector(self.spatial_dimensions)
)
@property
def physical_coordinates_staggered(self):
return self.coordinate_transform @ \
(self.coordinate_origin + pystencils.x_staggered_vector(self.spatial_dimensions))
return self.coordinate_transform @ (
self.coordinate_origin + x_staggered_vector(self.spatial_dimensions)
)
def index_to_physical(self, index_coordinates: sp.Matrix, staggered=False):
if staggered:
index_coordinates = sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
if hasattr(self.coordinate_transform, '__call__'):
index_coordinates = (
sp.Matrix([0.5] * len(self.coordinate_origin)) + index_coordinates
)
if hasattr(self.coordinate_transform, "__call__"):
return self.coordinate_transform(self.coordinate_origin + index_coordinates)
else:
return self.coordinate_transform @ (self.coordinate_origin + index_coordinates)
return self.coordinate_transform @ (
self.coordinate_origin + index_coordinates
)
def physical_to_index(self, physical_coordinates: sp.Matrix, staggered=False):
if hasattr(self.coordinate_transform, '__call__'):
if hasattr(self.coordinate_transform, 'inv'):
return self.coordinate_transform.inv()(physical_coordinates) - self.coordinate_origin
if hasattr(self.coordinate_transform, "__call__"):
if hasattr(self.coordinate_transform, "inv"):
return (
self.coordinate_transform.inv()(physical_coordinates)
- self.coordinate_origin
)
else:
idx = sp.Matrix(sp.symbols(f'index_coordinates:{self.ndim}', real=True))
idx = sp.Matrix(sp.symbols(f"index_coordinates:{self.ndim}", real=True))
rtn = sp.solve(self.index_to_physical(idx) - physical_coordinates, idx)
assert rtn, f'Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}'
assert (
rtn
), f"Could not find inverese of coordinate_transform: {self.index_to_physical(idx)}"
return rtn
else:
rtn = self.coordinate_transform.inv() @ physical_coordinates - self.coordinate_origin
rtn = (
self.coordinate_transform.inv() @ physical_coordinates
- self.coordinate_origin
)
if staggered:
rtn = sp.Matrix([i - 0.5 for i in rtn])
......@@ -567,16 +785,40 @@ class Field:
>>> central_y_component.at_index(0) # change component
v_C^0
"""
_iterable = False # see https://i10git.cs.fau.de/pycodegen/pystencils/-/merge_requests/166#note_10680
__match_args__ = ("field", "offsets", "index")
# for the type checker
_field: Field
_offsets: tuple[int | sp.Basic, ...]
_offsetName: str
_superscript: None | str
_index: tuple[int | sp.Basic, ...] | str
_indirect_addressing_fields: set[Field]
_is_absolute_access: bool
def __new__(cls, name, *args, **kwargs):
obj = Field.Access.__xnew_cached_(cls, name, *args, **kwargs)
return obj
def __new_stage2__(self, field, offsets=(0, 0, 0), idx=None, is_absolute_access=False, dtype=None):
def __new_stage2__( # type: ignore
self,
field: Field,
offsets: tuple[int, ...] = (0, 0, 0),
idx: None | tuple[int, ...] | str = None,
is_absolute_access: bool = False,
dtype: PsType | None = None,
):
field_name = field.name
offsets_and_index = (*offsets, *idx) if idx is not None else offsets
constant_offsets = not any([isinstance(o, sp.Basic) and not o.is_Integer for o in offsets_and_index])
constant_offsets = not any(
[
isinstance(o, sp.Basic) and not o.is_Integer
for o in offsets_and_index
]
)
if not idx:
idx = tuple([0] * field.index_dimensions)
......@@ -590,31 +832,36 @@ class Field:
else:
idx_str = ",".join([str(e) for e in idx])
superscript = idx_str
if field.has_fixed_index_shape and not isinstance(field.dtype, StructType):
if field.has_fixed_index_shape and not isinstance(
field.dtype, PsStructType
):
for i, bound in zip(idx, field.index_shape):
if i >= bound:
raise ValueError("Field index out of bounds")
else:
offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[:12]
offset_name = hashlib.md5(pickle.dumps(offsets_and_index)).hexdigest()[
:12
]
superscript = None
symbol_name = f"{field_name}_{offset_name}"
if superscript is not None:
symbol_name += "^" + superscript
if dtype:
obj: Field.Access
if dtype is not None:
obj = super(Field.Access, self).__xnew__(self, symbol_name, dtype)
else:
obj = super(Field.Access, self).__xnew__(self, symbol_name, field.dtype)
obj._field = field
obj._offsets = []
_offsets: list[sp.Basic | int] = []
for o in offsets:
if isinstance(o, sp.Basic):
obj._offsets.append(o)
_offsets.append(o)
else:
obj._offsets.append(int(o))
obj._offsets = tuple(sp.sympify(obj._offsets))
_offsets.append(int(o))
obj._offsets = tuple(sp.sympify(_offsets))
obj._offsetName = offset_name
obj._superscript = superscript
obj._index = idx
......@@ -622,19 +869,33 @@ class Field:
obj._indirect_addressing_fields = set()
for e in chain(obj._offsets, obj._index):
if isinstance(e, sp.Basic):
obj._indirect_addressing_fields.update(a.field for a in e.atoms(Field.Access))
obj._indirect_addressing_fields.update(
a.field for a in e.atoms(Field.Access)
)
obj._is_absolute_access = is_absolute_access
return obj
def __getnewargs__(self):
return self.field, self.offsets, self.index, self.is_absolute_access, self.dtype
return (
self.field,
self.offsets,
self.index,
self.is_absolute_access,
self.dtype,
)
def __getnewargs_ex__(self):
return (self.field, self.offsets, self.index, self.is_absolute_access, self.dtype), {}
return (
self.field,
self.offsets,
self.index,
self.is_absolute_access,
self.dtype,
), {}
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
__xnew__ = staticmethod(__new_stage2__) # type: ignore
# noinspection SpellCheckingInspection
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
......@@ -648,20 +909,34 @@ class Field:
idx = ()
if len(idx) != self.field.index_dimensions:
raise ValueError(f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}")
raise ValueError(
f"Wrong number of indices: Got {len(idx)}, expected {self.field.index_dimensions}"
)
if len(idx) == 1 and isinstance(idx[0], str):
dtype = BasicType(self.field.dtype.numpy_dtype[idx[0]])
return Field.Access(self.field, self._offsets, idx,
is_absolute_access=self.is_absolute_access, dtype=dtype)
struct_type = self.field.dtype
assert isinstance(struct_type, PsStructType)
dtype = struct_type.get_member(idx[0]).dtype
return Field.Access(
self.field,
self._offsets,
idx,
is_absolute_access=self.is_absolute_access,
dtype=dtype,
)
else:
return Field.Access(self.field, self._offsets, idx,
is_absolute_access=self.is_absolute_access, dtype=self.dtype)
return Field.Access(
self.field,
self._offsets,
idx,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def __getitem__(self, *idx):
return self.__call__(*idx)
@property
def field(self) -> 'Field':
def field(self) -> "Field":
"""Field that the Access points to"""
return self._field
......@@ -673,7 +948,7 @@ class Field:
@property
def required_ghost_layers(self) -> int:
"""Largest spatial distance that is accessed."""
return int(np.max(np.abs(self._offsets)))
return int(np.max(np.abs(self._offsets))) # type: ignore
@property
def nr_of_coordinates(self):
......@@ -695,7 +970,7 @@ class Field:
"""Value of index coordinates as tuple."""
return self._index
def neighbor(self, coord_id: int, offset: int) -> 'Field.Access':
def neighbor(self, coord_id: int, offset: int) -> "Field.Access":
"""Returns a new Access with changed spatial coordinates.
Args:
......@@ -709,10 +984,15 @@ class Field:
"""
offset_list = list(self.offsets)
offset_list[coord_id] += offset
return Field.Access(self.field, tuple(offset_list), self.index,
is_absolute_access=self.is_absolute_access, dtype=self.dtype)
def get_shifted(self, *shift) -> 'Field.Access':
return Field.Access(
self.field,
tuple(offset_list),
self.index,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def get_shifted(self, *shift) -> "Field.Access":
"""Returns a new Access with changed spatial coordinates
Example:
......@@ -720,13 +1000,15 @@ class Field:
>>> f[0,0].get_shifted(1, 1)
f_NE
"""
return Field.Access(self.field,
tuple(a + b for a, b in zip(shift, self.offsets)),
self.index,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype)
def at_index(self, *idx_tuple) -> 'Field.Access':
return Field.Access(
self.field,
tuple(a + b for a, b in zip(shift, self.offsets)),
self.index,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def at_index(self, *idx_tuple) -> "Field.Access":
"""Returns new Access with changed index.
Example:
......@@ -734,15 +1016,22 @@ class Field:
>>> f(0).at_index(8)
f_C^8
"""
return Field.Access(self.field, self.offsets, idx_tuple,
is_absolute_access=self.is_absolute_access, dtype=self.dtype)
return Field.Access(
self.field,
self.offsets,
idx_tuple,
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
def _eval_subs(self, old, new):
return Field.Access(self.field,
tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
tuple(sp.sympify(a).subs(old, new) for a in self.index),
is_absolute_access=self.is_absolute_access,
dtype=self.dtype)
return Field.Access(
self.field,
tuple(sp.sympify(a).subs(old, new) for a in self.offsets),
tuple(sp.sympify(a).subs(old, new) for a in self.index),
is_absolute_access=self.is_absolute_access,
dtype=self.dtype,
)
@property
def is_absolute_access(self) -> bool:
......@@ -750,30 +1039,43 @@ class Field:
return self._is_absolute_access
@property
def indirect_addressing_fields(self) -> Set['Field']:
def indirect_addressing_fields(self) -> Set["Field"]:
"""Returns a set of fields that the access depends on.
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
"""
e.g. f[index_field[1, 0]], the outer access to f depends on index_field
"""
return self._indirect_addressing_fields
def _hashable_content(self):
super_class_contents = super(Field.Access, self)._hashable_content()
return (super_class_contents, self._field.hashable_contents(), *self._index,
*self._offsets, self._is_absolute_access)
return (
super_class_contents,
self._field.hashable_contents(),
*self._index,
*self._offsets,
self._is_absolute_access,
)
def _staggered_offset(self, offsets, index):
assert FieldType.is_staggered(self._field)
neighbor = self._field.staggered_stencil[index]
neighbor = direction_string_to_offset(neighbor, self._field.spatial_dimensions)
return [(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)]
neighbor = direction_string_to_offset(
neighbor, self._field.spatial_dimensions
)
return [
(o + sp.Rational(int(neighbor[i]), 2)) for i, o in enumerate(offsets)
]
def _latex(self, _):
n = self._field.latex_name if self._field.latex_name else self._field.name
offset_str = ",".join([sp.latex(o) for o in self.offsets])
if FieldType.is_staggered(self._field):
offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))])
offset_str = ",".join(
[
sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))
]
)
if self.is_absolute_access:
offset_str = f"\\mathbf{offset_str}"
elif self.field.spatial_dimensions > 1:
......@@ -794,8 +1096,12 @@ class Field:
n = self._field.latex_name if self._field.latex_name else self._field.name
offset_str = ",".join([sp.latex(o) for o in self.offsets])
if FieldType.is_staggered(self._field):
offset_str = ",".join([sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))])
offset_str = ",".join(
[
sp.latex(self._staggered_offset(self.offsets, self.index[0])[i])
for i in range(len(self.offsets))
]
)
if self.is_absolute_access:
offset_str = f"[abs]{offset_str}"
......@@ -811,12 +1117,36 @@ class Field:
return f"{n}[{offset_str}]"
def fields(description=None, index_dimensions=0, layout=None,
field_type=FieldType.GENERIC, **kwargs) -> Union[Field, List[Field]]:
def fields(
description=None,
index_dimensions=0,
layout=None,
field_type=FieldType.GENERIC,
**kwargs,
) -> Field | list[Field]:
"""Creates pystencils fields from a string description.
The description must be a string of the form
``"name(index-shape) [name(index-shape) ...]: <data-type>[<dimension-or-shape>]"``,
where:
- ``name`` is the name of the respective field
- ``(index-shape)`` is a tuple of integers describing the shape of the tensor on each field node
(can be omitted for scalar fields)
- ``<data-type>`` is the numerical data type of the field's entries;
this can be any type parseable by `create_type`,
as well as ``dyn`` for `DynamicType.NUMERIC_TYPE`
and ``dynidx`` for `DynamicType.INDEX_TYPE`.
- ``<dimension-or-shape>`` can be a dimensionality (e.g. ``1D``, ``2D``, ``3D``)
or a tuple of integers defining the spatial shape of the field.
Examples:
Create a 2D scalar and vector field:
Create a 3D scalar field of default numeric type:
>>> f = fields("f(1): [2D]")
>>> str(f.dtype)
'DynamicType.NUMERIC_TYPE'
Create a 2D scalar and vector field of 64-bit float type:
>>> s, v = fields("s, v(2): double[2D]")
>>> assert s.spatial_dimensions == 2 and s.index_dimensions == 0
>>> assert (v.spatial_dimensions, v.index_dimensions, v.index_shape) == (2, 1, (2,))
......@@ -834,7 +1164,7 @@ def fields(description=None, index_dimensions=0, layout=None,
Format string can be left out, field names are taken from keyword arguments.
>>> fields(f1=arr_s, f2=arr_s)
[f1: double[20,20], f2: double[20,20]]
[f1: float64[20,20], f2: float64[20,20]]
The keyword names ``index_dimension`` and ``layout`` have special meaning, don't use them for field names
>>> f = fields(f=arr_v, index_dimensions=1)
......@@ -842,35 +1172,70 @@ def fields(description=None, index_dimensions=0, layout=None,
>>> f = fields("pdfs(19) : float32[3D]", layout='fzyx')
>>> f.layout
(2, 1, 0)
Returns:
Sequence of fields created from the description
"""
result = []
if description:
field_descriptions, dtype, shape = _parse_description(description)
layout = 'numpy' if layout is None else layout
layout = "numpy" if layout is None else layout
for field_name, idx_shape in field_descriptions:
if field_name in kwargs:
arr = kwargs[field_name]
idx_shape_of_arr = () if not len(idx_shape) else arr.shape[-len(idx_shape):]
idx_shape_of_arr = (
() if not len(idx_shape) else arr.shape[-len(idx_shape):]
)
assert idx_shape_of_arr == idx_shape
f = Field.create_from_numpy_array(field_name, kwargs[field_name], index_dimensions=len(idx_shape),
field_type=field_type)
f = Field.create_from_numpy_array(
field_name,
kwargs[field_name],
index_dimensions=len(idx_shape),
field_type=field_type,
)
elif isinstance(shape, tuple):
f = Field.create_fixed_size(field_name, shape + idx_shape, dtype=dtype,
index_dimensions=len(idx_shape), layout=layout, field_type=field_type)
f = Field.create_fixed_size(
field_name,
shape + idx_shape,
dtype=dtype,
index_dimensions=len(idx_shape),
layout=layout,
field_type=field_type,
)
elif isinstance(shape, int):
f = Field.create_generic(field_name, spatial_dimensions=shape, dtype=dtype,
index_shape=idx_shape, layout=layout, field_type=field_type)
f = Field.create_generic(
field_name,
spatial_dimensions=shape,
dtype=dtype,
index_shape=idx_shape,
layout=layout,
field_type=field_type,
)
elif shape is None:
f = Field.create_generic(field_name, spatial_dimensions=2, dtype=dtype,
index_shape=idx_shape, layout=layout, field_type=field_type)
f = Field.create_generic(
field_name,
spatial_dimensions=2,
dtype=dtype,
index_shape=idx_shape,
layout=layout,
field_type=field_type,
)
else:
assert False
result.append(f)
else:
assert layout is None, "Layout can not be specified when creating Field from numpy array"
assert (
layout is None
), "Layout can not be specified when creating Field from numpy array"
for field_name, arr in kwargs.items():
result.append(Field.create_from_numpy_array(field_name, arr, index_dimensions=index_dimensions,
field_type=field_type))
result.append(
Field.create_from_numpy_array(
field_name,
arr,
index_dimensions=index_dimensions,
field_type=field_type,
)
)
if len(result) == 0:
raise ValueError("Could not parse field description")
......@@ -880,16 +1245,27 @@ def fields(description=None, index_dimensions=0, layout=None,
return result
def get_layout_from_strides(strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None):
def get_layout_from_strides(
strides: Sequence[int], index_dimension_ids: Optional[List[int]] = None
):
index_dimension_ids = [] if index_dimension_ids is None else index_dimension_ids
coordinates = list(range(len(strides)))
relevant_strides = [stride for i, stride in enumerate(strides) if i not in index_dimension_ids]
result = [x for (y, x) in sorted(zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True)]
relevant_strides = [
stride for i, stride in enumerate(strides) if i not in index_dimension_ids
]
result = [
x
for (y, x) in sorted(
zip(relevant_strides, coordinates), key=lambda pair: pair[0], reverse=True
)
]
return normalize_layout(result)
def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None):
""" Returns a list indicating the memory layout (linearization order) of the numpy array.
def get_layout_of_array(
arr: np.ndarray, index_dimension_ids: Optional[List[int]] = None
):
"""Returns a list indicating the memory layout (linearization order) of the numpy array.
Examples:
>>> get_layout_of_array(np.zeros([3,3,3]))
......@@ -906,7 +1282,9 @@ def get_layout_of_array(arr: np.ndarray, index_dimension_ids: Optional[List[int]
return get_layout_from_strides(arr.strides, index_dimension_ids)
def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0, **kwargs):
def create_numpy_array_with_layout(
shape, layout, alignment=False, byte_offset=0, **kwargs
):
"""Creates numpy array with given memory layout.
Args:
......@@ -930,7 +1308,10 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
if cur_layout[i] != layout[i]:
index_to_swap_with = cur_layout.index(layout[i])
swaps.append((i, index_to_swap_with))
cur_layout[i], cur_layout[index_to_swap_with] = cur_layout[index_to_swap_with], cur_layout[i]
cur_layout[i], cur_layout[index_to_swap_with] = (
cur_layout[index_to_swap_with],
cur_layout[i],
)
assert tuple(cur_layout) == tuple(layout)
shape = list(shape)
......@@ -938,7 +1319,7 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
shape[a], shape[b] = shape[b], shape[a]
if not alignment:
res = np.empty(shape, order='c', **kwargs)
res = np.empty(shape, order="c", **kwargs)
else:
res = aligned_empty(shape, alignment, byte_offset=byte_offset, **kwargs)
......@@ -948,28 +1329,45 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]:
if layout_str in ('fzyx', 'zyxf'):
assert dim <= 3
if dim <= 0:
raise ValueError("Dimensionality must be positive")
layout_str = layout_str.lower()
if layout_str in ("fzyx", "zyxf", "soa", "aos"):
if dim > 3:
raise ValueError(
f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3."
)
return tuple(reversed(range(dim)))
if layout_str in ('fzyx', 'f', 'reverse_numpy', 'SoA'):
if layout_str in ("f", "reverse_numpy"):
return tuple(reversed(range(dim)))
elif layout_str in ('c', 'numpy', 'AoS'):
elif layout_str in ("c", "numpy"):
return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layout_str)
def layout_string_to_tuple(layout_str, dim):
def layout_string_to_tuple(layout_str, dim) -> tuple[int, ...]:
if dim <= 0:
raise ValueError("Dimensionality must be positive")
layout_str = layout_str.lower()
if layout_str == 'fzyx' or layout_str == 'soa':
assert dim <= 4
if layout_str == "fzyx" or layout_str == "soa":
if dim > 4:
raise ValueError(
f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
)
return tuple(reversed(range(dim)))
elif layout_str == 'zyxf' or layout_str == 'aos':
assert dim <= 4
elif layout_str == "zyxf" or layout_str == "aos":
if dim > 4:
raise ValueError(
f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4."
)
return tuple(reversed(range(dim - 1))) + (dim - 1,)
elif layout_str == 'f' or layout_str == 'reverse_numpy':
elif layout_str == "f" or layout_str == "reverse_numpy":
return tuple(reversed(range(dim)))
elif layout_str == 'c' or layout_str == 'numpy':
elif layout_str == "c" or layout_str == "numpy":
return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layout_str)
......@@ -1004,7 +1402,8 @@ def compute_strides(shape, layout):
# ---------------------------------------- Parsing of string in fields() function --------------------------------------
field_description_regex = re.compile(r"""
field_description_regex = re.compile(
r"""
\s* # ignore leading white spaces
(\w+) # identifier is a sequence of alphanumeric characters, is stored in first group
(?: # optional index specification e.g. (1, 4, 2)
......@@ -1015,9 +1414,12 @@ field_description_regex = re.compile(r"""
\s*
)?
\s*,?\s* # ignore trailing white spaces and comma
""", re.VERBOSE)
""",
re.VERBOSE,
)
type_description_regex = re.compile(r"""
type_description_regex = re.compile(
r"""
\s*
(\w+)? # optional dtype
\s*
......@@ -1025,7 +1427,9 @@ type_description_regex = re.compile(r"""
([^\]]+)
\]
\s*
""", re.VERBOSE | re.IGNORECASE)
""",
re.VERBOSE | re.IGNORECASE,
)
def _parse_part1(d):
......@@ -1043,24 +1447,30 @@ def _parse_description(description):
result = type_description_regex.match(d)
if result:
data_type_str, size_info = result.group(1), result.group(2).strip().lower()
if data_type_str is None:
data_type_str = 'float64'
data_type_str = data_type_str.lower().strip()
if data_type_str is not None:
data_type_str = data_type_str.lower().strip()
if data_type_str:
match data_type_str:
case "dyn": dtype = DynamicType.NUMERIC_TYPE
case "dynidx": dtype = DynamicType.INDEX_TYPE
case _: dtype = create_type(data_type_str)
else:
dtype = DynamicType.NUMERIC_TYPE
if not data_type_str:
data_type_str = 'float64'
if size_info.endswith('d'):
if size_info.endswith("d"):
size_info = int(size_info[:-1])
else:
size_info = tuple(int(e) for e in size_info.split(","))
return data_type_str, size_info
return dtype, size_info
else:
raise ValueError("Could not parse field description")
if ':' in description:
field_description, field_info = description.split(':')
if ":" in description:
field_description, field_info = description.split(":")
else:
field_description, field_info = description, 'float64[2D]'
field_description, field_info = description, "float64[2D]"
fields_info = [e for e in _parse_part1(field_description)]
if not field_info:
......
from pystencils.gpu.gpu_array_handler import GPUArrayHandler, GPUNotAvailableHandler
from pystencils.gpu.gpujit import make_python_function
from pystencils.gpu.kernelcreation import create_cuda_kernel, created_indexed_cuda_kernel
from .indexing import AbstractIndexing, BlockIndexing, LineIndexing
__all__ = ['GPUArrayHandler', 'GPUNotAvailableHandler',
'create_cuda_kernel', 'created_indexed_cuda_kernel', 'make_python_function',
'AbstractIndexing', 'BlockIndexing', 'LineIndexing']
......@@ -2,48 +2,73 @@ import numpy as np
from itertools import product
from pystencils import CreateKernelConfig, create_kernel
from pystencils.gpu import make_python_function
from pystencils import Assignment, Field
from pystencils.enums import Target
from pystencils import Assignment, Field, Target
from pystencils.slicing import get_periodic_boundary_src_dst_slices, normalize_slice
def create_copy_kernel(domain_size, from_slice, to_slice, index_dimensions=0, index_dim_shape=1, dtype=np.float64):
def create_copy_kernel(
domain_size,
src_slice,
dst_slice,
index_dimensions=0,
index_dim_shape=1,
dtype=np.float64,
):
"""Copies a rectangular part of an array to another non-overlapping part"""
f = Field.create_generic("pdfs", len(domain_size), index_dimensions=index_dimensions, dtype=dtype)
normalized_from_slice = normalize_slice(from_slice, f.spatial_shape)
normalized_to_slice = normalize_slice(to_slice, f.spatial_shape)
field = Field.create_generic(
"field", len(domain_size), index_dimensions=index_dimensions, dtype=dtype
)
normalized_src_slice = normalize_slice(src_slice, field.spatial_shape)
normalized_dst_slice = normalize_slice(dst_slice, field.spatial_shape)
offset = [s1.start - s2.start for s1, s2 in zip(normalized_from_slice, normalized_to_slice)]
assert offset == [s1.stop - s2.stop for s1, s2 in zip(normalized_from_slice, normalized_to_slice)], \
"Slices have to have same size"
offset = [
s1.start - s2.start
for s1, s2 in zip(normalized_src_slice, normalized_dst_slice)
]
assert offset == [
s1.stop - s2.stop for s1, s2 in zip(normalized_src_slice, normalized_dst_slice)
], "Slices have to have same size"
update_eqs = []
if index_dimensions < 2:
index_dim_shape = [index_dim_shape]
for i in product(*[range(d) for d in index_dim_shape]):
eq = Assignment(f(*i), f[tuple(offset)](*i))
eq = Assignment(field(*i), field[tuple(offset)](*i))
update_eqs.append(eq)
config = CreateKernelConfig(target=Target.GPU, iteration_slice=to_slice, skip_independence_check=True)
config = CreateKernelConfig(
target=Target.GPU, iteration_slice=dst_slice, skip_independence_check=True
)
ast = create_kernel(update_eqs, config=config)
return ast
def get_periodic_boundary_functor(stencil, domain_size, index_dimensions=0, index_dim_shape=1, ghost_layers=1,
thickness=None, dtype=np.float64, target=Target.GPU):
def get_periodic_boundary_functor(
stencil,
domain_size,
index_dimensions=0,
index_dim_shape=1,
ghost_layers=1,
thickness=None,
dtype=np.float64,
target=Target.GPU,
):
assert target in {Target.GPU}
src_dst_slice_tuples = get_periodic_boundary_src_dst_slices(stencil, ghost_layers, thickness)
src_dst_slice_tuples = get_periodic_boundary_src_dst_slices(
stencil, ghost_layers, thickness
)
kernels = []
for src_slice, dst_slice in src_dst_slice_tuples:
ast = create_copy_kernel(domain_size, src_slice, dst_slice, index_dimensions, index_dim_shape, dtype)
kernels.append(make_python_function(ast))
ast = create_copy_kernel(
domain_size, src_slice, dst_slice, index_dimensions, index_dim_shape, dtype
)
kernels.append(ast.compile())
def functor(pdfs, **_):
def functor(field, **_):
for kernel in kernels:
kernel(pdfs=pdfs)
kernel(field=field)
return functor
from typing import overload
from .backend.ast import PsAstNode
from .backend.emission import CAstPrinter, IRAstPrinter, EmissionError
from .codegen import Kernel
from .codegen.driver import StageResult, CodegenIntermediates
from abc import ABC, abstractmethod
_UNABLE_TO_DISPLAY_CPP = """
<div>
<b>Unable to display C code for this abstract syntax tree</b>
<p>
This intermediate abstract syntax tree contains nodes that cannot be
printed as valid C code.
</p>
</div>
"""
_GRAPHVIZ_NOT_IMPLEMENTED = """
<div>
<b>AST Visualization Unavailable</b>
<p>
AST visualization using GraphViz is not implemented yet.
</p>
</div>
"""
_ERR_MSG = """
<div style="font-family: monospace; background-color: #EEEEEE; white-space: nowrap; overflow-x: scroll">
{}
</div>
"""
class CodeInspectionBase(ABC):
def __init__(self) -> None:
self._ir_printer = IRAstPrinter(annotate_constants=False)
self._c_printer = CAstPrinter()
def _ir_tab(self, ir_obj: PsAstNode | Kernel):
import ipywidgets as widgets
ir = self._ir_printer(ir_obj)
ir_tab = widgets.HTML(self._highlight_as_cpp(ir))
self._apply_tab_layout(ir_tab)
return ir_tab
def _cpp_tab(self, ir_obj: PsAstNode | Kernel):
import ipywidgets as widgets
try:
cpp = self._c_printer(ir_obj)
cpp_tab = widgets.HTML(self._highlight_as_cpp(cpp))
except EmissionError as e:
cpp_tab = widgets.VBox(
children=[
widgets.HTML(_UNABLE_TO_DISPLAY_CPP),
widgets.Accordion(
children=[widgets.HTML(_ERR_MSG.format(e.args[0]))],
titles=["Error Details"],
),
]
)
self._apply_tab_layout(cpp_tab)
return cpp_tab
def _graphviz_tab(self, ir_obj: PsAstNode | Kernel):
import ipywidgets as widgets
graphviz_tab = widgets.HTML(_GRAPHVIZ_NOT_IMPLEMENTED)
self._apply_tab_layout(graphviz_tab)
return graphviz_tab
def _apply_tab_layout(self, tab):
tab.layout.display = "inline-block"
tab.layout.padding = "0 15pt 0 0"
def _highlight_as_cpp(self, code: str) -> str:
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import CppLexer
formatter = HtmlFormatter(
prestyles="white-space: pre;",
)
html_code = highlight(code, CppLexer(), formatter)
return html_code
def _ipython_display_(self):
from IPython.display import display
display(self._widget())
@abstractmethod
def _widget(self): ...
class AstInspection(CodeInspectionBase):
"""Inspect an abstract syntax tree produced by the code generation backend.
**Interactive:** This class can be used in Jupyter notebooks to interactively
explore an abstract syntax tree.
"""
def __init__(
self,
ast: PsAstNode,
show_ir: bool = True,
show_cpp: bool = True,
show_graph: bool = True,
):
super().__init__()
self._ast = ast
self._show_ir = show_ir
self._show_cpp = show_cpp
self._show_graph = show_graph
def _widget(self):
import ipywidgets as widgets
tabs = []
if self._show_ir:
tabs.append(self._ir_tab(self._ast))
if self._show_cpp:
tabs.append(self._cpp_tab(self._ast))
if self._show_graph:
tabs.append(self._graphviz_tab(self._ast))
tabs = widgets.Tab(children=tabs)
tabs.titles = ["IR Code", "C Code", "AST Visualization"]
tabs.layout.height = "250pt"
return tabs
class KernelInspection(CodeInspectionBase):
def __init__(
self,
kernel: Kernel,
show_ir: bool = True,
show_cpp: bool = True,
show_graph: bool = True,
) -> None:
super().__init__()
self._kernel = kernel
self._show_ir = show_ir
self._show_cpp = show_cpp
self._show_graph = show_graph
def _widget(self):
import ipywidgets as widgets
tabs = []
if self._show_ir:
tabs.append(self._ir_tab(self._kernel))
if self._show_cpp:
tabs.append(self._cpp_tab(self._kernel))
if self._show_graph:
tabs.append(self._graphviz_tab(self._kernel))
tabs = widgets.Tab(children=tabs)
tabs.titles = ["IR Code", "C Code", "AST Visualization"]
tabs.layout.height = "250pt"
return tabs
class IntermediatesInspection:
def __init__(
self,
intermediates: CodegenIntermediates,
show_ir: bool = True,
show_cpp: bool = True,
show_graph: bool = True,
):
self._intermediates = intermediates
self._show_ir = show_ir
self._show_cpp = show_cpp
self._show_graph = show_graph
def _ipython_display_(self):
from IPython.display import display
import ipywidgets as widgets
stages = self._intermediates.available_stages
previews: list[AstInspection] = [
AstInspection(
stage.ast,
show_ir=self._show_ir,
show_cpp=self._show_cpp,
show_graph=self._show_graph,
)
for stage in stages
]
labels: list[str] = [stage.label for stage in stages]
code_views = [p._widget() for p in previews]
for v in code_views:
v.layout.width = "100%"
select_label = widgets.HTML("<div><b>Code Generator Stages</b></div>")
select = widgets.Select(options=labels)
select.layout.height = "250pt"
selection_box = widgets.VBox([select_label, select])
selection_box.layout.overflow = "visible"
preview_label = widgets.HTML("<div><b>Preview</b></div>")
preview_stack = widgets.Stack(children=code_views)
preview_stack.layout.overflow = "hidden"
preview_box = widgets.VBox([preview_label, preview_stack])
widgets.jslink((select, "index"), (preview_stack, "selected_index"))
grid = widgets.GridBox(
[selection_box, preview_box],
layout=widgets.Layout(grid_template_columns="max-content auto"),
)
display(grid)
@overload
def inspect(obj: PsAstNode): ...
@overload
def inspect(obj: Kernel): ...
@overload
def inspect(obj: StageResult): ...
@overload
def inspect(obj: CodegenIntermediates): ...
def inspect(obj, show_ir: bool = True, show_cpp: bool = True, show_graph: bool = True):
"""Interactively inspect various products of the code generator.
When run inside a Jupyter notebook, this function displays an inspection widget
for the following types of objects:
- `PsAstNode`
- `Kernel`
- `StageResult`
- `CodegenIntermediates`
"""
from IPython.display import display
match obj:
case PsAstNode():
preview = AstInspection(
obj, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
)
case Kernel():
preview = KernelInspection(
obj, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
)
case StageResult(ast, _):
preview = AstInspection(
ast, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
)
case CodegenIntermediates():
preview = IntermediatesInspection(
obj, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
)
case _:
raise ValueError(f"Cannot inspect object of type {type(obj)}")
display(preview)
# TODO #47 move to a module functions
import numpy as np
import sympy as sp
from pystencils.typing import CastFunc, collate_types, create_type, get_type_of_expression
from pystencils.sympyextensions import is_integer_sequence
class IntegerFunctionTwoArgsMixIn(sp.Function):
is_integer = True
def __new__(cls, arg1, arg2):
args = []
for a in (arg1, arg2):
if isinstance(a, sp.Number) or isinstance(a, int):
args.append(CastFunc(a, create_type("int")))
elif isinstance(a, np.generic):
args.append(CastFunc(a, a.dtype))
else:
args.append(a)
for a in args:
try:
dtype = get_type_of_expression(a)
if not dtype.is_int():
raise ValueError("Argument to integer function is not an int but " + str(dtype))
except NotImplementedError:
raise ValueError("Integer functions can only be constructed with typed expressions")
return super().__new__(cls, *args)
def _eval_evalf(self, *pargs, **kwargs):
arg1 = self.args[0].evalf(*pargs, **kwargs) if hasattr(self.args[0], 'evalf') else self.args[0]
arg2 = self.args[1].evalf(*pargs, **kwargs) if hasattr(self.args[1], 'evalf') else self.args[1]
return self._eval_op(arg1, arg2)
def _eval_op(self, arg1, arg2):
return self
# noinspection PyPep8Naming
class bitwise_xor(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_right(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bit_shift_left(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_and(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class bitwise_or(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
def _eval_op(self, arg1, arg2):
return int(arg1 // arg2)
# noinspection PyPep8Naming
class int_power_of_2(IntegerFunctionTwoArgsMixIn):
pass
# noinspection PyPep8Naming
class modulo_floor(sp.Function):
"""Returns the next smaller integer divisible by given divisor.
Examples:
>>> modulo_floor(9, 4)
8
>>> modulo_floor(11, 4)
8
>>> modulo_floor(12, 4)
12
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> modulo_floor(a, b).to_c(str)
'(int64_t)((a) / (b)) * (b)'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return (int(integer) // int(divisor)) * divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class modulo_ceil(sp.Function):
"""Returns the next bigger integer divisible by given divisor.
Examples:
>>> modulo_ceil(9, 4)
12
>>> modulo_ceil(11, 4)
12
>>> modulo_ceil(12, 4)
12
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> modulo_ceil(a, b).to_c(str)
'((a) % (b) == 0 ? a : ((int64_t)((a) / (b))+1) * (b))'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class div_ceil(sp.Function):
"""Integer division that is always rounded up
Examples:
>>> div_ceil(9, 4)
3
>>> div_ceil(8, 4)
2
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> div_ceil(a, b).to_c(str)
'( (a) % (b) == 0 ? (int64_t)(a) / (int64_t)(b) : ( (int64_t)(a) / (int64_t)(b) ) +1 )'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor if integer % divisor == 0 else (integer // divisor) + 1
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
# noinspection PyPep8Naming
class div_floor(sp.Function):
"""Integer division
Examples:
>>> div_floor(9, 4)
2
>>> div_floor(8, 4)
2
>>> from pystencils import TypedSymbol
>>> a, b = TypedSymbol("a", "int64"), TypedSymbol("b", "int32")
>>> div_floor(a, b).to_c(str)
'((int64_t)(a) / (int64_t)(b))'
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor
else:
return super().__new__(cls, integer, divisor)
def to_c(self, print_func):
dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
assert dtype.is_int()
code = "(({dtype})({0}) / ({dtype})({1}))"
return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
"""
JIT compilation is realized by subclasses of `JitBase`.
A JIT compiler may freely be created and configured by the user.
It can then be passed to `create_kernel` using the ``jit`` argument of
`CreateKernelConfig`, in which case it is hooked into the `Kernel.compile` method
of the generated kernel function::
my_jit = MyJit()
kernel = create_kernel(ast, CreateKernelConfig(jit=my_jit))
func = kernel.compile()
Otherwise, a JIT compiler may also be created free-standing, with the same effect::
my_jit = MyJit()
kernel = create_kernel(ast)
func = my_jit.compile(kernel)
For GPU kernels, a basic JIT-compiler based on cupy is available (`CupyJit`).
For CPU kernels, at the moment there is only `LegacyCpuJit`, which is a wrapper
around the legacy CPU compiler wrapper used by pystencils 1.3.x.
It is due to be replaced in the near future.
"""
from .jit import JitBase, NoJit, KernelWrapper
from .legacy_cpu import LegacyCpuJit
from .cpu import CpuJit
from .gpu_cupy import CupyJit, CupyKernelWrapper, LaunchGrid
no_jit = NoJit()
"""Disables just-in-time compilation for a kernel."""
__all__ = [
"JitBase",
"KernelWrapper",
"LegacyCpuJit",
"CpuJit",
"NoJit",
"no_jit",
"CupyJit",
"CupyKernelWrapper",
"LaunchGrid"
]
from .compiler_info import GccInfo, ClangInfo
from .cpujit import CpuJit
__all__ = [
"GccInfo",
"ClangInfo",
"CpuJit"
]
from __future__ import annotations
from typing import Sequence
from abc import ABC, abstractmethod
from dataclasses import dataclass
from ...codegen.target import Target
@dataclass
class CompilerInfo(ABC):
"""Base class for compiler infos."""
openmp: bool = True
"""Enable/disable OpenMP compilation"""
optlevel: str | None = "fast"
"""Compiler optimization level"""
cxx_standard: str = "c++11"
"""C++ language standard to be compiled with"""
target: Target = Target.CurrentCPU
"""Hardware target to compile for.
Here, `Target.CurrentCPU` represents the current hardware,
which is reflected by ``-march=native`` in GNU-like compilers.
"""
@abstractmethod
def cxx(self) -> str:
"""Path to the executable of this compiler"""
@abstractmethod
def cxxflags(self) -> list[str]:
"""Compiler flags affecting C++ compilation"""
@abstractmethod
def linker_flags(self) -> list[str]:
"""Flags affecting linkage of the extension module"""
@abstractmethod
def include_flags(self, include_dirs: Sequence[str]) -> list[str]:
"""Convert a list of include directories into corresponding compiler flags"""
@abstractmethod
def restrict_qualifier(self) -> str:
"""*restrict* memory qualifier recognized by this compiler"""
class _GnuLikeCliCompiler(CompilerInfo):
def cxxflags(self) -> list[str]:
flags = ["-DNDEBUG", f"-std={self.cxx_standard}", "-fPIC"]
if self.optlevel is not None:
flags.append(f"-O{self.optlevel}")
if self.openmp:
flags.append("-fopenmp")
match self.target:
case Target.CurrentCPU:
flags.append("-march=native")
case Target.X86_SSE:
flags += ["-march=x86-64-v2"]
case Target.X86_AVX:
flags += ["-march=x86-64-v3"]
case Target.X86_AVX512:
flags += ["-march=x86-64-v4"]
case Target.X86_AVX512_FP16:
flags += ["-march=x86-64-v4", "-mavx512fp16"]
return flags
def linker_flags(self) -> list[str]:
return ["-shared"]
def include_flags(self, include_dirs: Sequence[str]) -> list[str]:
return [f"-I{d}" for d in include_dirs]
def restrict_qualifier(self) -> str:
return "__restrict__"
class GccInfo(_GnuLikeCliCompiler):
"""Compiler info for the GNU Compiler Collection C++ compiler (``g++``)."""
def cxx(self) -> str:
return "g++"
@dataclass
class ClangInfo(_GnuLikeCliCompiler):
"""Compiler info for the LLVM C++ compiler (``clang``)."""
llvm_version: int | None = None
"""Major version number of the LLVM installation providing the compiler."""
def cxx(self) -> str:
if self.llvm_version is None:
return "clang"
else:
return f"clang-{self.llvm_version}"
def linker_flags(self) -> list[str]:
return super().linker_flags() + ["-lstdc++"]
from __future__ import annotations
from types import ModuleType
from pathlib import Path
import subprocess
from copy import copy
from abc import ABC, abstractmethod
from ...codegen.config import _AUTO_TYPE, AUTO
from ..jit import JitError, JitBase, KernelWrapper
from ...codegen import Kernel
from .compiler_info import CompilerInfo, GccInfo
class CpuJit(JitBase):
"""Just-in-time compiler for CPU kernels.
**Creation**
To configure and create a CPU JIT compiler instance, use the `create` factory method.
**Implementation Details**
The `CpuJit` class acts as an orchestrator between two components:
- The *extension module builder* produces the code of the dynamically built extension module
that contains the kernel and its invocation wrappers;
- The *compiler info* describes the host compiler used to compile and link that extension module.
Args:
compiler_info: The compiler info object defining the capabilities
and command-line interface of the host compiler
ext_module_builder: Extension module builder object used to generate the kernel extension module
objcache: Directory to cache the generated code files and compiled modules in.
If `None`, a temporary directory will be used, and compilation results will not be cached.
"""
@staticmethod
def create(
compiler_info: CompilerInfo | None = None,
objcache: str | Path | _AUTO_TYPE | None = AUTO,
) -> CpuJit:
"""Configure and create a CPU JIT compiler object.
Args:
compiler_info: Compiler info object defining capabilities and interface of the host compiler.
If `None`, a default compiler configuration will be determined from the current OS and runtime
environment.
objcache: Directory used for caching compilation results.
If set to `AUTO`, a persistent cache directory in the current user's home will be used.
If set to `None`, compilation results will not be cached--this may impact performance.
Returns:
The CPU just-in-time compiler.
"""
if objcache is AUTO:
from appdirs import AppDirs
dirs = AppDirs(appname="pystencils")
objcache = Path(dirs.user_cache_dir) / "cpujit"
elif objcache is not None:
assert not isinstance(objcache, _AUTO_TYPE)
objcache = Path(objcache)
if compiler_info is None:
compiler_info = GccInfo()
from .cpujit_pybind11 import Pybind11KernelModuleBuilder
modbuilder = Pybind11KernelModuleBuilder(compiler_info)
return CpuJit(compiler_info, modbuilder, objcache)
def __init__(
self,
compiler_info: CompilerInfo,
ext_module_builder: ExtensionModuleBuilderBase,
objcache: Path | None,
):
self._compiler_info = copy(compiler_info)
self._ext_module_builder = ext_module_builder
self._objcache = objcache
# Include Directories
import sysconfig
from ...include import get_pystencils_include_path
include_dirs = [
sysconfig.get_path("include"),
get_pystencils_include_path(),
] + self._ext_module_builder.include_dirs()
# Compiler Flags
self._cxx = self._compiler_info.cxx()
self._cxx_fixed_flags = (
self._compiler_info.cxxflags()
+ self._compiler_info.include_flags(include_dirs)
+ self._compiler_info.linker_flags()
)
def compile(self, kernel: Kernel) -> KernelWrapper:
"""Compile the given kernel to an executable function.
Args:
kernel: The kernel object to be compiled.
Returns:
Wrapper object around the compiled function
"""
# Get the Code
module_name = f"{kernel.name}_jit"
cpp_code = self._ext_module_builder.render_module(kernel, module_name)
# Get compiler information
import sysconfig
so_abi = sysconfig.get_config_var("SOABI")
lib_suffix = f"{so_abi}.so"
# Compute Code Hash
code_utf8: bytes = cpp_code.encode("utf-8")
compiler_utf8: bytes = (" ".join([self._cxx] + self._cxx_fixed_flags)).encode("utf-8")
import hashlib
module_hash = hashlib.sha256(code_utf8 + compiler_utf8)
module_stem = f"module_{module_hash.hexdigest()}"
def compile_and_load(module_dir: Path):
cpp_file = module_dir / f"{module_stem}.cpp"
if not cpp_file.exists():
cpp_file.write_bytes(code_utf8)
lib_file = module_dir / f"{module_stem}.{lib_suffix}"
if not lib_file.exists():
self._compile_extension_module(cpp_file, lib_file)
module = self._load_extension_module(module_name, lib_file)
return module
if self._objcache is not None:
module_dir = self._objcache
# Lock module
import fasteners
lockfile = module_dir / f"{module_stem}.lock"
with fasteners.InterProcessLock(lockfile):
module = compile_and_load(module_dir)
else:
from tempfile import TemporaryDirectory
with TemporaryDirectory() as tmpdir:
module_dir = Path(tmpdir)
module = compile_and_load(module_dir)
return self._ext_module_builder.get_wrapper(kernel, module)
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
[self._cxx]
+ self._cxx_fixed_flags
+ ["-o", str(libfile), str(src_file)]
)
result = subprocess.run(args, capture_output=True)
if result.returncode != 0:
raise JitError(
"Compilation failed: C++ compiler terminated with an error.\n"
+ result.stderr.decode()
)
def _load_extension_module(self, module_name: str, module_loc: Path) -> ModuleType:
from importlib import util as iutil
spec = iutil.spec_from_file_location(name=module_name, location=module_loc)
if spec is None:
raise JitError(
"Unable to load kernel extension module -- this is probably a bug."
)
mod = iutil.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
class ExtensionModuleBuilderBase(ABC):
"""Base class for CPU extension module builders."""
@staticmethod
@abstractmethod
def include_dirs() -> list[str]:
"""List of directories that must be on the include path when compiling
generated extension modules."""
@abstractmethod
def render_module(self, kernel: Kernel, module_name: str) -> str:
"""Produce the extension module code for the given kernel."""
@abstractmethod
def get_wrapper(
self, kernel: Kernel, extension_module: ModuleType
) -> KernelWrapper:
"""Produce the invocation wrapper for the given kernel
and its compiled extension module."""
from __future__ import annotations
from types import ModuleType
from typing import Sequence, cast
from pathlib import Path
from textwrap import indent
from pystencils.jit.jit import KernelWrapper
from ...types import PsPointerType, PsType
from ...field import Field
from ...sympyextensions import DynamicType
from ...codegen import Kernel, Parameter
from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
from .compiler_info import CompilerInfo
from .cpujit import ExtensionModuleBuilderBase
_module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp"
class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
@staticmethod
def include_dirs() -> list[str]:
import pybind11 as pb11
pybind11_include = pb11.get_include()
return [pybind11_include]
def __init__(
self,
compiler_info: CompilerInfo,
):
self._compiler_info = compiler_info
self._actual_field_types: dict[Field, PsType]
self._param_binds: list[str]
self._public_params: list[str]
self._param_check_lines: list[str]
self._extraction_lines: list[str]
def render_module(self, kernel: Kernel, module_name: str) -> str:
self._actual_field_types = dict()
self._param_binds = []
self._public_params = []
self._param_check_lines = []
self._extraction_lines = []
self._handle_params(kernel.parameters)
kernel_def = self._get_kernel_definition(kernel)
kernel_args = [param.name for param in kernel.parameters]
includes = [f"#include {h}" for h in sorted(kernel.required_headers)]
from string import Template
templ = Template(_module_template.read_text())
code_str = templ.substitute(
includes="\n".join(includes),
restrict_qualifier=self._compiler_info.restrict_qualifier(),
module_name=module_name,
kernel_name=kernel.name,
param_binds=", ".join(self._param_binds),
public_params=", ".join(self._public_params),
param_check_lines=indent("\n".join(self._param_check_lines), prefix=" "),
extraction_lines=indent("\n".join(self._extraction_lines), prefix=" "),
kernel_args=", ".join(kernel_args),
kernel_definition=kernel_def,
)
return code_str
def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper:
return Pybind11KernelWrapper(kernel, extension_module)
def _get_kernel_definition(self, kernel: Kernel) -> str:
from ...backend.emission import CAstPrinter
printer = CAstPrinter()
return printer(kernel)
def _add_field_param(self, ptr_param: Parameter):
field: Field = ptr_param.fields[0]
ptr_type = ptr_param.dtype
assert isinstance(ptr_type, PsPointerType)
if isinstance(field.dtype, DynamicType):
elem_type = ptr_type.base_type
else:
elem_type = field.dtype
self._actual_field_types[field] = elem_type
param_bind = f'py::arg("{field.name}").noconvert()'
self._param_binds.append(param_bind)
kernel_param = f"py::array_t< {elem_type.c_string()} > & {field.name}"
self._public_params.append(kernel_param)
expect_shape = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.shape) + ")"
for coord, size in enumerate(field.shape):
if isinstance(size, int):
self._param_check_lines.append(
f"checkFieldShape(\"{field.name}\", \"{expect_shape}\", {field.name}, {coord}, {size});"
)
expect_strides = "(" + ", ".join((str(s) if isinstance(s, int) else "*") for s in field.strides) + ")"
for coord, stride in enumerate(field.strides):
if isinstance(stride, int):
self._param_check_lines.append(
f"checkFieldStride(\"{field.name}\", \"{expect_strides}\", {field.name}, {coord}, {stride});"
)
def _add_scalar_param(self, sc_param: Parameter):
param_bind = f'py::arg("{sc_param.name}")'
self._param_binds.append(param_bind)
kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
self._public_params.append(kernel_param)
def _extract_base_ptr(self, ptr_param: Parameter, ptr_prop: FieldBasePtr):
field_name = ptr_prop.field.name
assert isinstance(ptr_param.dtype, PsPointerType)
data_method = "data()" if ptr_param.dtype.base_type.const else "mutable_data()"
extraction = f"{ptr_param.dtype.c_string()} {ptr_param.name} {{ {field_name}.{data_method} }};"
self._extraction_lines.append(extraction)
def _extract_shape(self, shape_param: Parameter, shape_prop: FieldShape):
field_name = shape_prop.field.name
coord = shape_prop.coordinate
extraction = f"{shape_param.dtype.c_string()} {shape_param.name} {{ {field_name}.shape({coord}) }};"
self._extraction_lines.append(extraction)
def _extract_stride(self, stride_param: Parameter, stride_prop: FieldStride):
field = stride_prop.field
field_name = field.name
coord = stride_prop.coordinate
field_type = self._actual_field_types[field]
assert field_type.itemsize is not None
extraction = (
f"{stride_param.dtype.c_string()} {stride_param.name} "
f"{{ {field_name}.strides({coord}) / {field_type.itemsize} }};"
)
self._extraction_lines.append(extraction)
def _handle_params(self, parameters: Sequence[Parameter]):
for param in parameters:
if param.get_properties(FieldBasePtr):
self._add_field_param(param)
for param in parameters:
if ptr_props := param.get_properties(FieldBasePtr):
self._extract_base_ptr(param, cast(FieldBasePtr, ptr_props.pop()))
elif shape_props := param.get_properties(FieldShape):
self._extract_shape(param, cast(FieldShape, shape_props.pop()))
elif stride_props := param.get_properties(FieldStride):
self._extract_stride(param, cast(FieldStride, stride_props.pop()))
else:
self._add_scalar_param(param)
class Pybind11KernelWrapper(KernelWrapper):
def __init__(self, kernel: Kernel, jit_module: ModuleType):
super().__init__(kernel)
self._module = jit_module
self._check_params = getattr(jit_module, "check_params")
self._invoke = getattr(jit_module, "invoke")
def __call__(self, **kwargs) -> None:
self._check_params(**kwargs)
return self._invoke(**kwargs)
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include <array>
#include <string>
#include <sstream>
${includes}
namespace py = pybind11;
#define RESTRICT ${restrict_qualifier}
namespace internal {
${kernel_definition}
}
std::string tuple_to_str(const ssize_t * data, const size_t N){
std::stringstream acc;
acc << "(";
for(size_t i = 0; i < N; ++i){
acc << data[i];
if(i + 1 < N){
acc << ", ";
}
}
acc << ")";
return acc.str();
}
template< typename T >
void checkFieldShape(const std::string& fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){
std::stringstream err;
err << "Invalid shape of argument " << fieldName
<< ". Expected " << expected
<< ", but got " << tuple_to_str(arr.shape(), arr.ndim())
<< ".";
throw py::value_error{ err.str() };
};
if(arr.ndim() <= coord){
panic();
}
if(arr.shape(coord) != desired){
panic();
}
}
template< typename T >
void checkFieldStride(const std::string fieldName, const std::string& expected, const py::array_t< T > & arr, size_t coord, size_t desired) {
auto panic = [&](){
std::stringstream err;
err << "Invalid strides of argument " << fieldName
<< ". Expected " << expected
<< ", but got " << tuple_to_str(arr.strides(), arr.ndim())
<< ".";
throw py::value_error{ err.str() };
};
if(arr.ndim() <= coord){
panic();
}
if(arr.strides(coord) / sizeof(T) != desired){
panic();
}
}
void check_params_${kernel_name} (${public_params}) {
${param_check_lines}
}
void run_${kernel_name} (${public_params}) {
${extraction_lines}
internal::${kernel_name}(${kernel_args});
}
PYBIND11_MODULE(${module_name}, m) {
m.def("check_params", &check_params_${kernel_name}, py::kw_only(), ${param_binds});
m.def("invoke", &run_${kernel_name}, py::kw_only(), ${param_binds});
}
from __future__ import annotations
from typing import Any, cast
from os import path
import hashlib
from itertools import chain
from textwrap import indent
import numpy as np
from ..codegen import (
Kernel,
Parameter,
)
from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride
from ..types import (
PsType,
PsUnsignedIntegerType,
PsSignedIntegerType,
PsIeeeFloatType,
PsPointerType,
)
from ..types.quick import Fp, SInt, UInt
from ..field import Field
class PsKernelExtensioNModule:
"""Replacement for `pystencils.cpu.cpujit.ExtensionModuleCode`.
Conforms to its interface for plug-in to `compile_and_load`.
"""
def __init__(
self, module_name: str = "generated", custom_backend: Any = None
) -> None:
self._module_name = module_name
if custom_backend is not None:
raise Exception(
"The `custom_backend` parameter exists only for interface compatibility and cannot be set."
)
self._kernels: dict[str, Kernel] = dict()
self._code_string: str | None = None
self._code_hash: str | None = None
@property
def module_name(self) -> str:
return self._module_name
def add_function(self, kernel_function: Kernel, name: str | None = None):
if name is None:
name = kernel_function.name
self._kernels[name] = kernel_function
def create_code_string(self, restrict_qualifier: str, function_prefix: str):
code = ""
# Collect headers
headers = {"<stdint.h>"}
for kernel in self._kernels.values():
headers |= kernel.required_headers
header_list = sorted(headers)
header_list.insert(0, '"Python.h"')
from pystencils.include import get_pystencils_include_path
ps_incl_path = get_pystencils_include_path()
ps_headers = []
for header in header_list:
header = header[1:-1]
header_path = path.join(ps_incl_path, header)
if path.exists(header_path):
ps_headers.append(header_path)
header_hash = b"".join(
[hashlib.sha256(open(h, "rb").read()).digest() for h in ps_headers]
)
# Prelude: Includes and definitions
includes = "\n".join(f"#include {header}" for header in header_list)
code += includes
code += "\n"
code += f"#define RESTRICT {restrict_qualifier}\n"
code += f"#define FUNC_PREFIX {function_prefix}\n"
code += "\n"
# Kernels and call wrappers
from ..backend.emission import CAstPrinter
printer = CAstPrinter(func_prefix="FUNC_PREFIX")
for name, kernel in self._kernels.items():
old_name = kernel.name
kernel.name = f"kernel_{name}"
code += printer(kernel)
code += "\n"
code += emit_call_wrapper(name, kernel)
code += "\n"
kernel.name = old_name
self._code_hash = (
"mod_" + hashlib.sha256(code.encode() + header_hash).hexdigest()
)
code += create_module_boilerplate_code(self._code_hash, self._kernels.keys())
self._code_string = code
def get_hash_of_code(self):
assert self._code_string is not None, "The code must be generated first"
return self._code_hash
def write_to_file(self, file):
assert self._code_string is not None, "The code must be generated first"
print(self._code_string, file=file)
def emit_call_wrapper(function_name: str, kernel: Kernel) -> str:
builder = CallWrapperBuilder()
builder.extract_params(kernel.parameters)
# for c in kernel.constraints:
# builder.check_constraint(c)
builder.call(kernel, kernel.parameters)
return builder.resolve(function_name)
template_module_boilerplate = """
static PyMethodDef method_definitions[] = {{
{method_definitions}
{{NULL, NULL, 0, NULL}}
}};
static struct PyModuleDef module_definition = {{
PyModuleDef_HEAD_INIT,
"{module_name}", /* name of module */
NULL, /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_definitions
}};
PyMODINIT_FUNC
PyInit_{module_name}(void)
{{
return PyModule_Create(&module_definition);
}}
"""
def create_module_boilerplate_code(module_name, names):
method_definition = (
'{{"{name}", (PyCFunction){name}, METH_VARARGS | METH_KEYWORDS, ""}},'
)
method_definitions = "\n".join(
[method_definition.format(name=name) for name in names]
)
return template_module_boilerplate.format(
module_name=module_name, method_definitions=method_definitions
)
class CallWrapperBuilder:
TMPL_EXTRACT_SCALAR = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
{target_type} {name} = ({target_type}) {extract_function}( obj_{name} );
if( PyErr_Occurred() ) {{ return NULL; }}
"""
TMPL_EXTRACT_ARRAY = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
Py_buffer buffer_{name};
int buffer_{name}_res = PyObject_GetBuffer(obj_{name}, &buffer_{name}, PyBUF_STRIDES | PyBUF_WRITABLE | PyBUF_FORMAT);
if (buffer_{name}_res == -1) {{ return NULL; }}
"""
TMPL_CHECK_ARRAY_TYPE = """
if(!({cond})) {{
PyErr_SetString(PyExc_TypeError, "Wrong {what} of array {name}. Expected {expected}");
return NULL;
}}
"""
KWCHECK = """
if( !kwargs || !PyDict_Check(kwargs) ) {{
PyErr_SetString(PyExc_TypeError, "No keyword arguments passed");
return NULL;
}}
"""
def __init__(self) -> None:
self._buffer_types: dict[Field, PsType] = dict()
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
self._array_assoc_var_extractions: dict[Parameter, str] = dict()
self._scalar_extractions: dict[Parameter, str] = dict()
self._constraint_checks: list[str] = []
self._call: str | None = None
def _scalar_extractor(self, dtype: PsType) -> str:
match dtype:
case Fp(32) | Fp(64):
return "PyFloat_AsDouble"
case SInt():
return "PyLong_AsLong"
case UInt():
return "PyLong_AsUnsignedLong"
case _:
raise ValueError(f"Don't know how to cast Python objects to {dtype}")
def _type_char(self, dtype: PsType) -> str | None:
if isinstance(
dtype, (PsUnsignedIntegerType, PsSignedIntegerType, PsIeeeFloatType)
):
np_dtype = dtype.NUMPY_TYPES[dtype.width]
return np.dtype(np_dtype).char
else:
return None
def get_field_buffer(self, field: Field) -> str:
"""Get the Python buffer object for the given field."""
return f"buffer_{field.name}"
def extract_field(self, field: Field) -> None:
"""Add the necessary code to extract the NumPy array for a given field"""
if field not in self._array_extractions:
extraction_code = self.TMPL_EXTRACT_ARRAY.format(name=field.name)
actual_dtype = self._buffer_types[field]
# Check array type
type_char = self._type_char(actual_dtype)
if type_char is not None:
dtype_cond = f"buffer_{field.name}.format[0] == '{type_char}'"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=dtype_cond,
what="data type",
name=field.name,
expected=str(actual_dtype),
)
# Check item size
itemsize = actual_dtype.itemsize
item_size_cond = f"buffer_{field.name}.itemsize == {itemsize}"
extraction_code += self.TMPL_CHECK_ARRAY_TYPE.format(
cond=item_size_cond, what="itemsize", name=field.name, expected=itemsize
)
self._array_extractions[field] = extraction_code
release_code = f"PyBuffer_Release(&buffer_{field.name});"
self._array_frees[field] = release_code
def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
code = self.TMPL_EXTRACT_SCALAR.format(
name=param.name,
target_type=param.dtype.c_string(),
extract_function=extract_func,
)
self._scalar_extractions[param] = code
return param.name
def extract_array_assoc_var(self, param: Parameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.fields[0]
buffer = self.get_field_buffer(field)
buffer_dtype = self._buffer_types[field]
code: str | None = None
for prop in param.properties:
match prop:
case FieldBasePtr():
code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;"
break
case FieldShape(_, coord):
code = f"{param.dtype.c_string()} {param.name} = {buffer}.shape[{coord}];"
break
case FieldStride(_, coord):
code = (
f"{param.dtype.c_string()} {param.name} = "
f"{buffer}.strides[{coord}] / {buffer_dtype.itemsize};"
)
break
assert code is not None
self._array_assoc_var_extractions[param] = code
return param.name
def extract_params(self, params: tuple[Parameter, ...]) -> None:
for param in params:
if ptr_props := param.get_properties(FieldBasePtr):
prop: FieldBasePtr = cast(FieldBasePtr, ptr_props.pop())
field = prop.field
actual_field_type: PsType
from .. import DynamicType
if isinstance(field.dtype, DynamicType):
ptr_type = param.dtype
assert isinstance(ptr_type, PsPointerType)
actual_field_type = ptr_type.base_type
else:
actual_field_type = field.dtype
self._buffer_types[prop.field] = actual_field_type
self.extract_field(prop.field)
for param in params:
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
# def check_constraint(self, constraint: KernelParamsConstraint):
# variables = constraint.get_parameters()
# for var in variables:
# self.extract_parameter(var)
# cond = constraint.to_code()
# code = f"""
# if(!({cond}))
# {{
# PyErr_SetString(PyExc_ValueError, "Violated constraint: {constraint}");
# return NULL;
# }}
# """
# self._constraint_checks.append(code)
def call(self, kernel: Kernel, params: tuple[Parameter, ...]):
param_list = ", ".join(p.name for p in params)
self._call = f"{kernel.name} ({param_list});"
def resolve(self, function_name) -> str:
assert self._call is not None
body = "\n\n".join(
chain(
[self.KWCHECK],
self._scalar_extractions.values(),
self._array_extractions.values(),
self._array_assoc_var_extractions.values(),
self._constraint_checks,
[self._call],
self._array_frees.values(),
["Py_RETURN_NONE;"],
)
)
code = f"static PyObject * {function_name}(PyObject * self, PyObject * args, PyObject * kwargs)\n"
code += "{\n" + indent(body, prefix=" ") + "\n}\n"
return code