diff --git a/runtime/devices/device.hpp b/runtime/devices/device.hpp index 07d671b2c32b8980a2cdab919fc417f1b3d7b024..3ef88ef29d9a46c3e15b5a0a8ec9aca76486dfc5 100644 --- a/runtime/devices/device.hpp +++ b/runtime/devices/device.hpp @@ -19,6 +19,21 @@ __host__ void copy_to_host(const void *d_ptr, void *h_ptr, size_t count); __host__ void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count); __host__ void copy_static_symbol_to_host(void *d_ptr, const void *h_ptr, size_t count); +inline __host__ int host_atomic_add(int *addr, int val) { + *addr += val; + return *addr - val; +} + +inline __host__ int host_atomic_add_resize_check(int *addr, int val, int *resize, int capacity) { + const int add_res = *addr + val; + if(add_res >= capacity) { + *resize = add_res; + return *addr; + } + + return host_atomic_add(addr, val); +} + #ifdef PAIRS_TARGET_CUDA __device__ int atomic_add(int *addr, int val) { return atomicAdd(addr, val); } __device__ int atomic_add_resize_check(int *addr, int val, int *resize, int capacity) { @@ -31,8 +46,10 @@ __device__ int atomic_add_resize_check(int *addr, int val, int *resize, int capa return atomic_add(addr, val); } #else -int atomic_add(int *addr, int val); -int atomic_add_resize_check(int *addr, int val, int *resize, int capacity); +inline int atomic_add(int *addr, int val) { return host_atomic_add(addr, val); } +inline int atomic_add_resize_check(int *addr, int val, int *resize, int capacity) { + return host_atomic_add_resize_check(addr, val, resize, capacity); +} #endif } diff --git a/runtime/devices/dummy.cpp b/runtime/devices/dummy.cpp index c93712d07eacb134765192e1d2ac8ae777af2936..de6a8df0d985ec20770dcf352da871acf404fd31 100644 --- a/runtime/devices/dummy.cpp +++ b/runtime/devices/dummy.cpp @@ -9,19 +9,5 @@ void copy_to_device(void const *h_ptr, void *d_ptr, size_t count) {} void copy_to_host(void const *d_ptr, void *h_ptr, size_t count) {} void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count) {} void copy_static_symbol_to_host(void *d_ptr, const void *h_ptr, size_t count) {} -int atomic_add(int *addr, int val) { - *addr += val; - return *addr - val; -} - -int atomic_add_resize_check(int *addr, int val, int *resize, int capacity) { - const int add_res = *addr + val; - if(add_res >= capacity) { - *resize = add_res; - return *addr; - } - - return atomic_add(addr, val); -} } diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index 8bf493048f0a587c2502a1d1ff784062735640dc..c75ded481f573c55395e198146dc3b353982de2f 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -29,6 +29,10 @@ class FetchModulesReferences(Visitor): self.writing = False self.visit(ast_node.value) + for m in self.module_stack: + if m.run_on_device: + ast_node.device_flag = True + if ast_node.resize is not None: self.visit(ast_node.resize) self.visit(ast_node.capacity) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 270a8b1fbdb3c42f83f02b65646cb714a5374731..9e40f1d21be03578d700428a9cda7ac112253d94 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -255,13 +255,14 @@ class CGen: value = self.generate_expression(atomic_add.value) tkw = Types.c_keyword(atomic_add.type()) acc_ref = f"atm_add{atomic_add.id()}" + prefix = "" if ast_node.elem.device_flag else "host_" if atomic_add.check_for_resize(): resize = self.generate_expression(atomic_add.resize) capacity = self.generate_expression(atomic_add.capacity) - self.print(f"const {tkw} {acc_ref} = pairs::atomic_add_resize_check(&({elem}), {value}, &({resize}), {capacity});") + self.print(f"const {tkw} {acc_ref} = pairs::{prefix}atomic_add_resize_check(&({elem}), {value}, &({resize}), {capacity});") else: - self.print(f"const {tkw} {acc_ref} = pairs::atomic_add(&({elem}), {value});") + self.print(f"const {tkw} {acc_ref} = pairs::{prefix}atomic_add(&({elem}), {value});") if isinstance(ast_node, Branch): cond = self.generate_expression(ast_node.cond) diff --git a/src/pairs/ir/atomic.py b/src/pairs/ir/atomic.py index c492b35d844f90436708c6810e12aeb3e8a9704d..9de1c9eb43def2ad4e21be2352a6bf4578654a7f 100644 --- a/src/pairs/ir/atomic.py +++ b/src/pairs/ir/atomic.py @@ -16,6 +16,7 @@ class AtomicAdd(ASTTerm): self.value = Lit.cvt(sim, value) self.resize = None self.capacity = None + self.device_flag = False def __str__(self): return f"AtomicAdd<{self.elem, self.val}>"