Skip to content
Snippets Groups Projects
Commit 826ee8e2 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Try supporting pointer dtypes for reductions in cupy gpu jit

parent 4c7fd409
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -11,7 +11,7 @@ except ImportError: ...@@ -11,7 +11,7 @@ except ImportError:
from ..codegen import Target from ..codegen import Target
from ..field import FieldType from ..field import FieldType
from ..types import PsType from ..types import PsType, PsPointerType
from .jit import JitBase, JitError, KernelWrapper from .jit import JitBase, JitError, KernelWrapper
from ..codegen import ( from ..codegen import (
Kernel, Kernel,
...@@ -183,6 +183,9 @@ class CupyKernelWrapper(KernelWrapper): ...@@ -183,6 +183,9 @@ class CupyKernelWrapper(KernelWrapper):
kparam.dtype, kparam.dtype,
) )
break break
elif isinstance(kparam.dtype, PsPointerType):
val = kwargs[kparam.name]
args.append(val)
else: else:
# scalar parameter # scalar parameter
val: Any = kwargs[kparam.name] val: Any = kwargs[kparam.name]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment