From b4dd0c8c55d26f87b4467f814c526fafc2ced76b Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Fri, 17 Jan 2025 14:29:10 +0100 Subject: [PATCH] Fix usage of numerical limits for init value of reduction --- src/pystencils/backend/functions.py | 8 ++++++-- src/pystencils/backend/kernelcreation/freeze.py | 4 ++-- src/pystencils/backend/platforms/generic_cpu.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 736345395..18c2277cf 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -100,8 +100,12 @@ class NumericLimitsFunctions(Enum): Each platform has to materialize these functions to a concrete implementation. """ - min = ("min", 0) - max = ("max", 0) + Min = ("min", 0) + Max = ("max", 0) + + def __init__(self, func_name, num_args): + self.function_name = func_name + self.num_args = num_args class PsMathFunction(PsFunction): diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index ae728dd49..9a34303e2 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -206,10 +206,10 @@ class FreezeExpressions: # TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards case "min": op = sp.Min - init_val = NumericLimitsFunctions("min") + init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) case "max": op = sp.Max - init_val = NumericLimitsFunctions("max") + init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 620cf9cfb..ae59d0423 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -62,7 +62,7 @@ class GenericCpu(Platform): dtype = call.get_dtype() arg_types = (dtype,) * func.num_args - if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.min, NumericLimitsFunctions.max): + if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max): cfunc = CFunction(f"{dtype.c_string()}_{func.function_name}".capitalize(), arg_types, dtype) call.function = cfunc return call -- GitLab