From 4b1f3f5374dcef0c0e56f44c9679782080278bd7 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Tue, 15 Aug 2023 09:44:27 +0200
Subject: [PATCH] fixed the _add_launch_bounds and also added some small tests

---
 pystencils_benchmark/gpu/benchmark.py | 10 +++++++---
 tests/test_launch_bounds.py           | 19 +++++++++++++++++++
 2 files changed, 26 insertions(+), 3 deletions(-)
 create mode 100644 tests/test_launch_bounds.py

diff --git a/pystencils_benchmark/gpu/benchmark.py b/pystencils_benchmark/gpu/benchmark.py
index 5a4852c..96fb58c 100644
--- a/pystencils_benchmark/gpu/benchmark.py
+++ b/pystencils_benchmark/gpu/benchmark.py
@@ -18,9 +18,13 @@ from pystencils_benchmark.enums import Compiler
 
 
 def _add_launch_bound(code: str, launch_bounds: tuple) -> str:
-    lb_str = f"__launch_bounds__({','.join(str(lb) for lb in launch_bounds)})"
-    splitted = code.split("void")
-    return splitted[0] + lb_str + "".join(splitted[1:])
+    lb_str = f"__launch_bounds__({', '.join(str(lb) for lb in launch_bounds)}) "
+    splitted = code.split("void ")
+    prefix = splitted[0]
+    if code.startswith("void "):
+        # just in case that there is nothing before the first void
+        prefix = ""
+    return prefix + "void " + lb_str + "void ".join(splitted[1:])
 
 
 def generate_benchmark(kernel_asts: Union[KernelFunction, List[KernelFunction]],
diff --git a/tests/test_launch_bounds.py b/tests/test_launch_bounds.py
new file mode 100644
index 0000000..48af06d
--- /dev/null
+++ b/tests/test_launch_bounds.py
@@ -0,0 +1,19 @@
+import numpy as np
+import pystencils as ps
+from pystencils_benchmark.gpu.benchmark import kernel_header, _add_launch_bound, kernel_source
+
+
+def test_launch_bounds():
+    a, b, c = ps.fields(a=np.ones(4000000), b=np.ones(4000000), c=np.ones(4000000))
+
+    @ps.kernel_config(ps.CreateKernelConfig(target=ps.Target.GPU))
+    def vadd():
+        a[0] @= b[0] + c[0]
+    kernel_vadd = ps.create_kernel(**vadd)
+    launch_bounds = (256, 2)
+    header = kernel_header(kernel_vadd)
+    header = _add_launch_bound(header, launch_bounds)
+    assert "void __launch_bounds__(256, 2)" in header
+    source = kernel_source(kernel_vadd)
+    source = _add_launch_bound(source, launch_bounds)
+    assert "void __launch_bounds__(256, 2)" in source 
-- 
GitLab