Skip to content
Snippets Groups Projects
Commit 2e6f3efe authored by Stephan Seitz's avatar Stephan Seitz
Browse files

llvm: Use addressspace 1 (global memory) for nvvm_target

parent a526fe47
Branches
Tags
1 merge request!53Compile CUDA using the LLVM backend
...@@ -300,7 +300,7 @@ def ctypes_from_llvm(data_type): ...@@ -300,7 +300,7 @@ def ctypes_from_llvm(data_type):
raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type)) raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
def to_llvm_type(data_type): def to_llvm_type(data_type, nvvm_target=False):
""" """
Transforms a given type into ctypes Transforms a given type into ctypes
:param data_type: Subclass of Type :param data_type: Subclass of Type
...@@ -309,7 +309,7 @@ def to_llvm_type(data_type): ...@@ -309,7 +309,7 @@ def to_llvm_type(data_type):
if not ir: if not ir:
raise _ir_importerror raise _ir_importerror
if isinstance(data_type, PointerType): if isinstance(data_type, PointerType):
return to_llvm_type(data_type.base_type).as_pointer() return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
else: else:
return to_llvm_type.map[data_type.numpy_dtype] return to_llvm_type.map[data_type.numpy_dtype]
......
...@@ -21,13 +21,13 @@ def _call_sreg(builder, name): ...@@ -21,13 +21,13 @@ def _call_sreg(builder, name):
return builder.call(fn, ()) return builder.call(fn, ())
def generate_llvm(ast_node, module=None, builder=None): def generate_llvm(ast_node, module=None, builder=None, target='cpu'):
"""Prints the ast as llvm code.""" """Prints the ast as llvm code."""
if module is None: if module is None:
module = lc.Module() module = lc.Module()
if builder is None: if builder is None:
builder = ir.IRBuilder() builder = ir.IRBuilder()
printer = LLVMPrinter(module, builder) printer = LLVMPrinter(module, builder, target=target)
return printer._print(ast_node) return printer._print(ast_node)
...@@ -173,7 +173,7 @@ class LLVMPrinter(Printer): ...@@ -173,7 +173,7 @@ class LLVMPrinter(Printer):
parameter_type = [] parameter_type = []
parameters = func.get_parameters() parameters = func.get_parameters()
for parameter in parameters: for parameter in parameters:
parameter_type.append(to_llvm_type(parameter.symbol.dtype)) parameter_type.append(to_llvm_type(parameter.symbol.dtype, nvvm_target=self.target == 'gpu'))
func_type = ir.FunctionType(return_type, tuple(parameter_type)) func_type = ir.FunctionType(return_type, tuple(parameter_type))
name = func.function_name name = func.function_name
fn = ir.Function(self.module, func_type, name) fn = ir.Function(self.module, func_type, name)
...@@ -307,7 +307,7 @@ class LLVMPrinter(Printer): ...@@ -307,7 +307,7 @@ class LLVMPrinter(Printer):
self.builder.branch(after_block) self.builder.branch(after_block)
self.builder.position_at_end(false_block) self.builder.position_at_end(false_block)
phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece))) phi = self.builder.phi(to_llvm_type(get_type_of_expression(piece), nvvm_target=self.target == 'gpu'))
for (val, block) in phi_data: for (val, block) in phi_data:
phi.add_incoming(val, block) phi.add_incoming(val, block)
return phi return phi
......
...@@ -101,7 +101,7 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict, ...@@ -101,7 +101,7 @@ def make_python_function_incomplete_params(kernel_function_node, argument_dict,
def generate_and_jit(ast): def generate_and_jit(ast):
target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu' target = 'gpu' if ast._backend == 'llvm_gpu' else 'cpu'
gen = generate_llvm(ast) gen = generate_llvm(ast, target=target)
if isinstance(gen, ir.Module): if isinstance(gen, ir.Module):
return compile_llvm(gen, target) return compile_llvm(gen, target)
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment