Skip to content
Snippets Groups Projects

Draft: feat: implement `__cuda_array_interface__`

Closed Stephan Seitz requested to merge seitz/pystencils:cuda-array-interface into master
Files
2
@@ -2,17 +2,19 @@ import numpy as np
from pystencils.backends.cbackend import get_headers
from pystencils.backends.cuda_backend import generate_cuda
from pystencils.typing import StructType
from pystencils.field import FieldType
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.typing import StructType
from pystencils.typing.typed_sympy import FieldPointerSymbol
from pystencils.utils import DotDict
USE_FAST_MATH = True
USE_PYCUDA = False
def get_cubic_interpolation_include_paths():
from os.path import join, dirname
from os.path import dirname, join
return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]
@@ -32,8 +34,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
Returns:
compiled kernel as Python function
"""
import pycuda.autoinit # NOQA
from pycuda.compiler import SourceModule
if USE_PYCUDA:
Please register or sign in to reply
import pycuda.autoinit # NOQA
from pycuda.compiler import SourceModule
if argument_dict is None:
argument_dict = {}
@@ -42,17 +45,23 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
includes = "\n".join([f"#include {include_file}" for include_file in header_list])
code = includes + "\n"
code += "#define FUNC_PREFIX __global__\n"
code += "#define FUNC_PREFIX extern \"C\" __global__\n"
code += "#define RESTRICT __restrict__\n\n"
code += str(generate_cuda(kernel_function_node, custom_backend=custom_backend))
nvcc_options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
nvcc_options = ["-w", "-std=c++11"]
if USE_FAST_MATH:
nvcc_options.append("-use_fast_math")
mod = SourceModule(code, options=nvcc_options, include_dirs=[
get_pystencils_include_path(), get_pycuda_include_path()])
func = mod.get_function(kernel_function_node.function_name)
if USE_PYCUDA:
nvcc_options.append("-Wno-deprecated-gpu-targets")
mod = SourceModule(code, options=nvcc_options, include_dirs=[
get_pystencils_include_path(), get_pycuda_include_path()])
func = mod.get_function(kernel_function_node.function_name)
else:
import cupy
nvcc_options.append("-I" + get_pystencils_include_path())
func = cupy.RawKernel(code, kernel_function_node.function_name, options=tuple(nvcc_options), jitify=True)
parameters = kernel_function_node.get_parameters()
@@ -78,13 +87,19 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
args = _build_numpy_argument_list(parameters, full_arguments)
cache[key] = (args, block_and_thread_numbers)
cache_values.append(kwargs) # keep objects alive such that ids remain unique
func(*args, **block_and_thread_numbers)
if USE_PYCUDA:
func(*args, **block_and_thread_numbers)
else:
func(block_and_thread_numbers['grid'],
block_and_thread_numbers['block'], [cupy.asarray(a) for a in args])
# import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
ast = kernel_function_node
parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs
if USE_PYCUDA:
wrapper.num_regs = func.num_regs
return wrapper
@@ -95,7 +110,15 @@ def _build_numpy_argument_list(parameters, argument_dict):
for param in parameters:
if param.is_field_pointer:
array = argument_dict[param.field_name]
actual_type = array.dtype
if hasattr(array, "__cuda_array_interface__"):
interface = DotDict()
interface.shape = array.__cuda_array_interface__["shape"]
interface.dtype = np.dtype(array.__cuda_array_interface__["typestr"])
interface.strides = array.__cuda_array_interface__["strides"]
interface.data = array.__cuda_array_interface__["data"][0]
else:
interface = array
actual_type = interface.dtype
expected_type = param.fields[0].dtype.numpy_dtype
if expected_type != actual_type:
raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
@@ -136,6 +159,14 @@ def _check_arguments(parameter_specification, argument_dict):
except KeyError:
raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))
field_interface = DotDict()
if hasattr(field_arr, "__cuda_array_interface__"):
field_interface.shape = field_arr.__cuda_array_interface__["shape"]
field_interface.dtype = np.dtype(field_arr.__cuda_array_interface__["typestr"])
field_interface.strides = field_arr.__cuda_array_interface__["strides"]
field_interface.data = field_arr.__cuda_array_interface__["data"][0]
field_arr = field_interface
if symbolic_field.has_fixed_shape:
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
if isinstance(symbolic_field.dtype, StructType):
@@ -147,7 +178,7 @@ def _check_arguments(parameter_specification, argument_dict):
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_strides = symbolic_field_strides[:-1]
if symbolic_field_strides != field_arr.strides:
if field_arr.strides and symbolic_field_strides != field_arr.strides:
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
(symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))
Loading