From 352cb06ec92a845d4fede6fb62b5cebebab8e9d4 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Wed, 29 Nov 2023 21:21:37 +0100
Subject: [PATCH] Make device copies and status change in one function

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 runtime/pairs.cpp                    | 72 ++++++++++++++++++-----
 runtime/pairs.hpp                    | 87 +++++++++++++++++++---------
 runtime/thermo.hpp                   |  4 +-
 runtime/vtk.hpp                      |  4 +-
 src/pairs/analysis/devices.py        | 27 +++++++--
 src/pairs/analysis/modules.py        |  3 +-
 src/pairs/code_gen/cgen.py           | 66 +++++----------------
 src/pairs/ir/device.py               | 63 ++++----------------
 src/pairs/ir/kernel.py               | 61 ++++++++++++++-----
 src/pairs/ir/module.py               | 44 +++++++++++---
 src/pairs/sim/arrays.py              |  5 --
 src/pairs/sim/comm.py                |  8 +--
 src/pairs/transformations/devices.py | 37 ++++++------
 13 files changed, 277 insertions(+), 204 deletions(-)

diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp
index cbd8a28..352653f 100644
--- a/runtime/pairs.cpp
+++ b/runtime/pairs.cpp
@@ -109,8 +109,9 @@ FeatureProperty &PairsSimulation::getFeaturePropertyByName(std::string name) {
     return *fp;
 }
 
-void PairsSimulation::copyArrayToDevice(Array &array) {
+void PairsSimulation::copyArrayToDevice(Array &array, bool write) {
     int array_id = array.getId();
+
     if(!array_flags->isDeviceFlagSet(array_id)) {
         if(array.isStatic()) {
             PAIRS_DEBUG("Copying static array %s to device\n", array.getName().c_str());
@@ -119,11 +120,18 @@ void PairsSimulation::copyArrayToDevice(Array &array) {
             PAIRS_DEBUG("Copying array %s to device\n", array.getName().c_str());
             pairs::copy_to_device(array.getHostPointer(), array.getDevicePointer(), array.getSize());
         }
+
+        array_flags->setDeviceFlag(array_id);
+    }
+
+    if(write) {
+        array_flags->clearHostFlag(array_id);
     }
 }
 
-void PairsSimulation::copyArrayToHost(Array &array) {
+void PairsSimulation::copyArrayToHost(Array &array, bool write) {
     int array_id = array.getId();
+
     if(!array_flags->isHostFlagSet(array_id)) {
         if(array.isStatic()) {
             PAIRS_DEBUG("Copying static array %s to host\n", array.getName().c_str());
@@ -132,34 +140,68 @@ void PairsSimulation::copyArrayToHost(Array &array) {
             PAIRS_DEBUG("Copying array %s to host\n", array.getName().c_str());
             pairs::copy_to_host(array.getDevicePointer(), array.getHostPointer(), array.getSize());
         }
+
+        array_flags->setHostFlag(array_id);
+    }
+
+    if(write) {
+        array_flags->clearDeviceFlag(array_id);
     }
 }
 
-void PairsSimulation::copyPropertyToDevice(Property &prop) {
-    if(!prop_flags->isDeviceFlagSet(prop.getId())) {
+void PairsSimulation::copyPropertyToDevice(Property &prop, bool write) {
+    int prop_id = prop.getId();
+
+    if(!prop_flags->isDeviceFlagSet(prop_id)) {
         PAIRS_DEBUG("Copying property %s to device\n", prop.getName().c_str());
         pairs::copy_to_device(prop.getHostPointer(), prop.getDevicePointer(), prop.getTotalSize());
+        prop_flags->setDeviceFlag(prop_id);
+    }
+
+    if(write) {
+        prop_flags->clearHostFlag(prop_id);
     }
 }
 
-void PairsSimulation::copyPropertyToHost(Property &prop) {
-    if(!prop_flags->isHostFlagSet(prop.getId())) {
+void PairsSimulation::copyPropertyToHost(Property &prop, bool write) {
+    int prop_id = prop.getId();
+
+    if(!prop_flags->isHostFlagSet(prop_id)) {
         PAIRS_DEBUG("Copying property %s to host\n", prop.getName().c_str());
         pairs::copy_to_host(prop.getDevicePointer(), prop.getHostPointer(), prop.getTotalSize());
+        prop_flags->setHostFlag(prop_id);
+    }
+
+    if(write) {
+        prop_flags->clearDeviceFlag(prop_id);
     }
 }
 
-void PairsSimulation::copyContactPropertyToDevice(ContactProperty &contact_prop) {
-    if(!contact_prop_flags->isDeviceFlagSet(contact_prop.getId())) {
+void PairsSimulation::copyContactPropertyToDevice(ContactProperty &contact_prop, bool write) {
+    int prop_id = contact_prop.getId();
+
+    if(!contact_prop_flags->isDeviceFlagSet(prop_id)) {
         PAIRS_DEBUG("Copying contact property %s to device\n", contact_prop.getName().c_str());
         pairs::copy_to_device(contact_prop.getHostPointer(), contact_prop.getDevicePointer(), contact_prop.getTotalSize());
+        contact_prop_flags->setDeviceFlag(prop_id);
+    }
+
+    if(write) {
+        contact_prop_flags->clearHostFlag(prop_id);
     }
 }
 
-void PairsSimulation::copyContactPropertyToHost(ContactProperty &contact_prop) {
+void PairsSimulation::copyContactPropertyToHost(ContactProperty &contact_prop, bool write) {
+    int prop_id = contact_prop.getId();
+
     if(!contact_prop_flags->isHostFlagSet(contact_prop.getId())) {
         PAIRS_DEBUG("Copying contact property %s to host\n", contact_prop.getName().c_str());
         pairs::copy_to_host(contact_prop.getDevicePointer(), contact_prop.getHostPointer(), contact_prop.getTotalSize());
+        contact_prop_flags->setHostFlag(prop_id);
+    }
+
+    if(write) {
+        contact_prop_flags->clearDeviceFlag(prop_id);
     }
 }
 
@@ -172,7 +214,7 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv
     auto nsend_id = getArrayByHostPointer(send_sizes).getId();
     auto nrecv_id = getArrayByHostPointer(recv_sizes).getId();
 
-    copyArrayToHost(nsend_id);
+    copyArrayToHost(nsend_id, false);
     array_flags->setHostFlag(nrecv_id);
     array_flags->clearDeviceFlag(nrecv_id);
     this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes);
@@ -191,11 +233,11 @@ void PairsSimulation::communicateData(
     auto nsend_id = getArrayByHostPointer(nsend).getId();
     auto nrecv_id = getArrayByHostPointer(nrecv).getId();
 
-    copyArrayToHost(send_buf_id);
-    copyArrayToHost(send_offsets_id);
-    copyArrayToHost(recv_offsets_id);
-    copyArrayToHost(nsend_id);
-    copyArrayToHost(nrecv_id);
+    copyArrayToHost(send_buf_id, false);
+    copyArrayToHost(send_offsets_id, false);
+    copyArrayToHost(recv_offsets_id, false);
+    copyArrayToHost(nsend_id, false);
+    copyArrayToHost(nrecv_id, false);
     array_flags->setHostFlag(recv_buf_id);
     array_flags->clearDeviceFlag(recv_buf_id);
     this->getDomainPartitioner()->communicateData(dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv);
diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp
index aa6be15..ec26c6a 100644
--- a/runtime/pairs.hpp
+++ b/runtime/pairs.hpp
@@ -122,45 +122,78 @@ public:
     void setArrayDeviceFlag(Array &array) { array_flags->setDeviceFlag(array.getId()); }
     void clearArrayDeviceFlag(array_t id) { clearArrayDeviceFlag(getArray(id)); }
     void clearArrayDeviceFlag(Array &array) { array_flags->clearDeviceFlag(array.getId()); }
-    void copyArrayToDevice(array_t id) { copyArrayToDevice(getArray(id)); }
-    void copyArrayToDevice(Array &array);
+    void copyArrayToDevice(array_t id, bool write) { copyArrayToDevice(getArray(id), write); }
+    void copyArrayToDevice(Array &array, bool write);
 
     void setArrayHostFlag(array_t id) { setArrayHostFlag(getArray(id)); }
     void setArrayHostFlag(Array &array) { array_flags->setHostFlag(array.getId()); }
     void clearArrayHostFlag(array_t id) { clearArrayHostFlag(getArray(id)); }
     void clearArrayHostFlag(Array &array) { array_flags->clearHostFlag(array.getId()); }
-    void copyArrayToHost(array_t id) { copyArrayToHost(getArray(id)); }
-    void copyArrayToHost(Array &array);
+    void copyArrayToHost(array_t id, bool write) { copyArrayToHost(getArray(id), write); }
+    void copyArrayToHost(Array &array, bool write);
 
     void setPropertyDeviceFlag(property_t id) { setPropertyDeviceFlag(getProperty(id)); }
     void setPropertyDeviceFlag(Property &prop) { prop_flags->setDeviceFlag(prop.getId()); }
     void clearPropertyDeviceFlag(property_t id) { clearPropertyDeviceFlag(getProperty(id)); }
     void clearPropertyDeviceFlag(Property &prop) { prop_flags->clearDeviceFlag(prop.getId()); }
-    void copyPropertyToDevice(property_t id) { copyPropertyToDevice(getProperty(id)); }
-    void copyPropertyToDevice(Property &prop);
+    void copyPropertyToDevice(property_t id, bool write) { copyPropertyToDevice(getProperty(id), write); }
+    void copyPropertyToDevice(Property &prop, bool write);
 
     void setPropertyHostFlag(property_t id) { setPropertyHostFlag(getProperty(id)); }
     void setPropertyHostFlag(Property &prop) { prop_flags->setHostFlag(prop.getId()); }
     void clearPropertyHostFlag(property_t id) { clearPropertyHostFlag(getProperty(id)); }
     void clearPropertyHostFlag(Property &prop) { prop_flags->clearHostFlag(prop.getId()); }
-    void copyPropertyToHost(property_t id) { copyPropertyToHost(getProperty(id)); }
-    void copyPropertyToHost(Property &prop);
-
-    void setContactPropertyDeviceFlag(property_t id) { setContactPropertyDeviceFlag(getContactProperty(id)); }
-    void setContactPropertyDeviceFlag(ContactProperty &prop) { contact_prop_flags->setDeviceFlag(prop.getId()); }
-    void clearContactPropertyDeviceFlag(property_t id) { clearContactPropertyDeviceFlag(getContactProperty(id)); }
-    void clearContactPropertyDeviceFlag(ContactProperty &prop) { contact_prop_flags->clearDeviceFlag(prop.getId()); }
-    void copyContactPropertyToDevice(property_t id) { copyContactPropertyToDevice(getContactProperty(id)); }
-    void copyContactPropertyToDevice(ContactProperty &prop);
-
-    void setContactPropertyHostFlag(property_t id) { setContactPropertyHostFlag(getContactProperty(id)); }
-    void setContactPropertyHostFlag(ContactProperty &prop) { contact_prop_flags->setHostFlag(prop.getId()); }
-    void clearContactPropertyHostFlag(property_t id) { clearContactPropertyHostFlag(getContactProperty(id)); }
-    void clearContactPropertyHostFlag(ContactProperty &prop) { contact_prop_flags->clearHostFlag(prop.getId()); }
-    void copyContactPropertyToHost(property_t id) { copyContactPropertyToHost(getContactProperty(id)); }
-    void copyContactPropertyToHost(ContactProperty &prop);
-
-    void copyFeaturePropertyToDevice(property_t id) { copyFeaturePropertyToDevice(getFeatureProperty(id)); }
+    void copyPropertyToHost(property_t id, bool write) { copyPropertyToHost(getProperty(id), write); }
+    void copyPropertyToHost(Property &prop, bool write);
+
+    void setContactPropertyDeviceFlag(property_t id) {
+        setContactPropertyDeviceFlag(getContactProperty(id));
+    }
+
+    void setContactPropertyDeviceFlag(ContactProperty &prop) {
+        contact_prop_flags->setDeviceFlag(prop.getId());
+    }
+
+    void clearContactPropertyDeviceFlag(property_t id) {
+        clearContactPropertyDeviceFlag(getContactProperty(id));
+    }
+
+    void clearContactPropertyDeviceFlag(ContactProperty &prop) {
+        contact_prop_flags->clearDeviceFlag(prop.getId());
+    }
+
+    void copyContactPropertyToDevice(property_t id, bool write) {
+        copyContactPropertyToDevice(getContactProperty(id), write);
+    }
+
+    void copyContactPropertyToDevice(ContactProperty &prop, bool write);
+
+    void setContactPropertyHostFlag(property_t id) {
+        setContactPropertyHostFlag(getContactProperty(id));
+    }
+
+    void setContactPropertyHostFlag(ContactProperty &prop) {
+        contact_prop_flags->setHostFlag(prop.getId());
+    }
+
+    void clearContactPropertyHostFlag(property_t id) {
+        clearContactPropertyHostFlag(getContactProperty(id));
+    }
+
+    void clearContactPropertyHostFlag(ContactProperty &prop) {
+        contact_prop_flags->clearHostFlag(prop.getId());
+    }
+
+    void copyContactPropertyToHost(property_t id, bool write) {
+        copyContactPropertyToHost(getContactProperty(id), write);
+    }
+
+    void copyContactPropertyToHost(ContactProperty &prop, bool write);
+
+    void copyFeaturePropertyToDevice(property_t id) {
+        copyFeaturePropertyToDevice(getFeatureProperty(id));
+    }
+
     void copyFeaturePropertyToDevice(FeatureProperty &feature_prop);
 
     void communicateSizes(int dim, const int *send_sizes, int *recv_sizes);
@@ -240,7 +273,7 @@ void PairsSimulation::reallocArray(array_t id, T_ptr **h_ptr, T_ptr **d_ptr, siz
     *h_ptr = (T_ptr *) new_h_ptr;
     *d_ptr = (T_ptr *) new_d_ptr;
     if(array_flags->isDeviceFlagSet(id)) {
-        copyArrayToDevice(id);
+        copyArrayToDevice(id, false);
     }
 }
 
@@ -310,7 +343,7 @@ void PairsSimulation::reallocProperty(property_t id, T_ptr **h_ptr, T_ptr **d_pt
     *h_ptr = (T_ptr *) new_h_ptr;
     *d_ptr = (T_ptr *) new_d_ptr;
     if(prop_flags->isDeviceFlagSet(id)) {
-        copyPropertyToDevice(id);
+        copyPropertyToDevice(id, false);
     }
 }
 
@@ -380,7 +413,7 @@ void PairsSimulation::reallocContactProperty(property_t id, T_ptr **h_ptr, T_ptr
     *h_ptr = (T_ptr *) new_h_ptr;
     *d_ptr = (T_ptr *) new_d_ptr;
     if(contact_prop_flags->isDeviceFlagSet(id)) {
-        copyContactPropertyToDevice(id);
+        copyContactPropertyToDevice(id, false);
     }
 }
 
diff --git a/runtime/thermo.hpp b/runtime/thermo.hpp
index 8f00fd2..993e9c8 100644
--- a/runtime/thermo.hpp
+++ b/runtime/thermo.hpp
@@ -26,8 +26,8 @@ double compute_thermo(PairsSimulation *ps, int nlocal, double xprd, double yprd,
     //const double e_scale = 0.5;
     double t = 0.0, p;
 
-    ps->copyPropertyToHost(masses);
-    ps->copyPropertyToHost(velocities);
+    ps->copyPropertyToHost(masses, false);
+    ps->copyPropertyToHost(velocities, false);
 
     for(int i = 0; i < nlocal; i++) {
         t += masses(i) * (  velocities(i, 0) * velocities(i, 0) +
diff --git a/runtime/vtk.hpp b/runtime/vtk.hpp
index 3e69500..527f074 100644
--- a/runtime/vtk.hpp
+++ b/runtime/vtk.hpp
@@ -29,8 +29,8 @@ void vtk_write_data(PairsSimulation *ps, const char *filename, int start, int en
     filename_oss << timestep << ".vtk";
     std::ofstream out_file(filename_oss.str());
 
-    ps->copyPropertyToHost(masses);
-    ps->copyPropertyToHost(positions);
+    ps->copyPropertyToHost(masses, false);
+    ps->copyPropertyToHost(positions, false);
 
     for(int i = start; i < end; i++) {
         if(flags(i) & FLAGS_INFINITE) {
diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py
index 154683a..07f3a17 100644
--- a/src/pairs/analysis/devices.py
+++ b/src/pairs/analysis/devices.py
@@ -48,10 +48,11 @@ class FetchKernelReferences(Visitor):
         self.writing = writing_state
 
     def visit_Assign(self, ast_node):
+        self.writing = False
+        self.visit(ast_node._src)
         self.writing = True
         self.visit(ast_node._dest)
         self.writing = False
-        self.visit(ast_node._src)
 
     def visit_AtomicAdd(self, ast_node):
         self.writing = True
@@ -75,11 +76,25 @@ class FetchKernelReferences(Visitor):
         self.visit_children(ast_node)
         self.kernel_stack.pop()
 
-        ast_node.add_array_access([a for a in self.kernel_used_array_accesses[kernel_id] if a not in self.kernel_decls[kernel_id]])
-        ast_node.add_scalar_op([b for b in self.kernel_used_scalar_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place])
-        ast_node.add_vector_op([b for b in self.kernel_used_vector_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place])
-        ast_node.add_matrix_op([b for b in self.kernel_used_matrix_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place])
-        ast_node.add_quaternion_op([b for b in self.kernel_used_quat_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place])
+        ast_node.add_array_access(
+            [a for a in self.kernel_used_array_accesses[kernel_id] \
+             if a not in self.kernel_decls[kernel_id]])
+
+        ast_node.add_scalar_op(
+            [b for b in self.kernel_used_scalar_ops[kernel_id] \
+            if b not in self.kernel_decls[kernel_id] and not b.in_place])
+
+        ast_node.add_vector_op(
+            [b for b in self.kernel_used_vector_ops[kernel_id] \
+            if b not in self.kernel_decls[kernel_id] and not b.in_place])
+
+        ast_node.add_matrix_op(
+            [b for b in self.kernel_used_matrix_ops[kernel_id] \
+            if b not in self.kernel_decls[kernel_id] and not b.in_place])
+
+        ast_node.add_quaternion_op(
+            [b for b in self.kernel_used_quat_ops[kernel_id] \
+            if b not in self.kernel_decls[kernel_id] and not b.in_place])
 
     def visit_PropertyAccess(self, ast_node):
         # Visit property and save current writing state
diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py
index 4a8ebad..5b97619 100644
--- a/src/pairs/analysis/modules.py
+++ b/src/pairs/analysis/modules.py
@@ -18,10 +18,11 @@ class FetchModulesReferences(Visitor):
         self.writing = writing_state
 
     def visit_Assign(self, ast_node):
+        self.writing = False
+        self.visit(ast_node._src)
         self.writing = True
         self.visit(ast_node._dest)
         self.writing = False
-        self.visit(ast_node._src)
 
     def visit_AtomicAdd(self, ast_node):
         self.writing = True
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index dba78f3..cf16325 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -8,7 +8,7 @@ from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
 from pairs.ir.declaration import Decl
 from pairs.ir.scalars import ScalarOp
-from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef
+from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
 from pairs.ir.features import FeatureProperty, FeaturePropertyAccess, RegisterFeatureProperty
 from pairs.ir.functions import Call
 from pairs.ir.kernel import KernelLaunch
@@ -456,64 +456,28 @@ class CGen:
         if isinstance(ast_node, CopyArray):
             array_id = ast_node.array.id()
             array_name = ast_node.array.name()
+            ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host"
+            write = "true" if ast_node.write else "false"
+            self.print(f"pairs->copyArrayTo{ctx_suffix}({array_id}, {write}); // {array_name}")
 
-            if ast_node.context() == Contexts.Device:
-                self.print(f"pairs->copyArrayToDevice({array_id}); // {array_name}")
-            else:
-                self.print(f"pairs->copyArrayToHost({array_id}); // {array_name}")
+        if isinstance(ast_node, CopyContactProperty):
+            prop_id = ast_node.contact_prop.id()
+            prop_name = ast_node.contact_prop.name()
+            write = "true" if ast_node.write else "false"
+            ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host"
+            self.print(f"pairs->copyContactPropertyTo{ctx_suffix}({prop_id}, {write}); // {prop_name}")
 
         if isinstance(ast_node, CopyProperty):
             prop_id = ast_node.prop.id()
             prop_name = ast_node.prop.name()
-
-            if ast_node.context() == Contexts.Device:
-                self.print(f"pairs->copyPropertyToDevice({prop_id}); // {prop_name}")
-            else:
-                self.print(f"pairs->copyPropertyToHost({prop_id}); // {prop_name}")
+            write = "true" if ast_node.write else "false"
+            ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host"
+            self.print(f"pairs->copyPropertyTo{ctx_suffix}({prop_id}, {write}); // {prop_name}")
 
         if isinstance(ast_node, CopyVar):
             var_name = ast_node.variable.name()
-
-            if ast_node.context() == Contexts.Device:
-                self.print(f"rv_{var_name}.copyToDevice();")
-            else:
-                self.print(f"rv_{var_name}.copyToHost();")
-
-        if isinstance(ast_node, ClearArrayFlag):
-            array_id = ast_node.array.id()
-            array_name = ast_node.array.name()
-
-            if ast_node.context() == Contexts.Device:
-                self.print(f"pairs->clearArrayDeviceFlag({array_id}); // {array_name}")
-            else:
-                self.print(f"pairs->clearArrayHostFlag({array_id}); // {array_name}")
-
-        if isinstance(ast_node, ClearPropertyFlag):
-            prop_id = ast_node.prop.id()
-            prop_name = ast_node.prop.name()
-
-            if ast_node.context() == Contexts.Device:
-                self.print(f"pairs->clearPropertyDeviceFlag({prop_id}); // {prop_name}")
-            else:
-                self.print(f"pairs->clearPropertyHostFlag({prop_id}); // {prop_name}")
-
-        if isinstance(ast_node, SetArrayFlag):
-            array_id = ast_node.array.id()
-            array_name = ast_node.array.name()
-
-            if ast_node.context() == Contexts.Device:
-                self.print(f"pairs->setArrayDeviceFlag({array_id}); // {array_name}")
-            else:
-                self.print(f"pairs->setArrayHostFlag({array_id}); // {array_name}")
-
-        if isinstance(ast_node, SetPropertyFlag):
-            prop_id = ast_node.prop.id()
-            prop_name = ast_node.prop.name()
-
-            if ast_node.context() == Contexts.Device:
-                self.print(f"pairs->setPropertyDeviceFlag({prop_id}); // {prop_name}")
-            else:
-                self.print(f"pairs->setPropertyHostFlag({prop_id}); // {prop_name}")
+            ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host"
+            self.print(f"rv_{var_name}.copyTo{ctx_suffix}();")
 
         if isinstance(ast_node, For):
             iterator = self.generate_expression(ast_node.iterator)
diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py
index cef8e0f..96af505 100644
--- a/src/pairs/ir/device.py
+++ b/src/pairs/ir/device.py
@@ -30,10 +30,11 @@ class DeviceStaticRef(ASTNode):
 
 
 class CopyArray(ASTNode):
-    def __init__(self, sim, array, ctx):
+    def __init__(self, sim, array, ctx, write):
         super().__init__(sim)
         self.array = array
         self.ctx = ctx
+        self.write = write
         self.sim.add_statement(self)
 
     def context(self):
@@ -44,10 +45,11 @@ class CopyArray(ASTNode):
 
 
 class CopyProperty(ASTNode):
-    def __init__(self, sim, prop, ctx):
+    def __init__(self, sim, prop, ctx, write):
         super().__init__(sim)
         self.prop = prop
         self.ctx = ctx
+        self.write = write
         self.sim.add_statement(self)
 
     def context(self):
@@ -57,39 +59,12 @@ class CopyProperty(ASTNode):
         return [self.prop]
 
 
-class CopyVar(ASTNode):
-    def __init__(self, sim, variable, ctx):
+class CopyContactProperty(ASTNode):
+    def __init__(self, sim, prop, ctx, write):
         super().__init__(sim)
-        self.variable = variable
-        self.ctx = ctx
-        self.sim.add_statement(self)
-
-    def context(self):
-        return self.ctx
-
-    def children(self):
-        return [self.variable]
-
-
-class ClearArrayFlag(ASTNode):
-    def __init__(self, sim, array, ctx):
-        super().__init__(sim)
-        self.array = array
-        self.ctx = ctx
-        self.sim.add_statement(self)
-
-    def context(self):
-        return self.ctx
-
-    def children(self):
-        return [self.array]
-
-
-class ClearPropertyFlag(ASTNode):
-    def __init__(self, sim, prop, ctx):
-        super().__init__(sim)
-        self.prop = prop
+        self.contact_prop = prop
         self.ctx = ctx
+        self.write = write
         self.sim.add_statement(self)
 
     def context(self):
@@ -99,24 +74,10 @@ class ClearPropertyFlag(ASTNode):
         return [self.prop]
 
 
-class SetArrayFlag(ASTNode):
-    def __init__(self, sim, array, ctx):
-        super().__init__(sim)
-        self.array = array
-        self.ctx = ctx
-        self.sim.add_statement(self)
-
-    def context(self):
-        return self.ctx
-
-    def children(self):
-        return [self.array]
-
-
-class SetPropertyFlag(ASTNode):
-    def __init__(self, sim, prop, ctx):
+class CopyVar(ASTNode):
+    def __init__(self, sim, variable, ctx):
         super().__init__(sim)
-        self.prop = prop
+        self.variable = variable
         self.ctx = ctx
         self.sim.add_statement(self)
 
@@ -124,4 +85,4 @@ class SetPropertyFlag(ASTNode):
         return self.ctx
 
     def children(self):
-        return [self.prop]
+        return [self.variable]
diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py
index a29773d..12ed846 100644
--- a/src/pairs/ir/kernel.py
+++ b/src/pairs/ir/kernel.py
@@ -85,66 +85,99 @@ class Kernel(ASTNode):
     def add_array(self, array, write=False):
         array_list = array if isinstance(array, list) else [array]
         character = 'w' if write else 'r'
+
         for a in array_list:
-            assert isinstance(a, Array), "Kernel.add_array(): Element is not of type Array."
-            self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character
+            assert isinstance(a, Array), \
+                "Kernel.add_array(): Element is not of type Array."
+
+            self._arrays[a] = character if a not in self._arrays else \
+                              self._arrays[a] + character
 
     def add_variable(self, variable, write=False):
         variable_list = variable if isinstance(variable, list) else [variable]
         character = 'w' if write else 'r'
+
         for v in variable_list:
             if not v.temporary():
-                assert isinstance(v, Var), "Kernel.add_variable(): Element is not of type Var."
-                self._variables[v] = character if v not in self._variables else self._variables[v] + character
+                assert isinstance(v, Var), \
+                    "Kernel.add_variable(): Element is not of type Var."
+
+                self._variables[v] = character if v not in self._variables else \
+                                     self._variables[v] + character
 
     def add_property(self, prop, write=False):
         prop_list = prop if isinstance(prop, list) else [prop]
         character = 'w' if write else 'r'
+
         for p in prop_list:
-            assert isinstance(p, Property), "Kernel.add_property(): Element is not of type Property."
-            self._properties[p] = character if p not in self._properties else self._properties[p] + character
+            assert isinstance(p, Property), \
+                "Kernel.add_property(): Element is not of type Property."
+
+            self._properties[p] = character if p not in self._properties else \
+                                  self._properties[p] + character
 
     def add_contact_property(self, contact_prop, write=False):
         contact_prop_list = contact_prop if isinstance(contact_prop, list) else [contact_prop]
         character = 'w' if write else 'r'
+
         for cp in contact_prop_list:
-            assert isinstance(cp, ContactProperty), "Kernel.add_contact_property(): Element is not of type ContactProperty."
-            self._contact_properties[cp] = character if cp not in self._contact_properties else self._contact_properties[cp] + character
+            assert isinstance(cp, ContactProperty), \
+                "Kernel.add_contact_property(): Element is not of type ContactProperty."
+
+            self._contact_properties[cp] = character if cp not in self._contact_properties else \
+                                           self._contact_properties[cp] + character
 
     def add_feature_property(self, feature_prop):
         feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop]
+
         for fp in feature_prop_list:
-            assert isinstance(fp, FeatureProperty), "Kernel.add_feature_property(): Element is not of type FeatureProperty."
+            assert isinstance(fp, FeatureProperty), \
+                "Kernel.add_feature_property(): Element is not of type FeatureProperty."
+
             self._feature_properties[fp] = 'r'
 
     def add_array_access(self, array_access):
         array_access_list = array_access if isinstance(array_access, list) else [array_access]
         for a in array_access_list:
-            assert isinstance(a, ArrayAccess), "Kernel.add_array_access(): Element is not of type ArrayAccess."
+            assert isinstance(a, ArrayAccess), \
+                "Kernel.add_array_access(): Element is not of type ArrayAccess."
+
             self._array_accesses.add(a)
 
     def add_scalar_op(self, scalar_op):
         scalar_op_list = scalar_op if isinstance(scalar_op, list) else [scalar_op]
+
         for b in scalar_op_list:
-            assert isinstance(b, ScalarOp), "Kernel.add_scalar_op(): Element is not of type ScalarOp."
+            assert isinstance(b, ScalarOp), \
+                "Kernel.add_scalar_op(): Element is not of type ScalarOp."
+
             self._scalar_ops.append(b)
 
     def add_vector_op(self, vector_op):
         vector_op_list = vector_op if isinstance(vector_op, list) else [vector_op]
+
         for b in vector_op_list:
-            assert isinstance(b, VectorOp), "Kernel.add_vector_op(): Element is not of type VectorOp."
+            assert isinstance(b, VectorOp), \
+                "Kernel.add_vector_op(): Element is not of type VectorOp."
+
             self._vector_ops.append(b)
 
     def add_matrix_op(self, matrix_op):
         matrix_op_list = matrix_op if isinstance(matrix_op, list) else [matrix_op]
+
         for b in matrix_op_list:
-            assert isinstance(b, MatrixOp), "Kernel.add_matrix_op(): Element is not of type MatrixOp."
+            assert isinstance(b, MatrixOp), \
+                "Kernel.add_matrix_op(): Element is not of type MatrixOp."
+
             self._matrix_ops.append(b)
 
     def add_quaternion_op(self, quat_op):
         quat_op_list = quat_op if isinstance(quat_op, list) else [quat_op]
+
         for b in quat_op_list:
-            assert isinstance(b, QuaternionOp), "Kernel.add_quaternion_op(): Element is not of type QuaternionOp."
+            assert isinstance(b, QuaternionOp), \
+                "Kernel.add_quaternion_op(): Element is not of type QuaternionOp."
+
             self._quat_ops.append(b)
 
     def children(self):
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
index 71bd663..abbc401 100644
--- a/src/pairs/ir/module.py
+++ b/src/pairs/ir/module.py
@@ -86,6 +86,13 @@ class Module(ASTNode):
     def write_properties(self):
         return {p for p in self._properties if 'w' in self._properties[p]}
 
+    def contact_properties_to_synchronize(self):
+        #return {cp for cp in self._contact_properties if self._contact_properties[cp][0] == 'r'}
+        return {cp for cp in self._contact_properties}
+
+    def write_contact_properties(self):
+        return {cp for cp in self._contact_properties if 'w' in self._contact_properties[cp]}
+
     def arrays_to_synchronize(self):
         #return {a for a in self._arrays if a.sync() and self._arrays[a][0] == 'r'}
         return {a for a in self._arrays if a.sync()}
@@ -96,35 +103,54 @@ class Module(ASTNode):
     def add_array(self, array, write=False):
         array_list = array if isinstance(array, list) else [array]
         character = 'w' if write else 'r'
+
         for a in array_list:
-            assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!"
-            self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character
+            assert isinstance(a, Array), \
+                "Module.add_array(): given element is not of type Array!"
+
+            self._arrays[a] = character if a not in self._arrays else \
+                              self._arrays[a] + character
 
     def add_variable(self, variable, write=False):
         variable_list = variable if isinstance(variable, list) else [variable]
         character = 'w' if write else 'r'
+
         for v in variable_list:
-            assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!"
-            self._variables[v] = character if v not in self._variables else self._variables[v] + character
+            assert isinstance(v, Var), \
+                "Module.add_variable(): given element is not of type Var!"
+
+            self._variables[v] = character if v not in self._variables else \
+                                 self._variables[v] + character
 
     def add_property(self, prop, write=False):
         prop_list = prop if isinstance(prop, list) else [prop]
         character = 'w' if write else 'r'
+
         for p in prop_list:
-            assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!"
-            self._properties[p] = character if p not in self._properties else self._properties[p] + character
+            assert isinstance(p, Property), \
+                "Module.add_property(): given element is not of type Property!"
+
+            self._properties[p] = character if p not in self._properties else \
+                                  self._properties[p] + character
 
     def add_contact_property(self, contact_prop, write=False):
         contact_prop_list = contact_prop if isinstance(contact_prop, list) else [contact_prop]
         character = 'w' if write else 'r'
+
         for cp in contact_prop_list:
-            assert isinstance(cp, ContactProperty), "Module.add_contact_property(): given element is not of type ContactProperty!"
-            self._contact_properties[cp] = character if cp not in self._contact_properties else self._contact_properties[cp] + character
+            assert isinstance(cp, ContactProperty), \
+                "Module.add_contact_property(): given element is not of type ContactProperty!"
+
+            self._contact_properties[cp] = character if cp not in self._contact_properties else \
+                                           self._contact_properties[cp] + character
 
     def add_feature_property(self, feature_prop):
         feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop]
+
         for fp in feature_prop_list:
-            assert isinstance(fp, FeatureProperty), "Module.add_feature_property(): given element is not of type FeatureProperty!"
+            assert isinstance(fp, FeatureProperty), \
+                "Module.add_feature_property(): given element is not of type FeatureProperty!"
+
             self._feature_properties[fp] = 'r'
 
     def add_host_reference(self, elem):
diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py
index d3fa32d..7ea490b 100644
--- a/src/pairs/sim/arrays.py
+++ b/src/pairs/sim/arrays.py
@@ -1,6 +1,5 @@
 from pairs.ir.block import pairs_inline
 from pairs.ir.contexts import Contexts
-from pairs.ir.device import ClearArrayFlag
 from pairs.ir.memory import Malloc
 from pairs.ir.arrays import DeclareStaticArray, RegisterArray
 from pairs.sim.lowerable import FinalLowerable
@@ -17,7 +16,3 @@ class DeclareArrays(FinalLowerable):
                 DeclareStaticArray(self.sim, a)
 
             RegisterArray(self.sim, a, a.alloc_size())
-
-            if not a.sync():
-                ClearArrayFlag(self.sim, self.sim.resizes, Contexts.Host)
-                ClearArrayFlag(self.sim, self.sim.resizes, Contexts.Device)
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index c13c349..5152eef 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -17,10 +17,10 @@ class Comm:
         self.sim = sim
         self.dom_part = dom_part
         self.nsend_all      = sim.add_var('nsend_all', Types.Int32)
-        self.send_capacity  = sim.add_var('send_capacity', Types.Int32, 100000)
-        self.recv_capacity  = sim.add_var('recv_capacity', Types.Int32, 100000)
-        self.elem_capacity  = sim.add_var('elem_capacity', Types.Int32, 10)
-        self.neigh_capacity = sim.add_var('neigh_capacity', Types.Int32, 6)
+        self.send_capacity  = sim.add_var('send_capacity', Types.Int32, 200000)
+        self.recv_capacity  = sim.add_var('recv_capacity', Types.Int32, 200000)
+        self.elem_capacity  = sim.add_var('elem_capacity', Types.Int32, 40)
+        self.neigh_capacity = sim.add_var('neigh_capacity', Types.Int32, 10)
         self.nsend          = sim.add_array('nsend', [self.neigh_capacity], Types.Int32)
         self.send_offsets   = sim.add_array('send_offsets', [self.neigh_capacity], Types.Int32)
         self.send_buffer    = sim.add_array('send_buffer', [self.send_capacity, self.elem_capacity], Types.Real)
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index c2d02fc..45764a5 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -4,7 +4,7 @@ from pairs.ir.block import Block
 from pairs.ir.branches import Filter
 from pairs.ir.cast import Cast
 from pairs.ir.contexts import Contexts
-from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef
+from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
 from pairs.ir.kernel import Kernel, KernelLaunch
 from pairs.ir.lit import Lit
 from pairs.ir.loops import For
@@ -34,34 +34,34 @@ class AddDeviceCopies(Mutator):
                     copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host
                     clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device
 
-                    for a in s.module.arrays_to_synchronize():
-                        new_stmts += [CopyArray(s.sim, a, copy_context)]
+                    for array in s.module.arrays_to_synchronize():
+                        write = array in s.module.write_arrays()
+                        new_stmts += [CopyArray(s.sim, array, copy_context, write)]
 
-                    for p in s.module.properties_to_synchronize():
-                        new_stmts += [CopyProperty(s.sim, p, copy_context)]
+                    for prop in s.module.properties_to_synchronize():
+                        write = prop in s.module.write_properties()
+                        new_stmts += [CopyProperty(s.sim, prop, copy_context, write)]
 
-                    for a in s.module.write_arrays():
-                        new_stmts += [SetArrayFlag(s.sim, a, copy_context), ClearArrayFlag(s.sim, a, clear_context)]
-
-                    for p in s.module.write_properties():
-                        new_stmts += [SetPropertyFlag(s.sim, p, copy_context), ClearPropertyFlag(s.sim, p, clear_context)]
+                    for contact_prop in s.module.contact_properties_to_synchronize():
+                        write = prop in s.module.write_contact_properties()
+                        new_stmts += [CopyContactProperty(s.sim, contact_prop, copy_context, write)]
 
                     if self.module_resizes[s.module] and s.module.run_on_device:
-                        new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device)]
+                        new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device, False)]
 
                     if s.module.run_on_device:
-                        for v in s.module.variables_to_synchronize():
-                            new_stmts += [CopyVar(s.sim, v, Contexts.Device)]
+                        for var in s.module.variables_to_synchronize():
+                            new_stmts += [CopyVar(s.sim, var, Contexts.Device)]
 
                 new_stmts.append(s)
 
                 if isinstance(s, ModuleCall):
                     if s.module.run_on_device:
-                        for v in s.module.variables_to_synchronize():
-                            new_stmts += [CopyVar(s.sim, v, Contexts.Host)]
+                        for var in s.module.variables_to_synchronize():
+                            new_stmts += [CopyVar(s.sim, var, Contexts.Host)]
 
                         if self.module_resizes[s.module]:
-                            new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)]
+                            new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, False)]
 
         ast_node.stmts = new_stmts
         return ast_node
@@ -83,7 +83,10 @@ class AddDeviceKernels(Mutator):
                         kernel_name = f"{ast_node.name}_kernel{kernel_id}"
                         kernel = ast_node.sim.find_kernel_by_name(kernel_name)
                         if kernel is None:
-                            kernel_body = Filter(ast_node.sim, ScalarOp.inline(s.iterator < s.max.copy(True)), s.block)
+                            kernel_body = Filter(ast_node.sim,
+                                                 ScalarOp.inline(s.iterator < s.max.copy(True)),
+                                                 s.block)
+
                             kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
                             kernel_id += 1
 
-- 
GitLab