diff --git a/src/pystencilssfg/composer/gpu_composer.py b/src/pystencilssfg/composer/gpu_composer.py
index 2315a767438f7133ec14c0131eeb9e68aa664d43..06021dd9bd8e128cb2a84809ddbb67c2be43a51d 100644
--- a/src/pystencilssfg/composer/gpu_composer.py
+++ b/src/pystencilssfg/composer/gpu_composer.py
@@ -10,8 +10,9 @@ from pystencils.codegen.gpu_indexing import (
 )
 
 from .mixin import SfgComposerMixIn
-from .basic_composer import SfgBasicComposer, make_statements
+from .basic_composer import make_statements, make_sequence
 
+from ..context import SfgContext
 from ..ir import (
     SfgKernelHandle,
     SfgCallTreeNode,
@@ -99,18 +100,60 @@ class SfgGpuComposer(SfgComposerMixIn):
         stream: ExprLike | None = None,
     ) -> SfgCallTreeNode: ...
 
-    def gpu_invoke(self, kernel_handle: SfgKernelHandle, **kwargs) -> SfgCallTreeNode:
-        assert isinstance(
-            self, SfgBasicComposer
-        )  # for type checking this function body
+    def gpu_invoke(
+        self,
+        kernel_handle: SfgKernelHandle,
+        shared_memory_bytes: ExprLike = "0",
+        stream: ExprLike | None = None,
+        **kwargs,
+    ) -> SfgCallTreeNode:
+        builder = GpuInvocationBuilder(self._ctx, kernel_handle)
+        builder.shared_memory_bytes = shared_memory_bytes
+        builder.stream = stream
+
+        return builder(**kwargs)
+
+    def cuda_invoke(
+        self,
+        kernel_handle: SfgKernelHandle,
+        num_blocks: ExprLike,
+        threads_per_block: ExprLike,
+        stream: ExprLike | None,
+    ):
+        from warnings import warn
+
+        warn(
+            "cuda_invoke is deprecated and will be removed before version 0.1. "
+            "Use `gpu_invoke` instead.",
+            FutureWarning,
+        )
+
+        return self.gpu_invoke(
+            kernel_handle,
+            grid_size=num_blocks,
+            block_size=threads_per_block,
+            stream=stream,
+        )
+
+
+class GpuInvocationBuilder:
+    def __init__(
+        self,
+        ctx: SfgContext,
+        kernel_handle: SfgKernelHandle,
+    ):
+        self._ctx = ctx
+        self._kernel_handle = kernel_handle
 
         ker = kernel_handle.kernel
 
         if not isinstance(ker, GpuKernel):
-            raise ValueError(f"Non-GPU kernel was passed to `cuda_invoke`: {ker}")
+            raise ValueError(f"Non-GPU kernel was passed to `gpu_invoke`: {ker}")
 
         launch_config = ker.get_launch_configuration()
 
+        self._launch_config = launch_config
+
         gpu_api: type[ProvidesGpuRuntimeAPI]
         match ker.target:
             case Target.CUDA:
@@ -120,134 +163,145 @@ class SfgGpuComposer(SfgComposerMixIn):
             case _:
                 assert False, "unexpected GPU target"
 
-        dim3 = gpu_api.dim3
+        self._gpu_api = gpu_api
+        self._dim3 = gpu_api.dim3
 
-        grid_size: ExprLike
-        block_size: ExprLike
-        shared_memory_bytes: ExprLike = kwargs.get("shared_memory_bytes", "0")
-        stream: ExprLike | None = kwargs.get("stream", None)
+        self._shared_memory_bytes: ExprLike = "0"
+        self._stream: ExprLike | None
 
-        def _render_invocation(grid_size: ExprLike, block_size: ExprLike):
-            stmt_grid_size = make_statements(grid_size)
-            stmt_block_size = make_statements(block_size)
-            stmt_smem = (
-                make_statements(shared_memory_bytes)
-                if shared_memory_bytes is not None
-                else None
-            )
-            stmt_stream = make_statements(stream) if stream is not None else None
-
-            return self.seq(
-                "// clang-format off: "
-                "[pystencils-sfg] Formatting may add illegal spaces between angular brackets in `<<< >>>`.",
-                SfgGpuKernelInvocation(
-                    kernel_handle,
-                    stmt_grid_size,
-                    stmt_block_size,
-                    shared_memory_bytes=stmt_smem,
-                    stream=stmt_stream,
-                ),
-                "// clang-format on",
-            )
+    @property
+    def shared_memory_bytes(self) -> ExprLike:
+        return self._shared_memory_bytes
 
-        def to_uint32_t(expr: AugExpr) -> AugExpr:
-            return AugExpr("uint32_t").format("uint32_t({})", expr)
+    @shared_memory_bytes.setter
+    def shared_memory_bytes(self, bs: ExprLike):
+        self._shared_memory_bytes = bs
 
-        match launch_config:
-            case ManualLaunchConfiguration():
-                grid_size = kwargs["grid_size"]
-                block_size = kwargs["block_size"]
+    @property
+    def stream(self) -> ExprLike | None:
+        return self._stream
 
-                return _render_invocation(grid_size, block_size)
+    @stream.setter
+    def stream(self, s: ExprLike | None):
+        self._stream = s
 
+    def _render_invocation(
+        self, grid_size: ExprLike, block_size: ExprLike
+    ) -> SfgSequence:
+        stmt_grid_size = make_statements(grid_size)
+        stmt_block_size = make_statements(block_size)
+        stmt_smem = make_statements(self._shared_memory_bytes)
+        stmt_stream = (
+            make_statements(self._stream) if self._stream is not None else None
+        )
+
+        return make_sequence(
+            "// clang-format off: "
+            "[pystencils-sfg] Formatting may add illegal spaces between angular brackets in `<<< >>>`.",
+            SfgGpuKernelInvocation(
+                self._kernel_handle,
+                stmt_grid_size,
+                stmt_block_size,
+                shared_memory_bytes=stmt_smem,
+                stream=stmt_stream,
+            ),
+            "// clang-format on",
+        )
+
+    def __call__(self, **kwargs: ExprLike) -> SfgCallTreeNode:
+        match self._launch_config:
+            case ManualLaunchConfiguration():
+                return self._invoke_manual(**kwargs)
             case AutomaticLaunchConfiguration():
-                grid_size_entries = [
-                    to_uint32_t(self.expr_from_lambda(gs))
-                    for gs in launch_config._grid_size
-                ]
-                grid_size_var = dim3(const=True).var("__grid_size")
+                return self._invoke_automatic(**kwargs)
+            case DynamicBlockSizeLaunchConfiguration():
+                return self._invoke_dynamic(**kwargs)
+            case _:
+                raise ValueError(
+                    f"Unexpected launch configuration: {self._launch_config}"
+                )
 
-                block_size_entries = [
-                    to_uint32_t(self.expr_from_lambda(bs))
-                    for bs in launch_config._block_size
-                ]
-                block_size_var = dim3(const=True).var("__block_size")
+    def _invoke_manual(self, grid_size: ExprLike, block_size: ExprLike):
+        assert isinstance(self._launch_config, ManualLaunchConfiguration)
+        return self._render_invocation(grid_size, block_size)
 
-                nodes = [
-                    self.init(grid_size_var)(*grid_size_entries),
-                    self.init(block_size_var)(*block_size_entries),
-                    _render_invocation(grid_size_var, block_size_var),
-                ]
+    def _invoke_automatic(self):
+        assert isinstance(self._launch_config, AutomaticLaunchConfiguration)
 
-                return SfgBlock(SfgSequence(nodes))
+        from .composer import SfgComposer
 
-            case DynamicBlockSizeLaunchConfiguration():
-                user_block_size: ExprLike | None = kwargs.get("block_size", None)
+        sfg = SfgComposer(self._ctx)
 
-                block_size_init_args: tuple[ExprLike, ...]
-                if user_block_size is None:
-                    block_size_init_args = tuple(
-                        str(bs) for bs in launch_config.default_block_size
-                    )
-                else:
-                    block_size_init_args = (user_block_size,)
+        grid_size_entries = [
+            self._to_uint32_t(sfg.expr_from_lambda(gs))
+            for gs in self._launch_config._grid_size
+        ]
+        grid_size_var = self._dim3(const=True).var("__grid_size")
 
-                block_size_var = dim3(const=True).var("__block_size")
+        block_size_entries = [
+            self._to_uint32_t(sfg.expr_from_lambda(bs))
+            for bs in self._launch_config._block_size
+        ]
+        block_size_var = self._dim3(const=True).var("__block_size")
 
-                from ..lang.cpp import std
+        nodes = [
+            sfg.init(grid_size_var)(*grid_size_entries),
+            sfg.init(block_size_var)(*block_size_entries),
+            self._render_invocation(grid_size_var, block_size_var),
+        ]
 
-                work_items_entries = [
-                    self.expr_from_lambda(wit) for wit in launch_config.num_work_items
-                ]
-                work_items_var = std.tuple(
-                    "uint32_t", "uint32_t", "uint32_t", const=True
-                ).var("__work_items")
-
-                def _div_ceil(a: ExprLike, b: ExprLike):
-                    return AugExpr.format("({a} + {b} - 1) / {b}", a=a, b=b)
-
-                grid_size_entries = [
-                    _div_ceil(work_items_var.get(i), bs)
-                    for i, bs in enumerate(
-                        [
-                            block_size_var.x,
-                            block_size_var.y,
-                            block_size_var.z,
-                        ]
-                    )
-                ]
-                grid_size_var = dim3(const=True).var("__grid_size")
+        return SfgBlock(SfgSequence(nodes))
 
-                nodes = [
-                    self.init(block_size_var)(*block_size_init_args),
-                    self.init(work_items_var)(*work_items_entries),
-                    self.init(grid_size_var)(*grid_size_entries),
-                    _render_invocation(grid_size_var, block_size_var),
-                ]
+    def _invoke_dynamic(self, block_size: ExprLike | None = None):
+        assert isinstance(self._launch_config, DynamicBlockSizeLaunchConfiguration)
 
-                return SfgBlock(SfgSequence(nodes))
+        from .composer import SfgComposer
 
-            case _:
-                raise ValueError(f"Unexpected launch configuration: {launch_config}")
+        sfg = SfgComposer(self._ctx)
 
-    def cuda_invoke(
-        self,
-        kernel_handle: SfgKernelHandle,
-        num_blocks: ExprLike,
-        threads_per_block: ExprLike,
-        stream: ExprLike | None,
-    ):
-        from warnings import warn
+        block_size_init_args: tuple[ExprLike, ...]
+        if block_size is None:
+            block_size_init_args = tuple(
+                str(bs) for bs in self._launch_config.default_block_size
+            )
+        else:
+            block_size_init_args = (block_size,)
 
-        warn(
-            "cuda_invoke is deprecated and will be removed before version 0.1. "
-            "Use `gpu_invoke` instead.",
-            FutureWarning,
-        )
+        block_size_var = self._dim3(const=True).var("__block_size")
 
-        return self.gpu_invoke(
-            kernel_handle,
-            grid_size=num_blocks,
-            block_size=threads_per_block,
-            stream=stream,
+        from ..lang.cpp import std
+
+        work_items_entries = [
+            sfg.expr_from_lambda(wit) for wit in self._launch_config.num_work_items
+        ]
+        work_items_var = std.tuple("uint32_t", "uint32_t", "uint32_t", const=True).var(
+            "__work_items"
         )
+
+        def _div_ceil(a: ExprLike, b: ExprLike):
+            return AugExpr.format("({a} + {b} - 1) / {b}", a=a, b=b)
+
+        grid_size_entries = [
+            _div_ceil(work_items_var.get(i), bs)
+            for i, bs in enumerate(
+                [
+                    block_size_var.x,
+                    block_size_var.y,
+                    block_size_var.z,
+                ]
+            )
+        ]
+        grid_size_var = self._dim3(const=True).var("__grid_size")
+
+        nodes = [
+            sfg.init(block_size_var)(*block_size_init_args),
+            sfg.init(work_items_var)(*work_items_entries),
+            sfg.init(grid_size_var)(*grid_size_entries),
+            self._render_invocation(grid_size_var, block_size_var),
+        ]
+
+        return SfgBlock(SfgSequence(nodes))
+
+    @staticmethod
+    def _to_uint32_t(expr: AugExpr) -> AugExpr:
+        return AugExpr("uint32_t").format("uint32_t({})", expr)