Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • const_fix
  • fhennig/v2.0-deprecations
  • fma
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • master
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikeliml/124-rework-tutorials
  • zikeliml/Task-96-dotExporterForAST
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
  • release/1.3.4
  • release/1.3.5
  • release/1.3.6
  • release/1.3.7
  • release/2.0.dev0
56 results

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • backend-rework
  • bauerd/dynamic-array-type
  • bauerd/wip/vec-is-rework
  • const_fix
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • kohl/vectorization-fixes
  • master
  • rvv
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikelim/mixed-precision
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
53 results
Show changes
Showing
with 945 additions and 245 deletions
File moved
...@@ -256,9 +256,7 @@ class Field: ...@@ -256,9 +256,7 @@ class Field:
self.shape = shape self.shape = shape
self.strides = strides self.strides = strides
self.latex_name: Optional[str] = None self.latex_name: Optional[str] = None
self.coordinate_origin: tuple[float, sp.Symbol] = sp.Matrix(tuple( self.coordinate_origin = sp.Matrix([0] * self.spatial_dimensions)
0 for _ in range(self.spatial_dimensions)
))
self.coordinate_transform = sp.eye(self.spatial_dimensions) self.coordinate_transform = sp.eye(self.spatial_dimensions)
if field_type == FieldType.STAGGERED: if field_type == FieldType.STAGGERED:
assert self.staggered_stencil assert self.staggered_stencil
...@@ -267,8 +265,7 @@ class Field: ...@@ -267,8 +265,7 @@ class Field:
if self.has_fixed_shape: if self.has_fixed_shape:
return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides) return Field(new_name, self.field_type, self._dtype, self._layout, self.shape, self.strides)
else: else:
return Field.create_generic(new_name, self.spatial_dimensions, self.dtype.numpy_dtype, return Field(new_name, self.field_type, self.dtype, self.layout, self.shape, self.strides)
self.index_dimensions, self._layout, self.index_shape, self.field_type)
@property @property
def spatial_dimensions(self) -> int: def spatial_dimensions(self) -> int:
...@@ -951,24 +948,35 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0 ...@@ -951,24 +948,35 @@ def create_numpy_array_with_layout(shape, layout, alignment=False, byte_offset=0
def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]: def spatial_layout_string_to_tuple(layout_str: str, dim: int) -> Tuple[int, ...]:
if layout_str in ('fzyx', 'zyxf'): if dim <= 0:
assert dim <= 3 raise ValueError("Dimensionality must be positive")
return tuple(reversed(range(dim)))
layout_str = layout_str.lower()
if layout_str in ('fzyx', 'f', 'reverse_numpy', 'SoA'): if layout_str in ('fzyx', 'zyxf', 'soa', 'aos'):
if dim > 3:
raise ValueError(f"Invalid spatial dimensionality for layout descriptor {layout_str}: May be at most 3.")
return tuple(reversed(range(dim)))
if layout_str in ('f', 'reverse_numpy'):
return tuple(reversed(range(dim))) return tuple(reversed(range(dim)))
elif layout_str in ('c', 'numpy', 'AoS'): elif layout_str in ('c', 'numpy'):
return tuple(range(dim)) return tuple(range(dim))
raise ValueError("Unknown layout descriptor " + layout_str) raise ValueError("Unknown layout descriptor " + layout_str)
def layout_string_to_tuple(layout_str, dim): def layout_string_to_tuple(layout_str, dim):
if dim <= 0:
raise ValueError("Dimensionality must be positive")
layout_str = layout_str.lower() layout_str = layout_str.lower()
if layout_str == 'fzyx' or layout_str == 'soa': if layout_str == 'fzyx' or layout_str == 'soa':
assert dim <= 4 if dim > 4:
raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
return tuple(reversed(range(dim))) return tuple(reversed(range(dim)))
elif layout_str == 'zyxf' or layout_str == 'aos': elif layout_str == 'zyxf' or layout_str == 'aos':
assert dim <= 4 if dim > 4:
raise ValueError(f"Invalid total dimensionality for layout descriptor {layout_str}: May be at most 4.")
return tuple(reversed(range(dim - 1))) + (dim - 1,) return tuple(reversed(range(dim - 1))) + (dim - 1,)
elif layout_str == 'f' or layout_str == 'reverse_numpy': elif layout_str == 'f' or layout_str == 'reverse_numpy':
return tuple(reversed(range(dim))) return tuple(reversed(range(dim)))
......
File moved
File moved
import abc import abc
from functools import partial from functools import partial
import math import math
from typing import List, Tuple
import sympy as sp import sympy as sp
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from pystencils.astnodes import Block, Conditional from pystencils.astnodes import Block, Conditional, SympyAssignment
from pystencils.typing import TypedSymbol, create_type from pystencils.typing import TypedSymbol, create_type
from pystencils.integer_functions import div_ceil, div_floor from pystencils.integer_functions import div_ceil, div_floor
from pystencils.slicing import normalize_slice
from pystencils.sympyextensions import is_integer_sequence, prod from pystencils.sympyextensions import is_integer_sequence, prod
...@@ -33,12 +33,37 @@ GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for c ...@@ -33,12 +33,37 @@ GRID_DIM = [ThreadIndexingSymbol("gridDim." + coord, create_type("int32")) for c
class AbstractIndexing(abc.ABC): class AbstractIndexing(abc.ABC):
""" """
Abstract base class for all Indexing classes. An Indexing class defines how a multidimensional Abstract base class for all Indexing classes. An Indexing class defines how an iteration space is mapped
field is mapped to GPU's block and grid system. It calculates indices based on GPU's thread and block indices to GPU's block and grid system. It calculates indices based on GPU's thread and block indices
and computes the number of blocks and threads a kernel is started with. The Indexing class is created with and computes the number of blocks and threads a kernel is started with.
a pystencils field, a slice to iterate over, and further optional parameters that must have default values. The Indexing class is created with an iteration space that is given as list of slices to determine start, stop
and the step size for each coordinate. Further the data_layout is given as tuple to determine the fast and slow
coordinates. This is important to get an optimal mapping of coordinates to GPU threads.
""" """
def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple):
for iter_space in iteration_space:
assert isinstance(iter_space, slice), f"iteration_space must be of type Tuple[slice], " \
f"not tuple of type {type(iter_space)}"
self._iteration_space = iteration_space
self._data_layout = data_layout
self._dim = len(iteration_space)
@property
def iteration_space(self):
"""Iteration space to loop over"""
return self._iteration_space
@property
def data_layout(self):
"""Data layout of the kernels arrays"""
return self._data_layout
@property
def dim(self):
"""Number of spatial dimensions"""
return self._dim
@property @property
@abc.abstractmethod @abc.abstractmethod
def coordinates(self): def coordinates(self):
...@@ -50,6 +75,16 @@ class AbstractIndexing(abc.ABC): ...@@ -50,6 +75,16 @@ class AbstractIndexing(abc.ABC):
"""Sympy symbols for GPU's block and thread indices, and block and grid dimensions. """ """Sympy symbols for GPU's block and thread indices, and block and grid dimensions. """
return BLOCK_IDX + THREAD_IDX + BLOCK_DIM + GRID_DIM return BLOCK_IDX + THREAD_IDX + BLOCK_DIM + GRID_DIM
@abc.abstractmethod
def get_loop_ctr_assignments(self, loop_counter_symbols) -> List[SympyAssignment]:
"""Adds assignments for the loop counter symbols depending on the gpu threads.
Args:
loop_counter_symbols: typed symbols representing the loop counters
Returns:
assignments for the loop counters
"""
@abc.abstractmethod @abc.abstractmethod
def call_parameters(self, arr_shape): def call_parameters(self, arr_shape):
"""Determine grid and block size for kernel call. """Determine grid and block size for kernel call.
...@@ -92,8 +127,9 @@ class BlockIndexing(AbstractIndexing): ...@@ -92,8 +127,9 @@ class BlockIndexing(AbstractIndexing):
"""Generic indexing scheme that maps sub-blocks of an array to GPU blocks. """Generic indexing scheme that maps sub-blocks of an array to GPU blocks.
Args: Args:
field: pystencils field (common to all Indexing classes) iteration_space: list of slices to determine start, stop and the step size for each coordinate
iteration_slice: slice that defines rectangular subarea which is iterated over data_layout: tuple specifying loop order with innermost loop last.
This is the same format as returned by `Field.layout`.
permute_block_size_dependent_on_layout: if True the block_size is permuted such that the fastest coordinate permute_block_size_dependent_on_layout: if True the block_size is permuted such that the fastest coordinate
gets the largest amount of threads gets the largest amount of threads
compile_time_block_size: compile in concrete block size, otherwise the gpu variable 'blockDim' is used compile_time_block_size: compile in concrete block size, otherwise the gpu variable 'blockDim' is used
...@@ -102,14 +138,16 @@ class BlockIndexing(AbstractIndexing): ...@@ -102,14 +138,16 @@ class BlockIndexing(AbstractIndexing):
device_number: device number of the used GPU. By default, the zeroth device is used. device_number: device number of the used GPU. By default, the zeroth device is used.
""" """
def __init__(self, field, iteration_slice, def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple[int],
block_size=(16, 16, 1), permute_block_size_dependent_on_layout=True, compile_time_block_size=False, block_size=(128, 2, 1), permute_block_size_dependent_on_layout=True, compile_time_block_size=False,
maximum_block_size=(1024, 1024, 64), device_number=None): maximum_block_size=(1024, 1024, 64), device_number=None):
if field.spatial_dimensions > 3: super(BlockIndexing, self).__init__(iteration_space, data_layout)
raise NotImplementedError("This indexing scheme supports at most 3 spatial dimensions")
if self._dim > 4:
raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions")
if permute_block_size_dependent_on_layout: if permute_block_size_dependent_on_layout and self._dim < 4:
block_size = self.permute_block_size_according_to_layout(block_size, field.layout) block_size = self.permute_block_size_according_to_layout(block_size, data_layout)
self._block_size = block_size self._block_size = block_size
if maximum_block_size == 'auto': if maximum_block_size == 'auto':
...@@ -124,9 +162,6 @@ class BlockIndexing(AbstractIndexing): ...@@ -124,9 +162,6 @@ class BlockIndexing(AbstractIndexing):
maximum_block_size = tuple(da[f"MaxBlockDim{c}"] for c in ["X", "Y", "Z"]) maximum_block_size = tuple(da[f"MaxBlockDim{c}"] for c in ["X", "Y", "Z"])
self._maximum_block_size = maximum_block_size self._maximum_block_size = maximum_block_size
self._iterationSlice = normalize_slice(iteration_slice, field.spatial_shape)
self._dim = field.spatial_dimensions
self._symbolic_shape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape]
self._compile_time_block_size = compile_time_block_size self._compile_time_block_size = compile_time_block_size
self._device_number = device_number self._device_number = device_number
...@@ -140,17 +175,28 @@ class BlockIndexing(AbstractIndexing): ...@@ -140,17 +175,28 @@ class BlockIndexing(AbstractIndexing):
@property @property
def coordinates(self): def coordinates(self):
offsets = _get_start_from_slice(self._iterationSlice) if self._dim < 4:
coordinates = [c + off for c, off in zip(self.cuda_indices, offsets)] coordinates = [c + iter_slice.start for c, iter_slice in zip(self.cuda_indices, self._iteration_space)]
return coordinates[:self._dim]
else:
coordinates = list()
width = self._iteration_space[1].stop - self.iteration_space[1].start
coordinates.append(div_floor(self.cuda_indices[0], width))
coordinates.append(sp.Mod(self.cuda_indices[0], width))
coordinates.append(self.cuda_indices[1] + self.iteration_space[2].start)
coordinates.append(self.cuda_indices[2] + self.iteration_space[3].start)
return coordinates
return coordinates[:self._dim] def get_loop_ctr_assignments(self, loop_counter_symbols):
return _loop_ctr_assignments(loop_counter_symbols, self.coordinates, self._iteration_space)
def call_parameters(self, arr_shape): def call_parameters(self, arr_shape):
substitution_dict = {sym: value for sym, value in zip(self._symbolic_shape, arr_shape) if sym is not None} numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
widths = _get_widths(numeric_iteration_slice)
if len(widths) > 3:
widths = [widths[0] * widths[1], widths[2], widths[3]]
widths = [end - start for start, end in zip(_get_start_from_slice(self._iterationSlice),
_get_end_from_slice(self._iterationSlice, arr_shape))]
widths = sp.Matrix(widths).subs(substitution_dict)
extend_bs = (1,) * (3 - len(self._block_size)) extend_bs = (1,) * (3 - len(self._block_size))
block_size = self._block_size + extend_bs block_size = self._block_size + extend_bs
if not self._compile_time_block_size: if not self._compile_time_block_size:
...@@ -171,20 +217,30 @@ class BlockIndexing(AbstractIndexing): ...@@ -171,20 +217,30 @@ class BlockIndexing(AbstractIndexing):
def guard(self, kernel_content, arr_shape): def guard(self, kernel_content, arr_shape):
arr_shape = arr_shape[:self._dim] arr_shape = arr_shape[:self._dim]
end = _get_end_from_slice(self._iterationSlice, arr_shape) if len(self._iteration_space) - 1 == len(arr_shape):
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space[1:], arr_shape)
conditions = [c < e for c, e in zip(self.coordinates, end)] numeric_iteration_slice = [self.iteration_space[0]] + numeric_iteration_slice
for cuda_index, iter_slice in zip(self.cuda_indices, self._iterationSlice): else:
if isinstance(iter_slice, slice) and iter_slice.step > 1: assert len(self._iteration_space) == len(arr_shape), "Iteration space must be equal to the array shape"
conditions.append(sp.Eq(sp.Mod(cuda_index, iter_slice.step), 0)) numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice]
for i, s in enumerate(numeric_iteration_slice):
if s.step and s.step != 1:
end[i] = div_ceil(s.stop - s.start, s.step) + s.start
if self._dim < 4:
conditions = [c < e for c, e in zip(self.coordinates, end)]
else:
end = [end[0] * end[1], end[2], end[3]]
coordinates = [c + iter_slice.start for c, iter_slice in zip(self.cuda_indices, self._iteration_space[1:])]
conditions = [c < e for c, e in zip(coordinates, end)]
condition = conditions[0] condition = conditions[0]
for c in conditions[1:]: for c in conditions[1:]:
condition = sp.And(condition, c) condition = sp.And(condition, c)
return Block([Conditional(condition, kernel_content)]) return Block([Conditional(condition, kernel_content)])
def iteration_space(self, arr_shape): def numeric_iteration_space(self, arr_shape):
return _iteration_space(self._iterationSlice, arr_shape) return _get_numeric_iteration_slice(self._iteration_space, arr_shape)
def limit_block_size_by_register_restriction(self, block_size, required_registers_per_thread): def limit_block_size_by_register_restriction(self, block_size, required_registers_per_thread):
"""Shrinks the block_size if there are too many registers used per block. """Shrinks the block_size if there are too many registers used per block.
...@@ -246,38 +302,44 @@ class LineIndexing(AbstractIndexing): ...@@ -246,38 +302,44 @@ class LineIndexing(AbstractIndexing):
The fastest coordinate is indexed with thread_idx.x, the remaining coordinates are mapped to block_idx.{x,y,z} The fastest coordinate is indexed with thread_idx.x, the remaining coordinates are mapped to block_idx.{x,y,z}
This indexing scheme supports up to 4 spatial dimensions, where the innermost dimensions is not larger than the This indexing scheme supports up to 4 spatial dimensions, where the innermost dimensions is not larger than the
maximum amount of threads allowed in a GPU block (which depends on device). maximum amount of threads allowed in a GPU block (which depends on device).
Args:
iteration_space: list of slices to determine start, stop and the step size for each coordinate
data_layout: tuple to determine the fast and slow coordinates.
""" """
def __init__(self, field, iteration_slice): def __init__(self, iteration_space: Tuple[slice], data_layout: Tuple):
available_indices = [THREAD_IDX[0]] + BLOCK_IDX super(LineIndexing, self).__init__(iteration_space, data_layout)
if field.spatial_dimensions > 4:
if len(iteration_space) > 4:
raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions") raise NotImplementedError("This indexing scheme supports at most 4 spatial dimensions")
coordinates = available_indices[:field.spatial_dimensions] @property
def cuda_indices(self):
available_indices = [THREAD_IDX[0]] + BLOCK_IDX
coordinates = available_indices[:self.dim]
fastest_coordinate = field.layout[-1] fastest_coordinate = self.data_layout[-1]
coordinates[0], coordinates[fastest_coordinate] = coordinates[fastest_coordinate], coordinates[0] coordinates[0], coordinates[fastest_coordinate] = coordinates[fastest_coordinate], coordinates[0]
self._coordinates = coordinates return coordinates
self._iterationSlice = normalize_slice(iteration_slice, field.spatial_shape)
self._symbolicShape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape]
@property @property
def coordinates(self): def coordinates(self):
return [i + offset for i, offset in zip(self._coordinates, _get_start_from_slice(self._iterationSlice))] return [i + o.start for i, o in zip(self.cuda_indices, self._iteration_space)]
def call_parameters(self, arr_shape): def get_loop_ctr_assignments(self, loop_counter_symbols):
substitution_dict = {sym: value for sym, value in zip(self._symbolicShape, arr_shape) if sym is not None} return _loop_ctr_assignments(loop_counter_symbols, self.coordinates, self._iteration_space)
widths = [end - start for start, end in zip(_get_start_from_slice(self._iterationSlice), def call_parameters(self, arr_shape):
_get_end_from_slice(self._iterationSlice, arr_shape))] numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
widths = sp.Matrix(widths).subs(substitution_dict) widths = _get_widths(numeric_iteration_slice)
def get_shape_of_cuda_idx(cuda_idx): def get_shape_of_cuda_idx(cuda_idx):
if cuda_idx not in self._coordinates: if cuda_idx not in self.cuda_indices:
return 1 return 1
else: else:
idx = self._coordinates.index(cuda_idx) idx = self.cuda_indices.index(cuda_idx)
return widths[idx] return widths[idx]
return {'block': tuple([get_shape_of_cuda_idx(idx) for idx in THREAD_IDX]), return {'block': tuple([get_shape_of_cuda_idx(idx) for idx in THREAD_IDX]),
...@@ -292,50 +354,66 @@ class LineIndexing(AbstractIndexing): ...@@ -292,50 +354,66 @@ class LineIndexing(AbstractIndexing):
def symbolic_parameters(self): def symbolic_parameters(self):
return set() return set()
def iteration_space(self, arr_shape): def numeric_iteration_space(self, arr_shape):
return _iteration_space(self._iterationSlice, arr_shape) return _get_numeric_iteration_slice(self._iteration_space, arr_shape)
# -------------------------------------- Helper functions -------------------------------------------------------------- # -------------------------------------- Helper functions --------------------------------------------------------------
def _get_start_from_slice(iteration_slice): def _get_numeric_iteration_slice(iteration_slice, arr_shape):
res = [] res = []
for slice_component in iteration_slice: for slice_component, shape in zip(iteration_slice, arr_shape):
if type(slice_component) is slice: result_slice = slice_component
res.append(slice_component.start if slice_component.start is not None else 0) if not isinstance(result_slice.start, int):
else: start = result_slice.start
assert isinstance(slice_component, int) assert len(start.free_symbols) == 1
res.append(slice_component) start = start.subs({symbol: shape for symbol in start.free_symbols})
result_slice = slice(start, result_slice.stop, result_slice.step)
if not isinstance(result_slice.stop, int):
stop = result_slice.stop
assert len(stop.free_symbols) == 1
stop = stop.subs({symbol: shape for symbol in stop.free_symbols})
result_slice = slice(result_slice.start, stop, result_slice.step)
assert isinstance(result_slice.step, int)
res.append(result_slice)
return res return res
def _get_end_from_slice(iteration_slice, arr_shape): def _get_widths(iteration_slice):
iter_slice = normalize_slice(iteration_slice, arr_shape) widths = []
res = [] for iter_slice in iteration_slice:
for slice_component in iter_slice: step = iter_slice.step
if type(slice_component) is slice: assert isinstance(step, int), f"Step can only be of type int not of type {type(step)}"
res.append(slice_component.stop) start = iter_slice.start
stop = iter_slice.stop
if step == 1:
if stop - start == 0:
widths.append(1)
else:
widths.append(stop - start)
else: else:
assert isinstance(slice_component, int) width = (stop - start) / step
res.append(slice_component + 1) if isinstance(width, int):
return res widths.append(width)
elif isinstance(width, float):
widths.append(math.ceil(width))
def _get_steps_from_slice(iteration_slice): else:
res = [] widths.append(div_ceil(stop - start, step))
for slice_component in iteration_slice: return widths
if type(slice_component) is slice:
res.append(slice_component.step)
def _loop_ctr_assignments(loop_counter_symbols, coordinates, iteration_space):
loop_ctr_assignments = []
for loop_counter, coordinate, iter_slice in zip(loop_counter_symbols, coordinates, iteration_space):
if isinstance(iter_slice, slice) and iter_slice.step > 1:
offset = (iter_slice.step * iter_slice.start) - iter_slice.start
loop_ctr_assignments.append(SympyAssignment(loop_counter, coordinate * iter_slice.step - offset))
elif iter_slice.start == iter_slice.stop:
loop_ctr_assignments.append(SympyAssignment(loop_counter, 0))
else: else:
res.append(1) loop_ctr_assignments.append(SympyAssignment(loop_counter, coordinate))
return res
def _iteration_space(iteration_slice, arr_shape): return loop_ctr_assignments
starts = _get_start_from_slice(iteration_slice)
ends = _get_end_from_slice(iteration_slice, arr_shape)
steps = _get_steps_from_slice(iteration_slice)
return [slice(start, end, step) for start, end, step in zip(starts, ends, steps)]
def indexing_creator_from_params(gpu_indexing, gpu_indexing_params): def indexing_creator_from_params(gpu_indexing, gpu_indexing_params):
......
from typing import Union import sympy as sp
import numpy as np
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.config import CreateKernelConfig from pystencils.config import CreateKernelConfig
...@@ -11,14 +9,13 @@ from pystencils.enums import Target, Backend ...@@ -11,14 +9,13 @@ from pystencils.enums import Target, Backend
from pystencils.gpu.gpujit import make_python_function from pystencils.gpu.gpujit import make_python_function
from pystencils.node_collection import NodeCollection from pystencils.node_collection import NodeCollection
from pystencils.gpu.indexing import indexing_creator_from_params from pystencils.gpu.indexing import indexing_creator_from_params
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.slicing import normalize_slice
from pystencils.transformations import ( from pystencils.transformations import (
get_base_buffer_index, get_common_field, parse_base_pointer_info, get_base_buffer_index, get_common_field, get_common_indexed_element, parse_base_pointer_info,
resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols) resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
config: CreateKernelConfig):
function_name = config.function_name function_name = config.function_name
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params) indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
...@@ -39,7 +36,9 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -39,7 +36,9 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
field_accesses = set() field_accesses = set()
num_buffer_accesses = 0 num_buffer_accesses = 0
indexed_elements = set()
for eq in assignments: for eq in assignments:
indexed_elements.update(eq.atoms(sp.Indexed))
field_accesses.update(eq.atoms(Field.Access)) field_accesses.update(eq.atoms(Field.Access))
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field)) num_buffer_accesses += sum(1 for access in eq.atoms(Field.Access) if FieldType.is_buffer(access.field))
...@@ -64,17 +63,28 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -64,17 +63,28 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
iteration_slice.append(slice(ghost_layers[i][0], iteration_slice.append(slice(ghost_layers[i][0],
-ghost_layers[i][1] if ghost_layers[i][1] > 0 else None)) -ghost_layers[i][1] if ghost_layers[i][1] > 0 else None))
indexing = indexing_creator(field=common_field, iteration_slice=iteration_slice) iteration_space = normalize_slice(iteration_slice, common_shape)
coord_mapping = indexing.coordinates else:
iteration_space = normalize_slice(iteration_slice, common_shape)
cell_idx_assignments = [SympyAssignment(LoopOverCoordinate.get_loop_counter_symbol(i), value)
for i, value in enumerate(coord_mapping)] iteration_space = tuple([s if isinstance(s, slice) else slice(s, s + 1, 1) for s in iteration_space])
cell_idx_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i, _ in enumerate(coord_mapping)]
assignments = cell_idx_assignments + assignments loop_counter_symbols = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(len(iteration_space))]
if len(indexed_elements) > 0:
common_indexed_element = get_common_indexed_element(indexed_elements)
index = common_indexed_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
indexing = indexing_creator(iteration_space=(slice(0, common_indexed_element.shape[0], 1), *iteration_space),
data_layout=common_field.layout)
extended_ctrs = [index.pop(), *loop_counter_symbols]
loop_counter_assignments = indexing.get_loop_ctr_assignments(extended_ctrs)
else:
indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
loop_counter_assignments = indexing.get_loop_ctr_assignments(loop_counter_symbols)
assignments = loop_counter_assignments + assignments
block = indexing.guard(Block(assignments), common_shape)
block = Block(assignments)
block = indexing.guard(block, common_shape)
unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
ast = KernelFunction(block, ast = KernelFunction(block,
...@@ -86,16 +96,18 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -86,16 +96,18 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
assignments=assignments) assignments=assignments)
ast.global_variables.update(indexing.index_variables) ast.global_variables.update(indexing.index_variables)
base_pointer_spec = [['spatialInner0']] base_pointer_spec = config.base_pointer_specification
if base_pointer_spec is None:
base_pointer_spec = []
base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0], base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions) f.spatial_dimensions, f.index_dimensions)
for f in all_fields} for f in all_fields}
coord_mapping = {f.name: cell_idx_symbols for f in all_fields} coord_mapping = {f.name: loop_counter_symbols for f in all_fields}
if any(FieldType.is_buffer(f) for f in all_fields): if any(FieldType.is_buffer(f) for f in all_fields):
iteration_space = indexing.iteration_space(common_shape) base_buffer_index = get_base_buffer_index(ast, loop_counter_symbols, iteration_space)
resolve_buffer_accesses(ast, get_base_buffer_index(ast, cell_idx_symbols, iteration_space), read_only_fields) resolve_buffer_accesses(ast, base_buffer_index, read_only_fields)
resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info, resolve_field_accesses(ast, read_only_fields, field_to_base_pointer_info=base_pointer_info,
field_to_fixed_coordinates=coord_mapping) field_to_fixed_coordinates=coord_mapping)
...@@ -114,40 +126,41 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -114,40 +126,41 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
return ast return ast
def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], def created_indexed_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
config: CreateKernelConfig):
index_fields = config.index_fields index_fields = config.index_fields
function_name = config.function_name function_name = config.function_name
coordinate_names = config.coordinate_names coordinate_names = config.coordinate_names
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params) indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
fields_written = assignments.bound_fields fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields fields_read = assignments.rhs_fields
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written) all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written]) read_only_fields = set([f.name for f in fields_read - fields_written])
# extract the index fields based on the name. The original index field might have been modified
index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
f"Non index fields are {non_index_fields}, spatial coordinates are " \
f"{spatial_coordinates}"
spatial_coordinates = list(spatial_coordinates)[0]
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
for index_field in index_fields: for index_field in index_fields:
index_field.field_type = FieldType.INDEXED index_field.field_type = FieldType.INDEXED
assert FieldType.is_indexed(index_field) assert FieldType.is_indexed(index_field)
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D" assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
spatial_coordinates = list(spatial_coordinates)[0]
def get_coordinate_symbol_assignment(name): def get_coordinate_symbol_assignment(name):
for ind_f in index_fields: for ind_f in index_fields:
assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type" assert isinstance(ind_f.dtype, StructType), "Index fields have to have a struct data type"
data_type = ind_f.dtype data_type = ind_f.dtype
if data_type.has_element(name): if data_type.has_element(name):
rhs = ind_f[0](name) rhs = ind_f[0](name)
lhs = TypedSymbol(name, np.int64) lhs = TypedSymbol(name, data_type.get_element_type(name))
return SympyAssignment(lhs, rhs) return SympyAssignment(lhs, rhs)
raise ValueError(f"Index {name} not found in any of the passed index fields") raise ValueError(f"Index {name} not found in any of the passed index fields")
...@@ -156,8 +169,12 @@ def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCol ...@@ -156,8 +169,12 @@ def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCol
coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments] coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
idx_field = list(index_fields)[0] idx_field = list(index_fields)[0]
indexing = indexing_creator(field=idx_field,
iteration_slice=[slice(None, None, None)] * len(idx_field.spatial_shape)) iteration_space = normalize_slice(tuple([slice(None, None, None)]) * len(idx_field.spatial_shape),
idx_field.spatial_shape)
indexing = indexing_creator(iteration_space=iteration_space,
data_layout=idx_field.layout)
function_body = Block(coordinate_symbol_assignments + assignments) function_body = Block(coordinate_symbol_assignments + assignments)
function_body = indexing.guard(function_body, get_common_field(index_fields).spatial_shape) function_body = indexing.guard(function_body, get_common_field(index_fields).spatial_shape)
......
/*
Copyright 2010-2011, D. E. Shaw Research. All rights reserved.
Copyright 2019-2025, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#if defined(_MSC_VER) && defined(_M_ARM64)
#define __ARM_NEON
#endif
#ifdef __ARM_NEON
#include <arm_neon.h>
#if defined(__ARM_FEATURE_SVE)
#include <arm_sve.h>
#include <arm_neon_sve_bridge.h>
#endif
#else
#include <emmintrin.h> // SSE2 #include <emmintrin.h> // SSE2
#include <wmmintrin.h> // AES #include <wmmintrin.h> // AES
#ifdef __AVX__ #ifdef __AVX__
...@@ -8,6 +53,7 @@ ...@@ -8,6 +53,7 @@
#include <immintrin.h> // FMA #include <immintrin.h> // FMA
#endif #endif
#endif #endif
#endif
#include <cstdint> #include <cstdint>
#include <array> #include <array>
#include <map> #include <map>
...@@ -21,6 +67,14 @@ ...@@ -21,6 +67,14 @@
typedef std::uint32_t uint32; typedef std::uint32_t uint32;
typedef std::uint64_t uint64; typedef std::uint64_t uint64;
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_SVE_BITS) && __ARM_FEATURE_SVE_BITS > 0
typedef svfloat32_t svfloat32_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
typedef svfloat64_t svfloat64_st __attribute__((arm_sve_vector_bits(__ARM_FEATURE_SVE_BITS)));
#elif defined(__ARM_FEATURE_SVE)
typedef svfloat32_t svfloat32_st;
typedef svfloat64_t svfloat64_st;
#endif
template <typename T, std::size_t Alignment> template <typename T, std::size_t Alignment>
class AlignedAllocator class AlignedAllocator
{ {
...@@ -36,7 +90,14 @@ public: ...@@ -36,7 +90,14 @@ public:
if (n == 0) { if (n == 0) {
return nullptr; return nullptr;
} }
void * const p = _mm_malloc(n*sizeof(T), Alignment); #ifdef _WIN32
void * const p = _aligned_malloc(n*sizeof(T), Alignment);
#else
void * p;
if (posix_memalign(&p, Alignment, n*sizeof(T)) != 0) {
p = nullptr;
}
#endif
if (p == nullptr) { if (p == nullptr) {
throw std::bad_alloc(); throw std::bad_alloc();
} }
...@@ -44,7 +105,11 @@ public: ...@@ -44,7 +105,11 @@ public:
} }
void deallocate(T * const p, const std::size_t n) const { void deallocate(T * const p, const std::size_t n) const {
_mm_free(p); #ifdef _WIN32
_aligned_free(p);
#else
free(p);
#endif
} }
}; };
...@@ -54,7 +119,7 @@ using AlignedMap = std::map<Key, T, std::less<Key>, AlignedAllocator<std::pair<c ...@@ -54,7 +119,7 @@ using AlignedMap = std::map<Key, T, std::less<Key>, AlignedAllocator<std::pair<c
#if defined(__AES__) || defined(_MSC_VER) #if defined(__AES__) || defined(_MSC_VER)
QUALIFIERS __m128i aesni_keygen_assist(__m128i temp1, __m128i temp2) { QUALIFIERS __m128i aesni_keygen_assist(__m128i temp1, __m128i temp2) {
__m128i temp3; __m128i temp3;
temp2 = _mm_shuffle_epi32(temp2 ,0xff); temp2 = _mm_shuffle_epi32(temp2, 0xff);
temp3 = _mm_slli_si128(temp1, 0x4); temp3 = _mm_slli_si128(temp1, 0x4);
temp1 = _mm_xor_si128(temp1, temp3); temp1 = _mm_xor_si128(temp1, temp3);
temp3 = _mm_slli_si128(temp3, 0x4); temp3 = _mm_slli_si128(temp3, 0x4);
...@@ -241,6 +306,19 @@ QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y) ...@@ -241,6 +306,19 @@ QUALIFIERS __m128d _uniform_double_hq(__m128i x, __m128i y)
return rs; return rs;
} }
QUALIFIERS void transpose128(__m128i & R0, __m128i & R1, __m128i & R2, __m128i & R3)
{
__m128i T0, T1, T2, T3;
T0 = _mm_unpacklo_epi32(R0, R1);
T1 = _mm_unpacklo_epi32(R2, R3);
T2 = _mm_unpackhi_epi32(R0, R1);
T3 = _mm_unpackhi_epi32(R2, R3);
R0 = _mm_unpacklo_epi64(T0, T1);
R1 = _mm_unpackhi_epi64(T0, T1);
R2 = _mm_unpacklo_epi64(T2, T3);
R3 = _mm_unpackhi_epi64(T2, T3);
}
QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3, QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3, uint32 key0, uint32 key1, uint32 key2, uint32 key3,
...@@ -249,12 +327,12 @@ QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i c ...@@ -249,12 +327,12 @@ QUALIFIERS void aesni_float4(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i c
// pack input and call AES // pack input and call AES
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0); __m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i) for (int i = 0; i < 4; ++i)
{ {
ctr[i] = aesni1xm128i(ctr[i], k128); ctr[i] = aesni1xm128i(ctr[i], k128);
} }
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
// convert uint32 to float // convert uint32 to float
rnd1 = _my_cvtepu32_ps(ctr[0]); rnd1 = _my_cvtepu32_ps(ctr[0]);
...@@ -287,12 +365,12 @@ QUALIFIERS void aesni_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i ...@@ -287,12 +365,12 @@ QUALIFIERS void aesni_double2(__m128i ctr0, __m128i ctr1, __m128i ctr2, __m128i
// pack input and call AES // pack input and call AES
__m128i k128 = _mm_set_epi32(key3, key2, key1, key0); __m128i k128 = _mm_set_epi32(key3, key2, key1, key0);
__m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m128i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i) for (int i = 0; i < 4; ++i)
{ {
ctr[i] = aesni1xm128i(ctr[i], k128); ctr[i] = aesni1xm128i(ctr[i], k128);
} }
_MY_TRANSPOSE4_EPI32(ctr[0], ctr[1], ctr[2], ctr[3]); transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
...@@ -408,6 +486,20 @@ QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y) ...@@ -408,6 +486,20 @@ QUALIFIERS __m256d _uniform_double_hq(__m256i x, __m256i y)
} }
QUALIFIERS void transpose128(__m256i & R0, __m256i & R1, __m256i & R2, __m256i & R3)
{
__m256i T0, T1, T2, T3;
T0 = _mm256_unpacklo_epi32(R0, R1);
T1 = _mm256_unpacklo_epi32(R2, R3);
T2 = _mm256_unpackhi_epi32(R0, R1);
T3 = _mm256_unpackhi_epi32(R2, R3);
R0 = _mm256_unpacklo_epi64(T0, T1);
R1 = _mm256_unpackhi_epi64(T0, T1);
R2 = _mm256_unpacklo_epi64(T2, T3);
R3 = _mm256_unpackhi_epi64(T2, T3);
}
QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3, QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3, uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4) __m256 & rnd1, __m256 & rnd2, __m256 & rnd3, __m256 & rnd4)
...@@ -415,33 +507,12 @@ QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i c ...@@ -415,33 +507,12 @@ QUALIFIERS void aesni_float4(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i c
// pack input and call AES // pack input and call AES
__m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0); __m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0);
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4]; transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
for (int i = 0; i < 4; ++i) for (int i = 0; i < 4; ++i)
{ {
ctr[i] = aesni1xm128i(ctr[i], k256); ctr[i] = aesni1xm128i(ctr[i], k256);
} }
for (int i = 0; i < 4; ++i) transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
// convert uint32 to float // convert uint32 to float
rnd1 = _my256_cvtepu32_ps(ctr[0]); rnd1 = _my256_cvtepu32_ps(ctr[0]);
...@@ -474,33 +545,12 @@ QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i ...@@ -474,33 +545,12 @@ QUALIFIERS void aesni_double2(__m256i ctr0, __m256i ctr1, __m256i ctr2, __m256i
// pack input and call AES // pack input and call AES
__m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0); __m256i k256 = _mm256_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0);
__m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m256i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4]; transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
for (int i = 0; i < 4; ++i) for (int i = 0; i < 4; ++i)
{ {
ctr[i] = aesni1xm128i(ctr[i], k256); ctr[i] = aesni1xm128i(ctr[i], k256);
} }
for (int i = 0; i < 4; ++i) transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
{
a[i] = _mm256_extractf128_si256(ctr[i], 0);
b[i] = _mm256_extractf128_si256(ctr[i], 1);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my256_set_m128i(b[i], a[i]);
}
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
...@@ -551,7 +601,7 @@ QUALIFIERS void aesni_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr ...@@ -551,7 +601,7 @@ QUALIFIERS void aesni_double2(uint32 ctr0, __m256i ctr1, uint32 ctr2, uint32 ctr
#endif #endif
#ifdef __AVX512F__ #if defined(__AVX512F__) || defined(__AVX10_512BIT__)
QUALIFIERS const std::array<__m512i,11> & aesni_roundkeys(const __m512i & k512) { QUALIFIERS const std::array<__m512i,11> & aesni_roundkeys(const __m512i & k512) {
alignas(64) std::array<uint32,16> a; alignas(64) std::array<uint32,16> a;
_mm512_store_si512((__m512i*) a.data(), k512); _mm512_store_si512((__m512i*) a.data(), k512);
...@@ -622,6 +672,20 @@ QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y) ...@@ -622,6 +672,20 @@ QUALIFIERS __m512d _uniform_double_hq(__m512i x, __m512i y)
} }
QUALIFIERS void transpose128(__m512i & R0, __m512i & R1, __m512i & R2, __m512i & R3)
{
__m512i T0, T1, T2, T3;
T0 = _mm512_unpacklo_epi32(R0, R1);
T1 = _mm512_unpacklo_epi32(R2, R3);
T2 = _mm512_unpackhi_epi32(R0, R1);
T3 = _mm512_unpackhi_epi32(R2, R3);
R0 = _mm512_unpacklo_epi64(T0, T1);
R1 = _mm512_unpackhi_epi64(T0, T1);
R2 = _mm512_unpacklo_epi64(T2, T3);
R3 = _mm512_unpackhi_epi64(T2, T3);
}
QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3, QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3, uint32 key0, uint32 key1, uint32 key2, uint32 key3,
__m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4) __m512 & rnd1, __m512 & rnd2, __m512 & rnd3, __m512 & rnd4)
...@@ -630,41 +694,12 @@ QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i c ...@@ -630,41 +694,12 @@ QUALIFIERS void aesni_float4(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i c
__m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0, __m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0,
key3, key2, key1, key0, key3, key2, key1, key0); key3, key2, key1, key0, key3, key2, key1, key0);
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4], c[4], d[4]; transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
for (int i = 0; i < 4; ++i) for (int i = 0; i < 4; ++i)
{ {
ctr[i] = aesni1xm128i(ctr[i], k512); ctr[i] = aesni1xm128i(ctr[i], k512);
} }
for (int i = 0; i < 4; ++i) transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
// convert uint32 to float // convert uint32 to float
rnd1 = _mm512_cvtepu32_ps(ctr[0]); rnd1 = _mm512_cvtepu32_ps(ctr[0]);
...@@ -687,41 +722,12 @@ QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i ...@@ -687,41 +722,12 @@ QUALIFIERS void aesni_double2(__m512i ctr0, __m512i ctr1, __m512i ctr2, __m512i
__m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0, __m512i k512 = _mm512_set_epi32(key3, key2, key1, key0, key3, key2, key1, key0,
key3, key2, key1, key0, key3, key2, key1, key0); key3, key2, key1, key0, key3, key2, key1, key0);
__m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3}; __m512i ctr[4] = {ctr0, ctr1, ctr2, ctr3};
__m128i a[4], b[4], c[4], d[4]; transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
for (int i = 0; i < 4; ++i) for (int i = 0; i < 4; ++i)
{ {
ctr[i] = aesni1xm128i(ctr[i], k512); ctr[i] = aesni1xm128i(ctr[i], k512);
} }
for (int i = 0; i < 4; ++i) transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
{
a[i] = _mm512_extracti32x4_epi32(ctr[i], 0);
b[i] = _mm512_extracti32x4_epi32(ctr[i], 1);
c[i] = _mm512_extracti32x4_epi32(ctr[i], 2);
d[i] = _mm512_extracti32x4_epi32(ctr[i], 3);
}
_MY_TRANSPOSE4_EPI32(a[0], a[1], a[2], a[3]);
_MY_TRANSPOSE4_EPI32(b[0], b[1], b[2], b[3]);
_MY_TRANSPOSE4_EPI32(c[0], c[1], c[2], c[3]);
_MY_TRANSPOSE4_EPI32(d[0], d[1], d[2], d[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = _my512_set_m128i(d[i], c[i], b[i], a[i]);
}
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]); rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]); rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
...@@ -771,3 +777,466 @@ QUALIFIERS void aesni_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr ...@@ -771,3 +777,466 @@ QUALIFIERS void aesni_double2(uint32 ctr0, __m512i ctr1, uint32 ctr2, uint32 ctr
} }
#endif #endif
#if defined(__ARM_NEON)
QUALIFIERS uint32x4_t aesni_keygen_assist(uint32x4_t temp1, uint32x4_t temp2) {
uint32x4_t temp3;
temp2 = vdupq_laneq_u32(temp2, 3);
temp3 = vextq_u32(vdupq_n_u32(0), temp1, 4 - 1);
temp1 = veorq_u32(temp1, temp3);
temp3 = vextq_u32(vdupq_n_u32(0), temp3, 4 - 1);
temp1 = veorq_u32(temp1, temp3);
temp3 = vextq_u32(vdupq_n_u32(0), temp3, 4 - 1);
temp1 = veorq_u32(temp1, temp3);
temp1 = veorq_u32(temp1, temp2);
return temp1;
}
QUALIFIERS uint32x4_t aesni_keygen_assist(uint32x4_t k, const unsigned char imm8)
{
uint8x16_t a = vreinterpretq_u8_u32(k);
a = vaeseq_u8(a, vdupq_n_u8(0));
uint8x16_t dest = {
a[0x4], a[0x1], a[0xE], a[0xB],
a[0x1], a[0xE], a[0xB], a[0x4],
a[0xC], a[0x9], a[0x6], a[0x3],
a[0x9], a[0x6], a[0x3], a[0xC],
};
return vreinterpretq_u32_u8(veorq_u8(dest, vreinterpretq_u8_u32(uint32x4_t{0, imm8, 0, imm8})));
}
QUALIFIERS std::array<uint32x4_t,11> aesni_keygen(uint32x4_t k) {
std::array<uint32x4_t,11> rk;
uint32x4_t tmp;
rk[0] = k;
tmp = aesni_keygen_assist(k, 0x1);
k = aesni_keygen_assist(k, tmp);
rk[1] = k;
tmp = aesni_keygen_assist(k, 0x2);
k = aesni_keygen_assist(k, tmp);
rk[2] = k;
tmp = aesni_keygen_assist(k, 0x4);
k = aesni_keygen_assist(k, tmp);
rk[3] = k;
tmp = aesni_keygen_assist(k, 0x8);
k = aesni_keygen_assist(k, tmp);
rk[4] = k;
tmp = aesni_keygen_assist(k, 0x10);
k = aesni_keygen_assist(k, tmp);
rk[5] = k;
tmp = aesni_keygen_assist(k, 0x20);
k = aesni_keygen_assist(k, tmp);
rk[6] = k;
tmp = aesni_keygen_assist(k, 0x40);
k = aesni_keygen_assist(k, tmp);
rk[7] = k;
tmp = aesni_keygen_assist(k, 0x80);
k = aesni_keygen_assist(k, tmp);
rk[8] = k;
tmp = aesni_keygen_assist(k, 0x1b);
k = aesni_keygen_assist(k, tmp);
rk[9] = k;
tmp = aesni_keygen_assist(k, 0x36);
k = aesni_keygen_assist(k, tmp);
rk[10] = k;
return rk;
}
QUALIFIERS const std::array<uint32x4_t,11> & aesni_roundkeys(const uint32x4_t & k128) {
alignas(16) std::array<uint32,4> a;
vst1q_u32((uint32_t*) a.data(), k128);
static AlignedMap<std::array<uint32,4>, std::array<uint32x4_t,11>> roundkeys;
if(roundkeys.find(a) == roundkeys.end()) {
auto rk = aesni_keygen(k128);
roundkeys[a] = rk;
}
return roundkeys[a];
}
QUALIFIERS uint32x4_t aesni1xm128i(const uint32x4_t & in, const uint32x4_t & k0) {
auto k = aesni_roundkeys(k0);
uint8x16_t x = vaesmcq_u8(vaeseq_u8(vreinterpretq_u8_u32(in), vreinterpretq_u8_u32(k[0])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[1])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[2])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[3])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[4])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[5])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[6])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[7])));
x = vaesmcq_u8(vaeseq_u8(x, vreinterpretq_u8_u32(k[8])));
x = vaeseq_u8(x, vreinterpretq_u8_u32(k[9]));
x = veorq_u8(x, vreinterpretq_u8_u32(k[10]));
return vreinterpretq_u32_u8(x);
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float & rnd1, float & rnd2, float & rnd3, float & rnd4)
{
// pack input and call AES
uint32x4_t c128 = {ctr0, ctr1, ctr2, ctr3};
uint32x4_t k128 = {key0, key1, key2, key3};
c128 = aesni1xm128i(c128, k128);
// convert uint32 to float
float32x4_t rs = vcvtq_f32_u32(c128);
// calculate rs * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rs = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rs);
// store result
rnd1 = vgetq_lane_f32(rs, 0);
rnd2 = vgetq_lane_f32(rs, 1);
rnd3 = vgetq_lane_f32(rs, 2);
rnd4 = vgetq_lane_f32(rs, 3);
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32 ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
double & rnd1, double & rnd2)
{
// pack input and call AES
uint32x4_t c128 = {ctr0, ctr1, ctr2, ctr3};
uint32x4_t k128 = {key0, key1, key2, key3};
c128 = aesni1xm128i(c128, k128);
// convert 32 to 64 bit and put 0th and 2nd element into x, 1st and 3rd element into y
uint32x4_t x = vandq_u32(c128, uint32x4_t{0xffffffff, 0, 0xffffffff, 0});
uint32x4_t y = vandq_u32(c128, uint32x4_t{0, 0xffffffff, 0, 0xffffffff});
y = vextq_u32(y, vdupq_n_u32(0), 1);
// calculate z = x ^ y << (53 - 32))
uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32);
z = veorq_u64(vreinterpretq_u64_u32(x), z);
// convert uint64 to double
float64x2_t rs = vcvtq_f64_u64(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE/2.0), vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs);
// store result
rnd1 = vgetq_lane_f64(rs, 0);
rnd2 = vgetq_lane_f64(rs, 1);
}
template<bool high>
QUALIFIERS float64x2_t _uniform_double_hq(uint32x4_t x, uint32x4_t y)
{
// convert 32 to 64 bit
if (high)
{
x = vzip2q_u32(x, vdupq_n_u32(0));
y = vzip2q_u32(y, vdupq_n_u32(0));
}
else
{
x = vzip1q_u32(x, vdupq_n_u32(0));
y = vzip1q_u32(y, vdupq_n_u32(0));
}
// calculate z = x ^ y << (53 - 32))
uint64x2_t z = vshlq_n_u64(vreinterpretq_u64_u32(y), 53 - 32);
z = veorq_u64(vreinterpretq_u64_u32(x), z);
// convert uint64 to double
float64x2_t rs = vcvtq_f64_u64(z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = vfmaq_f64(vdupq_n_f64(TWOPOW53_INV_DOUBLE/2.0), vdupq_n_f64(TWOPOW53_INV_DOUBLE), rs);
return rs;
}
QUALIFIERS void transpose128(uint32x4_t & R0, uint32x4_t & R1, uint32x4_t & R2, uint32x4_t & R3)
{
uint32x4_t T0, T1, T2, T3;
T0 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(R0), vreinterpretq_u64_u32(R2)));
T1 = vreinterpretq_u32_u64(vtrn1q_u64(vreinterpretq_u64_u32(R1), vreinterpretq_u64_u32(R3)));
T2 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(R0), vreinterpretq_u64_u32(R2)));
T3 = vreinterpretq_u32_u64(vtrn2q_u64(vreinterpretq_u64_u32(R1), vreinterpretq_u64_u32(R3)));
R0 = vtrn1q_u32(T0, T1);
R1 = vtrn2q_u32(T0, T1);
R2 = vtrn1q_u32(T2, T3);
R3 = vtrn2q_u32(T2, T3);
}
QUALIFIERS void aesni_float4(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t k128 = {key0, key1, key2, key3};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k128);
}
transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
// convert uint32 to float
rnd1 = vcvtq_f32_u32(ctr[0]);
rnd2 = vcvtq_f32_u32(ctr[1]);
rnd3 = vcvtq_f32_u32(ctr[2]);
rnd4 = vcvtq_f32_u32(ctr[3]);
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd1);
rnd2 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd2);
rnd3 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd3);
rnd4 = vfmaq_f32(vdupq_n_f32(TWOPOW32_INV_FLOAT/2.0), vdupq_n_f32(TWOPOW32_INV_FLOAT), rnd4);
}
QUALIFIERS void aesni_double2(uint32x4_t ctr0, uint32x4_t ctr1, uint32x4_t ctr2, uint32x4_t ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t k128 = {key0, key1, key2, key3};
uint32x4_t ctr[4] = {ctr0, ctr1, ctr2, ctr3};
transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
for (int i = 0; i < 4; ++i)
{
ctr[i] = aesni1xm128i(ctr[i], k128);
}
transpose128(ctr[0], ctr[1], ctr[2], ctr[3]);
rnd1lo = _uniform_double_hq<false>(ctr[0], ctr[1]);
rnd1hi = _uniform_double_hq<true>(ctr[0], ctr[1]);
rnd2lo = _uniform_double_hq<false>(ctr[2], ctr[3]);
rnd2hi = _uniform_double_hq<true>(ctr[2], ctr[3]);
}
QUALIFIERS void aesni_float4(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
#ifndef _MSC_VER
QUALIFIERS void aesni_float4(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float32x4_t & rnd1, float32x4_t & rnd2, float32x4_t & rnd3, float32x4_t & rnd4)
{
aesni_float4(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
#endif
QUALIFIERS void aesni_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float64x2_t & rnd1lo, float64x2_t & rnd1hi, float64x2_t & rnd2lo, float64x2_t & rnd2hi)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void aesni_double2(uint32 ctr0, uint32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float64x2_t & rnd1, float64x2_t & rnd2)
{
uint32x4_t ctr0v = vdupq_n_u32(ctr0);
uint32x4_t ctr2v = vdupq_n_u32(ctr2);
uint32x4_t ctr3v = vdupq_n_u32(ctr3);
float64x2_t ignore;
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore);
}
#ifndef _MSC_VER
QUALIFIERS void aesni_double2(uint32 ctr0, int32x4_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
float64x2_t & rnd1, float64x2_t & rnd2)
{
aesni_double2(ctr0, vreinterpretq_u32_s32(ctr1), ctr2, ctr3, key0, key1, key2, key3, rnd1, rnd2);
}
#endif
#endif
#if defined(__ARM_FEATURE_SVE)
template<bool high>
QUALIFIERS svfloat64_t _uniform_double_hq(svuint32_t x, svuint32_t y)
{
// convert 32 to 64 bit
if (high)
{
x = svzip2_u32(x, svdup_u32(0));
y = svzip2_u32(y, svdup_u32(0));
}
else
{
x = svzip1_u32(x, svdup_u32(0));
y = svzip1_u32(y, svdup_u32(0));
}
// calculate z = x ^ y << (53 - 32))
svuint64_t z = svlsl_n_u64_x(svptrue_b64(), svreinterpret_u64_u32(y), 53 - 32);
z = sveor_u64_x(svptrue_b64(), svreinterpret_u64_u32(x), z);
// convert uint64 to double
svfloat64_t rs = svcvt_f64_u64_x(svptrue_b64(), z);
// calculate rs * TWOPOW53_INV_DOUBLE + (TWOPOW53_INV_DOUBLE/2.0)
rs = svmad_f64_x(svptrue_b64(), rs, svdup_f64(TWOPOW53_INV_DOUBLE), svdup_f64(TWOPOW53_INV_DOUBLE/2.0));
return rs;
}
QUALIFIERS void transpose128(svuint32x4_t & R)
{
svuint32_t T0, T1, T2, T3;
T0 = svreinterpret_u32_u64(svtrn1_u64(svreinterpret_u64_u32(svget4_u32(R, 0)), svreinterpret_u64_u32(svget4_u32(R, 2))));
T1 = svreinterpret_u32_u64(svtrn1_u64(svreinterpret_u64_u32(svget4_u32(R, 1)), svreinterpret_u64_u32(svget4_u32(R, 3))));
T2 = svreinterpret_u32_u64(svtrn2_u64(svreinterpret_u64_u32(svget4_u32(R, 0)), svreinterpret_u64_u32(svget4_u32(R, 2))));
T3 = svreinterpret_u32_u64(svtrn2_u64(svreinterpret_u64_u32(svget4_u32(R, 1)), svreinterpret_u64_u32(svget4_u32(R, 3))));
R = svset4_u32(R, 0, svtrn1_u32(T0, T1));
R = svset4_u32(R, 1, svtrn2_u32(T0, T1));
R = svset4_u32(R, 2, svtrn1_u32(T2, T3));
R = svset4_u32(R, 3, svtrn2_u32(T2, T3));
}
QUALIFIERS svuint32_t aesni1xm128i(const svuint32_t & in, const uint32x4_t & k0)
{
#ifdef __ARM_FEATURE_SVE2_AES
auto k = aesni_roundkeys(k0);
svuint8_t x = svaesmc_u8(svaese_u8(svreinterpret_u8_u32(in), svdup_neonq_u8(vreinterpretq_u8_u32(k[0]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[1]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[2]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[3]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[4]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[5]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[6]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[7]))));
x = svaesmc_u8(svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[8]))));
x = svaese_u8(x, svdup_neonq_u8(vreinterpretq_u8_u32(k[9])));
x = sveor_u8_x(svptrue_b8(), x, svdup_neonq_u8(vreinterpretq_u8_u32(k[10])));
return svreinterpret_u32_u8(x);
#else
svuint32_t x;
for (int i = 0; i < svcntw(); i += 4)
{
svbool_t pred = svbic_z(svptrue_b32(), svwhilelt_b32_u32(0, i+4), svwhilelt_b32_u32(0, i));
uint32x4_t a = aesni1xm128i(svget_neonq_u32(svcompact_u32(pred, in)), k0);
x = svsel_u32(pred, svdup_neonq_u32(a), x);
}
return x;
#endif
}
QUALIFIERS void aesni_float4(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
uint32x4_t k128 = {key0, key1, key2, key3};
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
transpose128(ctr);
ctr = svset4_u32(ctr, 0, aesni1xm128i(svget4_u32(ctr, 0), k128));
ctr = svset4_u32(ctr, 1, aesni1xm128i(svget4_u32(ctr, 1), k128));
ctr = svset4_u32(ctr, 2, aesni1xm128i(svget4_u32(ctr, 2), k128));
ctr = svset4_u32(ctr, 3, aesni1xm128i(svget4_u32(ctr, 3), k128));
transpose128(ctr);
// convert uint32 to float
rnd1 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 0));
rnd2 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 1));
rnd3 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 2));
rnd4 = svcvt_f32_u32_x(svptrue_b32(), svget4_u32(ctr, 3));
// calculate rnd * TWOPOW32_INV_FLOAT + (TWOPOW32_INV_FLOAT/2.0f)
rnd1 = svmad_f32_x(svptrue_b32(), rnd1, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd2 = svmad_f32_x(svptrue_b32(), rnd2, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd3 = svmad_f32_x(svptrue_b32(), rnd3, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
rnd4 = svmad_f32_x(svptrue_b32(), rnd4, svdup_f32(TWOPOW32_INV_FLOAT), svdup_f32(TWOPOW32_INV_FLOAT/2.0));
}
QUALIFIERS void aesni_double2(svuint32_t ctr0, svuint32_t ctr1, svuint32_t ctr2, svuint32_t ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
uint32x4_t k128 = {key0, key1, key2, key3};
svuint32x4_t ctr = svcreate4_u32(ctr0, ctr1, ctr2, ctr3);
transpose128(ctr);
ctr = svset4_u32(ctr, 0, aesni1xm128i(svget4_u32(ctr, 0), k128));
ctr = svset4_u32(ctr, 1, aesni1xm128i(svget4_u32(ctr, 1), k128));
ctr = svset4_u32(ctr, 2, aesni1xm128i(svget4_u32(ctr, 2), k128));
ctr = svset4_u32(ctr, 3, aesni1xm128i(svget4_u32(ctr, 3), k128));
transpose128(ctr);
rnd1lo = _uniform_double_hq<false>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd1hi = _uniform_double_hq<true>(svget4_u32(ctr, 0), svget4_u32(ctr, 1));
rnd2lo = _uniform_double_hq<false>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
rnd2hi = _uniform_double_hq<true>(svget4_u32(ctr, 2), svget4_u32(ctr, 3));
}
QUALIFIERS void aesni_float4(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
aesni_float4(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_float4(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat32_st & rnd1, svfloat32_st & rnd2, svfloat32_st & rnd3, svfloat32_st & rnd4)
{
aesni_float4(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, key2, key3, rnd1, rnd2, rnd3, rnd4);
}
QUALIFIERS void aesni_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat64_st & rnd1lo, svfloat64_st & rnd1hi, svfloat64_st & rnd2lo, svfloat64_st & rnd2hi)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1lo, rnd1hi, rnd2lo, rnd2hi);
}
QUALIFIERS void aesni_double2(uint32 ctr0, svuint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
svuint32_t ctr0v = svdup_u32(ctr0);
svuint32_t ctr2v = svdup_u32(ctr2);
svuint32_t ctr3v = svdup_u32(ctr3);
svfloat64_st ignore;
aesni_double2(ctr0v, ctr1, ctr2v, ctr3v, key0, key1, key2, key3, rnd1, ignore, rnd2, ignore);
}
QUALIFIERS void aesni_double2(uint32 ctr0, svint32_t ctr1, uint32 ctr2, uint32 ctr3,
uint32 key0, uint32 key1, uint32 key2, uint32 key3,
svfloat64_st & rnd1, svfloat64_st & rnd2)
{
aesni_double2(ctr0, svreinterpret_u32_s32(ctr1), ctr2, ctr3, key0, key1, key2, key3, rnd1, rnd2);
}
#endif
#undef QUALIFIERS
#undef TWOPOW53_INV_DOUBLE
#undef TWOPOW32_INV_FLOAT
/*
Copyright 2021-2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#if defined(_MSC_VER) #if defined(_MSC_VER)
#define __ARM_NEON #define __ARM_NEON
#endif #endif
#include <cstddef>
#ifdef __ARM_NEON #ifdef __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
......
/*
Copyright 2023, Markus Holzer.
Copyright 2023, Michael Kuron.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#pragma once
#define POS_INFINITY __int_as_float(0x7f800000)
#define INFINITY POS_INFINITY
#define NEG_INFINITY __int_as_float(0xff800000)
#ifdef __HIPCC_RTC__
typedef __hip_uint8_t uint8_t;
typedef __hip_int8_t int8_t;
typedef __hip_uint16_t uint16_t;
typedef __hip_int16_t int16_t;
#endif
/*
Copyright 2023, Markus Holzer.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions, and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions, and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/// Half precision support. Experimental. Use carefully. /// Half precision support. Experimental. Use carefully.
/// ///
/// This feature is experimental, since it strictly depends on the underlying architecture and compiler support. /// This feature is experimental, since it strictly depends on the underlying architecture and compiler support.
......