From c306be593065f988be7386bd1aac6146feff6385 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sat, 15 Feb 2025 11:21:49 +0100
Subject: [PATCH] further simplify Python implementation of
 DynamicBlockSizeLaunchConfig

---
 src/pystencils/codegen/gpu_indexing.py        | 104 +++++++-----------
 .../sympyextensions/integer_functions.py      |   4 +-
 src/pystencils/utils.py                       |  47 +++++++-
 3 files changed, 89 insertions(+), 66 deletions(-)

diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py
index 80af8aba8..d32e4112a 100644
--- a/src/pystencils/codegen/gpu_indexing.py
+++ b/src/pystencils/codegen/gpu_indexing.py
@@ -19,7 +19,6 @@ from ..backend.ast.expressions import PsExpression
 
 
 dim3 = tuple[int, int, int]
-_Dim3Params = tuple[Parameter, Parameter, Parameter]
 _Dim3Lambda = tuple[Lambda, Lambda, Lambda]
 
 
@@ -73,15 +72,9 @@ class AutomaticLaunchConfiguration(GpuLaunchConfiguration):
 
     @property
     def parameters(self) -> frozenset[Parameter]:
-        """Parameters of this launch configuration"""
         return self._params
 
     def evaluate(self, **kwargs) -> tuple[dim3, dim3]:
-        """Compute block and grid size for a kernel launch.
-
-        Args:
-            kwargs: Valuation providing a value for each parameter listed in `parameters`
-        """
         block_size = tuple(int(bs(**kwargs)) for bs in self._block_size)
         grid_size = tuple(int(gs(**kwargs)) for gs in self._grid_size)
         return cast(dim3, block_size), cast(dim3, grid_size)
@@ -136,27 +129,38 @@ class ManualLaunchConfiguration(GpuLaunchConfiguration):
 
 
 class DynamicBlockSizeLaunchConfiguration(GpuLaunchConfiguration):
-    """GPU launch configuration that permits the user to set a block size dynamically."""
+    """GPU launch configuration that permits the user to set a block size and dynamically computes the grid size.
+    
+    The actual launch grid size is computed from the user-defined ``user_block_size`` and the number of work items
+    in the kernel's iteration space as follows.
+    For each dimension :math:`c \\in \\{ x, y, z \\}`,
+
+    - if ``user_block_size.c > num_work_items.c``, ``block_size = num_work_items.c`` and ``grid_size.c = 1``;
+    - otherwise, ``block_size.c = user_block_size.c`` and ``grid_size.c = ceil(num_work_items.c / block_size.c)``.
+    """
 
     def __init__(
         self,
-        block_size_expr: _Dim3Lambda,
-        grid_size_expr: _Dim3Lambda,
-        block_size_params: _Dim3Params,
+        num_work_items: _Dim3Lambda,
         default_block_size: dim3 | None = None,
     ) -> None:
-        self._block_size_expr = block_size_expr
-        self._grid_size_expr = grid_size_expr
+        self._num_work_items = num_work_items
 
-        self._block_size_params = block_size_params
         self._block_size: dim3 | None = default_block_size
 
         self._params: frozenset[Parameter] = frozenset().union(
-            *(lb.parameters for lb in chain(block_size_expr, grid_size_expr))
-        ) - set(self._block_size_params)
+            *(wit.parameters for wit in num_work_items)
+        )
+
+    @property
+    def num_work_items(self) -> _Dim3Lambda:
+        """Lambda expressions that compute the number of work items in each iteration space
+        dimension from kernel parameters."""
+        return self._num_work_items
 
     @property
     def block_size(self) -> dim3 | None:
+        """The desired GPU block size."""
         return self._block_size
 
     @block_size.setter
@@ -172,16 +176,23 @@ class DynamicBlockSizeLaunchConfiguration(GpuLaunchConfiguration):
         if self._block_size is None:
             raise AttributeError("No GPU block size was specified by the user!")
 
-        kwargs.update(
-            {
-                param.name: value
-                for param, value in zip(self._block_size_params, self._block_size)
-            }
+        from ..utils import div_ceil
+
+        num_work_items = cast(
+            dim3, tuple(int(wit(**kwargs)) for wit in self._num_work_items)
+        )
+        reduced_block_size = cast(
+            dim3,
+            tuple(min(wit, bs) for wit, bs in zip(num_work_items, self._block_size)),
+        )
+        grid_size = cast(
+            dim3,
+            tuple(
+                div_ceil(wit, bs) for wit, bs in zip(num_work_items, reduced_block_size)
+            ),
         )
 
-        block_size = tuple(int(bs(**kwargs)) for bs in self._block_size_expr)
-        grid_size = tuple(int(gs(**kwargs)) for gs in self._grid_size_expr)
-        return cast(dim3, block_size), cast(dim3, grid_size)
+        return reduced_block_size, grid_size
 
     def jit_cache_key(self) -> Any:
         return self._block_size
@@ -226,50 +237,17 @@ class GpuIndexing(ABC):
     def _get_linear3d_config_factory(
         self,
     ) -> Callable[[], DynamicBlockSizeLaunchConfiguration]:
-        work_items = self._get_work_items()
-        rank = len(work_items)
-
-        from ..backend.constants import PsConstant
-        from ..backend.ast.expressions import PsExpression, PsIntDiv
-
-        block_size_symbols = [
-            self._ctx.get_new_symbol(f"gpuBlockSize_{c}", self._ctx.index_dtype) for c in range(rank)
-        ]
-
-        block_size = [
-            Lambda.from_expression(self._ctx, self._factory.parse_index(bs_symb))
-            for bs_symb in block_size_symbols
-        ] + [
-            Lambda.from_expression(self._ctx, self._factory.parse_index(1))
-            for _ in range(3 - rank)
-        ]
-
-        def div_ceil(a: PsExpression, b: PsExpression):
-            return self._factory.parse_index(
-                PsIntDiv(a + b - PsExpression.make(PsConstant(1)), b)
-            )
-
-        grid_size = [
-            Lambda.from_expression(
-                self._ctx, div_ceil(witems, PsExpression.make(bsize))
-            )
-            for witems, bsize in zip(work_items, block_size_symbols)
-        ] + [
-            Lambda.from_expression(self._ctx, self._factory.parse_index(1))
-            for _ in range(3 - rank)
-        ]
-
-        from .driver import _symbol_to_param
+        work_items_expr = self._get_work_items()
+        rank = len(work_items_expr)
 
-        block_size_params = tuple(
-            _symbol_to_param(self._ctx, s) for s in block_size_symbols
+        num_work_items = cast(
+            _Dim3Lambda,
+            tuple(Lambda.from_expression(self._ctx, wit) for wit in work_items_expr),
         )
 
         def factory():
             return DynamicBlockSizeLaunchConfiguration(
-                cast(_Dim3Lambda, tuple(block_size)),
-                cast(_Dim3Lambda, tuple(grid_size)),
-                cast(_Dim3Params, block_size_params),
+                num_work_items,
                 self._get_default_block_size(rank),
             )
 
diff --git a/src/pystencils/sympyextensions/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py
index 42513ef9c..9d2c69502 100644
--- a/src/pystencils/sympyextensions/integer_functions.py
+++ b/src/pystencils/sympyextensions/integer_functions.py
@@ -140,10 +140,10 @@ class div_ceil(IntegerFunctionTwoArgsMixIn):
 
     @classmethod
     def eval(cls, arg1, arg2):
-        from ..utils import c_intdiv
+        from ..utils import div_ceil
 
         if is_integer_sequence((arg1, arg2)):
-            return c_intdiv(arg1 + arg2 - 1, arg2)
+            return div_ceil(arg1, arg2)
 
     def _eval_op(self, arg1, arg2):
         return self.eval(arg1, arg2)
diff --git a/src/pystencils/utils.py b/src/pystencils/utils.py
index a53eb8289..0049d0a2c 100644
--- a/src/pystencils/utils.py
+++ b/src/pystencils/utils.py
@@ -4,11 +4,13 @@ from itertools import groupby
 from collections import Counter
 from contextlib import contextmanager
 from tempfile import NamedTemporaryFile
-from typing import Mapping
+from typing import Mapping, overload
 
 import numpy as np
 import sympy as sp
 
+from numpy.typing import NDArray
+
 
 class DotDict(dict):
     """Normal dict with additional dot access for all keys"""
@@ -254,6 +256,24 @@ class ContextVar:
         return self.stack[-1]
 
 
+@overload
+def c_intdiv(num: int, denom: int) -> int: ...
+
+
+@overload
+def c_intdiv(
+    num: NDArray[np.integer], denom: NDArray[np.integer]
+) -> NDArray[np.integer]: ...
+
+
+@overload
+def c_intdiv(num: int, denom: NDArray[np.integer]) -> NDArray[np.integer]: ...
+
+
+@overload
+def c_intdiv(num: NDArray[np.integer], denom: int) -> NDArray[np.integer]: ...
+
+
 def c_intdiv(num, denom):
     """C-style integer division"""
     if isinstance(num, np.ndarray) or isinstance(denom, np.ndarray):
@@ -271,3 +291,28 @@ def c_rem(num, denom):
     """C-style integer remainder"""
     div = c_intdiv(num, denom)
     return num - div * denom
+
+
+@overload
+def div_ceil(divident: int, divisor: int) -> int: ...
+
+
+@overload
+def div_ceil(
+    divident: NDArray[np.integer], divisor: NDArray[np.integer]
+) -> NDArray[np.integer]: ...
+
+
+@overload
+def div_ceil(divident: int, divisor: NDArray[np.integer]) -> NDArray[np.integer]: ...
+
+
+@overload
+def div_ceil(divident: NDArray[np.integer], divisor: int) -> NDArray[np.integer]: ...
+
+
+def div_ceil(divident, divisor):
+    """For nonnegative integer arguments, compute ``ceil(num / denom)``.
+    The result is unspecified if either argument is negative."""
+
+    return c_intdiv(divident + divisor - 1, divisor)
-- 
GitLab