diff --git a/runtime/devices/cuda.cu b/runtime/devices/cuda.cu index f2cd32a550db6d911df0beab57595e0c0717f977..9bb443742b3b09cb9769398f4aee080c5b0d132d 100644 --- a/runtime/devices/cuda.cu +++ b/runtime/devices/cuda.cu @@ -25,6 +25,10 @@ __host__ void *device_realloc(void *ptr, size_t size) { return new_ptr; } +__host__ void device_free(void *ptr) { + CUDA_ASSERT(cudaFree(ptr)); +} + __host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count) { CUDA_ASSERT(cudaMemcpy(d_ptr, h_ptr, count, cudaMemcpyHostToDevice)); } diff --git a/runtime/devices/device.hpp b/runtime/devices/device.hpp index 9d3c9c9e685ee2bf58a161345170e99b1f9f7f93..07d671b2c32b8980a2cdab919fc417f1b3d7b024 100644 --- a/runtime/devices/device.hpp +++ b/runtime/devices/device.hpp @@ -13,6 +13,7 @@ namespace pairs { void cuda_assert(cudaError_t err, const char *file, int line); __host__ void *device_alloc(size_t size); __host__ void *device_realloc(void *ptr, size_t size); +__host__ void device_free(void *ptr); __host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count); __host__ void copy_to_host(const void *d_ptr, void *h_ptr, size_t count); __host__ void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count); diff --git a/runtime/devices/dummy.cpp b/runtime/devices/dummy.cpp index 8b8c80a85027c2c4a861dfce9e0cb57d66175b65..c93712d07eacb134765192e1d2ac8ae777af2936 100644 --- a/runtime/devices/dummy.cpp +++ b/runtime/devices/dummy.cpp @@ -4,6 +4,7 @@ namespace pairs { void *device_alloc(size_t size) { return nullptr; } void *device_realloc(void *ptr, size_t size) { return nullptr; } +void device_free(void *ptr) {} void copy_to_device(void const *h_ptr, void *d_ptr, size_t count) {} void copy_to_host(void const *d_ptr, void *h_ptr, size_t count) {} void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count) {} diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp index 7c5c80d65503466a8c2f937e18a62e81d64fdd9a..7be160e2609e3a8eb5e5f74eb8fda1533c2a2458 100644 --- a/runtime/pairs.hpp +++ b/runtime/pairs.hpp @@ -7,6 +7,7 @@ #include "device_flags.hpp" #include "pairs_common.hpp" #include "property.hpp" +#include "runtime_var.hpp" #include "vector3.hpp" #include "devices/device.hpp" #include "domain/regular_6d_stencil.hpp" @@ -40,6 +41,11 @@ public: void initDomain(int *argc, char ***argv, real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax); Regular6DStencil *getDomainPartitioner() { return dom_part; } + template<typename T> + RuntimeVar<T> addDeviceVariable(T *h_ptr) { + return RuntimeVar<T>(h_ptr); + } + template<typename T_ptr> void addArray(array_t id, std::string name, T_ptr **h_ptr, std::nullptr_t, size_t size); template<typename T_ptr> void addArray(array_t id, std::string name, T_ptr **h_ptr, T_ptr **d_ptr, size_t size); template<typename T_ptr> void addStaticArray(array_t id, std::string name, T_ptr *h_ptr, std::nullptr_t, size_t size); diff --git a/runtime/runtime_var.hpp b/runtime/runtime_var.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7cf3aeaa9b6c32883299bfcb44044db457f4af48 --- /dev/null +++ b/runtime/runtime_var.hpp @@ -0,0 +1,28 @@ +#include "devices/device.hpp" + +#pragma once + +namespace pairs { + +template<typename T> +class RuntimeVar{ +protected: + T *h_ptr, *d_ptr; + +public: + RuntimeVar(T *ptr) { + h_ptr = ptr; + d_ptr = (T *) pairs::device_alloc(sizeof(T)); + } + + ~RuntimeVar() { + pairs::device_free(d_ptr); + } + + inline void copyToDevice() { pairs::copy_to_device(h_ptr, d_ptr, sizeof(T)); } + inline void copyToHost() { pairs::copy_to_host(d_ptr, h_ptr, sizeof(T)); } + inline T *getHostPointer() { return h_ptr; } + inline T *getDevicePointer() { return d_ptr; } +}; + +} diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 0eb4834b492440256e5712dec022854a72f5fe1b..9cdb4064914b97b66aa2c14503284f56c9845a7b 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -503,7 +503,7 @@ class CGen: self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};") if self.target.is_gpu() and ast_node.var.device_flag: - self.print(f"RuntimeVar *rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") + self.print(f"RuntimeVar<{tkw}> rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") #self.print(f"{tkw} *d_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") if isinstance(ast_node, While):