From 2c507d796f0c5abbff32386f4268d3bdb988c6fa Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Thu, 20 Mar 2025 18:18:10 +0100 Subject: [PATCH] Fix required headers for cuda/hip platforms --- src/pystencils/backend/platforms/cuda.py | 4 +++- src/pystencils/backend/platforms/hip.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 98ff3e3d3..c05c45f04 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -8,4 +8,6 @@ class CudaPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return set() + return { + '"gpu_atomics.h"', + } diff --git a/src/pystencils/backend/platforms/hip.py b/src/pystencils/backend/platforms/hip.py index c758995a0..65d844bbb 100644 --- a/src/pystencils/backend/platforms/hip.py +++ b/src/pystencils/backend/platforms/hip.py @@ -8,4 +8,7 @@ class HipPlatform(GenericGpu): @property def required_headers(self) -> set[str]: - return {'"pystencils_runtime/hip.h"'} + return { + '"gpu_atomics.h"', + '"pystencils_runtime/hip.h"', + } -- GitLab