From 210f768a5a5ffdab3531361a45bada18ce699358 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 14 Feb 2025 12:35:42 +0100
Subject: [PATCH] WIP refactor launch configuration and gpu indexing

---
 src/pystencils/codegen/config.py       |   2 +-
 src/pystencils/codegen/driver.py       |  35 +++--
 src/pystencils/codegen/gpu_indexing.py | 208 +++++++++++++++++++++----
 src/pystencils/jit/gpu_cupy.py         |   8 +-
 4 files changed, 200 insertions(+), 53 deletions(-)

diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 53a271852..47a64df64 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -380,7 +380,7 @@ class GpuOptions(ConfigBase):
     This check can be discarded through this option, at your own peril.
     """
 
-    block_size: BasicOption[tuple[int, int, int]] = BasicOption()
+    block_size: BasicOption[tuple[int, int, int] | _AUTO_TYPE] = BasicOption(AUTO)
     """Desired block size for the execution of GPU kernels. May be overridden later by the runtime system."""
 
     manual_launch_grid: BasicOption[bool] = BasicOption(False)
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index b14fc0272..f7eb8ddb4 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -516,26 +516,27 @@ def create_gpu_kernel_function(
     return kfunc
 
 
-def _get_function_params(
-    ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
-) -> list[Parameter]:
-    params: list[Parameter] = []
-
+def _symbol_to_param(ctx: KernelCreationContext, symbol: PsSymbol):
     from pystencils.backend.memory import BufferBasePtr, BackendPrivateProperty
 
-    for symb in symbols:
-        props: set[PsSymbolProperty] = set()
-        for prop in symb.properties:
-            match prop:
-                case BufferBasePtr(buf):
-                    field = ctx.find_field(buf.name)
-                    props.add(FieldBasePtr(field))
-                case BackendPrivateProperty():
-                    pass
-                case _:
-                    props.add(prop)
-        params.append(Parameter(symb.name, symb.get_dtype(), props))
+    props: set[PsSymbolProperty] = set()
+    for prop in symbol.properties:
+        match prop:
+            case BufferBasePtr(buf):
+                field = ctx.find_field(buf.name)
+                props.add(FieldBasePtr(field))
+            case BackendPrivateProperty():
+                pass
+            case _:
+                props.add(prop)
+
+    return Parameter(symbol.name, symbol.get_dtype(), props)
+
 
+def _get_function_params(
+    ctx: KernelCreationContext, symbols: Iterable[PsSymbol]
+) -> list[Parameter]:
+    params: list[Parameter] = [_symbol_to_param(ctx, s) for s in symbols]
     params.sort(key=lambda p: p.name)
     return params
 
diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py
index 08134a622..24189bf63 100644
--- a/src/pystencils/codegen/gpu_indexing.py
+++ b/src/pystencils/codegen/gpu_indexing.py
@@ -1,13 +1,14 @@
 from __future__ import annotations
 
 from abc import ABC, abstractmethod
-from typing import cast, Any
+from typing import cast, Any, Callable
 from itertools import chain
 
 from .functions import Lambda
 from .kernel import GpuKernel
 from .parameters import Parameter
 from .errors import CodegenError
+from .config import GpuIndexingScheme, _AUTO_TYPE
 
 from ..backend.kernelcreation import (
     KernelCreationContext,
@@ -18,6 +19,8 @@ from ..backend.platforms.cuda import ThreadToIndexMapping
 from ..backend.ast.expressions import PsExpression
 
 
+dim3 = tuple[int, int, int]
+_Dim3Params = tuple[Parameter, Parameter, Parameter]
 _Dim3Lambda = tuple[Lambda, Lambda, Lambda]
 
 
@@ -48,29 +51,179 @@ class GpuLaunchConfiguration:
         """Parameters to this set of constraints"""
         return self._params
 
-    @property
-    def parameter_values(self) -> dict[Parameter, Any]:
+    def get_valuation(self) -> dict[Parameter, Any]:
         """Values for all parameters that are specific to the launch grid configuration and not
         also kernel parameters."""
         return self._valuation
 
-    @property
-    def block_size(self) -> _Dim3Lambda:
-        """Constraints on the number of threads per block"""
+    def get_block_size(self) -> _Dim3Lambda:
         return self._block_size
 
-    @property
-    def grid_size(self) -> _Dim3Lambda:
-        """Constraints on the number of blocks on the grid"""
+    def get_grid_size(self) -> _Dim3Lambda:
         return self._grid_size
 
 
+class ManualLaunchConfiguration(GpuLaunchConfiguration):
+    """Manual GPU launch configuration.
+
+    This launch configuration requires the user to set block and grid size.
+    """
+
+    def __init__(
+        self,
+        block_size: _Dim3Lambda,
+        grid_size: _Dim3Lambda,
+        block_size_params: _Dim3Params,
+        grid_size_params: _Dim3Params,
+    ):
+        super().__init__(
+            cast(_Dim3Lambda, block_size),
+            cast(_Dim3Lambda, grid_size),
+            set(block_size_params).union(grid_size_params),
+        )
+        self._block_size_params = block_size_params
+        self._grid_size_params = grid_size_params
+
+        self._user_block_size: dim3 | None = None
+        self._user_grid_size: dim3 | None = None
+
+    @property
+    def block_size(self) -> dim3 | None:
+        return self._user_block_size
+
+    @block_size.setter
+    def block_size(self, val: dim3):
+        self._user_block_size = val
+
+    @property
+    def grid_size(self) -> dim3 | None:
+        return self._user_grid_size
+
+    @grid_size.setter
+    def grid_size(self, val: dim3):
+        self._user_grid_size = val
+
+    def get_valuation(self) -> dict[Parameter, Any]:
+        if self._user_block_size is None:
+            raise AttributeError("No GPU block size was specified")
+
+        if self._user_grid_size is None:
+            raise AttributeError("No GPU grid size was specified")
+
+        valuation: dict[Parameter, Any] = dict()
+
+        for bs_param, bs in zip(self._block_size_params, self._user_block_size):
+            valuation[bs_param] = bs
+
+        for gs_param, gs in zip(self._grid_size_params, self._user_grid_size):
+            valuation[gs_param] = gs
+
+        return valuation
+
+
+class GridFromBlockSizeConfiguration(GpuLaunchConfiguration):
+    """GPU launch configuration that computes the grid size from a user-defined block size."""
+
+    def __init__(
+        self,
+        block_size: _Dim3Lambda,
+        grid_size: _Dim3Lambda,
+        block_size_params: _Dim3Params,
+        default_block_size: dim3 | None = None,
+    ) -> None:
+        super().__init__(block_size, grid_size, set(block_size_params))
+
+        self._block_size_params = block_size_params
+        self._user_block_size: dim3 | None = default_block_size
+
+    @property
+    def block_size(self) -> dim3 | None:
+        return self._user_block_size
+
+    @block_size.setter
+    def block_size(self, val: dim3):
+        self._user_block_size = val
+
+    def get_valuation(self) -> dict[Parameter, Any]:
+        if self._user_block_size is None:
+            raise AttributeError("No GPU block size was specified")
+
+        valuation: dict[Parameter, Any] = dict()
+
+        for bs_param, bs in zip(self._block_size_params, self._user_block_size):
+            valuation[bs_param] = bs
+
+        return valuation
+
+
 class GpuIndexing(ABC):
-    @abstractmethod
-    def get_thread_mapping(self) -> ThreadToIndexMapping | None: ...
+    def __init__(
+        self,
+        ctx: KernelCreationContext,
+        scheme: GpuIndexingScheme,
+        block_size: dim3 | _AUTO_TYPE,
+        manual_launch_grid: bool,
+    ) -> None:
+        self._ctx = ctx
+        self._scheme = scheme
+        self._block_size = block_size
+        self._manual_launch_grid = manual_launch_grid
+
+        from ..backend.kernelcreation import AstFactory
 
-    @abstractmethod
-    def get_launch_config(self, kernel: GpuKernel) -> GpuLaunchConfiguration: ...
+        self._factory = AstFactory(self._ctx)
+
+    def get_thread_mapping(self):
+        from ..backend.platforms.cuda import Linear3DMapping, Blockwise4DMapping
+
+        match self._scheme:
+            case GpuIndexingScheme.Linear3D:
+                return Linear3DMapping()
+            case GpuIndexingScheme.Blockwise4D:
+                return Blockwise4DMapping()
+
+    def get_launch_config_factory(
+        self, scheme: GpuIndexingScheme
+    ) -> Callable[[], GpuLaunchConfiguration]:
+        if self._manual_launch_grid:
+            return self._manual_config_factory()
+
+        raise NotImplementedError()
+
+    def _manual_config_factory(self) -> Callable[[], ManualLaunchConfiguration]:
+        ctx = self._ctx
+
+        block_size_symbols = [
+            ctx.get_new_symbol(f"gpuBlockSize_{c}", ctx.index_dtype) for c in range(3)
+        ]
+        grid_size_symbols = [
+            ctx.get_new_symbol(f"gpuGridSize_{c}", ctx.index_dtype) for c in range(3)
+        ]
+
+        block_size = tuple(
+            Lambda.from_expression(ctx, PsExpression.make(bs))
+            for bs in block_size_symbols
+        )
+
+        grid_size = tuple(
+            Lambda.from_expression(ctx, PsExpression.make(gs))
+            for gs in grid_size_symbols
+        )
+
+        from .driver import _symbol_to_param
+
+        bs_params = [_symbol_to_param(ctx, s) for s in block_size_symbols]
+        gs_params = [_symbol_to_param(ctx, s) for s in grid_size_symbols]
+
+        def factory():
+            return ManualLaunchConfiguration(
+                cast(_Dim3Lambda, block_size),
+                cast(_Dim3Lambda, grid_size),
+                cast(_Dim3Params, bs_params),
+                cast(_Dim3Params, gs_params),
+            )
+
+        return factory
 
 
 class Linear3DGpuIndexing(GpuIndexing):
@@ -88,23 +241,6 @@ class Linear3DGpuIndexing(GpuIndexing):
         return Linear3DMapping()
 
     def get_launch_config(self, kernel: GpuKernel) -> GpuLaunchConfiguration:
-        block_size, grid_size = self._prepare_launch_grid()
-
-        kernel_params = set(kernel.parameters)
-        launch_config_params = (
-            set().union(
-                *(lb.parameters for lb in chain(block_size, grid_size))
-            )
-            - kernel_params
-        )
-
-        return GpuLaunchConfiguration(
-            block_size=cast(_Dim3Lambda, tuple(block_size)),
-            grid_size=cast(_Dim3Lambda, tuple(grid_size)),
-            config_parameters=launch_config_params,
-        )
-
-    def _prepare_launch_grid(self):
         work_items = self._get_work_items()
         rank = len(work_items)
 
@@ -138,7 +274,17 @@ class Linear3DGpuIndexing(GpuIndexing):
             for _ in range(3 - rank)
         ]
 
-        return block_size, grid_size
+        from .driver import _symbol_to_param
+
+        block_size_params = tuple(
+            _symbol_to_param(self._ctx, s) for s in block_size_symbols
+        )
+
+        return GridFromBlockSizeConfiguration(
+            cast(_Dim3Lambda, tuple(block_size)),
+            cast(_Dim3Lambda, tuple(grid_size)),
+            cast(tuple[Parameter, Parameter, Parameter], block_size_params),
+        )
 
     def _get_work_items(self) -> tuple[PsExpression, ...]:
         ispace = self._ctx.get_iteration_space()
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index f3f834f76..d4f1c0204 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -67,7 +67,7 @@ class CupyKernelWrapper(KernelWrapper):
         return devices.pop()
 
     def _get_cached_args(self, **kwargs):
-        launch_config_params = self._launch_config.parameter_values
+        launch_config_params = self._launch_config.get_valuation
         key = tuple(
             (k, v) for k, v in launch_config_params.items()
         ) + tuple((k, id(v)) for k, v in kwargs.items())
@@ -195,7 +195,7 @@ class CupyKernelWrapper(KernelWrapper):
         launch_cfg_valuation.update(
             {
                 param.name: value
-                for param, value in self._launch_config.parameter_values.items()
+                for param, value in self._launch_config.get_valuation.items()
             }
         )
 
@@ -203,7 +203,7 @@ class CupyKernelWrapper(KernelWrapper):
             tuple[int, int, int],
             tuple(
                 int(component(**launch_cfg_valuation))
-                for component in self._launch_config.block_size
+                for component in self._launch_config.get_block_size()
             ),
         )
 
@@ -211,7 +211,7 @@ class CupyKernelWrapper(KernelWrapper):
             tuple[int, int, int],
             tuple(
                 int(component(**launch_cfg_valuation))
-                for component in self._launch_config.grid_size
+                for component in self._launch_config.get_grid_size()
             ),
         )
 
-- 
GitLab