From 826ee8e26f2bb09b95bb473b347d06cd6a36207c Mon Sep 17 00:00:00 2001 From: Richard Angersbach <iwia025h@csnhr.nhr.fau.de> Date: Tue, 4 Feb 2025 16:30:13 +0100 Subject: [PATCH] Try supporting pointer dtypes for reductions in cupy gpu jit --- src/pystencils/jit/gpu_cupy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pystencils/jit/gpu_cupy.py b/src/pystencils/jit/gpu_cupy.py index c208ac219..467e86be7 100644 --- a/src/pystencils/jit/gpu_cupy.py +++ b/src/pystencils/jit/gpu_cupy.py @@ -11,7 +11,7 @@ except ImportError: from ..codegen import Target from ..field import FieldType -from ..types import PsType +from ..types import PsType, PsPointerType from .jit import JitBase, JitError, KernelWrapper from ..codegen import ( Kernel, @@ -183,6 +183,9 @@ class CupyKernelWrapper(KernelWrapper): kparam.dtype, ) break + elif isinstance(kparam.dtype, PsPointerType): + val = kwargs[kparam.name] + args.append(val) else: # scalar parameter val: Any = kwargs[kparam.name] -- GitLab