Skip to content
Snippets Groups Projects
Commit 94332a50 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Add device synchronization

parent cfcfb992
Branches
Tags
No related merge requests found
...@@ -29,6 +29,11 @@ __host__ void device_free(void *ptr) { ...@@ -29,6 +29,11 @@ __host__ void device_free(void *ptr) {
CUDA_ASSERT(cudaFree(ptr)); CUDA_ASSERT(cudaFree(ptr));
} }
__host__ void device_synchronize() {
CUDA_ASSERT(cudaPeekAtLastError());
CUDA_ASSERT(cudaDeviceSynchronize());
}
__host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count) { __host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count) {
CUDA_ASSERT(cudaMemcpy(d_ptr, h_ptr, count, cudaMemcpyHostToDevice)); CUDA_ASSERT(cudaMemcpy(d_ptr, h_ptr, count, cudaMemcpyHostToDevice));
} }
......
...@@ -14,6 +14,7 @@ void cuda_assert(cudaError_t err, const char *file, int line); ...@@ -14,6 +14,7 @@ void cuda_assert(cudaError_t err, const char *file, int line);
__host__ void *device_alloc(size_t size); __host__ void *device_alloc(size_t size);
__host__ void *device_realloc(void *ptr, size_t size); __host__ void *device_realloc(void *ptr, size_t size);
__host__ void device_free(void *ptr); __host__ void device_free(void *ptr);
__host__ void device_synchronize();
__host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count); __host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count);
__host__ void copy_to_host(const void *d_ptr, void *h_ptr, size_t count); __host__ void copy_to_host(const void *d_ptr, void *h_ptr, size_t count);
__host__ void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count); __host__ void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count);
......
...@@ -5,6 +5,7 @@ namespace pairs { ...@@ -5,6 +5,7 @@ namespace pairs {
void *device_alloc(size_t size) { return nullptr; } void *device_alloc(size_t size) { return nullptr; }
void *device_realloc(void *ptr, size_t size) { return nullptr; } void *device_realloc(void *ptr, size_t size) { return nullptr; }
void device_free(void *ptr) {} void device_free(void *ptr) {}
void device_synchronize() {}
void copy_to_device(void const *h_ptr, void *d_ptr, size_t count) {} void copy_to_device(void const *h_ptr, void *d_ptr, size_t count) {}
void copy_to_host(void const *d_ptr, void *h_ptr, size_t count) {} void copy_to_host(void const *d_ptr, void *h_ptr, size_t count) {}
void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count) {} void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count) {}
......
...@@ -112,6 +112,7 @@ public: ...@@ -112,6 +112,7 @@ public:
real_t *recv_buf, const int *recv_offsets, const int *nrecv); real_t *recv_buf, const int *recv_offsets, const int *nrecv);
void fillCommunicationArrays(int neighbor_ranks[], int pbc[], real_t subdom[]); void fillCommunicationArrays(int neighbor_ranks[], int pbc[], real_t subdom[]);
void sync() { device_synchronize(); }
}; };
template<typename T_ptr> template<typename T_ptr>
......
...@@ -97,34 +97,34 @@ class CGen: ...@@ -97,34 +97,34 @@ class CGen:
self.print("}") self.print("}")
else: else:
module_params = "" module_params = "PairsSimulation *pairs"
for var in module.read_only_variables(): for var in module.read_only_variables():
type_kw = Types.c_keyword(var.type()) type_kw = Types.c_keyword(var.type())
decl = f"{type_kw} {var.name()}" decl = f"{type_kw} {var.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
for var in module.write_variables(): for var in module.write_variables():
type_kw = Types.c_keyword(var.type()) type_kw = Types.c_keyword(var.type())
decl = f"{type_kw} *{var.name()}" decl = f"{type_kw} *{var.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
for array in module.arrays(): for array in module.arrays():
type_kw = Types.c_keyword(array.type()) type_kw = Types.c_keyword(array.type())
decl = f"{type_kw} *{array.name()}" decl = f"{type_kw} *{array.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
if array in module.host_references(): if array in module.host_references():
decl = f"{type_kw} *h_{array.name()}" decl = f"{type_kw} *h_{array.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
for prop in module.properties(): for prop in module.properties():
type_kw = Types.c_keyword(prop.type()) type_kw = Types.c_keyword(prop.type())
decl = f"{type_kw} *{prop.name()}" decl = f"{type_kw} *{prop.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
if prop in module.host_references(): if prop in module.host_references():
decl = f"{type_kw} *h_{prop.name()}" decl = f"{type_kw} *h_{prop.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
self.print(f"void {module.name}({module_params}) {{") self.print(f"void {module.name}({module_params}) {{")
...@@ -396,35 +396,36 @@ class CGen: ...@@ -396,35 +396,36 @@ class CGen:
self.print(f"if({nblocks} > 0 && {threads_per_block} > 0) {{") self.print(f"if({nblocks} > 0 && {threads_per_block} > 0) {{")
self.print.add_indent(4) self.print.add_indent(4)
self.print(f"{kernel.name}<<<{nblocks}, {threads_per_block}>>>({kernel_params});") self.print(f"{kernel.name}<<<{nblocks}, {threads_per_block}>>>({kernel_params});")
self.print("pairs->sync();")
self.print.add_indent(-4) self.print.add_indent(-4)
self.print("}") self.print("}")
if isinstance(ast_node, ModuleCall): if isinstance(ast_node, ModuleCall):
module = ast_node.module module = ast_node.module
module_params = "" module_params = "pairs"
device_cond = module.run_on_device and self.target.is_gpu() device_cond = module.run_on_device and self.target.is_gpu()
for var in module.read_only_variables(): for var in module.read_only_variables():
decl = var.name() decl = var.name()
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
for var in module.write_variables(): for var in module.write_variables():
decl = f"rv_{var.name()}.getDevicePointer()" if device_cond and var.device_flag else f"&{var.name()}" decl = f"rv_{var.name()}.getDevicePointer()" if device_cond and var.device_flag else f"&{var.name()}"
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
for array in module.arrays(): for array in module.arrays():
decl = f"d_{array.name()}" if device_cond else array.name() decl = f"d_{array.name()}" if device_cond else array.name()
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += decl if len(module_params) <= 0 else f", {decl}"
if array in module.host_references(): if array in module.host_references():
decl = array.name() decl = array.name()
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
for prop in module.properties(): for prop in module.properties():
decl = f"d_{prop.name()}" if device_cond else prop.name() decl = f"d_{prop.name()}" if device_cond else prop.name()
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
if prop in module.host_references(): if prop in module.host_references():
decl = prop.name() decl = prop.name()
module_params += decl if len(module_params) <= 0 else f", {decl}" module_params += f", {decl}"
self.print(f"{module.name}({module_params});") self.print(f"{module.name}({module_params});")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment