From 4c726aa6aa2df6312252cb848ace95e790d62331 Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Fri, 24 Jan 2025 14:43:32 +0100
Subject: [PATCH] Prepare reduction test for GPU support

---
 tests/kernelcreation/test_reduction.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py
index b97343e72..b56a24a19 100644
--- a/tests/kernelcreation/test_reduction.py
+++ b/tests/kernelcreation/test_reduction.py
@@ -1,6 +1,7 @@
 import pytest
 import numpy as np
 import sympy as sp
+import cupy as cp
 
 import pystencils as ps
 from pystencils.sympyextensions import reduced_assign
@@ -18,6 +19,9 @@ SOLUTION = {
 @pytest.mark.parametrize('dtype', ["float64"])
 @pytest.mark.parametrize("op", ["+", "-", "*", "min", "max"])
 def test_reduction(dtype, op):
+
+    gpu_avail = True
+
     x = ps.fields(f'x: {dtype}[1d]')
     w = sp.Symbol("w")
 
@@ -25,7 +29,7 @@ def test_reduction(dtype, op):
 
     reduction_assignment = reduced_assign(w, op, x.center())
 
-    config = ps.CreateKernelConfig(cpu_openmp=True)
+    config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True)
 
     ast_reduction = ps.create_kernel([reduction_assignment], config, default_dtype=dtype)
     #code_reduction = ps.get_code_str(ast_reduction)
@@ -35,5 +39,13 @@ def test_reduction(dtype, op):
 
     array = np.full((SIZE,), INIT, dtype=dtype)
     reduction_array = np.zeros(1, dtype=dtype)
-    kernel_reduction(x=array, w=reduction_array)
-    assert np.allclose(reduction_array, SOLUTION[op])
\ No newline at end of file
+
+    if gpu_avail:
+        array_gpu = cp.asarray(array)
+        reduction_array_gpu = cp.asarray(reduction_array)
+
+        kernel_reduction(x=array_gpu, w=reduction_array_gpu)
+        assert np.allclose(reduction_array_gpu.get(), SOLUTION[op])
+    else:
+        kernel_reduction(x=array, w=reduction_array)
+        assert np.allclose(reduction_array, SOLUTION[op])
\ No newline at end of file
-- 
GitLab