Skip to content
Snippets Groups Projects
Commit 826ee8e2 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Try supporting pointer dtypes for reductions in cupy gpu jit

parent 4c7fd409
1 merge request!438Reduction Support
......@@ -11,7 +11,7 @@ except ImportError:
from ..codegen import Target
from ..field import FieldType
from ..types import PsType
from ..types import PsType, PsPointerType
from .jit import JitBase, JitError, KernelWrapper
from ..codegen import (
Kernel,
......@@ -183,6 +183,9 @@ class CupyKernelWrapper(KernelWrapper):
kparam.dtype,
)
break
elif isinstance(kparam.dtype, PsPointerType):
val = kwargs[kparam.name]
args.append(val)
else:
# scalar parameter
val: Any = kwargs[kparam.name]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment