From 216461eb031becf6cd4fe8f75b57a1257a3abea2 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 18 Feb 2022 01:41:21 +0100
Subject: [PATCH] Adjust arrays and properties references according to module
 context

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/code_gen/cgen.py      |  4 ++--
 src/pairs/ir/block.py           | 17 ++++++++++++++++-
 src/pairs/sim/arrays.py         |  4 ++--
 src/pairs/sim/cell_lists.py     |  4 ++--
 src/pairs/sim/interaction.py    |  4 ++--
 src/pairs/sim/lattice.py        |  4 ++--
 src/pairs/sim/pbc.py            |  4 ++--
 src/pairs/sim/properties.py     |  4 ++--
 src/pairs/sim/read_from_file.py |  4 ++--
 src/pairs/sim/variables.py      |  4 ++--
 src/pairs/sim/vtk.py            |  4 ++--
 11 files changed, 36 insertions(+), 21 deletions(-)

diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index e5bb5ae..7b05e17 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -236,11 +236,11 @@ class CGen:
                 module_params += decl if len(module_params) <= 0 else f", {decl}"
 
             for array in module.arrays():
-                decl = array.name()
+                decl = f"d_{array.name()}" if module.run_on_device else array.name()
                 module_params += decl if len(module_params) <= 0 else f", {decl}"
 
             for prop in module.properties():
-                decl = prop.name()
+                decl = f"d_{prop.name()}" if module.run_on_device else prop.name()
                 module_params += decl if len(module_params) <= 0 else f", {decl}"
 
             self.print(f"{module.name}({module_params});")
diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index 2d65656..e0e6019 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -2,7 +2,7 @@ from pairs.ir.ast_node import ASTNode
 from pairs.ir.module import Module
 
 
-def pairs_block(func):
+def pairs_inline(func):
     def inner(*args, **kwargs):
         sim = args[0].sim # self.sim
         sim.init_block()
@@ -12,6 +12,21 @@ def pairs_block(func):
     return inner
 
 
+def pairs_host_block(func):
+    def inner(*args, **kwargs):
+        sim = args[0].sim # self.sim
+        sim.init_block()
+        func(*args, **kwargs)
+        return Module(sim,
+            name=sim._module_name,
+            block=Block(sim, sim._block),
+            resizes_to_check=sim._resizes_to_check,
+            check_properties_resize=sim._check_properties_resize,
+            run_on_device=False)
+
+    return inner
+
+
 def pairs_device_block(func):
     def inner(*args, **kwargs):
         sim = args[0].sim # self.sim
diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py
index 5adcaec..39468c2 100644
--- a/src/pairs/sim/arrays.py
+++ b/src/pairs/sim/arrays.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_block
+from pairs.ir.block import pairs_inline
 from pairs.ir.memory import Malloc
 from pairs.ir.arrays import ArrayDecl
 from pairs.sim.lowerable import Lowerable
@@ -8,7 +8,7 @@ class ArraysDecl(Lowerable):
     def __init__(self, sim):
         super().__init__(sim)
 
-    @pairs_block
+    @pairs_inline
     def lower(self):
         for a in self.sim.arrays.all():
             if a.is_static():
diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py
index b5e5d18..01fd6ce 100644
--- a/src/pairs/sim/cell_lists.py
+++ b/src/pairs/sim/cell_lists.py
@@ -1,7 +1,7 @@
 from functools import reduce
 import math
 from pairs.ir.bin_op import BinOp
-from pairs.ir.block import pairs_device_block
+from pairs.ir.block import pairs_device_block, pairs_host_block
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.cast import Cast
 from pairs.ir.loops import For, ParticleFor
@@ -36,7 +36,7 @@ class CellListsStencilBuild(Lowerable):
         super().__init__(sim)
         self.cell_lists = cell_lists
 
-    @pairs_device_block
+    @pairs_host_block
     def lower(self):
         sim = self.sim
         cl = self.cell_lists
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index 123d701..f597b94 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -1,5 +1,5 @@
 from pairs.ir.bin_op import BinOp
-from pairs.ir.block import Block, pairs_block
+from pairs.ir.block import Block, pairs_device_block
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.loops import For, ParticleFor
 from pairs.ir.types import Types
@@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable):
         yield self.i, self.j
         self.sim.leave()
 
-    @pairs_block
+    @pairs_device_block
     def lower(self):
         if self.nbody == 2:
             position = self.sim.position()
diff --git a/src/pairs/sim/lattice.py b/src/pairs/sim/lattice.py
index bc3e657..9ab1f9b 100644
--- a/src/pairs/sim/lattice.py
+++ b/src/pairs/sim/lattice.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_block
+from pairs.ir.block import pairs_inline
 from pairs.ir.loops import For
 from pairs.ir.types import Types
 from pairs.sim.lowerable import Lowerable
@@ -12,7 +12,7 @@ class ParticleLattice(Lowerable):
         self.props = props
         self.positions = positions
 
-    @pairs_block
+    @pairs_inline
     def lower(self):
         index = None
         loop_indexes = []
diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py
index f1b4571..bb73102 100644
--- a/src/pairs/sim/pbc.py
+++ b/src/pairs/sim/pbc.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_device_block
+from pairs.ir.block import pairs_device_block, pairs_host_block
 from pairs.ir.branches import Branch, Filter
 from pairs.ir.loops import For, ParticleFor
 from pairs.ir.utils import Print
@@ -72,7 +72,7 @@ class SetupPBC(Lowerable):
         super().__init__(sim)
         self.pbc = pbc
 
-    @pairs_device_block
+    @pairs_host_block
     def lower(self):
         sim = self.sim
         ndims = sim.ndims()
diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py
index 1818511..7fce292 100644
--- a/src/pairs/sim/properties.py
+++ b/src/pairs/sim/properties.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_block, pairs_device_block
+from pairs.ir.block import pairs_device_block, pairs_inline
 from pairs.ir.loops import ParticleFor
 from pairs.ir.memory import Malloc, Realloc
 from pairs.ir.properties import RegisterProperty, UpdateProperty
@@ -14,7 +14,7 @@ class PropertiesAlloc(Lowerable):
         self.sim = sim
         self.realloc = realloc
 
-    @pairs_block
+    @pairs_inline
     def lower(self):
         capacity = sum(self.sim.properties.capacities)
         for p in self.sim.properties.all():
diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py
index d20a485..7f37b5f 100644
--- a/src/pairs/sim/read_from_file.py
+++ b/src/pairs/sim/read_from_file.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_block
+from pairs.ir.block import pairs_inline
 from pairs.ir.functions import Call_Int
 from pairs.ir.properties import PropertyList
 from pairs.ir.types import Types
@@ -14,7 +14,7 @@ class ReadFromFile(Lowerable):
         self.grid = MutableGrid(sim, sim.ndims())
         self.grid_buffer = self.sim.add_static_array("grid_buffer", [self.sim.ndims() * 2], Types.Double)
 
-    @pairs_block
+    @pairs_inline
     def lower(self):
         self.sim.nlocal.set(Call_Int(self.sim, "pairs::read_particle_data",
                             [self.filename, self.grid_buffer, self.props, self.props.length()]))
diff --git a/src/pairs/sim/variables.py b/src/pairs/sim/variables.py
index f7b512f..18847b3 100644
--- a/src/pairs/sim/variables.py
+++ b/src/pairs/sim/variables.py
@@ -1,4 +1,4 @@
-from pairs.ir.block import pairs_block
+from pairs.ir.block import pairs_inline
 from pairs.ir.variables import VarDecl
 from pairs.sim.lowerable import Lowerable
 
@@ -7,7 +7,7 @@ class VariablesDecl(Lowerable):
     def __init__(self, sim):
         super().__init__(sim)
 
-    @pairs_block
+    @pairs_inline
     def lower(self):
         for v in self.sim.vars.all():
             VarDecl(self.sim, v)
diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py
index ad09013..dc19a8e 100644
--- a/src/pairs/sim/vtk.py
+++ b/src/pairs/sim/vtk.py
@@ -1,5 +1,5 @@
 from pairs.ir.ast_node import ASTNode
-from pairs.ir.block import pairs_block
+from pairs.ir.block import pairs_inline
 from pairs.ir.functions import Call_Void
 from pairs.ir.lit import Lit 
 from pairs.sim.lowerable import Lowerable
@@ -11,7 +11,7 @@ class VTKWrite(Lowerable):
         self.filename = filename
         self.timestep = Lit.cvt(sim, timestep)
 
-    @pairs_block
+    @pairs_inline
     def lower(self):
         nlocal = self.sim.nlocal
         npbc = self.sim.pbc.npbc
-- 
GitLab