From 463bb37d977110510b9562ea50dd2853600cfc27 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 18 Feb 2022 00:05:20 +0100
Subject: [PATCH] Add device allocations

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 examples/lj_func.py           |  4 ++--
 src/pairs/analysis/modules.py |  6 ++++++
 src/pairs/code_gen/cgen.py    | 20 ++++++++++++++++----
 src/pairs/ir/arrays.py        |  1 +
 src/pairs/ir/properties.py    |  1 +
 src/pairs/ir/variables.py     |  1 +
 src/pairs/sim/simulation.py   |  1 +
 7 files changed, 28 insertions(+), 6 deletions(-)

diff --git a/examples/lj_func.py b/examples/lj_func.py
index 61c8119..c7e5618 100644
--- a/examples/lj_func.py
+++ b/examples/lj_func.py
@@ -31,6 +31,6 @@ psim.periodic(2.8)
 psim.vtk_output("output/test")
 psim.compute(lj, cutoff_radius, {'sigma6': sigma6, 'epsilon': epsilon})
 psim.compute(euler, symbols={'dt': dt})
-psim.target(pairs.target_cpu())
-#psim.target(pairs.target_gpu())
+#psim.target(pairs.target_cpu())
+psim.target(pairs.target_gpu())
 psim.generate()
diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py
index 4da4846..6cf31a5 100644
--- a/src/pairs/analysis/modules.py
+++ b/src/pairs/analysis/modules.py
@@ -24,11 +24,17 @@ class FetchModulesReferences(Visitor):
     def visit_Array(self, ast_node):
         for m in self.module_stack:
             m.add_array(ast_node, self.writing)
+            if m.run_on_device:
+                ast_node.device_flag = True
 
     def visit_Property(self, ast_node):
         for m in self.module_stack:
             m.add_property(ast_node, self.writing)
+            if m.run_on_device:
+                ast_node.device_flag = True
 
     def visit_Var(self, ast_node):
         for m in self.module_stack:
             m.add_variable(ast_node, self.writing)
+            if m.run_on_device:
+                ast_node.device_flag = True
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index 7191b0a..e5bb5ae 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -25,14 +25,18 @@ from pairs.code_gen.printer import Printer
 class CGen:
     temp_id = 0
 
-    def __init__(self, output, debug=False):
+    def __init__(self, output, target, debug=False):
         self.sim = None
+        self.target = None
         self.debug = debug
         self.print = Printer(output)
 
     def assign_simulation(self, sim):
         self.sim = sim
 
+    def assign_target(self, target):
+        self.target = target
+
     def generate_program(self, ast_node):
         self.print.start()
         self.print("#include <math.h>")
@@ -101,7 +105,7 @@ class CGen:
             t = ast_node.array.type()
             tkw = Types.ctype2keyword(t)
             size = self.generate_expression(BinOp.inline(ast_node.array.alloc_size()))
-            if ast_node.array.is_static() and ast_node.array.init_value is not None:
+            if ast_node.array.init_value is not None:
                 v_str = str(ast_node.array.init_value)
                 if t == Types.Int64:
                     v_str += "LL"
@@ -180,10 +184,12 @@ class CGen:
             self.print(f"{call};")
 
         if isinstance(ast_node, CopyToDevice):
-            self.print(f"pairs::copy_to_device({ast_node.prop.name()})")
+            array_name = ast_node.prop.name()
+            self.print(f"pairs::copy_to_device({array_name}, d_{array_name})")
 
         if isinstance(ast_node, CopyToHost):
-            self.print(f"pairs::copy_to_host({ast_node.prop.name()})")
+            array_name = ast_node.prop.name()
+            self.print(f"pairs::copy_to_host(d_{array_name}, {array_name})")
 
         if isinstance(ast_node, For):
             iterator = self.generate_expression(ast_node.iterator)
@@ -211,8 +217,12 @@ class CGen:
 
             if ast_node.decl:
                 self.print(f"{tkw} *{array_name} = ({tkw} *) malloc({size});")
+                if self.target.is_gpu() and ast_node.array.device_flag:
+                    self.print(f"{tkw} *d_{array_name} = ({tkw} *) pairs::device_alloc({size});")
             else:
                 self.print(f"{array_name} = ({tkw} *) malloc({size});")
+                if self.target.is_gpu() and ast_node.array.device_flag:
+                    self.print(f"d_{array_name} = ({tkw} *) pairs::device_alloc({size});")
 
         if isinstance(ast_node, ModuleCall):
             module = ast_node.module
@@ -244,6 +254,8 @@ class CGen:
             size = self.generate_expression(ast_node.size)
             array_name = ast_node.array.name()
             self.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});")
+            if self.target.is_gpu() and ast_node.array.device_flag:
+                self.print(f"d_{array_name} = ({tkw} *) pairs::device_realloc(d_{array_name}, {size});")
 
         if isinstance(ast_node, RegisterProperty):
             p = ast_node.property()
diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index 6573c12..8316fc0 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -46,6 +46,7 @@ class Array(ASTNode):
         self.arr_layout = a_layout
         self.arr_ndims = len(self.arr_sizes)
         self.static = False
+        self.device_flag = False
         for var in [s for s in self.arr_sizes if isinstance(s, Var)]:
             var.add_bonded_array(self)
 
diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py
index 5ee78f4..7d85b19 100644
--- a/src/pairs/ir/properties.py
+++ b/src/pairs/ir/properties.py
@@ -57,6 +57,7 @@ class Property(ASTNode):
         self.prop_layout = layout
         self.default_value = default
         self.volatile = volatile
+        self.device_flag = False
         Property.last_prop_id += 1
 
     def __str__(self):
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index d354535..da13df1 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -33,6 +33,7 @@ class Var(ASTTerm):
         self.var_init_value = init_value
         self.mutable = True
         self.var_bonded_arrays = []
+        self.device_flag = False
 
     def __str__(self):
         return f"Var<{self.var_name}>"
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index d5ad6fe..7cca2fc 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -209,6 +209,7 @@ class Simulation:
 
     def target(self, target):
         self._target = target
+        self.code_gen.assign_target(target)
 
     def generate(self):
         assert self._target is not None, "Target not specified!"
-- 
GitLab