diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index bbb608f5c0963a73441fd2a6f6239d747b45fe69..b5b3478e4267e451bf6628dcc04829ea99666635 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -1,11 +1,32 @@ from __future__ import annotations from .generic_gpu import GenericGpu +from ..ast.expressions import PsExpression, PsLiteralExpr +from ..functions import PsFunction, NumericLimitsFunctions +from ..literals import PsLiteral +from ...types import PsType, PsIeeeFloatType class CudaPlatform(GenericGpu): - """Platform for the CUDA GPU taret.""" + """Platform for the CUDA GPU target.""" @property def required_headers(self) -> set[str]: - return super().required_headers + return super().required_headers | { + '"npp.h"', + } + + def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + assert isinstance(dtype, PsIeeeFloatType) + + match func: + case NumericLimitsFunctions.Min: + define = f"NPP_MINABS_{dtype.width}F" + case NumericLimitsFunctions.Max: + define = f"NPP_MAXABS_{dtype.width}F" + case _: + raise MaterializationError( + f"Cannot materialize call to function {func}" + ) + + return PsLiteralExpr(PsLiteral(define, dtype)) diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 4f97264b089de1d7263d885f9d13d1498a32fa41..787b390fe0fcb66a4696eaf73c55f618027847d2 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -9,7 +9,7 @@ from ..ast import PsAstNode from ..constants import PsConstant from ...compound_op_mapping import compound_op_to_expr from ...sympyextensions.reduction import ReductionOp -from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType +from ...types import constify, deconstify, PsPointerType, PsScalarType, PsCustomType, PsType from ...types.quick import UInt, SInt from ..exceptions import MaterializationError from .platform import Platform @@ -203,9 +203,12 @@ class GenericGpu(Platform): def required_headers(self) -> set[str]: return { '"gpu_atomics.h"', - "<cmath>", } + @abstractmethod + def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + pass + def __init__( self, ctx: KernelCreationContext, @@ -369,19 +372,7 @@ class GenericGpu(Platform): arg_types = (dtype,) * func.num_args if isinstance(dtype, PsScalarType) and func in NumericLimitsFunctions: - assert isinstance(dtype, PsIeeeFloatType) - - match func: - case NumericLimitsFunctions.Min: - define = "-INFINITY" - case NumericLimitsFunctions.Max: - define = "INFINITY" - case _: - raise MaterializationError( - f"Cannot materialize call to function {func}" - ) - - return PsLiteralExpr(PsLiteral(define, dtype)) + return self.resolve_numeric_limits(func, dtype) if isinstance(dtype, PsIeeeFloatType) and func in MathFunctions: match func: diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py index c5e8b3882de8600e4bec8c72e71c7a2032eb1f6e..60e249aebeffeaafc35ad61254606391f16639d6 100644 --- a/src/pystencils/backend/platforms/hip.py +++ b/src/pystencils/backend/platforms/hip.py @@ -1,13 +1,28 @@ from __future__ import annotations from .generic_gpu import GenericGpu +from ..ast.expressions import PsExpression, PsLiteralExpr +from ..functions import PsMathFunction +from ..literals import PsLiteral +from ...types import PsType, PsIeeeFloatType class HipPlatform(GenericGpu): - """Platform for the HIP GPU taret.""" + """Platform for the HIP GPU target.""" @property def required_headers(self) -> set[str]: return super().required_headers | { '"pystencils_runtime/hip.h"', + "<limits>" } + + def resolve_numeric_limits(self, func: PsMathFunction, dtype: PsType) -> PsExpression: + assert isinstance(dtype, PsIeeeFloatType) + + return PsLiteralExpr( + PsLiteral( + f"std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", + dtype, + ) + )