Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
3 files
+ 66
3
Preferences
Compare changes
Files
3
@@ -2,7 +2,8 @@ from __future__ import annotations
from warnings import warn
from typing import TYPE_CHECKING
from ...types import constify
from ..ast import PsAstNode
from ...types import constify, PsPointerType, PsScalarType, PsCustomType
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu
@@ -22,12 +23,12 @@ from ..ast.expressions import (
PsCast,
PsCall,
PsLookup,
PsBufferAcc,
PsBufferAcc, PsSymbolExpr
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction
from ..functions import PsMathFunction, MathFunctions, CFunction, PsReductionFunction, ReductionFunctions
if TYPE_CHECKING:
from ...codegen import GpuIndexingConfig, GpuThreadsRange
@@ -134,6 +135,34 @@ class CudaPlatform(GenericGpu):
f"No implementation available for function {func} on data type {dtype}"
)
def unfold_function(
self, call: PsCall
) -> PsAstNode:
assert isinstance(call.function, PsReductionFunction)
func = call.function.func
match func:
case ReductionFunctions.InitLocalCopy:
symbol_expr, init_val = call.args
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression)
return PsDeclaration(symbol_expr, init_val)
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.op
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype], PsCustomType("void"))
call.args = [ptr_expr, symbol_expr]
if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64):
NotImplementedError("atomicMul is only available for float32/64 datatypes")
return call
# Internals
def _prepend_dense_translation(