From d8ae900242264392479f4405678d9a1f1b177890 Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Tue, 4 Feb 2025 17:49:37 +0100 Subject: [PATCH] Use predefined macro values for numeric limits in cuda backend --- src/pystencils/backend/platforms/cuda.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 73c4b3b47..fa246c128 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -65,7 +65,7 @@ class CudaPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return {'"gpu_defines.h"', "<cuda/std/limits>"} + return {'"gpu_defines.h"'} def materialize_iteration_space( self, body: PsBlock, ispace: IterationSpace @@ -85,8 +85,11 @@ class CudaPlatform(GenericGpu): arg_types = (dtype,) * func.num_args if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max): + assert isinstance(dtype, PsIeeeFloatType) + defines = { NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY" } + return PsLiteralExpr( - PsLiteral(f"::cuda::std::numeric_limits<{dtype.c_string()}>::{func.function_name}()", dtype)) + PsLiteral(defines[func.function_name], dtype)) if isinstance(dtype, PsIeeeFloatType): match func: -- GitLab