From 6f8144c35407f8882946ceb8fbf2149102dde51c Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 4 Nov 2022 02:30:15 +0100
Subject: [PATCH] Add runtime code for device variables

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/devices/cuda.cu    |  4 ++++
 runtime/devices/device.hpp |  1 +
 runtime/devices/dummy.cpp  |  1 +
 runtime/pairs.hpp          |  6 ++++++
 runtime/runtime_var.hpp    | 28 ++++++++++++++++++++++++++++
 src/pairs/code_gen/cgen.py |  2 +-
 6 files changed, 41 insertions(+), 1 deletion(-)
 create mode 100644 runtime/runtime_var.hpp

diff --git a/runtime/devices/cuda.cu b/runtime/devices/cuda.cu
index f2cd32a..9bb4437 100644
--- a/runtime/devices/cuda.cu
+++ b/runtime/devices/cuda.cu
@@ -25,6 +25,10 @@ __host__ void *device_realloc(void *ptr, size_t size) {
     return new_ptr;
 }
 
+__host__ void device_free(void *ptr) {
+    CUDA_ASSERT(cudaFree(ptr));
+}
+
 __host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count) {
     CUDA_ASSERT(cudaMemcpy(d_ptr, h_ptr, count, cudaMemcpyHostToDevice));
 }
diff --git a/runtime/devices/device.hpp b/runtime/devices/device.hpp
index 9d3c9c9..07d671b 100644
--- a/runtime/devices/device.hpp
+++ b/runtime/devices/device.hpp
@@ -13,6 +13,7 @@ namespace pairs {
 void cuda_assert(cudaError_t err, const char *file, int line);
 __host__ void *device_alloc(size_t size);
 __host__ void *device_realloc(void *ptr, size_t size);
+__host__ void device_free(void *ptr);
 __host__ void copy_to_device(const void *h_ptr, void *d_ptr, size_t count);
 __host__ void copy_to_host(const void *d_ptr, void *h_ptr, size_t count);
 __host__ void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count);
diff --git a/runtime/devices/dummy.cpp b/runtime/devices/dummy.cpp
index 8b8c80a..c93712d 100644
--- a/runtime/devices/dummy.cpp
+++ b/runtime/devices/dummy.cpp
@@ -4,6 +4,7 @@ namespace pairs {
 
 void *device_alloc(size_t size) { return nullptr; }
 void *device_realloc(void *ptr, size_t size) { return nullptr; }
+void device_free(void *ptr) {}
 void copy_to_device(void const *h_ptr, void *d_ptr, size_t count) {}
 void copy_to_host(void const *d_ptr, void *h_ptr, size_t count) {}
 void copy_static_symbol_to_device(void *h_ptr, const void *d_ptr, size_t count) {}
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index 7c5c80d..7be160e 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -7,6 +7,7 @@
 #include "device_flags.hpp"
 #include "pairs_common.hpp"
 #include "property.hpp"
+#include "runtime_var.hpp"
 #include "vector3.hpp"
 #include "devices/device.hpp"
 #include "domain/regular_6d_stencil.hpp"
@@ -40,6 +41,11 @@ public:
     void initDomain(int *argc, char ***argv, real_t xmin, real_t xmax, real_t ymin, real_t ymax, real_t zmin, real_t zmax);
     Regular6DStencil *getDomainPartitioner() { return dom_part; }
 
+    template<typename T>
+    RuntimeVar<T> addDeviceVariable(T *h_ptr) {
+       return RuntimeVar<T>(h_ptr); 
+    }
+
     template<typename T_ptr> void addArray(array_t id, std::string name, T_ptr **h_ptr, std::nullptr_t, size_t size);
     template<typename T_ptr> void addArray(array_t id, std::string name, T_ptr **h_ptr, T_ptr **d_ptr, size_t size);
     template<typename T_ptr> void addStaticArray(array_t id, std::string name, T_ptr *h_ptr, std::nullptr_t, size_t size);
diff --git a/runtime/runtime_var.hpp b/runtime/runtime_var.hpp
new file mode 100644
index 0000000..7cf3aea
--- /dev/null
+++ b/runtime/runtime_var.hpp
@@ -0,0 +1,28 @@
+#include "devices/device.hpp"
+
+#pragma once
+
+namespace pairs {
+
+template<typename T>
+class RuntimeVar{
+protected:
+    T *h_ptr, *d_ptr;
+
+public:
+    RuntimeVar(T *ptr) {
+        h_ptr = ptr;
+        d_ptr = (T *) pairs::device_alloc(sizeof(T));
+    }
+
+    ~RuntimeVar() {
+        pairs::device_free(d_ptr);
+    }
+
+    inline void copyToDevice() { pairs::copy_to_device(h_ptr, d_ptr, sizeof(T)); }
+    inline void copyToHost() { pairs::copy_to_host(d_ptr, h_ptr, sizeof(T)); }
+    inline T *getHostPointer() { return h_ptr; }
+    inline T *getDevicePointer() { return d_ptr; }
+};
+
+}
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 0eb4834..9cdb406 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -503,7 +503,7 @@ class CGen:
             self.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};")
 
             if self.target.is_gpu() and ast_node.var.device_flag:
-                self.print(f"RuntimeVar *rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));")
+                self.print(f"RuntimeVar<{tkw}> rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));")
                 #self.print(f"{tkw} *d_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));")
 
         if isinstance(ast_node, While):
-- 
GitLab