Skip to content
Snippets Groups Projects
Commit e8f5f8ae authored by Frederik Hennig's avatar Frederik Hennig
Browse files

small fixes to cuda_invoke

parent 2e175f00
1 merge request!24Extend Support for CUDA and HIP kernel invocations
Pipeline #75137 failed with stages
in 2 minutes and 3 seconds
......@@ -493,7 +493,7 @@ class SfgBasicComposer(SfgIComposer):
]
block_size_var = dim3(const=True).var("__block_size")
nodes: list[SfgCallTreeNode] = [
nodes = [
self.init(grid_size_var)(*grid_size_entries),
self.init(block_size_var)(*block_size_entries),
_render_invocation(grid_size_var, block_size_var, stream),
......@@ -502,12 +502,26 @@ class SfgBasicComposer(SfgIComposer):
return SfgBlock(SfgSequence(nodes))
case DynamicBlockSizeLaunchConfiguration():
block_size = kwargs["block_size"]
user_block_size: ExprLike | None = kwargs["block_size"]
stream = kwargs["stream"]
if user_block_size is None:
if launch_config.block_size is None:
raise ValueError(
"Neither a user-defined nor a default block size was defined."
)
block_size_init_args = tuple(
str(bs) for bs in launch_config.block_size
)
else:
block_size_init_args = (user_block_size,)
block_size_var = dim3(const=True).var("__block_size")
from ..lang.cpp import std
witem_types = [lmb.return_type for lmb in launch_config.num_work_items]
witem_types = [wit.return_type for wit in launch_config.num_work_items]
work_items_entries = [
self.expr_from_lambda(wit) for wit in launch_config.num_work_items
]
......@@ -519,8 +533,6 @@ class SfgBasicComposer(SfgIComposer):
def _div_ceil(a: ExprLike, b: ExprLike):
return AugExpr.format("({a} + {b} - 1) / {b}", a=a, b=b)
block_size_var = dim3(const=True).var("__block_size")
reduced_block_size_entries = [
_min(work_items_var.get(i), bs)
for i, bs in enumerate(
......@@ -542,7 +554,7 @@ class SfgBasicComposer(SfgIComposer):
grid_size_var = dim3(const=True).var("__grid_size")
nodes = [
self.init(block_size_var)(block_size),
self.init(block_size_var)(*block_size_init_args),
self.init(work_items_var)(*work_items_entries),
self.init(reduced_block_size_var)(*reduced_block_size_entries),
self.init(grid_size_var)(*grid_size_entries),
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment