diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 1a8fdc4821f6f62cc37b58ca1f4beb91f3c51d87..1af8917ccb7a4d07b700e43ae22958bf45817aa3 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -90,8 +90,7 @@ class CudaPlatform(GenericGpu): assert isinstance(dtype, PsIeeeFloatType) defines = { NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY" } - return PsLiteralExpr( - PsLiteral(defines[func.function_name], dtype)) + return PsLiteralExpr(PsLiteral(defines[func], dtype)) if isinstance(dtype, PsIeeeFloatType): match func: diff --git a/src/pystencils/include/gpu_defines.h b/src/pystencils/include/gpu_defines.h index 04eeace47c0f1b2659883d4b44e464d587adb2a7..8f961e25b6ea1a4b616af4dbba437c5dbb66d524 100644 --- a/src/pystencils/include/gpu_defines.h +++ b/src/pystencils/include/gpu_defines.h @@ -1,8 +1,10 @@ #pragma once #define POS_INFINITY __int_as_float(0x7f800000) -#define INFINITY POS_INFINITY #define NEG_INFINITY __int_as_float(0xff800000) +#ifndef INFINITY +#define INFINITY POS_INFINITY +#endif #ifdef __HIPCC_RTC__ typedef __hip_uint8_t uint8_t;