From 0fb11858f2c65fc46c8dca469c75c28bf283dfdb Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Thu, 30 Jan 2025 20:06:00 +0100 Subject: [PATCH] Add CUDA backend for numeric limits --- src/pystencils/backend/platforms/cuda.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index bf5b91b82..ef3c11598 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -28,7 +28,8 @@ from ..ast.expressions import ( from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType, PsIeeeFloatType from ..literals import PsLiteral -from ..functions import PsMathFunction, MathFunctions, CFunction, PsReductionFunction, ReductionFunctions +from ..functions import PsMathFunction, MathFunctions, CFunction, PsReductionFunction, ReductionFunctions, \ + NumericLimitsFunctions if TYPE_CHECKING: from ...codegen import GpuIndexingConfig, GpuThreadsRange @@ -64,7 +65,7 @@ class CudaPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return {'"gpu_defines.h"'} + return {'"gpu_defines.h"', "<cuda/std/limits>"} def materialize_iteration_space( self, body: PsBlock, ispace: IterationSpace @@ -83,6 +84,9 @@ class CudaPlatform(GenericGpu): dtype = call.get_dtype() arg_types = (dtype,) * func.num_args + if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max): + return PsLiteralExpr(PsLiteral(f"::cuda::std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", dtype)) + if isinstance(dtype, PsIeeeFloatType): match func: case ( -- GitLab