Skip to content
Snippets Groups Projects
Commit dd8f421d authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Introduce functions to be unfolded by platform into code blocks for reduction init and write-back

parent 53807242
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #73127 failed
......@@ -30,6 +30,7 @@ from typing import Any, Sequence, TYPE_CHECKING
from abc import ABC
from enum import Enum
from ..sympyextensions import ReductionOp
from ..types import PsType
from .exceptions import PsInternalCompilerError
......@@ -134,6 +135,48 @@ class PsMathFunction(PsFunction):
return hash(self._func)
class ReductionFunctions(Enum):
"""Function representing different steps in kernels with reductions supported by the backend.
Each platform has to materialize these functions to a concrete implementation.
"""
InitLocalCopy = ("InitLocalCopy", 2)
WriteBackToPtr = ("WriteBackToPtr", 2)
def __init__(self, func_name, num_args):
self.function_name = func_name
self.num_args = num_args
class PsReductionFunction(PsFunction):
def __init__(self, func: ReductionFunctions, op: ReductionOp) -> None:
super().__init__(func.function_name, func.num_args)
self._func = func
self._op = op
@property
def func(self) -> ReductionFunctions:
return self._func
@property
def op(self) -> ReductionOp:
return self._op
def __str__(self) -> str:
return f"{self._func.function_name}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsReductionFunction):
return False
return self._func == other._func
def __hash__(self) -> int:
return hash(self._func)
class CFunction(PsFunction):
"""A concrete C function.
......
......@@ -50,7 +50,7 @@ from ..ast.expressions import (
PsNot,
)
from ..ast.vector import PsVecBroadcast, PsVecMemAcc
from ..functions import PsMathFunction, CFunction
from ..functions import PsMathFunction, CFunction, PsReductionFunction
from ..ast.util import determine_memory_object
from ..exceptions import TypificationError
......@@ -590,7 +590,7 @@ class Typifier:
case PsCall(function, args):
match function:
case PsMathFunction():
case PsMathFunction() | PsReductionFunction():
for arg in args:
self.visit_expr(arg, tc)
tc.infer_dtype(expr)
......
from abc import ABC, abstractmethod
from typing import Sequence
from pystencils.backend.ast.expressions import PsCall
from ..ast.expressions import PsCall, PsMemAcc, PsConstantExpr
from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions
from ..ast import PsAstNode
from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions, ReductionFunctions, \
PsReductionFunction
from ..literals import PsLiteral
from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType
from ...compound_op_mapping import compound_op_to_expr
from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType, PsPointerType
from .platform import Platform
from ..exceptions import MaterializationError
......@@ -18,7 +21,7 @@ from ..kernelcreation.iteration_space import (
)
from ..constants import PsConstant
from ..ast.structural import PsDeclaration, PsLoop, PsBlock
from ..ast.structural import PsDeclaration, PsLoop, PsBlock, PsAssignment
from ..ast.expressions import (
PsSymbolExpr,
PsExpression,
......@@ -56,6 +59,36 @@ class GenericCpu(Platform):
else:
raise MaterializationError(f"Unknown type of iteration space: {ispace}")
def unfold_function(
self, call: PsCall
) -> PsAstNode:
assert isinstance(call.function, PsReductionFunction)
func = call.function.func
match func:
case ReductionFunctions.InitLocalCopy:
symbol_expr, init_val = call.args
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression)
return PsDeclaration(symbol_expr, init_val)
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.op
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
ptr_access = PsMemAcc(ptr_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
# TODO: can this be avoided somehow?
potential_call = compound_op_to_expr(op, ptr_access, symbol_expr)
if isinstance(potential_call, PsCall):
potential_call.dtype = symbol_expr.dtype
potential_call = self.select_function(potential_call)
return PsAssignment(ptr_access, potential_call)
def select_function(self, call: PsCall) -> PsExpression:
assert isinstance(call.function, PsMathFunction)
......
from abc import ABC, abstractmethod
from typing import Any
from ..ast import PsAstNode
from ..ast.structural import PsBlock
from ..ast.expressions import PsCall, PsExpression
......@@ -40,3 +41,13 @@ class Platform(ABC):
If no viable implementation exists, raise a `MaterializationError`.
"""
pass
@abstractmethod
def unfold_function(
self, call: PsCall
) -> PsAstNode:
"""Unfolds an implementation for the given function on the given data type.
If no viable implementation exists, raise a `MaterializationError`.
"""
pass
from ..platforms import Platform
from ..ast import PsAstNode
from ..ast.expressions import PsCall
from ..functions import PsMathFunction
from ..functions import PsMathFunction, PsReductionFunction
class SelectFunctions:
......@@ -19,5 +19,7 @@ class SelectFunctions:
if isinstance(node, PsCall) and isinstance(node.function, PsMathFunction):
return self._platform.select_function(node)
elif isinstance(node, PsCall) and isinstance(node.function, PsReductionFunction):
return self._platform.unfold_function(node)
else:
return node
......@@ -7,14 +7,14 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
from .parameters import Parameter
from ..compound_op_mapping import compound_op_to_expr
from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
from ..backend.functions import PsReductionFunction, ReductionFunctions
from ..backend.ast.expressions import PsSymbolExpr, PsCall
from ..types import create_numeric_type, PsIntegerType, PsScalarType
from ..backend.memory import PsSymbol
from ..backend.ast import PsAstNode
from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment, PsDeclaration
from ..backend.ast.structural import PsBlock, PsLoop
from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
from ..backend.kernelcreation import (
KernelCreationContext,
......@@ -156,19 +156,20 @@ class DefaultKernelCreationDriver:
# Extensions for reductions
for symbol, reduction_info in self._ctx.symbols_reduction_info.items():
# Init local reduction variable copy
kernel_ast.statements = [PsDeclaration(PsSymbolExpr(symbol),
reduction_info.init_val)] + kernel_ast.statements
typify = Typifier(self._ctx)
symbol_expr = typify(PsSymbolExpr(symbol))
ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
init_val = typify(reduction_info.init_val)
# Write back result to reduction target variable
ptr_access = PsMemAcc(PsSymbolExpr(reduction_info.ptr_symbol),
PsConstantExpr(PsConstant(0)))
kernel_ast.statements += [PsAssignment(
ptr_access, compound_op_to_expr(reduction_info.op, ptr_access, PsSymbolExpr(symbol)))]
init_local_copy = PsCall(PsReductionFunction(ReductionFunctions.InitLocalCopy, reduction_info.op),
[symbol_expr, init_val])
write_back_ptr = PsCall(PsReductionFunction(ReductionFunctions.WriteBackToPtr, reduction_info.op),
[ptr_symbol_expr, symbol_expr])
# TODO: only newly introduced nodes
typify = Typifier(self._ctx)
kernel_ast = typify(kernel_ast)
# Init local reduction variable copy
kernel_ast.statements = [init_local_copy] + kernel_ast.statements
# Write back result to reduction target variable
kernel_ast.statements += [write_back_ptr]
# Target-Specific optimizations
if self._cfg.target.is_cpu():
......
......@@ -32,11 +32,11 @@ def test_reduction(dtype, op):
config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True)
ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype)
ps.show_code(ast_reduction)
# code_reduction = ps.get_code_str(ast_reduction)
kernel_reduction = ast_reduction.compile()
ps.show_code(ast_reduction)
array = np.full((SIZE,), INIT_ARR, dtype=dtype)
reduction_array = np.full((1,), INIT_W, dtype=dtype)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment