Skip to content
Snippets Groups Projects
Commit 826ee8e2 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Try supporting pointer dtypes for reductions in cupy gpu jit

parent 4c7fd409
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -11,7 +11,7 @@ except ImportError:
from ..codegen import Target
from ..field import FieldType
from ..types import PsType
from ..types import PsType, PsPointerType
from .jit import JitBase, JitError, KernelWrapper
from ..codegen import (
Kernel,
......@@ -183,6 +183,9 @@ class CupyKernelWrapper(KernelWrapper):
kparam.dtype,
)
break
elif isinstance(kparam.dtype, PsPointerType):
val = kwargs[kparam.name]
args.append(val)
else:
# scalar parameter
val: Any = kwargs[kparam.name]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment