Skip to content
Snippets Groups Projects

Fundamental GPU Support

Merged Frederik Hennig requested to merge fhennig/sycl into backend-rework
Compare and Show latest version
2 files
+ 78
7
Preferences
Compare changes
Files
2
@@ -10,6 +10,7 @@ except ImportError:
from ...enums import Target
from ...kernel_wrapper import KernelWrapper
from ...field import FieldType
from ...types import PsType
from .jit import JitBase, JitError
@@ -42,6 +43,7 @@ class CupyKernelWrapper(KernelWrapper):
self._kfunc = kfunc
self._raw_kernel = raw_kernel
self._block_size = block_size
self._args_cache: dict[Any, tuple] = dict()
@property
def kernel_function(self) -> GpuKernelFunction:
@@ -52,7 +54,7 @@ class CupyKernelWrapper(KernelWrapper):
return self._raw_kernel
def __call__(self, **kwargs: Any):
kernel_args, launch_grid = self._get_args(**kwargs)
kernel_args, launch_grid = self._get_cached_args(**kwargs)
device = self._get_device(kernel_args)
with cp.cuda.Device(device):
self._raw_kernel(launch_grid.grid, launch_grid.block, kernel_args)
@@ -63,6 +65,25 @@ class CupyKernelWrapper(KernelWrapper):
raise JitError("Could not determine CUDA device to execute on")
return devices.pop()
def _get_cached_args(self, **kwargs):
key = hash(
tuple(
(
(k, v.dtype, v.strides, v.shape)
if isinstance(v, cp.ndarray)
else (k, id(v))
)
for k, v in kwargs.items()
)
)
if key not in self._args_cache:
args = self._get_args(**kwargs)
self._args_cache[key] = args
return args
else:
return self._args_cache[key]
def _get_args(self, **kwargs) -> tuple[tuple, LaunchGrid]:
args = []
valuation: dict[str, Any] = dict()
@@ -75,6 +96,48 @@ class CupyKernelWrapper(KernelWrapper):
args.append(arg)
valuation[name] = arg
field_shapes = set()
index_shapes = set()
def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray):
field = field_ptr.field
if field.has_fixed_shape:
expected_shape = tuple(int(s) for s in field.shape)
actual_shape = arr.shape
if expected_shape != actual_shape:
raise ValueError(
f"Array kernel argument {field.name} had unexpected shape:\n"
f" Expected {expected_shape}, but got {actual_shape}"
)
expected_strides = tuple(int(s) for s in field.strides)
actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides)
if expected_strides != actual_strides:
raise ValueError(
f"Array kernel argument {field.name} had unexpected strides:\n"
f" Expected {expected_strides}, but got {actual_strides}"
)
match field.field_type:
case FieldType.GENERIC:
field_shapes.add(arr.shape)
if len(field_shapes) > 1:
raise ValueError(
"Incompatible array shapes:"
"All arrays passed for generic fields to a kernel must have the same shape."
)
case FieldType.INDEXED:
index_shapes.add(arr.shape)
if len(index_shapes) > 1:
raise ValueError(
"Incompatible array shapes:"
"All arrays passed for index fields to a kernel must have the same shape."
)
# Collect parameter values
# TODO: Check array sizes
arr: cp.ndarray
@@ -88,6 +151,7 @@ class CupyKernelWrapper(KernelWrapper):
f"Data type mismatch at array argument {field.name}:"
f"Expected {field.dtype}, got {arr.dtype}"
)
check_shape(kparam, arr)
args.append(arr)
case FieldShapeParam(name, dtype, field, coord):