From f77909540a9d26ce9dc1decde5c9508bc1f2d14a Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 24 Mar 2025 16:22:40 +0100 Subject: [PATCH] Use NPP library for numeric limits for CUDA, use std limits for HIP --- src/pystencils/backend/platforms/cuda.py | 25 +++++++++++++++++-- .../backend/platforms/generic_gpu.py | 21 +++++----------- src/pystencils/backend/platforms/hip.py | 17 ++++++++++++- 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index bbb608f5c..b5b3478e4 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -1,11 +1,32 @@ from __future__ import annotations from .generic_gpu import GenericGpu +from ..ast.expressions import PsExpression, PsLiteralExpr +from ..functions import PsFunction, NumericLimitsFunctions +from ..literals import PsLiteral +from ...types import PsType, PsIeeeFloatType class CudaPlatform(GenericGpu): - """Platform for the CUDA GPU taret.""" + """Platform for the CUDA GPU target.""" @property def required_headers(self) -> set[str]: - return super().required_headers + return super().required_headers | { + '"npp.h"', + } + + def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + assert isinstance(dtype, PsIeeeFloatType) + + match func: + case NumericLimitsFunctions.Min: + define = f"NPP_MINABS_{dtype.width}F" + case NumericLimitsFunctions.Max: + define = f"NPP_MAXABS_{dtype.width}F" + case _: + raise MaterializationError( + f"Cannot materialize call to function {func}" + ) + + return PsLiteralExpr(PsLiteral(define, dtype)) diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 4f97264b0..787b390fe 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -9,7 +9,7 @@ from ..ast import PsAstNode from ..constants import PsConstant from ...compound_op_mapping import compound_op_to_expr from ...sympyextensions.reduction import ReductionOp -from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType +from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType, PsType from ...types.quick import UInt, SInt from ..exceptions import MaterializationError from .platform import Platform @@ -203,9 +203,12 @@ class GenericGpu(Platform): def required_headers(self) -> set[str]: return { '"gpu_atomics.h"', - "<cmath>", } + @abstractmethod + def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + pass + def __init__( self, ctx: KernelCreationContext, @@ -369,19 +372,7 @@ class GenericGpu(Platform): arg_types = (dtype,) * func.num_args if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions: - assert isinstance(dtype, PsIeeeFloatType) - - match func: - case NumericLimitsFunctions.Min: - define = "-INFINITY" - case NumericLimitsFunctions.Max: - define = "INFINITY" - case _: - raise MaterializationError( - f"Cannot materialize call to function {func}" - ) - - return PsLiteralExpr(PsLiteral(define, dtype)) + return self.resolve_numeric_limits(func, dtype) if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: match func: diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py index c5e8b3882..60e249aeb 100644 --- a/src/pystencils/backend/platforms/hip.py +++ b/src/pystencils/backend/platforms/hip.py @@ -1,13 +1,28 @@ from __future__ import annotations from .generic_gpu import GenericGpu +from ..ast.expressions import PsExpression, PsLiteralExpr +from ..functions import PsMathFunction +from ..literals import PsLiteral +from ...types import PsType, PsIeeeFloatType class HipPlatform(GenericGpu): - """Platform for the HIP GPU taret.""" + """Platform for the HIP GPU target.""" @property def required_headers(self) -> set[str]: return super().required_headers | { '"pystencils_runtime/hip.h"', + "<limits>" } + + def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + assert isinstance(dtype, PsIeeeFloatType) + + return PsLiteralExpr( + PsLiteral( + f"std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", + dtype, + ) + ) -- GitLab