From c31d407433bb075a42691866e87e59208bf99d90 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 20 Mar 2025 18:56:41 +0100
Subject: [PATCH] Try fixing required headers for cuda and hip for reductions

---
 src/pystencils/backend/platforms/cuda.py        |  4 +---
 src/pystencils/backend/platforms/generic_gpu.py | 12 ++++++++++--
 src/pystencils/backend/platforms/hip.py         |  3 +--
 3 files changed, 12 insertions(+), 7 deletions(-)

diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index c05c45f04..bbb608f5c 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -8,6 +8,4 @@ class CudaPlatform(GenericGpu):
 
     @property
     def required_headers(self) -> set[str]:
-        return {
-            '"gpu_atomics.h"',
-        }
+        return super().required_headers
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index 2a12d6b7b..4f97264b0 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -198,6 +198,14 @@ class GenericGpu(Platform):
         thread_mapping: Callback object which defines the mapping of thread indices onto iteration space points
     """
 
+    @property
+    @abstractmethod
+    def required_headers(self) -> set[str]:
+        return {
+            '"gpu_atomics.h"',
+            "<cmath>",
+        }
+
     def __init__(
         self,
         ctx: KernelCreationContext,
@@ -365,9 +373,9 @@ class GenericGpu(Platform):
 
             match func:
                 case NumericLimitsFunctions.Min:
-                    define = "NEG_INFINITY"
+                    define = "-INFINITY"
                 case NumericLimitsFunctions.Max:
-                    define = "POS_INFINITY"
+                    define = "INFINITY"
                 case _:
                     raise MaterializationError(
                         f"Cannot materialize call to function {func}"
diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py
index 65d844bbb..c5e8b3882 100644
--- a/src/pystencils/backend/platforms/hip.py
+++ b/src/pystencils/backend/platforms/hip.py
@@ -8,7 +8,6 @@ class HipPlatform(GenericGpu):
 
     @property
     def required_headers(self) -> set[str]:
-        return {
-            '"gpu_atomics.h"',
+        return super().required_headers | {
             '"pystencils_runtime/hip.h"',
         }
-- 
GitLab