Skip to content
Snippets Groups Projects
Commit c76b897f authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Drop ReductionFunctions and introduce PsReductionWriteBack

parent c4305590
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -95,45 +95,31 @@ class PsMathFunction(PsFunction):
return hash(self._func)
class ReductionFunctions(Enum):
"""Function representing different steps in kernels with reductions supported by the backend.
class PsReductionWriteBack(PsFunction):
"""Function representing a reduction kernel's write-back step supported by the backend.
Each platform has to materialize these functions to a concrete implementation.
Each platform has to materialize this function to a concrete implementation.
"""
WriteBackToPtr = ("WriteBackToPtr", 2)
def __init__(self, func_name, num_args):
self.function_name = func_name
self.num_args = num_args
class PsReductionFunction(PsFunction):
def __init__(self, func: ReductionFunctions, reduction_op: ReductionOp) -> None:
super().__init__(func.function_name, func.num_args)
self._func = func
def __init__(self, reduction_op: ReductionOp) -> None:
super().__init__("WriteBackToPtr", 2)
self._reduction_op = reduction_op
@property
def func(self) -> ReductionFunctions:
return self._func
@property
def reduction_op(self) -> ReductionOp:
return self._reduction_op
def __str__(self) -> str:
return f"{self._func.function_name}"
return f"{super().name}"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsReductionFunction):
if not isinstance(other, PsReductionWriteBack):
return False
return self._func == other._func and self._reduction_op == other._reduction_op
return self._reduction_op == other._reduction_op
def __hash__(self) -> int:
return hash(self._func)
return hash(self._reduction_op)
class ConstantFunctions(Enum):
......
......@@ -8,11 +8,10 @@ from ..ast import PsAstNode
from ..functions import (
CFunction,
MathFunctions,
ReductionFunctions,
PsMathFunction,
PsReductionFunction,
PsConstantFunction,
ConstantFunctions,
PsReductionWriteBack,
)
from ..reduction_op_mapping import reduction_op_to_expr
from ...sympyextensions import ReductionOp
......@@ -69,14 +68,9 @@ class GenericCpu(Platform):
self, call: PsCall
) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction))
assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction))
func = call_func.func
if (
isinstance(call_func, PsReductionFunction)
and func is ReductionFunctions.WriteBackToPtr
):
if isinstance(call_func, PsReductionWriteBack):
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
......@@ -103,6 +97,7 @@ class GenericCpu(Platform):
return potential_call
dtype = call.get_dtype()
func = call_func.func
arg_types = (dtype,) * call.function.arg_count
expr: PsExpression | None = None
......
......@@ -52,8 +52,7 @@ from ..literals import PsLiteral
from ..functions import (
MathFunctions,
CFunction,
ReductionFunctions,
PsReductionFunction,
PsReductionWriteBack,
PsMathFunction,
PsConstantFunction,
ConstantFunctions,
......@@ -296,20 +295,16 @@ class GenericGpu(Platform):
self, call: PsCall
) -> PsExpression | tuple[tuple[PsStructuralNode, ...], PsAstNode]:
call_func = call.function
assert isinstance(call_func, (PsReductionFunction | PsMathFunction | PsConstantFunction))
assert isinstance(call_func, (PsReductionWriteBack | PsMathFunction | PsConstantFunction))
func = call_func.func
if (
isinstance(call_func, PsReductionFunction)
and func is ReductionFunctions.WriteBackToPtr
):
if isinstance(call_func, PsReductionWriteBack):
ptr_expr, symbol_expr = call.args
op = call_func.reduction_op
return self.resolve_reduction(ptr_expr, symbol_expr, op)
dtype = call.get_dtype()
func = call_func.func
arg_types = (dtype,) * call.function.arg_count
expr: PsExpression | None = None
......
......@@ -3,7 +3,7 @@ from ..exceptions import MaterializationError
from ..platforms import Platform
from ..ast import PsAstNode
from ..ast.expressions import PsCall, PsExpression
from ..functions import PsMathFunction, PsConstantFunction, PsReductionFunction
from ..functions import PsMathFunction, PsConstantFunction, PsReductionWriteBack
class SelectFunctions:
......@@ -22,7 +22,7 @@ class SelectFunctions:
if isinstance(node, PsAssignment):
rhs = node.rhs
if isinstance(rhs, PsCall) and isinstance(
rhs.function, PsReductionFunction
rhs.function, PsReductionWriteBack
):
resolved_func = self._platform.select_function(rhs)
......
......@@ -26,7 +26,7 @@ from ..types import PsIntegerType, PsScalarType
from ..backend.memory import PsSymbol
from ..backend.ast import PsAstNode
from ..backend.functions import PsReductionFunction, ReductionFunctions
from ..backend.functions import PsReductionWriteBack
from ..backend.ast.expressions import (
PsExpression,
PsSymbolExpr,
......@@ -308,9 +308,7 @@ class DefaultKernelCreationDriver:
ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
)
write_back_ptr = PsCall(
PsReductionFunction(
ReductionFunctions.WriteBackToPtr, reduction_info.op
),
PsReductionWriteBack(reduction_info.op),
[ptr_symbol_expr, symbol_expr],
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment