diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 8a076745eac0d2cb98b88490469a58975de1e896..686b60db62499b455eda6d71ac755e7cb13a194e 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -493,7 +493,7 @@ class SfgBasicComposer(SfgIComposer): ] block_size_var = dim3(const=True).var("__block_size") - nodes: list[SfgCallTreeNode] = [ + nodes = [ self.init(grid_size_var)(*grid_size_entries), self.init(block_size_var)(*block_size_entries), _render_invocation(grid_size_var, block_size_var, stream), @@ -502,12 +502,26 @@ class SfgBasicComposer(SfgIComposer): return SfgBlock(SfgSequence(nodes)) case DynamicBlockSizeLaunchConfiguration(): - block_size = kwargs["block_size"] + user_block_size: ExprLike | None = kwargs["block_size"] stream = kwargs["stream"] + if user_block_size is None: + if launch_config.block_size is None: + raise ValueError( + "Neither a user-defined nor a default block size was defined." + ) + + block_size_init_args = tuple( + str(bs) for bs in launch_config.block_size + ) + else: + block_size_init_args = (user_block_size,) + + block_size_var = dim3(const=True).var("__block_size") + from ..lang.cpp import std - witem_types = [lmb.return_type for lmb in launch_config.num_work_items] + witem_types = [wit.return_type for wit in launch_config.num_work_items] work_items_entries = [ self.expr_from_lambda(wit) for wit in launch_config.num_work_items ] @@ -519,8 +533,6 @@ class SfgBasicComposer(SfgIComposer): def _div_ceil(a: ExprLike, b: ExprLike): return AugExpr.format("({a} + {b} - 1) / {b}", a=a, b=b) - block_size_var = dim3(const=True).var("__block_size") - reduced_block_size_entries = [ _min(work_items_var.get(i), bs) for i, bs in enumerate( @@ -542,7 +554,7 @@ class SfgBasicComposer(SfgIComposer): grid_size_var = dim3(const=True).var("__grid_size") nodes = [ - self.init(block_size_var)(block_size), + self.init(block_size_var)(*block_size_init_args), self.init(work_items_var)(*work_items_entries), self.init(reduced_block_size_var)(*reduced_block_size_entries), self.init(grid_size_var)(*grid_size_entries),