Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
5 files
+ 72
10
Preferences
Compare changes
Files
5
@@ -3,7 +3,7 @@ from warnings import warn
from typing import TYPE_CHECKING
from ..ast import PsAstNode
from ...types import constify
from ...types import constify, PsPointerType, PsScalarType, PsCustomType
from ..exceptions import MaterializationError
from .generic_gpu import GenericGpu
@@ -23,12 +23,13 @@ from ..ast.expressions import (
PsCast,
PsCall,
PsLookup,
PsBufferAcc,
PsBufferAcc, PsSymbolExpr
)
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction
from ..functions import PsMathFunction, MathFunctions, CFunction, PsReductionFunction, ReductionFunctions, \
NumericLimitsFunctions
if TYPE_CHECKING:
from ...codegen import GpuIndexingConfig, GpuThreadsRange
@@ -64,7 +65,7 @@ class CudaPlatform(GenericGpu):
@property
def required_headers(self) -> set[str]:
return {'"gpu_defines.h"'}
return {'"gpu_defines.h"', "<cuda/std/limits>"}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
@@ -83,6 +84,10 @@ class CudaPlatform(GenericGpu):
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max):
return PsLiteralExpr(
PsLiteral(f"::cuda::std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", dtype))
if isinstance(dtype, PsIeeeFloatType):
match func:
case (
@@ -138,7 +143,31 @@ class CudaPlatform(GenericGpu):
def unfold_function(
self, call: PsCall
) -> PsAstNode:
pass
assert isinstance(call.function, PsReductionFunction)
func = call.function.func
match func:
case ReductionFunctions.InitLocalCopy:
symbol_expr, init_val = call.args
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(init_val, PsExpression)
return PsDeclaration(symbol_expr, init_val)
case ReductionFunctions.WriteBackToPtr:
ptr_expr, symbol_expr = call.args
op = call.function.op
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64):
NotImplementedError("atomicMul is only available for float32/64 datatypes")
return call
# Internals