Skip to content
Snippets Groups Projects

Iteration Slices: Extended GPU support + bugfixes

Merged Frederik Hennig requested to merge fhennig/gpu-iteration-spaces into v2.0-dev
All threads resolved!
Compare and
23 files
+ 584
191
Preferences
Compare changes
Files
23
@@ -41,6 +41,7 @@ class CupyKernelWrapper(KernelWrapper):
self._kfunc: GpuKernelFunction = kfunc
self._raw_kernel = raw_kernel
self._block_size = block_size
self._num_blocks: tuple[int, int, int] | None = None
self._args_cache: dict[Any, tuple] = dict()
@property
@@ -59,6 +60,14 @@ class CupyKernelWrapper(KernelWrapper):
def block_size(self, bs: tuple[int, int, int]):
self._block_size = bs
@property
def num_blocks(self) -> tuple[int, int, int] | None:
return self._num_blocks
@num_blocks.setter
def num_blocks(self, nb: tuple[int, int, int] | None):
self._num_blocks = nb
def __call__(self, **kwargs: Any):
kernel_args, launch_grid = self._get_cached_args(**kwargs)
device = self._get_device(kernel_args)
@@ -72,7 +81,7 @@ class CupyKernelWrapper(KernelWrapper):
return devices.pop()
def _get_cached_args(self, **kwargs):
key = (self._block_size,) + tuple((k, id(v)) for k, v in kwargs.items())
key = (self._block_size, self._num_blocks) + tuple((k, id(v)) for k, v in kwargs.items())
if key not in self._args_cache:
args = self._get_args(**kwargs)
@@ -185,25 +194,36 @@ class CupyKernelWrapper(KernelWrapper):
symbolic_threads_range = self._kfunc.threads_range
threads_range: list[int] = [
evaluate_expression(expr, valuation)
for expr in symbolic_threads_range.num_work_items
]
if self._num_blocks is not None:
launch_grid = LaunchGrid(self._num_blocks, self._block_size)
if symbolic_threads_range.dim < 3:
threads_range += [1] * (3 - symbolic_threads_range.dim)
elif symbolic_threads_range is not None:
threads_range: list[int] = [
evaluate_expression(expr, valuation)
for expr in symbolic_threads_range.num_work_items
]
def div_ceil(a, b):
return a // b if a % b == 0 else a // b + 1
if symbolic_threads_range.dim < 3:
threads_range += [1] * (3 - symbolic_threads_range.dim)
# TODO: Refine this?
grid_size = tuple(
div_ceil(threads, tpb)
for threads, tpb in zip(threads_range, self._block_size)
)
assert len(grid_size) == 3
def div_ceil(a, b):
return a // b if a % b == 0 else a // b + 1
# TODO: Refine this?
num_blocks = tuple(
div_ceil(threads, tpb)
for threads, tpb in zip(threads_range, self._block_size)
)
assert len(num_blocks) == 3
launch_grid = LaunchGrid(num_blocks, self._block_size)
launch_grid = LaunchGrid(grid_size, self._block_size)
else:
raise JitError(
"Unable to determine launch grid for GPU kernel invocation: "
"No manual grid size was specified, and the number of threads could not "
"be determined automatically."
)
return tuple(args), launch_grid