From 2c507d796f0c5abbff32386f4268d3bdb988c6fa Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Thu, 20 Mar 2025 18:18:10 +0100
Subject: [PATCH] Fix required headers for cuda/hip platforms

---
 src/pystencils/backend/platforms/cuda.py | 4 +++-
 src/pystencils/backend/platforms/hip.py  | 5 ++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 98ff3e3d3..c05c45f04 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -8,4 +8,6 @@ class CudaPlatform(GenericGpu):
 
     @property
     def required_headers(self) -> set[str]:
-        return set()
+        return {
+            '"gpu_atomics.h"',
+        }
diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py
index c758995a0..65d844bbb 100644
--- a/src/pystencils/backend/platforms/hip.py
+++ b/src/pystencils/backend/platforms/hip.py
@@ -8,4 +8,7 @@ class HipPlatform(GenericGpu):
 
     @property
     def required_headers(self) -> set[str]:
-        return {'"pystencils_runtime/hip.h"'}
+        return {
+            '"gpu_atomics.h"',
+            '"pystencils_runtime/hip.h"',
+        }
-- 
GitLab