Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
Show changes
Showing
with 2909 additions and 108 deletions
from typing import cast
from ..kernelcreation import KernelCreationContext
from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsConditional, PsDeclaration, PsAssignment
from ..ast.expressions import (
PsExpression,
PsSymbolExpr,
PsConstantExpr,
PsLiteralExpr,
PsCall,
PsBufferAcc,
PsSubscript,
PsLookup,
PsUnOp,
PsBinOp,
PsArrayInitList,
)
from ..ast.util import determine_memory_object
from ...types import PsDereferencableType
from ..memory import PsSymbol
from ..functions import PsMathFunction
__all__ = ["HoistLoopInvariantDeclarations"]
class HoistContext:
def __init__(self) -> None:
self.hoisted_nodes: list[PsDeclaration] = []
self.assigned_symbols: set[PsSymbol] = set()
self.mutated_symbols: set[PsSymbol] = set()
self.invariant_symbols: set[PsSymbol] = set()
def _is_invariant(self, expr: PsExpression) -> bool:
def args_invariant(expr):
return all(
self._is_invariant(cast(PsExpression, arg)) for arg in expr.children
)
match expr:
case PsSymbolExpr(symbol):
return (symbol not in self.assigned_symbols) or (
symbol in self.invariant_symbols
)
case PsConstantExpr() | PsLiteralExpr():
return True
case PsCall(func):
return isinstance(func, PsMathFunction) and args_invariant(expr)
case PsSubscript() | PsLookup():
return determine_memory_object(expr)[1] and args_invariant(expr)
case PsBufferAcc(ptr, _):
# Regular pointer derefs are never invariant, since we cannot reason about aliasing
ptr_type = cast(PsDereferencableType, ptr.get_dtype())
return ptr_type.base_type.const and args_invariant(expr)
case PsUnOp() | PsBinOp() | PsArrayInitList():
return args_invariant(expr)
case _:
return False
class HoistLoopInvariantDeclarations:
"""Hoist loop-invariant declarations out of the loop nest.
This transformation moves loop-invariant symbol declarations outside of the loop
nest to prevent their repeated execution within the loops.
If this transformation results in the complete elimination of a loop body, the respective loop
is removed.
`HoistLoopInvariantDeclarations` assumes that symbols are canonical;
in particular, each symbol may have at most one declaration.
To ensure this, a `CanonicalizeSymbols` pass should be run before `HoistLoopInvariantDeclarations`.
`HoistLoopInvariantDeclarations` assumes that all `PsMathFunction` s are pure (have no side effects),
but makes no such assumption about instances of `CFunction`.
"""
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node)
def visit(self, node: PsAstNode) -> PsAstNode:
"""Search the outermost loop and start the hoisting cascade there."""
match node:
case PsLoop():
temp_block = PsBlock([node])
temp_block = cast(PsBlock, self.visit(temp_block))
if temp_block.statements == [node]:
return node
else:
return temp_block
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
if isinstance(stmt, PsLoop):
loop = stmt
hc = self._hoist(loop)
statements_new += hc.hoisted_nodes
if loop.body.statements:
statements_new.append(loop)
else:
self.visit(stmt)
statements_new.append(stmt)
node.statements = statements_new
return node
case PsConditional(_, then, els):
self.visit(then)
if els is not None:
self.visit(els)
return node
case _:
# if the node is none of the above, end the search
return node
# end match
def _hoist(self, loop: PsLoop) -> HoistContext:
"""Hoist invariant declarations out of the given loop."""
hc = HoistContext()
hc.assigned_symbols.add(loop.counter.symbol)
hc.mutated_symbols.add(loop.counter.symbol)
self._prepare_hoist(loop.body, hc)
self._hoist_from_block(loop.body, hc)
return hc
def _prepare_hoist(self, node: PsAstNode, hc: HoistContext):
"""Collect all symbols assigned within a loop body,
and recursively apply loop-invariant code motion to any nested loops."""
match node:
case PsExpression():
return
case PsDeclaration(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
case PsAssignment(PsSymbolExpr(lhs_symb), _):
hc.assigned_symbols.add(lhs_symb)
hc.mutated_symbols.add(lhs_symb)
case PsAssignment(_, _):
return
case PsBlock(statements):
statements_new: list[PsAstNode] = []
for stmt in statements:
if isinstance(stmt, PsLoop):
loop = stmt
nested_hc = self._hoist(loop)
hc.assigned_symbols |= nested_hc.assigned_symbols
hc.mutated_symbols |= nested_hc.mutated_symbols
statements_new += nested_hc.hoisted_nodes
if loop.body.statements:
statements_new.append(loop)
else:
self._prepare_hoist(stmt, hc)
statements_new.append(stmt)
node.statements = statements_new
case _:
for c in node.children:
self._prepare_hoist(c, hc)
def _hoist_from_block(self, block: PsBlock, hc: HoistContext):
"""Hoist invariant declarations from the given block, and any directly nested blocks.
This method processes only statements of the given block, and any blocks directly nested inside it.
It does not descend into control structures like conditionals and nested loops.
"""
statements_new: list[PsAstNode] = []
for node in block.statements:
if isinstance(node, PsDeclaration):
lhs_symb = cast(PsSymbolExpr, node.lhs).symbol
if lhs_symb not in hc.mutated_symbols and hc._is_invariant(node.rhs):
hc.hoisted_nodes.append(node)
hc.invariant_symbols.add(node.declared_symbol)
else:
statements_new.append(node)
else:
if isinstance(node, PsBlock):
self._hoist_from_block(node, hc)
statements_new.append(node)
block.statements = statements_new
import numpy as np
from enum import Enum, auto
from typing import cast, Callable, overload
from ...types import PsVectorType, PsScalarType
from ..kernelcreation import KernelCreationContext
from ..constants import PsConstant
from ..ast import PsAstNode
from ..ast.structural import PsLoop, PsBlock, PsDeclaration
from ..ast.expressions import PsExpression, PsTernary, PsGt
from ..ast.vector import PsVecBroadcast
from ..ast.analysis import collect_undefined_symbols
from .ast_vectorizer import VectorizationAxis, VectorizationContext, AstVectorizer
from .rewrite import substitute_symbols
class LoopVectorizer:
"""Vectorize loops.
The loop vectorizer provides methods to vectorize single loops inside an AST
using a given number of vector lanes.
During vectorization, the loop body is transformed using the `AstVectorizer`,
The loop's limits are adapted according to the number of vector lanes,
and a block treating trailing iterations is optionally added.
Args:
ctx: The current kernel creation context
lanes: The number of vector lanes to use
trailing_iters: Mode for the treatment of trailing iterations
"""
class TrailingItersTreatment(Enum):
"""How to treat trailing iterations during loop vectorization."""
SCALAR_LOOP = auto()
"""Cover trailing iterations using a scalar remainder loop."""
MASKED_BLOCK = auto()
"""Cover trailing iterations using a masked block."""
NONE = auto()
"""Assume that the loop iteration count is a multiple of the number of lanes
and do not cover any trailing iterations"""
def __init__(
self,
ctx: KernelCreationContext,
lanes: int,
trailing_iters: TrailingItersTreatment = TrailingItersTreatment.SCALAR_LOOP,
):
self._ctx = ctx
self._lanes = lanes
self._trailing_iters = trailing_iters
from ..kernelcreation import Typifier
from .eliminate_constants import EliminateConstants
self._typify = Typifier(ctx)
self._vectorize_ast = AstVectorizer(ctx)
self._fold = EliminateConstants(ctx)
@overload
def vectorize_select_loops(
self, node: PsBlock, predicate: Callable[[PsLoop], bool]
) -> PsBlock: ...
@overload
def vectorize_select_loops(
self, node: PsLoop, predicate: Callable[[PsLoop], bool]
) -> PsLoop | PsBlock: ...
@overload
def vectorize_select_loops(
self, node: PsAstNode, predicate: Callable[[PsLoop], bool]
) -> PsAstNode: ...
def vectorize_select_loops(
self, node: PsAstNode, predicate: Callable[[PsLoop], bool]
) -> PsAstNode:
"""Select and vectorize loops from a syntax tree according to a predicate.
Finds each loop inside a subtree and evaluates ``predicate`` on them.
If ``predicate(loop)`` evaluates to `True`, the loop is vectorized.
Loops nested inside a vectorized loop will not be processed.
Args:
node: Root of the subtree to process
predicate: Callback telling the vectorizer which loops to vectorize
"""
match node:
case PsLoop() if predicate(node):
return self.vectorize_loop(node)
case PsExpression():
return node
case _:
node.children = [
self.vectorize_select_loops(c, predicate) for c in node.children
]
return node
def __call__(self, loop: PsLoop) -> PsLoop | PsBlock:
return self.vectorize_loop(loop)
def vectorize_loop(self, loop: PsLoop) -> PsLoop | PsBlock:
"""Vectorize the given loop."""
scalar_ctr_expr = loop.counter
scalar_ctr = scalar_ctr_expr.symbol
# Prepare vector counter
vector_ctr_dtype = PsVectorType(
cast(PsScalarType, scalar_ctr_expr.get_dtype()), self._lanes
)
vector_ctr = self._ctx.duplicate_symbol(scalar_ctr, vector_ctr_dtype)
step_multiplier_val = np.array(
range(self._lanes), dtype=scalar_ctr_expr.get_dtype().numpy_dtype
)
step_multiplier = PsExpression.make(
PsConstant(step_multiplier_val, vector_ctr_dtype)
)
vector_counter_decl = self._type_fold(
PsDeclaration(
PsExpression.make(vector_ctr),
PsVecBroadcast(self._lanes, scalar_ctr_expr)
+ step_multiplier * PsVecBroadcast(self._lanes, loop.step),
)
)
# Prepare axis
axis = VectorizationAxis(scalar_ctr, vector_ctr, step=loop.step)
# Prepare vectorization context
vc = VectorizationContext(self._ctx, self._lanes, axis)
# Generate vectorized loop body
simd_body = self._vectorize_ast(loop.body, vc)
if vector_ctr in collect_undefined_symbols(simd_body):
simd_body.statements.insert(0, vector_counter_decl)
# Build new loop limits
simd_start = loop.start.clone()
simd_step = self._ctx.get_new_symbol(
f"__{scalar_ctr.name}_simd_step", scalar_ctr.get_dtype()
)
simd_step_decl = self._type_fold(
PsDeclaration(
PsExpression.make(simd_step),
loop.step.clone() * PsExpression.make(PsConstant(self._lanes)),
)
)
# Each iteration must satisfy `ctr + step * (lanes - 1) < stop`
simd_stop = self._ctx.get_new_symbol(
f"__{scalar_ctr.name}_simd_stop", scalar_ctr.get_dtype()
)
simd_stop_decl = self._type_fold(
PsDeclaration(
PsExpression.make(simd_stop),
loop.stop.clone()
- (
PsExpression.make(PsConstant(self._lanes))
- PsExpression.make(PsConstant(1))
)
* loop.step.clone(),
)
)
simd_loop = PsLoop(
PsExpression.make(scalar_ctr),
simd_start,
PsExpression.make(simd_stop),
PsExpression.make(simd_step),
simd_body,
)
# Treat trailing iterations
match self._trailing_iters:
case LoopVectorizer.TrailingItersTreatment.SCALAR_LOOP:
trailing_start = self._ctx.get_new_symbol(
f"__{scalar_ctr.name}_trailing_start", scalar_ctr.get_dtype()
)
trailing_start_decl = self._type_fold(
PsDeclaration(
PsExpression.make(trailing_start),
PsTernary(
# If at least one vectorized iteration took place...
PsGt(
PsExpression.make(simd_stop),
simd_start.clone(),
),
# start from the smallest non-valid multiple of simd_step, offset from simd_start
(
(
PsExpression.make(simd_stop)
- simd_start.clone()
- PsExpression.make(PsConstant(1))
)
/ PsExpression.make(simd_step)
+ PsExpression.make(PsConstant(1))
)
* PsExpression.make(simd_step)
+ simd_start.clone(),
# otherwise start at zero
simd_start.clone(),
),
)
)
trailing_ctr = self._ctx.duplicate_symbol(scalar_ctr)
trailing_loop_body = substitute_symbols(
loop.body.clone(), {scalar_ctr: PsExpression.make(trailing_ctr)}
)
trailing_loop = PsLoop(
PsExpression.make(trailing_ctr),
PsExpression.make(trailing_start),
loop.stop.clone(),
loop.step.clone(),
trailing_loop_body,
)
return PsBlock(
[
simd_stop_decl,
simd_step_decl,
simd_loop,
trailing_start_decl,
trailing_loop,
]
)
case LoopVectorizer.TrailingItersTreatment.MASKED_BLOCK:
raise NotImplementedError()
case LoopVectorizer.TrailingItersTreatment.NONE:
return PsBlock(
[
simd_stop_decl,
simd_step_decl,
simd_loop,
]
)
@overload
def _type_fold(self, node: PsExpression) -> PsExpression:
pass
@overload
def _type_fold(self, node: PsDeclaration) -> PsDeclaration:
pass
@overload
def _type_fold(self, node: PsAstNode) -> PsAstNode:
pass
def _type_fold(self, node: PsAstNode) -> PsAstNode:
return self._fold(self._typify(node))
from __future__ import annotations
from typing import cast
from functools import reduce
import operator
from ..kernelcreation import KernelCreationContext, Typifier
from ..constants import PsConstant
from ..memory import PsSymbol, PsBuffer, BufferBasePtr
from ..ast.structural import PsAstNode
from ..ast.expressions import (
PsBufferAcc,
PsLookup,
PsExpression,
PsMemAcc,
PsAddressOf,
PsCast,
PsSymbolExpr,
)
from ...types import PsStructType, PsPointerType, PsUnsignedIntegerType
class LowerToC:
"""Lower high-level IR constructs to C language concepts.
This pass will replace a number of IR constructs that have no direct counterpart in the C language
to lower-level AST nodes. These include:
- *Linearization of Buffer Accesses:* `PsBufferAcc` buffer accesses are linearized according to
their buffers' stride information and replaced by `PsMemAcc`.
- *Erasure of Anonymous Structs:*
For buffers whose element type is an anonymous struct, the struct type is erased from the base pointer,
making it a pointer to uint8_t.
Member lookups on accesses into these buffers are then transformed using type casts.
"""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._substitutions: dict[PsSymbol, PsSymbol] = dict()
self._typify = Typifier(ctx)
from .eliminate_constants import EliminateConstants
self._fold = EliminateConstants(self._ctx)
def __call__(self, node: PsAstNode) -> PsAstNode:
self._substitutions = dict()
node = self.visit(node)
for old, new in self._substitutions.items():
self._ctx.replace_symbol(old, new)
return node
def visit(self, node: PsAstNode) -> PsAstNode:
match node:
case PsBufferAcc(bptr, indices):
# Linearize
buf = node.buffer
# Typifier allows different data types in each index
def maybe_cast(i: PsExpression):
if i.get_dtype() != buf.index_type:
return PsCast(buf.index_type, i)
else:
return i
summands: list[PsExpression] = [
maybe_cast(cast(PsExpression, self.visit(idx).clone()))
* PsExpression.make(stride)
for idx, stride in zip(indices, buf.strides, strict=True)
]
linearized_idx: PsExpression = (
summands[0]
if len(summands) == 1
else reduce(operator.add, summands)
)
mem_acc = PsMemAcc(bptr.clone(), linearized_idx)
return self._fold(
self._typify.typify_expression(
mem_acc, target_type=buf.element_type
)[0]
)
case PsLookup(aggr, member_name) if isinstance(
aggr, PsBufferAcc
) and isinstance(
aggr.buffer.element_type, PsStructType
) and aggr.buffer.element_type.anonymous:
# Need to lower this buffer-lookup
linearized_acc = self.visit(aggr)
return self._lower_anon_lookup(
cast(PsMemAcc, linearized_acc), aggr.buffer, member_name
)
case _:
node.children = [self.visit(c) for c in node.children]
return node
def _lower_anon_lookup(
self, aggr: PsMemAcc, buf: PsBuffer, member_name: str
) -> PsExpression:
struct_type = cast(PsStructType, buf.element_type)
struct_size = struct_type.itemsize
assert isinstance(aggr.pointer, PsSymbolExpr)
bp = aggr.pointer.symbol
bp_type = bp.get_dtype()
assert isinstance(bp_type, PsPointerType)
# Need to keep track of base pointers already seen, since symbols must be unique
if bp not in self._substitutions:
erased_type = PsPointerType(
PsUnsignedIntegerType(8, const=bp_type.base_type.const),
const=bp_type.const,
restrict=bp_type.restrict,
)
type_erased_bp = PsSymbol(bp.name, erased_type)
type_erased_bp.add_property(BufferBasePtr(buf))
self._substitutions[bp] = type_erased_bp
else:
type_erased_bp = self._substitutions[bp]
base_index = aggr.offset * PsExpression.make(
PsConstant(struct_size, self._ctx.index_dtype)
)
member = struct_type.find_member(member_name)
assert member is not None
np_struct = struct_type.numpy_dtype
assert np_struct is not None
assert np_struct.fields is not None
member_offset = np_struct.fields[member_name][1]
byte_index = base_index + PsExpression.make(
PsConstant(member_offset, self._ctx.index_dtype)
)
type_erased_access = PsMemAcc(PsExpression.make(type_erased_bp), byte_index)
deref = PsMemAcc(
PsCast(PsPointerType(member.dtype), PsAddressOf(type_erased_access)),
PsExpression.make(PsConstant(0)),
)
deref = self._typify(deref)
return deref
from typing import Sequence
from ..kernelcreation import KernelCreationContext, Typifier
from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import PsExpression, PsConstantExpr, PsGe, PsLt
from ..constants import PsConstant
from .canonical_clone import CanonicalClone, CloneContext
from .eliminate_constants import EliminateConstants
class ReshapeLoops:
"""Various transformations for reshaping loop nests."""
def __init__(self, ctx: KernelCreationContext) -> None:
self._ctx = ctx
self._typify = Typifier(ctx)
self._factory = AstFactory(ctx)
self._canon_clone = CanonicalClone(ctx)
self._elim_constants = EliminateConstants(ctx)
def peel_loop_front(
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
) -> tuple[Sequence[PsBlock], PsLoop]:
"""Peel off iterations from the front of a loop.
Removes ``num_iterations`` from the front of the given loop and returns them as a sequence of
independent blocks.
Args:
loop: The loop node from which to peel iterations
num_iterations: The number of iterations to peel off
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
be executed, and omit their enclosing conditional.
Returns:
Tuple containing the peeled-off iterations as a sequence of blocks,
and the remaining loop.
"""
peeled_iters: list[PsBlock] = []
for i in range(num_iterations):
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
peeled_idx = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step)
)
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
if omit_range_check:
peeled_block.statements = [counter_decl] + peeled_block.statements
else:
iter_condition = PsLt(peeled_ctr, loop.stop)
peeled_block.statements = [
counter_decl,
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
]
peeled_iters.append(peeled_block)
loop.start = self._elim_constants(
self._typify(
loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step
)
)
return peeled_iters, loop
def peel_loop_back(
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
) -> tuple[PsLoop, Sequence[PsBlock]]:
"""Peel off iterations from the back of a loop.
Removes ``num_iterations`` from the back of the given loop and returns them as a sequence of
independent blocks.
Args:
loop: The loop node from which to peel iterations
num_iterations: The number of iterations to peel off
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
be executed, and omit their enclosing conditional.
Returns:
Tuple containing the modified loop and the peeled-off iterations (sequence of blocks).
"""
if not (
isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
):
raise NotImplementedError(
"Peeling iterations from the back of loops is only implemented"
"for loops with unit step. Implementation is deferred until"
"loop range canonicalization is available (also needed for the"
"vectorizer)."
)
peeled_iters: list[PsBlock] = []
for i in range(num_iterations)[::-1]:
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
peeled_idx = self._typify(loop.stop - PsExpression.make(PsConstant(i + 1)))
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
if omit_range_check:
peeled_block.statements = [counter_decl] + peeled_block.statements
else:
iter_condition = PsGe(peeled_ctr, loop.start)
peeled_block.statements = [
counter_decl,
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
]
peeled_iters.append(peeled_block)
loop.stop = self._elim_constants(
self._typify(loop.stop - PsExpression.make(PsConstant(num_iterations)))
)
return loop, peeled_iters
def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock]:
"""Cut a loop at the given cutting points.
Cut the given loop at the iterations specified by the given cutting points,
producing ``n`` new subtrees representing the iterations
``(loop.start:cutting_points[0]), (cutting_points[0]:cutting_points[1]), ..., (cutting_points[-1]:loop.stop)``.
Resulting subtrees representing zero iterations are dropped; subtrees representing exactly one iteration are
returned without the trivial loop structure.
Currently, `cut_loop` performs no checks to ensure that the given cutting points are in fact inside
the loop's iteration range.
Returns:
Sequence of ``n`` subtrees representing the respective iteration ranges
"""
if not (
isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
):
raise NotImplementedError(
"Loop cutting for loops with step != 1 is not implemented"
)
result: list[PsLoop | PsBlock] = []
new_start = loop.start
cutting_points = [self._factory.parse_index(idx) for idx in cutting_points] + [
loop.stop
]
for new_end in cutting_points:
if new_end.structurally_equal(new_start):
continue
num_iters = self._elim_constants(self._typify(new_end - new_start))
skip = False
if isinstance(num_iters, PsConstantExpr):
if num_iters.constant.value == 0:
skip = True
elif num_iters.constant.value == 1:
skip = True
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
local_counter = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
ctr_decl = PsDeclaration(
local_counter,
new_start,
)
cloned_body = self._canon_clone.visit(loop.body, cc)
cloned_body.statements = [ctr_decl] + cloned_body.statements
result.append(cloned_body)
if not skip:
loop_clone = self._canon_clone(loop)
loop_clone.start = new_start.clone()
loop_clone.stop = new_end.clone()
result.append(loop_clone)
new_start = new_end
return result
from typing import overload
from ..memory import PsSymbol
from ..ast import PsAstNode
from ..ast.structural import PsBlock
from ..ast.expressions import PsExpression, PsSymbolExpr
@overload
def substitute_symbols(node: PsBlock, subs: dict[PsSymbol, PsExpression]) -> PsBlock:
pass
@overload
def substitute_symbols(
node: PsExpression, subs: dict[PsSymbol, PsExpression]
) -> PsExpression:
pass
@overload
def substitute_symbols(
node: PsAstNode, subs: dict[PsSymbol, PsExpression]
) -> PsAstNode:
pass
def substitute_symbols(
node: PsAstNode, subs: dict[PsSymbol, PsExpression]
) -> PsAstNode:
"""Substitute expressions for symbols throughout a subtree."""
match node:
case PsSymbolExpr(symb) if symb in subs:
return subs[symb].clone()
case _:
node.children = [substitute_symbols(c, subs) for c in node.children]
return node
from ..platforms import Platform
from ..ast import PsAstNode
from ..ast.expressions import PsCall
from ..functions import PsMathFunction
class SelectFunctions:
"""Traverse the AST to replace all instances of `PsMathFunction` by their implementation
provided by the given `Platform`."""
def __init__(self, platform: Platform):
self._platform = platform
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node)
def visit(self, node: PsAstNode) -> PsAstNode:
node.children = [self.visit(c) for c in node.children]
if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction):
return self._platform.select_function(node)
else:
return node
from __future__ import annotations
from typing import cast
from ..kernelcreation import KernelCreationContext
from ..memory import PsSymbol
from ..ast.structural import PsAstNode, PsDeclaration, PsAssignment, PsStatement
from ..ast.expressions import PsExpression, PsCall, PsCast, PsLiteral
from ...types import PsCustomType, PsVectorType, constify, deconstify
from ..ast.expressions import PsSymbolExpr, PsConstantExpr, PsUnOp, PsBinOp
from ..ast.vector import PsVecMemAcc
from ..exceptions import MaterializationError
from ..functions import CFunction, PsMathFunction
from ..platforms import GenericVectorCpu
__all__ = ["SelectIntrinsics"]
class SelectionContext:
def __init__(self, ctx: KernelCreationContext, platform: GenericVectorCpu):
self._ctx = ctx
self._platform = platform
self._intrin_symbols: dict[PsSymbol, PsSymbol] = dict()
self._lane_mask: PsSymbol | None = None
def get_intrin_symbol(self, symb: PsSymbol) -> PsSymbol:
if symb not in self._intrin_symbols:
assert isinstance(symb.dtype, PsVectorType)
intrin_type = self._platform.type_intrinsic(deconstify(symb.dtype))
if symb.dtype.const:
intrin_type = constify(intrin_type)
replacement = self._ctx.duplicate_symbol(symb, intrin_type)
self._intrin_symbols[symb] = replacement
return self._intrin_symbols[symb]
class SelectIntrinsics:
"""Lower IR vector types to intrinsic vector types, and IR vector operations to intrinsic vector operations.
This transformation will replace all vectorial IR elements by conforming implementations using
compiler intrinsics for the given execution platform.
Args:
ctx: The current kernel creation context
platform: Platform object representing the target hardware, which provides the intrinsics
use_builtin_convertvector: If `True`, type conversions between SIMD
vectors use the compiler builtin ``__builtin_convertvector``
instead of instrinsics. It is supported by Clang >= 3.7, GCC >= 9.1,
and ICX. Not supported by ICC or MSVC. Activate if you need type
conversions not natively supported by your CPU, e.g. conversion from
64bit integer to double on an x86 AVX machine. Defaults to `False`.
Raises:
MaterializationError: If a vector type or operation cannot be represented by intrinsics
on the given platform
"""
def __init__(
self,
ctx: KernelCreationContext,
platform: GenericVectorCpu,
use_builtin_convertvector: bool = False,
):
self._ctx = ctx
self._platform = platform
self._use_builtin_convertvector = use_builtin_convertvector
def __call__(self, node: PsAstNode) -> PsAstNode:
return self.visit(node, SelectionContext(self._ctx, self._platform))
def visit(self, node: PsAstNode, sc: SelectionContext) -> PsAstNode:
match node:
case PsExpression() if isinstance(node.dtype, PsVectorType):
return self.visit_expr(node, sc)
case PsDeclaration(lhs, rhs) if isinstance(lhs.dtype, PsVectorType):
lhs_new = cast(PsSymbolExpr, self.visit_expr(lhs, sc))
rhs_new = self.visit_expr(rhs, sc)
return PsDeclaration(lhs_new, rhs_new)
case PsAssignment(lhs, rhs) if isinstance(lhs, PsVecMemAcc):
new_rhs = self.visit_expr(rhs, sc)
return PsStatement(self._platform.vector_store(lhs, new_rhs))
case _:
node.children = [self.visit(c, sc) for c in node.children]
return node
def visit_expr(self, expr: PsExpression, sc: SelectionContext) -> PsExpression:
if not isinstance(expr.dtype, PsVectorType):
return expr
match expr:
case PsSymbolExpr(symb):
return PsSymbolExpr(sc.get_intrin_symbol(symb))
case PsConstantExpr(c):
return self._platform.constant_intrinsic(c)
case PsCast(target_type, operand) if self._use_builtin_convertvector:
assert isinstance(target_type, PsVectorType)
op = self.visit_expr(operand, sc)
rtype = PsCustomType(
f"{target_type.scalar_type.c_string()} __attribute__((__vector_size__({target_type.itemsize})))"
)
target_type_literal = PsExpression.make(PsLiteral(rtype.name, rtype))
func = CFunction(
"__builtin_convertvector", (op.get_dtype(), rtype), target_type
)
intrinsic = func(op, target_type_literal)
intrinsic.dtype = func.return_type
return intrinsic
case PsUnOp(operand):
op = self.visit_expr(operand, sc)
return self._platform.op_intrinsic(expr, [op])
case PsBinOp(operand1, operand2):
op1 = self.visit_expr(operand1, sc)
op2 = self.visit_expr(operand2, sc)
return self._platform.op_intrinsic(expr, [op1, op2])
case PsVecMemAcc():
return self._platform.vector_load(expr)
case PsCall(function, args) if isinstance(function, PsMathFunction):
arguments = [self.visit_expr(a, sc) for a in args]
return self._platform.math_func_intrinsic(expr, arguments)
case _:
raise MaterializationError(
f"Unable to select intrinsic implementation for {expr}"
)
import sympy as sp from .sympyextensions.bit_masks import flag_cond as _flag_cond
# from pystencils.typing import get_type_of_expression
from warnings import warn
warn(
"Importing the `pystencils.bit_masks` module is deprecated. "
"Import `flag_cond` from `pystencils.sympyextensions` instead."
)
# noinspection PyPep8Naming flag_cond = _flag_cond
class flag_cond(sp.Function):
"""Evaluates a flag condition on a bit mask, and returns the value of one of two expressions,
depending on whether the flag is set.
Three argument version:
```
flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0
```
Four argument version:
```
flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else
```
"""
nargs = (3, 4)
def __new__(cls, flag_bit, mask_expression, *expressions):
# TODO Jan reintroduce checking
# flag_dtype = get_type_of_expression(flag_bit)
# if not flag_dtype.is_int():
# raise ValueError('Argument flag_bit must be of integer type.')
#
# mask_dtype = get_type_of_expression(mask_expression)
# if not mask_dtype.is_int():
# raise ValueError('Argument mask_expression must be of integer type.')
return super().__new__(cls, flag_bit, mask_expression, *expressions)
def to_c(self, print_func):
flag_bit = self.args[0]
mask = self.args[1]
then_expression = self.args[2]
flag_bit_code = print_func(flag_bit)
mask_code = print_func(mask)
then_code = print_func(then_expression)
code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})"
if len(self.args) > 3:
else_expression = self.args[3]
else_code = print_func(else_expression)
code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})"
return code
from typing import Any, List, Tuple from typing import Any, List, Tuple, Sequence
from pystencils.astnodes import SympyAssignment from pystencils.assignment import Assignment
from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
from pystencils.typing import create_type from pystencils.types import create_type
class Boundary: class Boundary:
...@@ -14,7 +14,7 @@ class Boundary: ...@@ -14,7 +14,7 @@ class Boundary:
def __init__(self, name=None): def __init__(self, name=None):
self._name = name self._name = name
def __call__(self, field, direction_symbol, index_field) -> List[SympyAssignment]: def __call__(self, field, direction_symbol, index_field) -> List[Assignment]:
"""Defines the boundary behavior and must therefore be implemented by all boundaries. """Defines the boundary behavior and must therefore be implemented by all boundaries.
Here the boundary is defined as a list of sympy assignments, from which a boundary kernel is generated. Here the boundary is defined as a list of sympy assignments, from which a boundary kernel is generated.
...@@ -30,7 +30,7 @@ class Boundary: ...@@ -30,7 +30,7 @@ class Boundary:
raise NotImplementedError("Boundary class has to overwrite __call__") raise NotImplementedError("Boundary class has to overwrite __call__")
@property @property
def additional_data(self) -> Tuple[str, Any]: def additional_data(self) -> Sequence[Tuple[str, Any]]:
"""Return a list of (name, type) tuples for additional data items required in this boundary """Return a list of (name, type) tuples for additional data items required in this boundary
These data items can either be initialized in separate kernel see additional_data_kernel_init or by These data items can either be initialized in separate kernel see additional_data_kernel_init or by
Python callbacks - see additional_data_callback """ Python callbacks - see additional_data_callback """
...@@ -63,13 +63,13 @@ class Neumann(Boundary): ...@@ -63,13 +63,13 @@ class Neumann(Boundary):
neighbor = BoundaryOffsetInfo.offset_from_dir(direction_symbol, field.spatial_dimensions) neighbor = BoundaryOffsetInfo.offset_from_dir(direction_symbol, field.spatial_dimensions)
if field.index_dimensions == 0: if field.index_dimensions == 0:
return [SympyAssignment(field.center, field[neighbor])] return [Assignment(field.center, field[neighbor])]
else: else:
from itertools import product from itertools import product
if not field.has_fixed_index_shape: if not field.has_fixed_index_shape:
raise NotImplementedError("Neumann boundary works only for fields with fixed index shape") raise NotImplementedError("Neumann boundary works only for fields with fixed index shape")
index_iter = product(*(range(i) for i in field.index_shape)) index_iter = product(*(range(i) for i in field.index_shape))
return [SympyAssignment(field(*idx), field[neighbor](*idx)) for idx in index_iter] return [Assignment(field(*idx), field[neighbor](*idx)) for idx in index_iter]
def __hash__(self): def __hash__(self):
# All boundaries of these class behave equal -> should also be equal # All boundaries of these class behave equal -> should also be equal
...@@ -103,11 +103,11 @@ class Dirichlet(Boundary): ...@@ -103,11 +103,11 @@ class Dirichlet(Boundary):
def __call__(self, field, direction_symbol, index_field, **kwargs): def __call__(self, field, direction_symbol, index_field, **kwargs):
if field.index_dimensions == 0: if field.index_dimensions == 0:
return [SympyAssignment(field.center, index_field("value") if self.additional_data else self._value)] return [Assignment(field.center, index_field("value") if self.additional_data else self._value)]
elif field.index_dimensions == 1: elif field.index_dimensions == 1:
assert not self.additional_data assert not self.additional_data
if not field.has_fixed_index_shape: if not field.has_fixed_index_shape:
raise NotImplementedError("Field needs fixed index shape") raise NotImplementedError("Field needs fixed index shape")
assert len(self._value) == field.index_shape[0], "Dirichlet value does not match index shape of field" assert len(self._value) == field.index_shape[0], "Dirichlet value does not match index shape of field"
return [SympyAssignment(field(i), self._value[i]) for i in range(field.index_shape[0])] return [Assignment(field(i), self._value[i]) for i in range(field.index_shape[0])]
raise NotImplementedError("Dirichlet boundary not implemented for fields with more than one index dimension") raise NotImplementedError("Dirichlet boundary not implemented for fields with more than one index dimension")
...@@ -4,14 +4,15 @@ import numpy as np ...@@ -4,14 +4,15 @@ import numpy as np
import sympy as sp import sympy as sp
from pystencils import create_kernel, CreateKernelConfig, Target from pystencils import create_kernel, CreateKernelConfig, Target
from pystencils.astnodes import SympyAssignment from pystencils.assignment import Assignment
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.boundaries.createindexlist import ( from pystencils.boundaries.createindexlist import (
create_boundary_index_array, numpy_data_type_for_boundary_object) create_boundary_index_array, numpy_data_type_for_boundary_object)
from pystencils.typing import TypedSymbol, create_type from pystencils.sympyextensions import TypedSymbol
from pystencils.types import PsIntegerType
from pystencils.types.quick import Arr, SInt
from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.gpu.gpu_array_handler import GPUArrayHandler
from pystencils.field import Field from pystencils.field import Field, FieldType
from pystencils.typing.typed_sympy import FieldPointerSymbol from pystencils.codegen.properties import FieldBasePtr
try: try:
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -35,12 +36,9 @@ class FlagInterface: ...@@ -35,12 +36,9 @@ class FlagInterface:
>>> dh = create_data_handling((4, 5)) >>> dh = create_data_handling((4, 5))
>>> fi = FlagInterface(dh, 'flag_field', np.uint8) >>> fi = FlagInterface(dh, 'flag_field', np.uint8)
>>> assert dh.has_data('flag_field') >>> assert dh.has_data('flag_field')
>>> fi.reserve_next_flag() >>> assert fi.reserve_next_flag() == 2
2 >>> assert fi.reserve_flag(4) == 4
>>> fi.reserve_flag(4) >>> assert fi.reserve_next_flag() == 8
4
>>> fi.reserve_next_flag()
8
""" """
def __init__(self, data_handling, flag_field_name, dtype=DEFAULT_FLAG_TYPE): def __init__(self, data_handling, flag_field_name, dtype=DEFAULT_FLAG_TYPE):
...@@ -248,7 +246,7 @@ class BoundaryHandling: ...@@ -248,7 +246,7 @@ class BoundaryHandling:
kwargs['indexField'] = idx_arr kwargs['indexField'] = idx_arr
data_used_in_kernel = (p.fields[0].name data_used_in_kernel = (p.fields[0].name
for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
if isinstance(p.symbol, FieldPointerSymbol) and p.fields[0].name not in kwargs) if bool(p.get_properties(FieldBasePtr)) and p.fields[0].name not in kwargs)
kwargs.update({name: b[name] for name in data_used_in_kernel}) kwargs.update({name: b[name] for name in data_used_in_kernel})
self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs) self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
...@@ -264,7 +262,7 @@ class BoundaryHandling: ...@@ -264,7 +262,7 @@ class BoundaryHandling:
arguments['indexField'] = idx_arr arguments['indexField'] = idx_arr
data_used_in_kernel = (p.fields[0].name data_used_in_kernel = (p.fields[0].name
for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
if isinstance(p.symbol, FieldPointerSymbol) and p.field_name not in arguments) if bool(p.get_properties(FieldBasePtr)) and p.fields[0].name not in arguments)
arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments}) arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments})
kernel = self._boundary_object_to_boundary_info[b_obj].kernel kernel = self._boundary_object_to_boundary_info[b_obj].kernel
...@@ -304,7 +302,8 @@ class BoundaryHandling: ...@@ -304,7 +302,8 @@ class BoundaryHandling:
def _add_boundary(self, boundary_obj, flag=None): def _add_boundary(self, boundary_obj, flag=None):
if boundary_obj not in self._boundary_object_to_boundary_info: if boundary_obj not in self._boundary_object_to_boundary_info:
sym_index_field = Field.create_generic('indexField', spatial_dimensions=1, sym_index_field = Field.create_generic('indexField', spatial_dimensions=1,
dtype=numpy_data_type_for_boundary_object(boundary_obj, self.dim)) dtype=numpy_data_type_for_boundary_object(boundary_obj, self.dim),
field_type=FieldType.INDEXED)
ast = self._create_boundary_kernel(self._data_handling.fields[self._field_name], ast = self._create_boundary_kernel(self._data_handling.fields[self._field_name],
sym_index_field, boundary_obj) sym_index_field, boundary_obj)
if flag is None: if flag is None:
...@@ -404,51 +403,63 @@ class BoundaryDataSetter: ...@@ -404,51 +403,63 @@ class BoundaryDataSetter:
return self.index_array[item] return self.index_array[item]
class BoundaryOffsetInfo(CustomCodeNode): class BoundaryOffsetInfo:
# --------------------------- Functions to be used by boundaries -------------------------- # --------------------------- Functions to be used by boundaries --------------------------
@staticmethod @staticmethod
def offset_from_dir(dir_idx, dim): def offset_from_dir(dir_idx, dim):
return tuple([sp.IndexedBase(symbol, shape=(1,))[dir_idx] return tuple([sp.IndexedBase(symbol, shape=(1,))[dir_idx]
for symbol in BoundaryOffsetInfo._offset_symbols(dim)]) for symbol in BoundaryOffsetInfo._untyped_offset_symbols(dim)])
@staticmethod @staticmethod
def inv_dir(dir_idx): def inv_dir(dir_idx):
return sp.IndexedBase(BoundaryOffsetInfo.INV_DIR_SYMBOL, shape=(1,))[dir_idx] return sp.IndexedBase(BoundaryOffsetInfo._untyped_inv_dir_symbol(), shape=(1,))[dir_idx]
# ---------------------------------- Internal --------------------------------------------- # ---------------------------------- Internal ---------------------------------------------
def __init__(self, stencil): def __init__(self, stencil, index_dtype: PsIntegerType = SInt(32)) -> None:
dim = len(stencil[0]) self._stencil = stencil
self._dim = len(stencil[0])
self._index_dtype = index_dtype
offset_sym = BoundaryOffsetInfo._offset_symbols(dim) def get_array_declarations(self) -> list[Assignment]:
code = "\n" asms = []
for i in range(dim): for i, offset_symb in enumerate(self._offset_symbols(self._dim)):
offset_str = ", ".join([str(d[i]) for d in stencil]) offsets = tuple(d[i] for d in self._stencil)
code += "const int32_t %s [] = { %s };\n" % (offset_sym[i].name, offset_str) asms.append(Assignment(offset_symb, offsets))
inv_dirs = [] inv_dirs = []
for direction in stencil: for direction in self._stencil:
inverse_dir = tuple([-i for i in direction]) inverse_dir = tuple([-i for i in direction])
inv_dirs.append(str(stencil.index(inverse_dir))) inv_dirs.append(str(self._stencil.index(inverse_dir)))
code += "const int32_t %s [] = { %s };\n" % (self.INV_DIR_SYMBOL.name, ", ".join(inv_dirs)) asms.append(Assignment(self._inv_dir_symbol(), tuple(inv_dirs)))
offset_symbols = BoundaryOffsetInfo._offset_symbols(dim) return asms
super(BoundaryOffsetInfo, self).__init__(code, symbols_read=set(),
symbols_defined=set(offset_symbols + [self.INV_DIR_SYMBOL]))
def _offset_symbols(self, dim, dtype: PsIntegerType = SInt(32)):
return [TypedSymbol(f"c{d}", Arr(dtype, len(self._stencil))) for d in ['x', 'y', 'z'][:dim]]
@staticmethod @staticmethod
def _offset_symbols(dim): def _untyped_offset_symbols(dim):
return [TypedSymbol(f"c{d}", create_type(np.int32)) for d in ['x', 'y', 'z'][:dim]] return [sp.Symbol(f"c{d}") for d in ['x', 'y', 'z'][:dim]]
INV_DIR_SYMBOL = TypedSymbol("invdir", np.int32) def _inv_dir_symbol(self, dtype: PsIntegerType = SInt(32)):
return TypedSymbol("invdir", Arr(dtype, len(self._stencil)))
@staticmethod
def _untyped_inv_dir_symbol(dtype: PsIntegerType = SInt(32)):
return sp.Symbol("invdir")
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args):
elements = [BoundaryOffsetInfo(stencil)] # TODO: reconsider how to control the index_dtype in boundary kernels
dir_symbol = TypedSymbol("dir", np.int32) config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args)
elements += [SympyAssignment(dir_symbol, index_field[0]('dir'))]
offset_info = BoundaryOffsetInfo(stencil, config.index_dtype)
elements = offset_info.get_array_declarations()
dir_symbol = TypedSymbol("dir", config.index_dtype)
elements += [Assignment(dir_symbol, index_field[0]('dir'))]
elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field) elements += boundary_functor(field, direction_symbol=dir_symbol, index_field=index_field)
config = CreateKernelConfig(index_fields=[index_field], target=target, **kernel_creation_args)
return create_kernel(elements, config=config) return create_kernel(elements, config=config)
import warnings import warnings
import numpy as np import numpy as np
from pystencils.types.quick import SInt
try: try:
...@@ -21,14 +22,14 @@ if cython_funcs_available: ...@@ -21,14 +22,14 @@ if cython_funcs_available:
boundary_index_array_coordinate_names = ["x", "y", "z"] boundary_index_array_coordinate_names = ["x", "y", "z"]
direction_member_name = "dir" direction_member_name = "dir"
default_index_array_dtype = np.int32 default_index_array_dtype = SInt(32)
def numpy_data_type_for_boundary_object(boundary_object, dim): def numpy_data_type_for_boundary_object(boundary_object, dim):
coordinate_names = boundary_index_array_coordinate_names[:dim] coordinate_names = boundary_index_array_coordinate_names[:dim]
return np.dtype( return np.dtype(
[(name, default_index_array_dtype) for name in coordinate_names] [(name, default_index_array_dtype.numpy_dtype) for name in coordinate_names]
+ [(direction_member_name, default_index_array_dtype)] + [(direction_member_name, default_index_array_dtype.numpy_dtype)]
+ [(i[0], i[1].numpy_dtype) for i in boundary_object.additional_data], + [(i[0], i[1].numpy_dtype) for i in boundary_object.additional_data],
align=True, align=True,
) )
...@@ -56,8 +57,8 @@ def _create_index_list_python( ...@@ -56,8 +57,8 @@ def _create_index_list_python(
: len(flag_field_arr.shape) : len(flag_field_arr.shape)
] ]
index_arr_dtype = np.dtype( index_arr_dtype = np.dtype(
[(name, default_index_array_dtype) for name in coordinate_names] [(name, default_index_array_dtype.numpy_dtype) for name in coordinate_names]
+ [(direction_member_name, default_index_array_dtype)] + [(direction_member_name, default_index_array_dtype.numpy_dtype)]
) )
# boundary cells are extracted via np.where. To ensure continous memory access in the compute kernel these cells # boundary cells are extracted via np.where. To ensure continous memory access in the compute kernel these cells
...@@ -147,11 +148,11 @@ def create_boundary_index_list( ...@@ -147,11 +148,11 @@ def create_boundary_index_list(
dim = len(flag_field.shape) dim = len(flag_field.shape)
coordinate_names = boundary_index_array_coordinate_names[:dim] coordinate_names = boundary_index_array_coordinate_names[:dim]
index_arr_dtype = np.dtype( index_arr_dtype = np.dtype(
[(name, default_index_array_dtype) for name in coordinate_names] [(name, default_index_array_dtype.numpy_dtype) for name in coordinate_names]
+ [(direction_member_name, default_index_array_dtype)] + [(direction_member_name, default_index_array_dtype.numpy_dtype)]
) )
stencil = np.array(stencil, dtype=default_index_array_dtype) stencil = np.array(stencil, dtype=default_index_array_dtype.numpy_dtype)
args = ( args = (
flag_field, flag_field,
nr_of_ghost_layers, nr_of_ghost_layers,
......
import sympy as sp import sympy as sp
from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE from pystencils.boundaries.boundaryhandling import DEFAULT_FLAG_TYPE
from pystencils.typing import TypedSymbol, create_type from pystencils.sympyextensions import TypedSymbol
from pystencils.types import create_type
from pystencils.field import Field from pystencils.field import Field
from pystencils.integer_functions import bitwise_and from pystencils.sympyextensions.integer_functions import bitwise_and
def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False): def add_neumann_boundary(eqs, fields, flag_field, boundary_flag="neumann_flag", inverse_flag=False):
......
from .target import Target
from .config import (
CreateKernelConfig,
AUTO,
)
from .parameters import Parameter
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .driver import create_kernel, get_driver
__all__ = [
"Target",
"CreateKernelConfig",
"AUTO",
"Parameter",
"Kernel",
"GpuKernel",
"GpuThreadsRange",
"create_kernel",
"get_driver",
]
from __future__ import annotations
from warnings import warn
from abc import ABC
from collections.abc import Collection
from typing import TYPE_CHECKING, Sequence, Generic, TypeVar, Callable, Any, cast
from dataclasses import dataclass, InitVar, fields
from .target import Target
from ..field import Field, FieldType
from ..types import (
PsIntegerType,
UserTypeSpec,
PsScalarType,
create_type,
)
from ..defaults import DEFAULTS
if TYPE_CHECKING:
from ..jit import JitBase
Option_T = TypeVar("Option_T")
"""Type variable for option values"""
Arg_T = TypeVar("Arg_T")
"""Type variable for option arguments"""
class Option(Generic[Option_T, Arg_T]):
"""Option descriptor.
This descriptor is used to model configuration options.
It maintains a default value for the option that is used when no value
was specified by the user.
In configuration options, the value `None` stands for ``unset``.
It can therefore not be used to set an option to the meaning "not any", or "empty"
- for these, special values need to be used.
The Option allows a validator function to be specified,
which will be called to perform sanity checks on user-provided values.
Through the validator, options may also be set from arguments of a different type (``Arg_T``)
than their value type (``Option_T``). If ``Arg_T`` is different from ``Option_T``,
the validator must perform the conversion from the former to the latter.
.. note::
``Arg_T`` must always be a supertype of ``Option_T``.
"""
def __init__(
self,
default: Option_T | None = None,
validator: Callable[[Any, Arg_T | None], Option_T | None] | None = None,
) -> None:
self._default = default
self._validator = validator
self._name: str
self._lookup: str
def validate(self, validator: Callable[[Any, Any], Any] | None):
self._validator = validator
return validator
@property
def default(self) -> Option_T | None:
return self._default
def get(self, obj) -> Option_T | None:
val = getattr(obj, self._lookup, None)
if val is None:
return self._default
else:
return val
def is_set(self, obj) -> bool:
return getattr(obj, self._lookup, None) is not None
def __set_name__(self, owner: ConfigBase, name: str):
self._name = name
self._lookup = f"_{name}"
def __get__(self, obj: ConfigBase, objtype: type[ConfigBase] | None = None) -> Option_T | None:
if obj is None:
return None
return getattr(obj, self._lookup, None)
def __set__(self, obj: ConfigBase, arg: Arg_T | None):
if arg is not None and self._validator is not None:
value = self._validator(obj, arg)
else:
value = cast(Option_T, arg)
setattr(obj, self._lookup, value)
def __delete__(self, obj):
delattr(obj, self._lookup)
class BasicOption(Option[Option_T, Option_T]):
"Subclass of Option where ``Arg_T == Option_T``."
class ConfigBase(ABC):
"""Base class for configuration categories.
This class implements query and retrieval mechanism for configuration options,
as well as deepcopy functionality for categories.
Subclasses of `ConfigBase` must be `dataclasses`,
and all of their instance fields must have one of two descriptors types:
- Either `Option`, for scalar options;
- Or `Category` for option subcategories.
`Option` fields must be assigned immutable values, but are otherwise unconstrained.
`Category` subobjects must be subclasses of `ConfigBase`.
**Retrieval** Options set to `None` are considered *unset*, i.e. the user has not provided a value.
Through the `Option` descriptor, these options can still have a default value.
To retrieve either the user-set value if one exists, or the default value otherwise, use `get_option`.
**Deep-Copy** When a configuration object is copied, all of its subcategories must be copied along with it,
such that changes in the original do no affect the copy, and vice versa.
Such a deep copy is performed by the `copy <ConfigBase.copy>` method.
"""
def get_option(self, name: str) -> Any:
"""Get the value set for the specified option, or the option's default value if none has been set."""
descr: Option = type(self).__dict__[name]
return descr.get(self)
def is_option_set(self, name: str) -> bool:
descr: Option = type(self).__dict__[name]
return descr.is_set(self)
def override(self, other: ConfigBase):
for f in fields(self): # type: ignore
fvalue = getattr(self, f.name)
if isinstance(fvalue, ConfigBase): # type: ignore
fvalue.override(getattr(other, f.name))
else:
new_val = getattr(other, f.name)
if new_val is not None:
setattr(self, f.name, new_val)
def copy(self):
"""Perform a semi-deep copy of this configuration object.
This will recursively copy any config subobjects
(categories, i.e. subclasses of `ConfigBase` wrapped in the `Category` descriptor)
nested in this configuration object. Any other fields will be copied by reference.
"""
# IMPLEMENTATION NOTES
#
# We do not need to call `copy` on any subcategories here, since the `Category`
# descriptor already calls `copy` in its `__set__` method,
# which is invoked during the constructor call in the `return` statement.
# Calling `copy` here would result in copying category objects twice.
#
# We cannot use the standard library `copy.copy` here, since it merely duplicates
# the instance dictionary and does not call the constructor.
config_fields = fields(self) # type: ignore
kwargs = dict()
for field in config_fields:
val = getattr(self, field.name)
kwargs[field.name] = val
return type(self)(**kwargs)
Category_T = TypeVar("Category_T", bound=ConfigBase)
"""Type variable for option categories."""
class Category(Generic[Category_T]):
"""Descriptor for a category of options.
This descriptor makes sure that when an entire category is set to an object,
that object is copied immediately such that later changes to the original
do not affect this configuration.
"""
def __init__(self, default: Category_T):
self._default = default
def __set_name__(self, owner: ConfigBase, name: str):
self._name = name
self._lookup = f"_{name}"
def __get__(self, obj: ConfigBase, objtype: type[ConfigBase] | None = None) -> Category_T:
if obj is None:
return None
cat = getattr(obj, self._lookup, None)
if cat is None:
cat = self._default.copy()
setattr(obj, self._lookup, cat)
return cast(Category_T, cat)
def __set__(self, obj: ConfigBase, cat: Category_T | None):
setattr(obj, self._lookup, cat.copy() if cat is not None else None)
class _AUTO_TYPE:
def __repr__(self) -> str:
return "AUTO" # for pretty-printing in the docs
AUTO = _AUTO_TYPE()
"""Special value that can be passed to some options for invoking automatic behaviour."""
@dataclass
class OpenMpOptions(ConfigBase):
"""Configuration options controlling automatic OpenMP instrumentation."""
enable: BasicOption[bool] = BasicOption(False)
"""Enable OpenMP instrumentation"""
nesting_depth: BasicOption[int] = BasicOption(0)
"""Nesting depth of the loop that should be parallelized. Must be a nonnegative number."""
collapse: BasicOption[int] = BasicOption()
"""Argument to the OpenMP ``collapse`` clause"""
schedule: BasicOption[str] = BasicOption("static")
"""Argument to the OpenMP ``schedule`` clause"""
num_threads: BasicOption[int] = BasicOption()
"""Set the number of OpenMP threads to execute the parallel region."""
omit_parallel_construct: BasicOption[bool] = BasicOption(False)
"""If set to ``True``, the OpenMP ``parallel`` construct is omitted, producing just a ``#pragma omp for``.
Use this option only if you intend to wrap the kernel into an external ``#pragma omp parallel`` region.
"""
@dataclass
class VectorizationOptions(ConfigBase):
"""Configuration for the auto-vectorizer."""
enable: BasicOption[bool] = BasicOption(False)
"""Enable intrinsic vectorization."""
lanes: BasicOption[int] = BasicOption()
"""Number of SIMD lanes to be used in vectorization.
If set to `None` (the default), the vector register width will be automatically set to the broadest possible.
If the CPU architecture specified in `target <CreateKernelConfig.target>` does not support some
operation contained in the kernel with the given number of lanes, an error will be raised.
"""
use_nontemporal_stores: BasicOption[bool | Collection[str | Field]] = BasicOption(
False
)
"""Enable nontemporal (streaming) stores.
If set to `True` and the selected CPU supports streaming stores, the vectorizer will generate
nontemporal store instructions for all stores.
If set to a collection of fields (or field names), streaming stores will only be generated for
the given fields.
"""
assume_aligned: BasicOption[bool] = BasicOption(False)
"""Assume field pointer alignment.
If set to `True`, the vectorizer will assume that the address of the first inner entry
(after ghost layers) of each field is aligned at the necessary byte boundary.
"""
assume_inner_stride_one: BasicOption[bool] = BasicOption(False)
"""Assume stride associated with the innermost spatial coordinate of all fields is one.
If set to `True`, the vectorizer will replace the stride of the innermost spatial coordinate
with unity, thus enabling vectorization. If any fields already have a fixed innermost stride
that is not equal to one, an error will be raised.
"""
@staticmethod
def default_lanes(target: Target, dtype: PsScalarType):
if not target.is_vector_cpu():
raise ValueError(f"Given target {target} is no vector CPU target.")
assert dtype.itemsize is not None
match target:
case Target.X86_SSE:
return 128 // (dtype.itemsize * 8)
case Target.X86_AVX:
return 256 // (dtype.itemsize * 8)
case Target.X86_AVX512 | Target.X86_AVX512_FP16:
return 512 // (dtype.itemsize * 8)
case _:
raise NotImplementedError(
f"No default number of lanes known for {dtype} on {target}"
)
@dataclass
class CpuOptions(ConfigBase):
"""Configuration options specific to CPU targets."""
openmp: Category[OpenMpOptions] = Category(OpenMpOptions())
"""Options governing OpenMP-instrumentation.
"""
vectorize: Category[VectorizationOptions] = Category(VectorizationOptions())
"""Options governing intrinsic vectorization.
"""
loop_blocking: BasicOption[tuple[int, ...]] = BasicOption()
"""Block sizes for loop blocking.
If set, the kernel's loops will be tiled according to the given block sizes.
"""
use_cacheline_zeroing: BasicOption[bool] = BasicOption(False)
"""Enable cache-line zeroing.
If set to `True` and the selected CPU supports cacheline zeroing, the CPU optimizer will attempt
to produce cacheline zeroing instructions where possible.
"""
@dataclass
class GpuOptions(ConfigBase):
"""Configuration options specific to GPU targets."""
omit_range_check: BasicOption[bool] = BasicOption(False)
"""If set to `True`, omit the iteration counter range check.
By default, the code generator introduces a check if the iteration counters computed from GPU block and thread
indices are within the prescribed loop range.
This check can be discarded through this option, at your own peril.
"""
block_size: BasicOption[tuple[int, int, int]] = BasicOption()
"""Desired block size for the execution of GPU kernels. May be overridden later by the runtime system."""
manual_launch_grid: BasicOption[bool] = BasicOption(False)
"""Always require a manually specified launch grid when running this kernel.
If set to `True`, the code generator will not attempt to infer the size of
the launch grid from the kernel.
The launch grid will then have to be specified manually at runtime.
"""
@dataclass
class SyclOptions(ConfigBase):
"""Options specific to the `SYCL <Target.SYCL>` target."""
automatic_block_size: BasicOption[bool] = BasicOption(True)
"""If set to `True`, let the SYCL runtime decide on the block size.
If set to `True`, the kernel is generated for execution via
`parallel_for <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_parallel_for_invoke>`_
-dispatch using
a flat ``sycl::range``. In this case, the GPU block size will be inferred by the SYCL runtime.
If set to `False`, the kernel will receive an ``nd_item`` and has to be executed using
`parallel_for <https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_parallel_for_invoke>`_
with an ``nd_range``. This allows manual specification of the block size.
"""
GhostLayerSpec = _AUTO_TYPE | int | Sequence[int | tuple[int, int]]
IterationSliceSpec = int | slice | tuple[int | slice]
@dataclass
class CreateKernelConfig(ConfigBase):
"""Options for create_kernel."""
target: BasicOption[Target] = BasicOption(Target.GenericCPU)
"""The code generation target."""
jit: BasicOption[JitBase] = BasicOption()
"""Just-in-time compiler used to compile and load the kernel for invocation from the current Python environment.
If left at `None`, a default just-in-time compiler will be inferred from the `target` parameter.
To explicitly disable JIT compilation, pass `pystencils.no_jit <pystencils.jit.no_jit>`.
"""
function_name: BasicOption[str] = BasicOption("kernel")
"""Name of the generated function"""
ghost_layers: BasicOption[GhostLayerSpec] = BasicOption()
"""Specifies the number of ghost layers of the iteration region.
Options:
- :py:data:`AUTO <pystencils.config.AUTO>`: Required ghost layers are inferred from field accesses
- `int`: A uniform number of ghost layers in each spatial coordinate is applied
- ``Sequence[int, tuple[int, int]]``: Ghost layers are specified for each spatial coordinate.
In each coordinate, a single integer specifies the ghost layers at both the lower and upper iteration limit,
while a pair of integers specifies the lower and upper ghost layers separately.
When manually specifying ghost layers, it is the user's responsibility to avoid out-of-bounds memory accesses.
.. note::
At most one of `ghost_layers`, `iteration_slice`, and `index_field` may be set.
"""
iteration_slice: BasicOption[IterationSliceSpec] = BasicOption()
"""Specifies the kernel's iteration slice.
Example:
>>> cfg = CreateKernelConfig(
... iteration_slice=ps.make_slice[3:14, 2:-2]
... )
>>> cfg.iteration_slice
(slice(3, 14, None), slice(2, -2, None))
.. note::
At most one of `ghost_layers`, `iteration_slice`, and `index_field` may be set.
"""
index_field: BasicOption[Field] = BasicOption()
"""Index field for a sparse kernel.
If this option is set, a sparse kernel with the given field as index field will be generated.
.. note::
At most one of `ghost_layers`, `iteration_slice`, and `index_field` may be set.
"""
"""Data Types"""
index_dtype: Option[PsIntegerType, UserTypeSpec] = Option(DEFAULTS.index_dtype)
"""Data type used for all index calculations."""
default_dtype: Option[PsScalarType, UserTypeSpec] = Option(DEFAULTS.numeric_dtype)
"""Default numeric data type.
This data type will be applied to all untyped symbols.
"""
"""Analysis"""
allow_double_writes: BasicOption[bool] = BasicOption(False)
"""
If True, don't check if every field is only written at a single location. This is required
for example for kernels that are compiled with loop step sizes > 1, that handle multiple
cells at once. Use with care!
"""
skip_independence_check: BasicOption[bool] = BasicOption(False)
"""
By default the assignment list is checked for read/write independence. This means fields are only written at
locations where they are read. Doing so guarantees thread safety. In some cases e.g. for
periodicity kernel, this can not be assured and does the check needs to be deactivated. Use with care!
"""
"""Target-Specific Options"""
cpu: Category[CpuOptions] = Category(CpuOptions())
"""Options for CPU kernels. See `CpuOptions`."""
gpu: Category[GpuOptions] = Category(GpuOptions())
"""Options for GPU Kernels. See `GpuOptions`."""
sycl: Category[SyclOptions] = Category(SyclOptions())
"""Options for SYCL kernels. See `SyclOptions`."""
@index_dtype.validate
def validate_index_type(self, spec: UserTypeSpec):
dtype = create_type(spec)
if not isinstance(dtype, PsIntegerType):
raise ValueError("index_dtype must be an integer type")
return dtype
@default_dtype.validate
def validate_default_dtype(self, spec: UserTypeSpec):
dtype = create_type(spec)
if not isinstance(dtype, PsScalarType):
raise ValueError("default_dtype must be a scalar numeric type")
return dtype
@index_field.validate
def validate_index_field(self, idx_field: Field):
if idx_field.field_type != FieldType.INDEXED:
raise ValueError(
"Only fields of type FieldType.INDEXED can be used as index fields"
)
return idx_field
# Deprecated Options
data_type: InitVar[UserTypeSpec | None] = None
"""Deprecated; use `default_dtype` instead"""
cpu_openmp: InitVar[bool | int | None] = None
"""Deprecated; use `cpu.openmp <CpuOptions.openmp>` instead."""
cpu_vectorize_info: InitVar[dict | None] = None
"""Deprecated; use `cpu.vectorize <CpuOptions.vectorize>` instead."""
gpu_indexing_params: InitVar[dict | None] = None
"""Deprecated; set options in the `gpu` category instead."""
# Getters
def get_target(self) -> Target:
t: Target = self.get_option("target")
match t:
case Target.CurrentCPU:
return Target.auto_cpu()
case _:
return t
def get_jit(self) -> JitBase:
"""Returns either the user-specified JIT compiler, or infers one from the target if none is given."""
jit: JitBase | None = self.get_option("jit")
if jit is None:
if self.get_target().is_cpu():
from ..jit import LegacyCpuJit
return LegacyCpuJit()
elif self.get_target() == Target.CUDA:
try:
from ..jit.gpu_cupy import CupyJit
if self.gpu is not None and self.gpu.block_size is not None:
return CupyJit(self.gpu.block_size)
else:
return CupyJit()
except ImportError:
from ..jit import no_jit
return no_jit
elif self.get_target() == Target.SYCL:
from ..jit import no_jit
return no_jit
else:
raise NotImplementedError(
f"No default JIT compiler implemented yet for target {self.target}"
)
else:
return jit
# Postprocessing
def __post_init__(self, *args):
# Check deprecated options
self._check_deprecations(*args)
def _check_deprecations(
self,
data_type: UserTypeSpec | None,
cpu_openmp: bool | int | None,
cpu_vectorize_info: dict | None,
gpu_indexing_params: dict | None,
): # pragma: no cover
if data_type is not None:
_deprecated_option("data_type", "default_dtype")
warn(
"Setting the deprecated `data_type` will override the value of `default_dtype`. "
"Set `default_dtype` instead.",
UserWarning,
)
self.default_dtype = data_type
if cpu_openmp is not None:
_deprecated_option("cpu_openmp", "cpu_optim.openmp")
warn(
"Setting the deprecated `cpu_openmp` option will override any options "
"passed in the `cpu.openmp` category.",
UserWarning,
)
deprecated_omp = OpenMpOptions()
match cpu_openmp:
case True:
deprecated_omp.enable = False
case False:
deprecated_omp.enable = False
case int():
deprecated_omp.enable = True
deprecated_omp.num_threads = cpu_openmp
case _:
raise ValueError(
f"Invalid option for `cpu_openmp`: {cpu_openmp}"
)
self.cpu.openmp = deprecated_omp
if cpu_vectorize_info is not None:
_deprecated_option("cpu_vectorize_info", "cpu_optim.vectorize")
if "instruction_set" in cpu_vectorize_info:
if self.target != Target.GenericCPU:
raise ValueError(
"Setting 'instruction_set' in the deprecated 'cpu_vectorize_info' option is only "
"valid if `target == Target.CPU`."
)
isa = cpu_vectorize_info["instruction_set"]
vec_target: Target
match isa:
case "best":
vec_target = Target.available_vector_cpu_targets().pop()
case "sse":
vec_target = Target.X86_SSE
case "avx":
vec_target = Target.X86_AVX
case "avx512":
vec_target = Target.X86_AVX512
case "avx512vl":
vec_target = Target.X86_AVX512 | Target._VL
case _:
raise ValueError(
f'Value {isa} in `cpu_vectorize_info["instruction_set"]` is not supported.'
)
warn(
f"Value {isa} for `instruction_set` in deprecated `cpu_vectorize_info` "
"will override the `target` option. "
f"Set `target` to {vec_target} instead.",
UserWarning,
)
self.target = vec_target
warn(
"Setting the deprecated `cpu_vectorize_info` will override any options "
"passed in the `cpu.vectorize` category.",
UserWarning,
)
deprecated_vec_opts = VectorizationOptions(
enable=True,
assume_inner_stride_one=cpu_vectorize_info.get(
"assume_inner_stride_one", False
),
assume_aligned=cpu_vectorize_info.get("assume_aligned", False),
use_nontemporal_stores=cpu_vectorize_info.get("nontemporal", False),
)
self.cpu.vectorize = deprecated_vec_opts
if gpu_indexing_params is not None:
_deprecated_option("gpu_indexing_params", "gpu_indexing")
warn(
"Setting the deprecated `gpu_indexing_params` will override any options "
"passed in the `gpu` category."
)
self.gpu = GpuOptions(
block_size=gpu_indexing_params.get("block_size", None)
)
def _deprecated_option(name, instead): # pragma: no cover
from warnings import warn
warn(
f"The `{name}` option of CreateKernelConfig is deprecated and will be removed in pystencils 2.1. "
f"Use `{instead}` instead.",
FutureWarning,
)
from __future__ import annotations
from typing import cast, Sequence, Iterable, TYPE_CHECKING
from dataclasses import dataclass, replace
from .target import Target
from .config import (
CreateKernelConfig,
VectorizationOptions,
AUTO,
_AUTO_TYPE,
GhostLayerSpec,
IterationSliceSpec,
)
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
from .parameters import Parameter
from ..field import Field
from ..types import PsIntegerType, PsScalarType
from ..backend.memory import PsSymbol
from ..backend.ast import PsAstNode
from ..backend.ast.structural import PsBlock, PsLoop
from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
from ..backend.kernelcreation import (
KernelCreationContext,
KernelAnalysis,
FreezeExpressions,
Typifier,
)
from ..backend.constants import PsConstant
from ..backend.kernelcreation.iteration_space import (
create_sparse_iteration_space,
create_full_iteration_space,
FullIterationSpace,
)
from ..backend.platforms import (
Platform,
GenericCpu,
GenericVectorCpu,
GenericGpu,
)
from ..backend.exceptions import VectorizationError
from ..backend.transformations import (
EliminateConstants,
LowerToC,
SelectFunctions,
CanonicalizeSymbols,
HoistLoopInvariantDeclarations,
)
from ..simp import AssignmentCollection
from sympy.codegen.ast import AssignmentBase
if TYPE_CHECKING:
from ..jit import JitBase
__all__ = ["create_kernel"]
def create_kernel(
assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
config: CreateKernelConfig | None = None,
**kwargs,
) -> Kernel:
"""Create a kernel function from a set of assignments.
Args:
assignments: The kernel's sequence of assignments, expressed using SymPy
config: The configuration for the kernel translator
kwargs: If ``config`` is not set, it is created from the keyword arguments;
if it is set, its option will be overridden by any keyword arguments.
Returns:
The numerical kernel in pystencil's internal representation, ready to be
exported or compiled
"""
if not config:
config = CreateKernelConfig()
if kwargs:
config = replace(config, **kwargs)
driver = DefaultKernelCreationDriver(config)
return driver(assignments)
def get_driver(
cfg: CreateKernelConfig, *, retain_intermediates: bool = False
) -> DefaultKernelCreationDriver:
"""Create a code generation driver object from the given configuration.
Args:
cfg: Configuration for the code generator
retain_intermediates: If `True`, instructs the driver to keep copies of
the intermediate results of its stages for later inspection.
"""
return DefaultKernelCreationDriver(cfg, retain_intermediates)
class DefaultKernelCreationDriver:
"""Drives the default kernel creation sequence.
Args:
cfg: Configuration for the code generator
retain_intermediates: If `True`, instructs the driver to keep copies of
the intermediate results of its stages for later inspection.
"""
def __init__(self, cfg: CreateKernelConfig, retain_intermediates: bool = False):
self._cfg = cfg
# Data Type Options
idx_dtype: PsIntegerType = cfg.get_option("index_dtype")
default_dtype: PsScalarType = cfg.get_option("default_dtype")
# Iteration Space Options
num_ispace_options_set = (
int(cfg.is_option_set("ghost_layers"))
+ int(cfg.is_option_set("iteration_slice"))
+ int(cfg.is_option_set("index_field"))
)
if num_ispace_options_set > 1:
raise ValueError(
"At most one of the options 'ghost_layers' 'iteration_slice' and 'index_field' may be set."
)
self._ghost_layers: GhostLayerSpec | None = cfg.get_option("ghost_layers")
self._iteration_slice: IterationSliceSpec | None = cfg.get_option(
"iteration_slice"
)
self._index_field: Field | None = cfg.get_option("index_field")
if num_ispace_options_set == 0:
self._ghost_layers = AUTO
# Create the context
self._ctx = KernelCreationContext(
default_dtype=default_dtype,
index_dtype=idx_dtype,
)
self._target = cfg.get_target()
self._platform = self._get_platform()
self._intermediates: CodegenIntermediates | None
if retain_intermediates:
self._intermediates = CodegenIntermediates()
else:
self._intermediates = None
@property
def intermediates(self) -> CodegenIntermediates | None:
return self._intermediates
def __call__(
self,
assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
) -> Kernel:
kernel_body = self.parse_kernel_body(assignments)
match self._platform:
case GenericCpu():
kernel_ast = self._platform.materialize_iteration_space(
kernel_body, self._ctx.get_iteration_space()
)
case GenericGpu():
kernel_ast, gpu_threads = self._platform.materialize_iteration_space(
kernel_body, self._ctx.get_iteration_space()
)
if self._intermediates is not None:
self._intermediates.materialized_ispace = kernel_ast.clone()
# Fold and extract constants
elim_constants = EliminateConstants(self._ctx, extract_constant_exprs=True)
kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
if self._intermediates is not None:
self._intermediates.constants_eliminated = kernel_ast.clone()
# Target-Specific optimizations
if self._target.is_cpu():
kernel_ast = self._transform_for_cpu(kernel_ast)
# Note: After this point, the AST may contain intrinsics, so type-dependent
# transformations cannot be run any more
# Lowering
lower_to_c = LowerToC(self._ctx)
kernel_ast = cast(PsBlock, lower_to_c(kernel_ast))
select_functions = SelectFunctions(self._platform)
kernel_ast = cast(PsBlock, select_functions(kernel_ast))
if self._intermediates is not None:
self._intermediates.lowered = kernel_ast.clone()
# Late canonicalization pass: Canonicalize new symbols introduced by LowerToC
canonicalize = CanonicalizeSymbols(self._ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
if self._target.is_cpu():
return create_cpu_kernel_function(
self._ctx,
self._platform,
kernel_ast,
self._cfg.get_option("function_name"),
self._target,
self._cfg.get_jit(),
)
else:
return create_gpu_kernel_function(
self._ctx,
self._platform,
kernel_ast,
gpu_threads,
self._cfg.get_option("function_name"),
self._target,
self._cfg.get_jit(),
)
def parse_kernel_body(
self,
assignments: AssignmentCollection | Sequence[AssignmentBase] | AssignmentBase,
) -> PsBlock:
if isinstance(assignments, AssignmentBase):
assignments = [assignments]
if not isinstance(assignments, AssignmentCollection):
assignments = AssignmentCollection(assignments) # type: ignore
_ = _parse_simplification_hints(assignments)
analysis = KernelAnalysis(
self._ctx,
not self._cfg.skip_independence_check,
not self._cfg.allow_double_writes,
)
analysis(assignments)
if self._index_field is not None:
ispace = create_sparse_iteration_space(
self._ctx, assignments, index_field=self._cfg.index_field
)
else:
gls: GhostLayerSpec | None
if self._ghost_layers == AUTO:
infer_gls = True
gls = None
else:
assert not isinstance(self._ghost_layers, _AUTO_TYPE)
infer_gls = False
gls = self._ghost_layers
ispace = create_full_iteration_space(
self._ctx,
assignments,
ghost_layers=gls,
iteration_slice=self._iteration_slice,
infer_ghost_layers=infer_gls,
)
self._ctx.set_iteration_space(ispace)
freeze = FreezeExpressions(self._ctx)
kernel_body = freeze(assignments)
typify = Typifier(self._ctx)
kernel_body = typify(kernel_body)
if self._intermediates is not None:
self._intermediates.parsed_body = kernel_body.clone()
return kernel_body
def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock:
canonicalize = CanonicalizeSymbols(self._ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
if self._intermediates is not None:
self._intermediates.cpu_canonicalize = kernel_ast.clone()
hoist_invariants = HoistLoopInvariantDeclarations(self._ctx)
kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
if self._intermediates is not None:
self._intermediates.cpu_hoist_invariants = kernel_ast.clone()
cpu_cfg = self._cfg.cpu
if cpu_cfg is None:
return kernel_ast
if cpu_cfg.loop_blocking:
raise NotImplementedError("Loop blocking not implemented yet.")
kernel_ast = self._vectorize(kernel_ast)
kernel_ast = self._add_openmp(kernel_ast)
if cpu_cfg.use_cacheline_zeroing:
raise NotImplementedError("CL-zeroing not implemented yet")
return kernel_ast
def _add_openmp(self, kernel_ast: PsBlock) -> PsBlock:
omp_options = self._cfg.cpu.openmp
enable_omp: bool = omp_options.get_option("enable")
if enable_omp:
from ..backend.transformations import AddOpenMP
add_omp = AddOpenMP(
self._ctx,
nesting_depth=omp_options.get_option("nesting_depth"),
num_threads=omp_options.get_option("num_threads"),
schedule=omp_options.get_option("schedule"),
collapse=omp_options.get_option("collapse"),
omit_parallel=omp_options.get_option("omit_parallel_construct"),
)
kernel_ast = cast(PsBlock, add_omp(kernel_ast))
if self._intermediates is not None:
self._intermediates.cpu_openmp = kernel_ast.clone()
return kernel_ast
def _vectorize(self, kernel_ast: PsBlock) -> PsBlock:
vec_options = self._cfg.cpu.vectorize
enable_vec = vec_options.get_option("enable")
if not enable_vec:
return kernel_ast
from ..backend.transformations import LoopVectorizer, SelectIntrinsics
assert isinstance(self._platform, GenericVectorCpu)
ispace = self._ctx.get_iteration_space()
if not isinstance(ispace, FullIterationSpace):
raise VectorizationError(
"Unable to vectorize kernel: The kernel is not using a dense iteration space."
)
inner_loop_coord = ispace.loop_order[-1]
inner_loop_dim = ispace.dimensions[inner_loop_coord]
# Apply stride (TODO: and alignment) assumptions
assume_unit_stride: bool = vec_options.get_option("assume_inner_stride_one")
if assume_unit_stride:
for field in self._ctx.fields:
buf = self._ctx.get_buffer(field)
inner_stride = buf.strides[inner_loop_coord]
if isinstance(inner_stride, PsConstant):
if inner_stride.value != 1:
raise VectorizationError(
f"Unable to apply assumption 'assume_inner_stride_one': "
f"Field {field} has fixed stride {inner_stride} "
f"set in the inner coordinate {inner_loop_coord}."
)
else:
buf.strides[inner_loop_coord] = PsConstant(1, buf.index_type)
# TODO: Communicate assumption to runtime system via a precondition
# Call loop vectorizer
num_lanes: int | None = vec_options.get_option("lanes")
if num_lanes is None:
num_lanes = VectorizationOptions.default_lanes(
self._target, cast(PsScalarType, self._ctx.default_dtype)
)
vectorizer = LoopVectorizer(self._ctx, num_lanes)
def loop_predicate(loop: PsLoop):
return loop.counter.symbol == inner_loop_dim.counter
kernel_ast = vectorizer.vectorize_select_loops(kernel_ast, loop_predicate)
if self._intermediates is not None:
self._intermediates.cpu_vectorize = kernel_ast.clone()
select_intrin = SelectIntrinsics(self._ctx, self._platform)
kernel_ast = cast(PsBlock, select_intrin(kernel_ast))
if self._intermediates is not None:
self._intermediates.cpu_select_intrins = kernel_ast.clone()
return kernel_ast
def _get_platform(self) -> Platform:
if Target._CPU in self._target:
if Target._X86 in self._target:
from ..backend.platforms.x86 import X86VectorArch, X86VectorCpu
arch: X86VectorArch
if Target._SSE in self._target:
arch = X86VectorArch.SSE
elif Target._AVX in self._target:
arch = X86VectorArch.AVX
elif Target._AVX512 in self._target:
if Target._FP16 in self._target:
arch = X86VectorArch.AVX512_FP16
else:
arch = X86VectorArch.AVX512
else:
assert False, "unreachable code"
return X86VectorCpu(self._ctx, arch)
elif self._target == Target.GenericCPU:
return GenericCpu(self._ctx)
else:
raise NotImplementedError(
f"No platform is currently available for CPU target {self._target}"
)
elif Target._GPU in self._target:
gpu_opts = self._cfg.gpu
omit_range_check: bool = gpu_opts.get_option("omit_range_check")
match self._target:
case Target.SYCL:
from ..backend.platforms import SyclPlatform
auto_block_size: bool = self._cfg.sycl.get_option("automatic_block_size")
return SyclPlatform(
self._ctx,
omit_range_check=omit_range_check,
automatic_block_size=auto_block_size,
)
case Target.CUDA:
from ..backend.platforms import CudaPlatform
manual_grid = gpu_opts.get_option("manual_launch_grid")
return CudaPlatform(
self._ctx,
omit_range_check=omit_range_check,
manual_launch_grid=manual_grid,
)
raise NotImplementedError(
f"Code generation for target {self._target} not implemented"
)
def create_cpu_kernel_function(
ctx: KernelCreationContext,
platform: Platform,
body: PsBlock,
function_name: str,
target_spec: Target,
jit: JitBase,
) -> Kernel:
undef_symbols = collect_undefined_symbols(body)
params = _get_function_params(ctx, undef_symbols)
req_headers = _get_headers(ctx, platform, body)
kfunc = Kernel(body, target_spec, function_name, params, req_headers, jit)
kfunc.metadata.update(ctx.metadata)
return kfunc
def create_gpu_kernel_function(
ctx: KernelCreationContext,
platform: Platform,
body: PsBlock,
threads_range: GpuThreadsRange | None,
function_name: str,
target_spec: Target,
jit: JitBase,
) -> GpuKernel:
undef_symbols = collect_undefined_symbols(body)
if threads_range is not None:
for threads in threads_range.num_work_items:
undef_symbols |= collect_undefined_symbols(threads)
params = _get_function_params(ctx, undef_symbols)
req_headers = _get_headers(ctx, platform, body)
kfunc = GpuKernel(
body,
threads_range,
target_spec,
function_name,
params,
req_headers,
jit,
)
kfunc.metadata.update(ctx.metadata)
return kfunc
def _get_function_params(
ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
) -> list[Parameter]:
params: list[Parameter] = []
from pystencils.backend.memory import BufferBasePtr
for symb in symbols:
props: set[PsSymbolProperty] = set()
for prop in symb.properties:
match prop:
case FieldShape() | FieldStride():
props.add(prop)
case BufferBasePtr(buf):
field = ctx.find_field(buf.name)
props.add(FieldBasePtr(field))
params.append(Parameter(symb.name, symb.get_dtype(), props))
params.sort(key=lambda p: p.name)
return params
def _get_headers(
ctx: KernelCreationContext, platform: Platform, body: PsBlock
) -> set[str]:
req_headers = collect_required_headers(body)
req_headers |= platform.required_headers
req_headers |= ctx.required_headers
return req_headers
@dataclass
class StageResult:
ast: PsAstNode
label: str
class StageResultSlot:
def __init__(self, description: str | None = None):
self._description = description
self._name: str
self._lookup: str
def __set_name__(self, owner, name: str):
self._name = name
self._lookup = f"_{name}"
def __get__(self, obj, objtype=None) -> StageResult | None:
if obj is None:
return None
ast = getattr(obj, self._lookup, None)
if ast is not None:
descr = self._name if self._description is None else self._description
return StageResult(ast, descr)
else:
return None
def __set__(self, obj, val: PsAstNode | None):
setattr(obj, self._lookup, val)
class CodegenIntermediates:
"""Intermediate results produced by the code generator."""
parsed_body = StageResultSlot("Freeze & Type Deduction")
materialized_ispace = StageResultSlot("Iteration Space Materialization")
constants_eliminated = StageResultSlot("Constant Elimination")
cpu_canonicalize = StageResultSlot("CPU: Symbol Canonicalization")
cpu_hoist_invariants = StageResultSlot("CPU: Hoisting of Loop Invariants")
cpu_vectorize = StageResultSlot("CPU: Vectorization")
cpu_select_intrins = StageResultSlot("CPU: Intrinsics Selection")
cpu_openmp = StageResultSlot("CPU: OpenMP Instrumentation")
lowered = StageResultSlot("C Language Lowering")
@property
def available_stages(self) -> Sequence[StageResult]:
all_results: list[StageResult | None] = [
getattr(self, name)
for name, slot in CodegenIntermediates.__dict__.items()
if isinstance(slot, StageResultSlot)
]
return tuple(filter(lambda r: r is not None, all_results)) # type: ignore
def create_staggered_kernel(
assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs
):
raise NotImplementedError(
"Staggered kernels are not yet implemented for pystencils 2.0"
)
# Internals
def _parse_simplification_hints(ac: AssignmentCollection):
if "split_groups" in ac.simplification_hints:
raise NotImplementedError(
"Loop splitting was requested, but is not implemented yet"
)
from __future__ import annotations
from warnings import warn
from typing import Callable, Sequence, Any, TYPE_CHECKING
from itertools import chain
from .target import Target
from .parameters import Parameter
from ..backend.ast.structural import PsBlock
from ..backend.ast.expressions import PsExpression
from ..field import Field
from .._deprecation import _deprecated
if TYPE_CHECKING:
from ..jit import JitBase
class Kernel:
"""A pystencils kernel.
The kernel object is the final result of the translation process.
It is immutable, and its AST should not be altered any more, either, as this
might invalidate information about the kernel already stored in the kernel object.
"""
def __init__(
self,
body: PsBlock,
target: Target,
name: str,
parameters: Sequence[Parameter],
required_headers: set[str],
jit: JitBase,
):
self._body: PsBlock = body
self._target = target
self._name = name
self._params = tuple(parameters)
self._required_headers = required_headers
self._jit = jit
self._metadata: dict[str, Any] = dict()
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
@property
def body(self) -> PsBlock:
return self._body
@property
def target(self) -> Target:
return self._target
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, n: str):
self._name = n
@property
def function_name(self) -> str: # pragma: no cover
_deprecated("function_name", "name")
return self._name
@function_name.setter
def function_name(self, n: str): # pragma: no cover
_deprecated("function_name", "name")
self._name = n
@property
def parameters(self) -> tuple[Parameter, ...]:
return self._params
def get_parameters(self) -> tuple[Parameter, ...]: # pragma: no cover
_deprecated("Kernel.get_parameters", "Kernel.parameters")
return self.parameters
def get_fields(self) -> set[Field]:
return set(chain.from_iterable(p.fields for p in self._params))
@property
def fields_accessed(self) -> set[Field]: # pragma: no cover
warn(
"`fields_accessed` is deprecated and will be removed in a future version of pystencils. "
"Use `get_fields` instead.",
DeprecationWarning,
)
return self.get_fields()
@property
def required_headers(self) -> set[str]:
return self._required_headers
def get_c_code(self) -> str:
from ..backend.emission import CAstPrinter
printer = CAstPrinter()
return printer(self)
def get_ir_code(self) -> str:
from ..backend.emission import IRAstPrinter
printer = IRAstPrinter()
return printer(self)
def compile(self) -> Callable[..., None]:
"""Invoke the underlying just-in-time compiler to obtain the kernel as an executable Python function."""
return self._jit.compile(self)
class GpuKernel(Kernel):
"""Internal representation of a kernel function targeted at CUDA GPUs."""
def __init__(
self,
body: PsBlock,
threads_range: GpuThreadsRange | None,
target: Target,
name: str,
parameters: Sequence[Parameter],
required_headers: set[str],
jit: JitBase,
):
super().__init__(body, target, name, parameters, required_headers, jit)
self._threads_range = threads_range
@property
def threads_range(self) -> GpuThreadsRange | None:
"""Object exposing the total size of the launch grid this kernel expects to be executed with."""
return self._threads_range
class GpuThreadsRange:
"""Number of threads required by a GPU kernel, in order (x, y, z)."""
def __init__(
self,
num_work_items: Sequence[PsExpression],
):
self._dim = len(num_work_items)
self._num_work_items = tuple(num_work_items)
# @property
# def grid_size(self) -> tuple[PsExpression, ...]:
# return self._grid_size
# @property
# def block_size(self) -> tuple[PsExpression, ...]:
# return self._block_size
@property
def num_work_items(self) -> tuple[PsExpression, ...]:
"""Number of work items in (x, y, z)-order."""
return self._num_work_items
@property
def dim(self) -> int:
return self._dim
def __str__(self) -> str:
rep = "GpuThreadsRange { "
rep += "; ".join(f"{x}: {w}" for x, w in zip("xyz", self._num_work_items))
rep += " }"
return rep
def _repr_html_(self) -> str:
return str(self)
from __future__ import annotations
from warnings import warn
from typing import Sequence, Iterable
from .properties import (
PsSymbolProperty,
_FieldProperty,
FieldShape,
FieldStride,
FieldBasePtr,
)
from ..types import PsType
from ..field import Field
from ..sympyextensions import TypedSymbol
class Parameter:
"""Parameter to an output object of the code generator."""
__match_args__ = ("name", "dtype", "properties")
def __init__(
self, name: str, dtype: PsType, properties: Iterable[PsSymbolProperty] = ()
):
self._name = name
self._dtype = dtype
self._properties: frozenset[PsSymbolProperty] = (
frozenset(properties) if properties is not None else frozenset()
)
self._fields: tuple[Field, ...] = tuple(
sorted(
set(
p.field # type: ignore
for p in filter(
lambda p: isinstance(p, _FieldProperty), self._properties
)
),
key=lambda f: f.name,
)
)
@property
def name(self):
return self._name
@property
def dtype(self):
return self._dtype
def _hashable_contents(self):
return (self._name, self._dtype, self._properties)
# TODO: Need?
def __hash__(self) -> int:
return hash(self._hashable_contents())
def __eq__(self, other: object) -> bool:
if not isinstance(other, Parameter):
return False
return (
type(self) is type(other)
and self._hashable_contents() == other._hashable_contents()
)
def __str__(self) -> str:
return self._name
def __repr__(self) -> str:
return f"{type(self).__name__}(name = {self._name}, dtype = {self._dtype})"
@property
def symbol(self) -> TypedSymbol:
return TypedSymbol(self.name, self.dtype)
@property
def fields(self) -> Sequence[Field]:
"""Set of fields associated with this parameter."""
return self._fields
def get_properties(
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
"""Retrieve all properties of the given type(s) attached to this parameter"""
return set(filter(lambda p: isinstance(p, prop_type), self._properties))
@property
def properties(self) -> frozenset[PsSymbolProperty]:
return self._properties
@property
def is_field_parameter(self) -> bool:
return bool(self._fields)
# Deprecated legacy properties
# These are kept mostly for the legacy waLBerla code generation system
@property
def is_field_pointer(self) -> bool: # pragma: no cover
warn(
"`is_field_pointer` is deprecated and will be removed in a future version of pystencils. "
"Use `param.get_properties(FieldBasePtr)` instead.",
DeprecationWarning,
)
return bool(self.get_properties(FieldBasePtr))
@property
def is_field_stride(self) -> bool: # pragma: no cover
warn(
"`is_field_stride` is deprecated and will be removed in a future version of pystencils. "
"Use `param.get_properties(FieldStride)` instead.",
DeprecationWarning,
)
return bool(self.get_properties(FieldStride))
@property
def is_field_shape(self) -> bool: # pragma: no cover
warn(
"`is_field_shape` is deprecated and will be removed in a future version of pystencils. "
"Use `param.get_properties(FieldShape)` instead.",
DeprecationWarning,
)
return bool(self.get_properties(FieldShape))
@property
def field_name(self) -> str: # pragma: no cover
warn(
"`field_name` is deprecated and will be removed in a future version of pystencils. "
"Use `param.fields[0].name` instead.",
DeprecationWarning,
)
return self._fields[0].name
from __future__ import annotations
from dataclasses import dataclass
from ..field import Field
@dataclass(frozen=True)
class PsSymbolProperty:
"""Base class for symbol properties, which can be used to add additional information to symbols"""
@dataclass(frozen=True)
class UniqueSymbolProperty(PsSymbolProperty):
"""Base class for unique properties, of which only one instance may be registered at a time."""
@dataclass(frozen=True)
class FieldShape(PsSymbolProperty):
"""Symbol acts as a shape parameter to a field."""
field: Field
coordinate: int
@dataclass(frozen=True)
class FieldStride(PsSymbolProperty):
"""Symbol acts as a stride parameter to a field."""
field: Field
coordinate: int
@dataclass(frozen=True)
class FieldBasePtr(UniqueSymbolProperty):
"""Symbol acts as a base pointer to a field."""
field: Field
FieldProperty = FieldShape | FieldStride | FieldBasePtr
_FieldProperty = (FieldShape, FieldStride, FieldBasePtr)
from __future__ import annotations
from enum import Flag, auto
from warnings import warn
from functools import cache
class Target(Flag):
"""
The Target enumeration represents all possible targets that can be used for code generation.
"""
# ------------------ Component Flags - Do Not Use Directly! -------------------------------------------
_CPU = auto()
_VECTOR = auto()
_X86 = auto()
_SSE = auto()
_AVX = auto()
_AVX512 = auto()
_VL = auto()
_FP16 = auto()
_ARM = auto()
_NEON = auto()
_SVE = auto()
_GPU = auto()
_CUDA = auto()
_SYCL = auto()
_AUTOMATIC = auto()
# ------------------ Actual Targets -------------------------------------------------------------------
CurrentCPU = _CPU | _AUTOMATIC
"""
Auto-best CPU target.
`CurrentCPU` causes the code generator to automatically select a CPU target according to CPUs found
on the current machine and runtime environment.
"""
GenericCPU = _CPU
"""Generic CPU target.
Generate the kernel for a generic multicore CPU architecture. This opens up all architecture-independent
optimizations including OpenMP, but no vectorization.
"""
CPU = GenericCPU
"""Alias for backward-compatibility"""
X86_SSE = _CPU | _VECTOR | _X86 | _SSE
"""x86 architecture with SSE vector extensions."""
X86_AVX = _CPU | _VECTOR | _X86 | _AVX
"""x86 architecture with AVX vector extensions."""
X86_AVX512 = _CPU | _VECTOR | _X86 | _AVX512
"""x86 architecture with AVX512 vector extensions."""
X86_AVX512_FP16 = _CPU | _VECTOR | _X86 | _AVX512 | _FP16
"""x86 architecture with AVX512 vector extensions and fp16-support."""
ARM_NEON = _CPU | _VECTOR | _ARM | _NEON
"""ARM architecture with NEON vector extensions"""
ARM_SVE = _CPU | _VECTOR | _ARM | _SVE
"""ARM architecture with SVE vector extensions"""
CurrentGPU = _GPU | _AUTOMATIC
"""Auto-best GPU target.
`CurrentGPU` causes the code generator to automatically select a GPU target according to GPU devices
found on the current machine and runtime environment.
"""
CUDA = _GPU | _CUDA
"""Generic CUDA GPU target.
Generate a CUDA kernel for a generic Nvidia GPU.
"""
GPU = CUDA
"""Alias for `Target.CUDA`, for backward compatibility."""
SYCL = _GPU | _SYCL
"""SYCL kernel target.
Generate a function to be called within a SYCL parallel command.
"""
def is_automatic(self) -> bool:
return Target._AUTOMATIC in self
def is_cpu(self) -> bool:
return Target._CPU in self
def is_vector_cpu(self) -> bool:
return self.is_cpu() and Target._VECTOR in self
def is_gpu(self) -> bool:
return Target._GPU in self
@staticmethod
def auto_cpu() -> Target:
"""Return the most capable vector CPU target available on the current machine."""
avail_targets = _available_vector_targets()
if avail_targets:
return avail_targets.pop()
else:
return Target.GenericCPU
@staticmethod
def available_targets() -> list[Target]:
targets = [Target.GenericCPU]
try:
import cupy # noqa: F401
targets.append(Target.CUDA)
except ImportError:
pass
targets += Target.available_vector_cpu_targets()
return targets
@staticmethod
def available_vector_cpu_targets() -> list[Target]:
"""Returns a list of available vector CPU targets, ordered from least to most capable."""
return _available_vector_targets()
@cache
def _available_vector_targets() -> list[Target]:
"""Returns available vector targets, sorted from leat to most capable."""
targets: list[Target] = []
import platform
if platform.machine() in ["x86_64", "x86", "AMD64", "i386"]:
try:
from cpuinfo import get_cpu_info
except ImportError:
warn(
"Unable to determine available x86 vector CPU targets for this system: "
"py-cpuinfo is not available.",
UserWarning,
)
return []
flags = set(get_cpu_info()["flags"])
if {"sse", "sse2", "ssse3", "sse4_1", "sse4_2"} < flags:
targets.append(Target.X86_SSE)
if {"avx", "avx2"} < flags:
targets.append(Target.X86_AVX)
if {"avx512f"} < flags:
targets.append(Target.X86_AVX512)
if {"avx512_fp16"} < flags:
targets.append(Target.X86_AVX512_FP16)
else:
warn(
"Unable to determine available vector CPU targets for this system: "
f"unknown platform {platform.machine()}.",
UserWarning,
)
return targets
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
from typing import Tuple, Union from typing import Tuple, Union
from .datahandling_interface import DataHandling from .datahandling_interface import DataHandling
from ..enums import Target from ..codegen.target import Target
from .serial_datahandling import SerialDataHandling from .serial_datahandling import SerialDataHandling
try: try:
......