Skip to content
Snippets Groups Projects

Draft: feat: implement `__cuda_array_interface__`

Closed Stephan Seitz requested to merge seitz/pystencils:cuda-array-interface into master
Files
2
@@ -2,17 +2,19 @@ import numpy as np
@@ -2,17 +2,19 @@ import numpy as np
from pystencils.backends.cbackend import get_headers
from pystencils.backends.cbackend import get_headers
from pystencils.backends.cuda_backend import generate_cuda
from pystencils.backends.cuda_backend import generate_cuda
from pystencils.typing import StructType
from pystencils.field import FieldType
from pystencils.field import FieldType
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernel_wrapper import KernelWrapper
 
from pystencils.typing import StructType
from pystencils.typing.typed_sympy import FieldPointerSymbol
from pystencils.typing.typed_sympy import FieldPointerSymbol
 
from pystencils.utils import DotDict
USE_FAST_MATH = True
USE_FAST_MATH = True
 
USE_PYCUDA = False
def get_cubic_interpolation_include_paths():
def get_cubic_interpolation_include_paths():
from os.path import join, dirname
from os.path import dirname, join
return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
return [join(dirname(__file__), "CubicInterpolationCUDA", "code"),
join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]
join(dirname(__file__), "CubicInterpolationCUDA", "code", "internal")]
@@ -32,9 +34,6 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
@@ -32,9 +34,6 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
Returns:
Returns:
compiled kernel as Python function
compiled kernel as Python function
"""
"""
import pycuda.autoinit # NOQA
from pycuda.compiler import SourceModule
if argument_dict is None:
if argument_dict is None:
argument_dict = {}
argument_dict = {}
@@ -42,17 +41,25 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
@@ -42,17 +41,25 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
includes = "\n".join([f"#include {include_file}" for include_file in header_list])
includes = "\n".join([f"#include {include_file}" for include_file in header_list])
code = includes + "\n"
code = includes + "\n"
code += "#define FUNC_PREFIX __global__\n"
code += "#define FUNC_PREFIX extern \"C\" __global__\n"
code += "#define RESTRICT __restrict__\n\n"
code += "#define RESTRICT __restrict__\n\n"
code += str(generate_cuda(kernel_function_node, custom_backend=custom_backend))
code += str(generate_cuda(kernel_function_node, custom_backend=custom_backend))
nvcc_options = ["-w", "-std=c++11", "-Wno-deprecated-gpu-targets"]
nvcc_options = ["-w", "-std=c++11"]
if USE_FAST_MATH:
if USE_FAST_MATH:
nvcc_options.append("-use_fast_math")
nvcc_options.append("-use_fast_math")
mod = SourceModule(code, options=nvcc_options, include_dirs=[
if USE_PYCUDA:
get_pystencils_include_path(), get_pycuda_include_path()])
import pycuda.autoinit # NOQA
func = mod.get_function(kernel_function_node.function_name)
from pycuda.compiler import SourceModule
 
nvcc_options.append("-Wno-deprecated-gpu-targets")
 
mod = SourceModule(code, options=nvcc_options, include_dirs=[
 
get_pystencils_include_path(), get_pycuda_include_path()])
 
func = mod.get_function(kernel_function_node.function_name)
 
else:
 
import cupy
 
nvcc_options.append("-I" + get_pystencils_include_path())
 
func = cupy.RawKernel(code, kernel_function_node.function_name, options=tuple(nvcc_options), jitify=True)
parameters = kernel_function_node.get_parameters()
parameters = kernel_function_node.get_parameters()
@@ -78,13 +85,19 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
@@ -78,13 +85,19 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
args = _build_numpy_argument_list(parameters, full_arguments)
args = _build_numpy_argument_list(parameters, full_arguments)
cache[key] = (args, block_and_thread_numbers)
cache[key] = (args, block_and_thread_numbers)
cache_values.append(kwargs) # keep objects alive such that ids remain unique
cache_values.append(kwargs) # keep objects alive such that ids remain unique
func(*args, **block_and_thread_numbers)
if USE_PYCUDA:
 
func(*args, **block_and_thread_numbers)
 
else:
 
func(block_and_thread_numbers['grid'],
 
block_and_thread_numbers['block'], [cupy.asarray(a) for a in args])
 
# import pycuda.driver as cuda
# import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
ast = kernel_function_node
ast = kernel_function_node
parameters = kernel_function_node.get_parameters()
parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs
if USE_PYCUDA:
 
wrapper.num_regs = func.num_regs
return wrapper
return wrapper
@@ -95,7 +108,15 @@ def _build_numpy_argument_list(parameters, argument_dict):
@@ -95,7 +108,15 @@ def _build_numpy_argument_list(parameters, argument_dict):
for param in parameters:
for param in parameters:
if param.is_field_pointer:
if param.is_field_pointer:
array = argument_dict[param.field_name]
array = argument_dict[param.field_name]
actual_type = array.dtype
if hasattr(array, "__cuda_array_interface__"):
 
interface = DotDict()
 
interface.shape = array.__cuda_array_interface__["shape"]
 
interface.dtype = np.dtype(array.__cuda_array_interface__["typestr"])
 
interface.strides = array.__cuda_array_interface__["strides"]
 
interface.data = array.__cuda_array_interface__["data"][0]
 
else:
 
interface = array
 
actual_type = interface.dtype
expected_type = param.fields[0].dtype.numpy_dtype
expected_type = param.fields[0].dtype.numpy_dtype
if expected_type != actual_type:
if expected_type != actual_type:
raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
raise ValueError("Data type mismatch for field '%s'. Expected '%s' got '%s'." %
@@ -136,6 +157,14 @@ def _check_arguments(parameter_specification, argument_dict):
@@ -136,6 +157,14 @@ def _check_arguments(parameter_specification, argument_dict):
except KeyError:
except KeyError:
raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))
raise KeyError("Missing field parameter for kernel call " + str(symbolic_field))
 
if hasattr(field_arr, "__cuda_array_interface__"):
 
field_interface = DotDict()
 
field_interface.shape = field_arr.__cuda_array_interface__["shape"]
 
field_interface.dtype = np.dtype(field_arr.__cuda_array_interface__["typestr"])
 
field_interface.strides = field_arr.__cuda_array_interface__["strides"]
 
field_interface.data = field_arr.__cuda_array_interface__["data"][0]
 
field_arr = field_interface
 
if symbolic_field.has_fixed_shape:
if symbolic_field.has_fixed_shape:
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
symbolic_field_shape = tuple(int(i) for i in symbolic_field.shape)
if isinstance(symbolic_field.dtype, StructType):
if isinstance(symbolic_field.dtype, StructType):
@@ -147,7 +176,7 @@ def _check_arguments(parameter_specification, argument_dict):
@@ -147,7 +176,7 @@ def _check_arguments(parameter_specification, argument_dict):
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
symbolic_field_strides = tuple(int(i) * field_arr.dtype.itemsize for i in symbolic_field.strides)
if isinstance(symbolic_field.dtype, StructType):
if isinstance(symbolic_field.dtype, StructType):
symbolic_field_strides = symbolic_field_strides[:-1]
symbolic_field_strides = symbolic_field_strides[:-1]
if symbolic_field_strides != field_arr.strides:
if field_arr.strides and symbolic_field_strides != field_arr.strides:
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
raise ValueError("Passed array '%s' has strides %s which does not match expected strides %s" %
(symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))
(symbolic_field.name, str(field_arr.strides), str(symbolic_field_strides)))
Loading