Skip to content
Snippets Groups Projects
Commit 2f69de28 authored by Markus Holzer's avatar Markus Holzer
Browse files

Small changes for better readability

parent 2216f7d3
Branches
No related tags found
No related merge requests found
Pipeline #53869 failed
......@@ -428,6 +428,7 @@ class KernelInfo:
def generate_kernel_invocation_code(self, **kwargs):
ast = self.ast
ast_params = self.parameters
fnc_name = ast.function_name
is_cpu = self.ast.target == Target.CPU
call_parameters = ", ".join([p.symbol.name for p in ast_params])
......@@ -446,20 +447,18 @@ class KernelInfo:
indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)
sp_printer_c = CudaSympyPrinter()
block = tuple(sp_printer_c.doprint(e) for e in indexing_dict['block'])
grid = tuple(sp_printer_c.doprint(e) for e in indexing_dict['grid'])
kernel_launch = f"internal_{ast.function_name}::{ast.function_name}<<<_grid, _block, 0, {stream}>>>({call_parameters});"
kernel_call_lines = [
f"dim3 _block(uint32_t({block[0]}), uint32_t({block[1]}), uint32_t({block[2]}));",
f"dim3 _grid(uint32_t({grid[0]}), uint32_t({grid[1]}), uint32_t({grid[2]}));",
kernel_launch]
f"dim3 _block(uint64_c({block[0]}), uint64_c({block[1]}), uint64_c({block[2]}));",
f"dim3 _grid(uint64_c({grid[0]}), uint64_c({grid[1]}), uint64_c({grid[2]}));",
f"internal_{fnc_name}::{fnc_name}<<<_grid, _block, 0, {stream}>>>({call_parameters});"
]
return "\n".join(kernel_call_lines)
else:
return f"internal_{ast.function_name}::{ast.function_name}({call_parameters});"
return f"internal_{fnc_name}::{fnc_name}({call_parameters});"
def get_vectorize_instruction_set(generation_context):
......
......@@ -35,6 +35,7 @@ class KernelInfo:
def generate_kernel_invocation_code(self, **kwargs):
ast = self.ast
ast_params = self.parameters
fnc_name = ast.function_name
is_cpu = self.ast.target == Target.CPU
call_parameters = ", ".join([p.symbol.name for p in ast_params])
......@@ -53,15 +54,15 @@ class KernelInfo:
indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)
sp_printer_c = CudaSympyPrinter()
block = tuple(sp_printer_c.doprint(e) for e in indexing_dict['block'])
grid = tuple(sp_printer_c.doprint(e) for e in indexing_dict['grid'])
kernel_call_lines = [
"dim3 _block(int(%s), int(%s), int(%s));" % tuple(sp_printer_c.doprint(e)
for e in indexing_dict['block']),
"dim3 _grid(int(%s), int(%s), int(%s));" % tuple(sp_printer_c.doprint(e)
for e in indexing_dict['grid']),
"internal_%s::%s<<<_grid, _block, 0, %s>>>(%s);" % (ast.function_name, ast.function_name,
stream, call_parameters),
f"dim3 _block(uint64_c({block[0]}), uint64_c({block[1]}), uint64_c({block[2]}));",
f"dim3 _grid(uint64_c({grid[0]}), uint64_c({grid[1]}), uint64_c({grid[2]}));",
f"internal_{fnc_name}::{fnc_name}<<<_grid, _block, 0, {stream}>>>({call_parameters});"
]
return "\n".join(kernel_call_lines)
else:
return f"internal_{ast.function_name}::{ast.function_name}({call_parameters});"
return f"internal_{fnc_name}::{fnc_name}({call_parameters});"
......@@ -172,6 +172,7 @@ class KernelCallNode(AbstractKernelSelectionNode):
def get_code(self, **kwargs):
ast = self.ast
ast_params = self.parameters
fnc_name = ast.function_name
is_cpu = self.ast.target == Target.CPU
call_parameters = ", ".join([p.symbol.name for p in ast_params])
......@@ -190,20 +191,18 @@ class KernelCallNode(AbstractKernelSelectionNode):
indexing_dict = ast.indexing.call_parameters(spatial_shape_symbols)
sp_printer_c = CudaSympyPrinter()
block = tuple(sp_printer_c.doprint(e) for e in indexing_dict['block'])
grid = tuple(sp_printer_c.doprint(e) for e in indexing_dict['grid'])
kernel_launch = f"internal_{ast.function_name}::{ast.function_name}<<<_grid, _block, 0, {stream}>>>({call_parameters});"
kernel_call_lines = [
f"dim3 _block(uint32_t({block[0]}), uint32_t({block[1]}), uint32_t({block[2]}));",
f"dim3 _grid(uint32_t({grid[0]}), uint32_t({grid[1]}), uint32_t({grid[2]}));",
kernel_launch]
f"dim3 _block(uint64_c({block[0]}), uint64_c({block[1]}), uint64_c({block[2]}));",
f"dim3 _grid(uint64_c({grid[0]}), uint64_c({grid[1]}), uint64_c({grid[2]}));",
f"internal_{fnc_name}::{fnc_name}<<<_grid, _block, 0, {stream}>>>({call_parameters});"
]
return "\n".join(kernel_call_lines)
else:
return f"internal_{ast.function_name}::{ast.function_name}({call_parameters});"
return f"internal_{fnc_name}::{fnc_name}({call_parameters});"
class SimpleBooleanCondition(AbstractConditionNode):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment