Skip to content
Snippets Groups Projects
Commit 974febd7 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'macos' into 'master'

OpenCL macOS support

See merge request !87
parents 33648089 885fc9c7
No related branches found
No related tags found
1 merge request!87OpenCL macOS support
Pipeline #19846 passed
...@@ -73,7 +73,7 @@ class OpenClSympyPrinter(CudaSympyPrinter): ...@@ -73,7 +73,7 @@ class OpenClSympyPrinter(CudaSympyPrinter):
function_name, dimension = tuple(symbol_name.split(".")) function_name, dimension = tuple(symbol_name.split("."))
dimension = self.DIMENSION_MAPPING[dimension] dimension = self.DIMENSION_MAPPING[dimension]
function_name = self.INDEXING_FUNCTION_MAPPING[function_name] function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
return f"{function_name}({dimension})" return f"int({function_name}({dimension}))"
def _print_TextureAccess(self, node): def _print_TextureAccess(self, node):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -30,6 +30,16 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen ...@@ -30,6 +30,16 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen
if argument_dict is None: if argument_dict is None:
argument_dict = {} argument_dict = {}
# check if double precision is supported and required
if any([d.double_fp_config == 0 for d in opencl_ctx.devices]):
for param in kernel_function_node.get_parameters():
if param.symbol.dtype.base_type:
if param.symbol.dtype.base_type.numpy_dtype == np.float64:
raise ValueError('OpenCL device does not support double precision')
else:
if param.symbol.dtype.numpy_dtype == np.float64:
raise ValueError('OpenCL device does not support double precision')
# Changing of kernel name necessary since compilation with default name "kernel" is not possible (OpenCL keyword!) # Changing of kernel name necessary since compilation with default name "kernel" is not possible (OpenCL keyword!)
kernel_function_node.function_name = "opencl_" + kernel_function_node.function_name kernel_function_node.function_name = "opencl_" + kernel_function_node.function_name
header_list = ['"opencl_stdint.h"'] + list(get_headers(kernel_function_node)) header_list = ['"opencl_stdint.h"'] + list(get_headers(kernel_function_node))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment