Skip to content
Snippets Groups Projects

Reduction Support

Open Richard Angersbach requested to merge rangersbach/reductions into v2.0-dev
Compare and Show latest version
4 files
+ 20
9
Preferences
Compare changes
Files
4
@@ -28,7 +28,8 @@ from ..ast.expressions import (
from ..ast.expressions import PsLt, PsAnd
from ...types import PsSignedIntegerType, PsIeeeFloatType
from ..literals import PsLiteral
from ..functions import PsMathFunction, MathFunctions, CFunction, PsReductionFunction, ReductionFunctions
from ..functions import PsMathFunction, MathFunctions, CFunction, PsReductionFunction, ReductionFunctions, \
NumericLimitsFunctions
if TYPE_CHECKING:
from ...codegen import GpuIndexingConfig, GpuThreadsRange
@@ -64,7 +65,7 @@ class CudaPlatform(GenericGpu):
@property
def required_headers(self) -> set[str]:
return {'"gpu_defines.h"'}
return {'"gpu_defines.h"', "<cuda/std/limits>"}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
@@ -83,6 +84,10 @@ class CudaPlatform(GenericGpu):
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max):
return PsLiteralExpr(
PsLiteral(f"::cuda::std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", dtype))
if isinstance(dtype, PsIeeeFloatType):
match func:
case (
@@ -155,8 +160,9 @@ class CudaPlatform(GenericGpu):
assert isinstance(ptr_expr, PsSymbolExpr) and isinstance(ptr_expr.dtype, PsPointerType)
assert isinstance(symbol_expr, PsSymbolExpr) and isinstance(symbol_expr.dtype, PsScalarType)
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype], PsCustomType("void"))
call.args = [ptr_expr, symbol_expr]
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void"))
call.args = (ptr_expr, symbol_expr)
if not isinstance(symbol_expr.dtype, PsIeeeFloatType) or symbol_expr.dtype.width not in (32, 64):
NotImplementedError("atomicMul is only available for float32/64 datatypes")