From e8f5f8aef9a1191f19b71138aabcb28dc55a08e9 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 5 Mar 2025 15:56:48 +0100
Subject: [PATCH] small fixes to cuda_invoke

---
 src/pystencilssfg/composer/basic_composer.py | 24 +++++++++++++++-----
 1 file changed, 18 insertions(+), 6 deletions(-)

diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 8a07674..686b60d 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -493,7 +493,7 @@ class SfgBasicComposer(SfgIComposer):
                 ]
                 block_size_var = dim3(const=True).var("__block_size")
 
-                nodes: list[SfgCallTreeNode] = [
+                nodes = [
                     self.init(grid_size_var)(*grid_size_entries),
                     self.init(block_size_var)(*block_size_entries),
                     _render_invocation(grid_size_var, block_size_var, stream),
@@ -502,12 +502,26 @@ class SfgBasicComposer(SfgIComposer):
                 return SfgBlock(SfgSequence(nodes))
 
             case DynamicBlockSizeLaunchConfiguration():
-                block_size = kwargs["block_size"]
+                user_block_size: ExprLike | None = kwargs["block_size"]
                 stream = kwargs["stream"]
 
+                if user_block_size is None:
+                    if launch_config.block_size is None:
+                        raise ValueError(
+                            "Neither a user-defined nor a default block size was defined."
+                        )
+
+                    block_size_init_args = tuple(
+                        str(bs) for bs in launch_config.block_size
+                    )
+                else:
+                    block_size_init_args = (user_block_size,)
+
+                block_size_var = dim3(const=True).var("__block_size")
+
                 from ..lang.cpp import std
 
-                witem_types = [lmb.return_type for lmb in launch_config.num_work_items]
+                witem_types = [wit.return_type for wit in launch_config.num_work_items]
                 work_items_entries = [
                     self.expr_from_lambda(wit) for wit in launch_config.num_work_items
                 ]
@@ -519,8 +533,6 @@ class SfgBasicComposer(SfgIComposer):
                 def _div_ceil(a: ExprLike, b: ExprLike):
                     return AugExpr.format("({a} + {b} - 1) / {b}", a=a, b=b)
 
-                block_size_var = dim3(const=True).var("__block_size")
-
                 reduced_block_size_entries = [
                     _min(work_items_var.get(i), bs)
                     for i, bs in enumerate(
@@ -542,7 +554,7 @@ class SfgBasicComposer(SfgIComposer):
                 grid_size_var = dim3(const=True).var("__grid_size")
 
                 nodes = [
-                    self.init(block_size_var)(block_size),
+                    self.init(block_size_var)(*block_size_init_args),
                     self.init(work_items_var)(*work_items_entries),
                     self.init(reduced_block_size_var)(*reduced_block_size_entries),
                     self.init(grid_size_var)(*grid_size_entries),
-- 
GitLab