diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py
index 7a5d62f691d81a0f251329c47216f65a981ef291..0e9b21d6c8268bec7fef97ac6fcb2b4b0e37f8f4 100644
--- a/src/pystencils/backend/memory.py
+++ b/src/pystencils/backend/memory.py
@@ -89,8 +89,13 @@ class PsSymbol:
         return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})"
 
 
+class BackendPrivateProperty:
+    """Mix-in marker for symbol properties that are private to the backend
+    and should not be exported to parameters"""
+
+
 @dataclass(frozen=True)
-class BufferBasePtr(UniqueSymbolProperty):
+class BufferBasePtr(UniqueSymbolProperty, BackendPrivateProperty):
     """Symbol acts as a base pointer to a buffer."""
 
     buffer: PsBuffer
@@ -120,12 +125,12 @@ class PsBuffer:
         strides: Sequence[PsSymbol | PsConstant],
     ):
         bptr_type = base_ptr.get_dtype()
-        
+
         if not isinstance(bptr_type, PsPointerType):
             raise ValueError(
                 f"Type of buffer base pointer {base_ptr} was not a pointer type: {bptr_type}"
             )
-        
+
         if bptr_type.base_type != element_type:
             raise ValueError(
                 f"Base type of primary buffer base pointer {base_ptr} "
diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py
index 4f78d344de861261ef89db6e10b51a93026a3dcd..cff6f935f2f89e09e2d0c8323e3d245c2b091c23 100644
--- a/src/pystencils/backend/platforms/cuda.py
+++ b/src/pystencils/backend/platforms/cuda.py
@@ -46,7 +46,7 @@ GRID_DIM = [
 ]
 
 
-class DenseThreadIdxMapping(ABC):
+class ThreadToIndexMapping(ABC):
 
     @abstractmethod
     def __call__(self, ispace: FullIterationSpace) -> dict[PsSymbol, PsExpression]:
@@ -57,7 +57,7 @@ class DenseThreadIdxMapping(ABC):
         """
 
 
-class Linear3DMapping(DenseThreadIdxMapping):
+class Linear3DMapping(ThreadToIndexMapping):
     """3D globally linearized mapping, where each thread is assigned a work item according to
     its location in the global launch grid."""
 
@@ -86,7 +86,7 @@ class Linear3DMapping(DenseThreadIdxMapping):
         return block_idx * block_size + thread_idx
 
 
-class Blockwise4DMapping(DenseThreadIdxMapping):
+class Blockwise4DMapping(ThreadToIndexMapping):
     """Blockwise index mapping for up to 4D iteration spaces, where the outer three dimensions
     are mapped to block indices."""
 
@@ -122,12 +122,12 @@ class CudaPlatform(GenericGpu):
         self,
         ctx: KernelCreationContext,
         omit_range_check: bool = False,
-        dense_idx_mapping: DenseThreadIdxMapping | None = None,
+        thread_mapping: ThreadToIndexMapping | None = None,
     ) -> None:
         super().__init__(ctx)
 
         self._omit_range_check = omit_range_check
-        self._dense_idx_mapping = dense_idx_mapping
+        self._thread_mapping = thread_mapping
 
         self._typify = Typifier(ctx)
 
@@ -227,8 +227,8 @@ class CudaPlatform(GenericGpu):
         #     threads_range = None
 
         idx_mapper = (
-            self._dense_idx_mapping
-            if self._dense_idx_mapping is not None
+            self._thread_mapping
+            if self._thread_mapping is not None
             else Linear3DMapping()
         )
         ctr_mapping = idx_mapper(ispace)
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index f66f9aa0ee5ea350ba4304c03ced1ceae346282f..7491ec8e96a57c7d53637f3bb0db990147e9c127 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -1,19 +1,9 @@
 from __future__ import annotations
-from typing import TYPE_CHECKING
 from abc import abstractmethod
 
-from ..ast.expressions import PsExpression
 from ..ast.structural import PsBlock
-from ..kernelcreation.iteration_space import (
-    IterationSpace,
-    FullIterationSpace,
-    SparseIterationSpace,
-)
+from ..kernelcreation.iteration_space import IterationSpace
 from .platform import Platform
-from ..exceptions import MaterializationError
-
-if TYPE_CHECKING:
-    from ...codegen.kernel import GpuThreadsRange
 
 
 class GenericGpu(Platform):
@@ -22,40 +12,3 @@ class GenericGpu(Platform):
         self, body: PsBlock, ispace: IterationSpace
     ) -> PsBlock:
         pass
-
-    @classmethod
-    def threads_from_ispace(cls, ispace: IterationSpace) -> GpuThreadsRange:
-        from ...codegen.kernel import GpuThreadsRange
-
-        if isinstance(ispace, FullIterationSpace):
-            return cls._threads_from_full_ispace(ispace)
-        elif isinstance(ispace, SparseIterationSpace):
-            work_items = (PsExpression.make(ispace.index_list.shape[0]),)
-            return GpuThreadsRange(work_items)
-        else:
-            assert False
-
-    @classmethod
-    def _threads_from_full_ispace(cls, ispace: FullIterationSpace) -> GpuThreadsRange:
-        from ...codegen.kernel import GpuThreadsRange
-
-        dimensions = ispace.dimensions_in_loop_order()[::-1]
-        if len(dimensions) > 3:
-            raise NotImplementedError(
-                f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space"
-            )
-
-        from ..ast.analysis import collect_undefined_symbols as collect
-
-        for dim in dimensions:
-            symbs = collect(dim.start) | collect(dim.stop) | collect(dim.step)
-            for ctr in ispace.counters:
-                if ctr in symbs:
-                    raise MaterializationError(
-                        "Unable to construct GPU threads range for iteration space: "
-                        f"Limits of dimension counter {dim.counter.name} "
-                        f"depend on another dimension's counter {ctr.name}"
-                    )
-
-        work_items = [ispace.actual_iterations(dim) for dim in dimensions]
-        return GpuThreadsRange(work_items)
diff --git a/src/pystencils/codegen/__init__.py b/src/pystencils/codegen/__init__.py
index e13f911dd9c2f8dd1a7b264a79dcdcd51cbef003..fc1b70ca0a4dafe4ed099024cd2feb05bf61746b 100644
--- a/src/pystencils/codegen/__init__.py
+++ b/src/pystencils/codegen/__init__.py
@@ -4,7 +4,7 @@ from .config import (
     AUTO,
 )
 from .parameters import Parameter
-from .kernel import Kernel, GpuKernel, GpuThreadsRange
+from .kernel import Kernel, GpuKernel
 from .driver import create_kernel, get_driver
 
 __all__ = [
@@ -14,7 +14,6 @@ __all__ = [
     "Parameter",
     "Kernel",
     "GpuKernel",
-    "GpuThreadsRange",
     "create_kernel",
     "get_driver",
 ]
diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py
index 96ab13ea0f5a8c8715952280b939bca32e4b884d..53a2718524fa4d78150f6626963722a4eaf718c0 100644
--- a/src/pystencils/codegen/config.py
+++ b/src/pystencils/codegen/config.py
@@ -86,7 +86,9 @@ class Option(Generic[Option_T, Arg_T]):
         self._name = name
         self._lookup = f"_{name}"
 
-    def __get__(self, obj: ConfigBase, objtype: type[ConfigBase] | None = None) -> Option_T | None:
+    def __get__(
+        self, obj: ConfigBase, objtype: type[ConfigBase] | None = None
+    ) -> Option_T | None:
         if obj is None:
             return None
 
@@ -194,7 +196,9 @@ class Category(Generic[Category_T]):
         self._name = name
         self._lookup = f"_{name}"
 
-    def __get__(self, obj: ConfigBase, objtype: type[ConfigBase] | None = None) -> Category_T:
+    def __get__(
+        self, obj: ConfigBase, objtype: type[ConfigBase] | None = None
+    ) -> Category_T:
         if obj is None:
             return None
 
@@ -365,6 +369,9 @@ class GpuIndexingScheme(Enum):
 class GpuOptions(ConfigBase):
     """Configuration options specific to GPU targets."""
 
+    indexing_scheme: Option[GpuIndexingScheme, str] = Option(GpuIndexingScheme.Linear3D)
+    """Thread indexing scheme for dense GPU kernels."""
+
     omit_range_check: BasicOption[bool] = BasicOption(False)
     """If set to `True`, omit the iteration counter range check.
     
@@ -384,6 +391,31 @@ class GpuOptions(ConfigBase):
     The launch grid will then have to be specified manually at runtime.
     """
 
+    @indexing_scheme.validate
+    def _validate_idx_scheme(self, val: str | GpuIndexingScheme):
+        if isinstance(val, GpuIndexingScheme):
+            return val
+
+        match val.lower():
+            case "block":
+                warn(
+                    "GPU indexing scheme name `block` is deprecated and will be removed in pystencils 2.1. "
+                    "Use `Linear3D` instead."
+                )
+                return GpuIndexingScheme.Linear3D
+            case "line":
+                warn(
+                    "GPU indexing scheme name `line` is deprecated and will be removed in pystencils 2.1. "
+                    "Use `Blockwise4D` instead."
+                )
+                return GpuIndexingScheme.Blockwise4D
+            case "linear3d":
+                return GpuIndexingScheme.Linear3D
+            case "blockwise4d":
+                return GpuIndexingScheme.Blockwise4D
+            case _:
+                raise ValueError(f"Invalid GPU indexing scheme: {val}")
+
 
 @dataclass
 class SyclOptions(ConfigBase):
@@ -536,6 +568,9 @@ class CreateKernelConfig(ConfigBase):
     cpu_vectorize_info: InitVar[dict | None] = None
     """Deprecated; use `cpu.vectorize <CpuOptions.vectorize>` instead."""
 
+    gpu_indexing: InitVar[str | None] = None
+    """Deprecated; use `gpu.indexing_scheme` instead."""
+
     gpu_indexing_params: InitVar[dict | None] = None
     """Deprecated; set options in the `gpu` category instead."""
 
@@ -594,6 +629,7 @@ class CreateKernelConfig(ConfigBase):
         data_type: UserTypeSpec | None,
         cpu_openmp: bool | int | None,
         cpu_vectorize_info: dict | None,
+        gpu_indexing: str | None,
         gpu_indexing_params: dict | None,
     ):  # pragma: no cover
         if data_type is not None:
@@ -623,9 +659,7 @@ class CreateKernelConfig(ConfigBase):
                     deprecated_omp.enable = True
                     deprecated_omp.num_threads = cpu_openmp
                 case _:
-                    raise ValueError(
-                        f"Invalid option for `cpu_openmp`: {cpu_openmp}"
-                    )
+                    raise ValueError(f"Invalid option for `cpu_openmp`: {cpu_openmp}")
 
             self.cpu.openmp = deprecated_omp
 
@@ -682,11 +716,20 @@ class CreateKernelConfig(ConfigBase):
 
             self.cpu.vectorize = deprecated_vec_opts
 
+        if gpu_indexing is not None:
+            _deprecated_option("gpu_indexing", "gpu.indexing_scheme")
+            warn(
+                "Setting the deprecated `gpu_indexing` will override the `gpu.indexing_scheme` option",
+                UserWarning,
+            )
+            self.gpu.indexing_scheme = gpu_indexing
+
         if gpu_indexing_params is not None:
-            _deprecated_option("gpu_indexing_params", "gpu_indexing")
+            _deprecated_option("gpu_indexing_params", "gpu")
             warn(
                 "Setting the deprecated `gpu_indexing_params` will override any options "
-                "passed in the `gpu` category."
+                "passed in the `gpu` category.",
+                UserWarning,
             )
 
             self.gpu = GpuOptions(
diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py
index 6f44e718d9aaebd393dacb9a99ae88702f07ccaa..152fceba832e5be49df9405a5d32989be315dc4e 100644
--- a/src/pystencils/codegen/driver.py
+++ b/src/pystencils/codegen/driver.py
@@ -10,10 +10,12 @@ from .config import (
     _AUTO_TYPE,
     GhostLayerSpec,
     IterationSliceSpec,
+    GpuIndexingScheme,
 )
-from .kernel import Kernel, GpuKernel, GpuThreadsRange
-from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
+from .kernel import Kernel, GpuKernel
+from .properties import PsSymbolProperty, FieldBasePtr
 from .parameters import Parameter
+from .gpu_indexing import GpuIndexing, GpuLaunchGridConstraints
 
 from ..field import Field
 from ..types import PsIntegerType, PsScalarType
@@ -145,6 +147,7 @@ class DefaultKernelCreationDriver:
         )
 
         self._target = cfg.get_target()
+        self._gpu_indexing: GpuIndexing | None = self._get_gpu_indexing()
         self._platform = self._get_platform()
 
         self._intermediates: CodegenIntermediates | None
@@ -169,9 +172,11 @@ class DefaultKernelCreationDriver:
                     kernel_body, self._ctx.get_iteration_space()
                 )
             case GenericGpu():
-                kernel_ast, gpu_threads = self._platform.materialize_iteration_space(
+                kernel_ast = self._platform.materialize_iteration_space(
                     kernel_body, self._ctx.get_iteration_space()
                 )
+            case _:
+                assert False, "unexpected platform"
 
         if self._intermediates is not None:
             self._intermediates.materialized_ispace = kernel_ast.clone()
@@ -219,7 +224,7 @@ class DefaultKernelCreationDriver:
                 self._ctx,
                 self._platform,
                 kernel_ast,
-                gpu_threads,
+                self._gpu_indexing,
                 self._cfg.get_option("function_name"),
                 self._target,
                 self._cfg.get_jit(),
@@ -395,6 +400,20 @@ class DefaultKernelCreationDriver:
 
         return kernel_ast
 
+    def _get_gpu_indexing(self) -> GpuIndexing | None:
+        if self._target != Target.CUDA:
+            return None
+
+        idx_scheme = self._cfg.gpu.get_option("indexing_scheme")
+
+        match idx_scheme:
+            case None | GpuIndexingScheme.Linear3D:
+                from .gpu_indexing import Linear3DGpuIndexing
+
+                return Linear3DGpuIndexing(self._ctx)
+            case _:
+                raise NotImplementedError()
+
     def _get_platform(self) -> Platform:
         if Target._CPU in self._target:
             if Target._X86 in self._target:
@@ -430,7 +449,9 @@ class DefaultKernelCreationDriver:
                 case Target.SYCL:
                     from ..backend.platforms import SyclPlatform
 
-                    auto_block_size: bool = self._cfg.sycl.get_option("automatic_block_size")
+                    auto_block_size: bool = self._cfg.sycl.get_option(
+                        "automatic_block_size"
+                    )
 
                     return SyclPlatform(
                         self._ctx,
@@ -440,12 +461,16 @@ class DefaultKernelCreationDriver:
                 case Target.CUDA:
                     from ..backend.platforms import CudaPlatform
 
-                    manual_grid = gpu_opts.get_option("manual_launch_grid")
+                    thread_mapping = (
+                        self._gpu_indexing.get_thread_mapping()
+                        if self._gpu_indexing is not None
+                        else None
+                    )
 
                     return CudaPlatform(
                         self._ctx,
                         omit_range_check=omit_range_check,
-                        manual_launch_grid=manual_grid,
+                        thread_mapping=thread_mapping,
                     )
 
         raise NotImplementedError(
@@ -475,23 +500,25 @@ def create_gpu_kernel_function(
     ctx: KernelCreationContext,
     platform: Platform,
     body: PsBlock,
-    threads_range: GpuThreadsRange | None,
+    indexing: GpuIndexing | None,
     function_name: str,
     target_spec: Target,
     jit: JitBase,
 ) -> GpuKernel:
     undef_symbols = collect_undefined_symbols(body)
 
-    if threads_range is not None:
-        for threads in threads_range.num_work_items:
-            undef_symbols |= collect_undefined_symbols(threads)
+    launch_grid_constraints = (
+        indexing.get_launch_grid_constraints()
+        if indexing is not None
+        else GpuLaunchGridConstraints()
+    )
 
     params = _get_function_params(ctx, undef_symbols)
     req_headers = _get_headers(ctx, platform, body)
 
     kfunc = GpuKernel(
         body,
-        threads_range,
+        launch_grid_constraints,
         target_spec,
         function_name,
         params,
@@ -507,17 +534,19 @@ def _get_function_params(
 ) -> list[Parameter]:
     params: list[Parameter] = []
 
-    from pystencils.backend.memory import BufferBasePtr
+    from pystencils.backend.memory import BufferBasePtr, BackendPrivateProperty
 
     for symb in symbols:
         props: set[PsSymbolProperty] = set()
         for prop in symb.properties:
             match prop:
-                case FieldShape() | FieldStride():
-                    props.add(prop)
                 case BufferBasePtr(buf):
                     field = ctx.find_field(buf.name)
                     props.add(FieldBasePtr(field))
+                case BackendPrivateProperty():
+                    pass
+                case _:
+                    props.add(prop)
         params.append(Parameter(symb.name, symb.get_dtype(), props))
 
     params.sort(key=lambda p: p.name)
diff --git a/src/pystencils/codegen/errors.py b/src/pystencils/codegen/errors.py
new file mode 100644
index 0000000000000000000000000000000000000000..eceb53f611d92a2f3e8c5c4c9105d5fc8e4aa507
--- /dev/null
+++ b/src/pystencils/codegen/errors.py
@@ -0,0 +1,2 @@
+class CodegenError(Exception):
+    """Exception that indicates a fatal error in the code generation driver."""
diff --git a/src/pystencils/codegen/lambdas.py b/src/pystencils/codegen/functions.py
similarity index 77%
rename from src/pystencils/codegen/lambdas.py
rename to src/pystencils/codegen/functions.py
index dd0fb571df51fc606ccf77c35e177c830b3e9fec..2779fa289e04cda9bc47fd46e48ff0ada9a98ad1 100644
--- a/src/pystencils/codegen/lambdas.py
+++ b/src/pystencils/codegen/functions.py
@@ -6,17 +6,26 @@ import numpy as np
 from .parameters import Parameter
 from ..types import PsType
 
+from ..backend.kernelcreation import KernelCreationContext
 from ..backend.ast.expressions import PsExpression
 
 
 class Lambda:
     """A one-line function emitted by the code generator as an auxiliary object."""
 
+    @staticmethod
+    def from_expression(ctx: KernelCreationContext, expr: PsExpression):
+        from ..backend.ast.analysis import collect_undefined_symbols
+        from .driver import _get_function_params
+
+        params = _get_function_params(ctx, collect_undefined_symbols(expr))
+        return Lambda(expr, params)
+
     def __init__(self, expr: PsExpression, params: Sequence[Parameter]):
         self._expr = expr
         self._params = tuple(params)
         self._return_type = expr.get_dtype()
-    
+
     @property
     def parameters(self) -> tuple[Parameter, ...]:
         """Parameters to this lambda"""
@@ -29,10 +38,11 @@ class Lambda:
 
     def __call__(self, **kwargs) -> np.generic:
         """Evaluate this lambda with the given arguments.
-        
+
         The lambda must receive a value for each parameter listed in `parameters`.
         """
         from ..backend.ast.expressions import evaluate_expression
+
         return evaluate_expression(self._expr, kwargs)
 
     def __str__(self) -> str:
@@ -41,5 +51,6 @@ class Lambda:
     def c_code(self) -> str:
         """Print the C code of this lambda"""
         from ..backend.emission import CAstPrinter
+
         printer = CAstPrinter()
         return printer(self._expr)
diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py
index 1d9bf9c2c10ea36297b249c264a9f6a50267a614..2b84ef00751ef12bc07f5642e71f7a9805e6184a 100644
--- a/src/pystencils/codegen/gpu_indexing.py
+++ b/src/pystencils/codegen/gpu_indexing.py
@@ -1,9 +1,21 @@
 from __future__ import annotations
 
+from abc import ABC, abstractmethod
+from typing import cast
 from itertools import chain
 
-from .lambdas import Lambda
+from .functions import Lambda
 from .parameters import Parameter
+from .properties import GpuBlockSize
+from .errors import CodegenError
+
+from ..backend.kernelcreation import (
+    KernelCreationContext,
+    FullIterationSpace,
+    SparseIterationSpace,
+)
+from ..backend.platforms.cuda import ThreadToIndexMapping
+from ..backend.ast.expressions import PsExpression
 
 
 _ConstraintTriple = tuple[Lambda | None, Lambda | None, Lambda | None]
@@ -11,7 +23,7 @@ _ConstraintTriple = tuple[Lambda | None, Lambda | None, Lambda | None]
 
 class GpuLaunchGridConstraints:
     """Constraints on the number of threads and blocks on the GPU launch grid for a given kernel.
-    
+
     This constraints set determines all or some of
     the number of threads on a GPU block as well as the number of blocks on the GPU grid,
     statically or depending on runtime parameters.
@@ -49,3 +61,93 @@ class GpuLaunchGridConstraints:
     def grid_size(self) -> _ConstraintTriple:
         """Constraints on the number of blocks on the grid"""
         return self._grid_size
+
+
+class GpuIndexing(ABC):
+    @abstractmethod
+    def get_thread_mapping(self) -> ThreadToIndexMapping | None: ...
+
+    @abstractmethod
+    def get_launch_grid_constraints(self) -> GpuLaunchGridConstraints: ...
+
+
+class Linear3DGpuIndexing(GpuIndexing):
+
+    def __init__(self, ctx: KernelCreationContext) -> None:
+        self._ctx = ctx
+
+        from ..backend.kernelcreation import AstFactory
+
+        self._factory = AstFactory(self._ctx)
+
+    def get_thread_mapping(self) -> ThreadToIndexMapping:
+        from ..backend.platforms.cuda import Linear3DMapping
+
+        return Linear3DMapping()
+
+    def get_launch_grid_constraints(self) -> GpuLaunchGridConstraints:
+        work_items = self._get_work_items()
+        rank = len(work_items)
+
+        from ..backend.constants import PsConstant
+        from ..backend.ast.expressions import PsExpression, PsIntDiv
+
+        block_size_constraints = [None] * rank + [
+            Lambda(self._factory.parse_index(1), ()) for _ in range(3 - rank)
+        ]
+
+        block_size_symbols = [
+            self._ctx.get_new_symbol(f"gpuBlockSize_{c}") for c in range(rank)
+        ]
+        for c, bs in enumerate(block_size_symbols):
+            bs.add_property(GpuBlockSize(c))
+
+        def div_ceil(a: PsExpression, b: PsExpression):
+            return self._factory.parse_index(
+                PsIntDiv(a + b - PsExpression.make(PsConstant(1)), b)
+            )
+
+        grid_size_constraints = [
+            Lambda.from_expression(
+                self._ctx, div_ceil(witems, PsExpression.make(bsize))
+            )
+            for witems, bsize in zip(work_items, block_size_symbols)
+        ] + [
+            Lambda.from_expression(self._ctx, self._factory.parse_index(1))
+            for _ in range(3 - rank)
+        ]
+
+        return GpuLaunchGridConstraints(
+            block_size=cast(_ConstraintTriple, tuple(block_size_constraints)),
+            grid_size=cast(_ConstraintTriple, tuple(grid_size_constraints)),
+        )
+
+    def _get_work_items(self) -> tuple[PsExpression, ...]:
+        ispace = self._ctx.get_iteration_space()
+        match ispace:
+            case FullIterationSpace():
+                dimensions = ispace.dimensions_in_loop_order()[::-1]
+                if len(dimensions) > 3:
+                    raise NotImplementedError(
+                        f"Cannot create a GPU threads range for an {len(dimensions)}-dimensional iteration space"
+                    )
+
+                from ..backend.ast.analysis import collect_undefined_symbols as collect
+
+                for i, dim in enumerate(dimensions):
+                    symbs = collect(dim.start) | collect(dim.stop) | collect(dim.step)
+                    for ctr in ispace.counters:
+                        if ctr in symbs:
+                            raise CodegenError(
+                                "Unable to construct GPU launch grid constraints for this kernel: "
+                                f"Limits in dimension {i} "
+                                f"depend on another dimension's counter {ctr.name}"
+                            )
+
+                return tuple(ispace.actual_iterations(dim) for dim in dimensions)
+
+            case SparseIterationSpace():
+                return (self._factory.parse_index(ispace.index_list.shape[0]),)
+
+            case _:
+                assert False, "unexpected iteration space"
diff --git a/src/pystencils/codegen/kernel.py b/src/pystencils/codegen/kernel.py
index 3adc47876dc36af02ee307dde25ad5d7250cd3fb..8038f24b017bf3b2a51cc0adb87959f8a9ea3c5b 100644
--- a/src/pystencils/codegen/kernel.py
+++ b/src/pystencils/codegen/kernel.py
@@ -6,8 +6,8 @@ from itertools import chain
 
 from .target import Target
 from .parameters import Parameter
+from .gpu_indexing import GpuLaunchGridConstraints
 from ..backend.ast.structural import PsBlock
-from ..backend.ast.expressions import PsExpression
 from ..field import Field
 
 from .._deprecation import _deprecated
@@ -118,7 +118,7 @@ class GpuKernel(Kernel):
     def __init__(
         self,
         body: PsBlock,
-        threads_range: GpuThreadsRange | None,
+        launch_grid_constraints: GpuLaunchGridConstraints,
         target: Target,
         name: str,
         parameters: Sequence[Parameter],
@@ -126,46 +126,9 @@ class GpuKernel(Kernel):
         jit: JitBase,
     ):
         super().__init__(body, target, name, parameters, required_headers, jit)
-        self._threads_range = threads_range
+        self._launch_grid_constraints = launch_grid_constraints
 
     @property
-    def threads_range(self) -> GpuThreadsRange | None:
+    def launch_grid_constraints(self) -> GpuLaunchGridConstraints:
         """Object exposing the total size of the launch grid this kernel expects to be executed with."""
-        return self._threads_range
-
-
-class GpuThreadsRange:
-    """Number of threads required by a GPU kernel, in order (x, y, z)."""
-
-    def __init__(
-        self,
-        num_work_items: Sequence[PsExpression],
-    ):
-        self._dim = len(num_work_items)
-        self._num_work_items = tuple(num_work_items)
-
-    # @property
-    # def grid_size(self) -> tuple[PsExpression, ...]:
-    #     return self._grid_size
-
-    # @property
-    # def block_size(self) -> tuple[PsExpression, ...]:
-    #     return self._block_size
-
-    @property
-    def num_work_items(self) -> tuple[PsExpression, ...]:
-        """Number of work items in (x, y, z)-order."""
-        return self._num_work_items
-
-    @property
-    def dim(self) -> int:
-        return self._dim
-
-    def __str__(self) -> str:
-        rep = "GpuThreadsRange { "
-        rep += "; ".join(f"{x}: {w}" for x, w in zip("xyz", self._num_work_items))
-        rep += " }"
-        return rep
-
-    def _repr_html_(self) -> str:
-        return str(self)
+        return self._launch_grid_constraints
diff --git a/src/pystencils/codegen/properties.py b/src/pystencils/codegen/properties.py
index d377fb3d35d99b59c4f364cc4d066b736bfd9140..df76489db175fb7fc576755a1008edb47f142493 100644
--- a/src/pystencils/codegen/properties.py
+++ b/src/pystencils/codegen/properties.py
@@ -39,3 +39,8 @@ class FieldBasePtr(UniqueSymbolProperty):
 
 FieldProperty = FieldShape | FieldStride | FieldBasePtr
 _FieldProperty = (FieldShape, FieldStride, FieldBasePtr)
+
+
+@dataclass(frozen=True)
+class GpuBlockSize(UniqueSymbolProperty):
+    coordinate: int
diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py
index a407bb75e08bfde9911070aef03b4a1769a6221a..afdbd5097dbd3b33957e1dc342d51ca7e03992c4 100644
--- a/src/pystencils/jit/gpu_cupy.py
+++ b/src/pystencils/jit/gpu_cupy.py
@@ -40,7 +40,7 @@ class CupyKernelWrapper(KernelWrapper):
         self._kfunc: GpuKernel = kfunc
         self._raw_kernel = raw_kernel
         self._block_size = block_size
-        self._num_blocks: tuple[int, int, int] | None = None
+        self._grid_size: tuple[int, int, int] | None = None
         self._args_cache: dict[Any, tuple] = dict()
 
     @property
@@ -61,11 +61,11 @@ class CupyKernelWrapper(KernelWrapper):
 
     @property
     def num_blocks(self) -> tuple[int, int, int] | None:
-        return self._num_blocks
+        return self._grid_size
 
     @num_blocks.setter
     def num_blocks(self, nb: tuple[int, int, int] | None):
-        self._num_blocks = nb
+        self._grid_size = nb
 
     def __call__(self, **kwargs: Any):
         kernel_args, launch_grid = self._get_cached_args(**kwargs)
@@ -80,7 +80,9 @@ class CupyKernelWrapper(KernelWrapper):
         return devices.pop()
 
     def _get_cached_args(self, **kwargs):
-        key = (self._block_size, self._num_blocks) + tuple((k, id(v)) for k, v in kwargs.items())
+        key = (self._block_size, self._grid_size) + tuple(
+            (k, id(v)) for k, v in kwargs.items()
+        )
 
         if key not in self._args_cache:
             args = self._get_args(**kwargs)
@@ -164,6 +166,7 @@ class CupyKernelWrapper(KernelWrapper):
                             elem_dtype: PsType
 
                             from .. import DynamicType
+
                             if isinstance(field.dtype, DynamicType):
                                 assert isinstance(kparam.dtype, PsPointerType)
                                 elem_dtype = kparam.dtype.base_type
@@ -199,42 +202,48 @@ class CupyKernelWrapper(KernelWrapper):
                 add_arg(kparam.name, val, kparam.dtype)
 
         #   Determine launch grid
-        from ..backend.ast.expressions import evaluate_expression
-
-        symbolic_threads_range = self._kfunc.threads_range
 
-        if self._num_blocks is not None:
-            launch_grid = LaunchGrid(self._num_blocks, self._block_size)
+        from ..codegen.gpu_indexing import GpuBlockSize
 
-        elif symbolic_threads_range is not None:
-            threads_range: list[int] = [
-                evaluate_expression(expr, valuation)
-                for expr in symbolic_threads_range.num_work_items
-            ]
+        constraints = self._kfunc.launch_grid_constraints
 
-            if symbolic_threads_range.dim < 3:
-                threads_range += [1] * (3 - symbolic_threads_range.dim)
-
-            def div_ceil(a, b):
-                return a // b if a % b == 0 else a // b + 1
-
-            #   TODO: Refine this?
-            num_blocks = tuple(
-                div_ceil(threads, tpb)
-                for threads, tpb in zip(threads_range, self._block_size)
+        for cparam in constraints.parameters:
+            for prop in cparam.properties:
+                match prop:
+                    case GpuBlockSize(coord):
+                        valuation[cparam.name] = self._block_size[coord]
+                        break
+            else:
+                valuation[cparam.name] = kwargs[cparam.name]
+
+        # launch_block_size: list[int] = []
+        # for coord, (bsize_constr, user_bsize) in enumerate(
+        #     zip(constraints.block_size, self._block_size)
+        # ):
+        #     if bsize_constr is None:
+        #         launch_grid_size
+
+        launch_block_size = [
+            (
+                int(bsize_constr(**valuation))
+                if bsize_constr is not None
+                else self._block_size[coord]
             )
-            assert len(num_blocks) == 3
-
-            launch_grid = LaunchGrid(num_blocks, self._block_size)
-
-        else:
-            raise JitError(
-                "Unable to determine launch grid for GPU kernel invocation: "
-                "No manual grid size was specified, and the number of threads could not "
-                "be determined automatically."
+            for coord, bsize_constr in enumerate(constraints.block_size)
+        ]
+
+        launch_grid_size = [
+            (
+                int(gsize_constr(**valuation))
+                if gsize_constr is not None
+                else self._grid_size[coord]
             )
+            for coord, gsize_constr in enumerate(constraints.grid_size)
+        ]
 
-        return tuple(args), launch_grid
+        return tuple(args), LaunchGrid(
+            tuple(launch_grid_size), tuple(launch_block_size)
+        )
 
 
 class CupyJit(JitBase):
diff --git a/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py b/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py
deleted file mode 100644
index da2b3a5ad3a0e224bc47a5dd0fa4f16b0ccde520..0000000000000000000000000000000000000000
--- a/tests/nbackend/kernelcreation/platform/test_gpu_platforms.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import pytest
-
-from pystencils.field import Field
-
-from pystencils.backend.kernelcreation import (
-    KernelCreationContext,
-    FullIterationSpace
-)
-
-from pystencils.backend.ast.structural import PsBlock, PsComment
-
-from pystencils.backend.platforms import CudaPlatform, SyclPlatform
-
-
-@pytest.mark.parametrize("layout", ["fzyx", "zyxf", "c", "f"])
-@pytest.mark.parametrize("platform_class", [CudaPlatform, SyclPlatform])
-def test_thread_range(platform_class, layout):
-    ctx = KernelCreationContext()
-
-    body = PsBlock([PsComment("Kernel body goes here")])
-    platform = platform_class(ctx)
-
-    dim = 3
-    archetype_field = Field.create_generic("field", spatial_dimensions=dim, layout=layout)
-    ispace = FullIterationSpace.create_with_ghost_layers(ctx, 1, archetype_field)
-
-    _, threads_range = platform.materialize_iteration_space(body, ispace)
-
-    assert threads_range.dim == dim
-    
-    match layout:
-        case "fzyx" | "zyxf" | "f":
-            indexing_order = [0, 1, 2]
-        case "c":
-            indexing_order = [2, 1, 0]
-
-    for i in range(dim):
-        #   Slowest to fastest coordinate
-        coordinate = indexing_order[i]
-        dimension = ispace.dimensions[coordinate]
-        witems = threads_range.num_work_items[i]
-        desired = dimension.stop - dimension.start
-        assert witems.structurally_equal(desired)