From 352cb06ec92a845d4fede6fb62b5cebebab8e9d4 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Wed, 29 Nov 2023 21:21:37 +0100 Subject: [PATCH] Make device copies and status change in one function Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- runtime/pairs.cpp | 72 ++++++++++++++++++----- runtime/pairs.hpp | 87 +++++++++++++++++++--------- runtime/thermo.hpp | 4 +- runtime/vtk.hpp | 4 +- src/pairs/analysis/devices.py | 27 +++++++-- src/pairs/analysis/modules.py | 3 +- src/pairs/code_gen/cgen.py | 66 +++++---------------- src/pairs/ir/device.py | 63 ++++---------------- src/pairs/ir/kernel.py | 61 ++++++++++++++----- src/pairs/ir/module.py | 44 +++++++++++--- src/pairs/sim/arrays.py | 5 -- src/pairs/sim/comm.py | 8 +-- src/pairs/transformations/devices.py | 37 ++++++------ 13 files changed, 277 insertions(+), 204 deletions(-) diff --git a/runtime/pairs.cpp b/runtime/pairs.cpp index cbd8a28..352653f 100644 --- a/runtime/pairs.cpp +++ b/runtime/pairs.cpp @@ -109,8 +109,9 @@ FeatureProperty &PairsSimulation::getFeaturePropertyByName(std::string name) { return *fp; } -void PairsSimulation::copyArrayToDevice(Array &array) { +void PairsSimulation::copyArrayToDevice(Array &array, bool write) { int array_id = array.getId(); + if(!array_flags->isDeviceFlagSet(array_id)) { if(array.isStatic()) { PAIRS_DEBUG("Copying static array %s to device\n", array.getName().c_str()); @@ -119,11 +120,18 @@ void PairsSimulation::copyArrayToDevice(Array &array) { PAIRS_DEBUG("Copying array %s to device\n", array.getName().c_str()); pairs::copy_to_device(array.getHostPointer(), array.getDevicePointer(), array.getSize()); } + + array_flags->setDeviceFlag(array_id); + } + + if(write) { + array_flags->clearHostFlag(array_id); } } -void PairsSimulation::copyArrayToHost(Array &array) { +void PairsSimulation::copyArrayToHost(Array &array, bool write) { int array_id = array.getId(); + if(!array_flags->isHostFlagSet(array_id)) { if(array.isStatic()) { PAIRS_DEBUG("Copying static array %s to host\n", array.getName().c_str()); @@ -132,34 +140,68 @@ void PairsSimulation::copyArrayToHost(Array &array) { PAIRS_DEBUG("Copying array %s to host\n", array.getName().c_str()); pairs::copy_to_host(array.getDevicePointer(), array.getHostPointer(), array.getSize()); } + + array_flags->setHostFlag(array_id); + } + + if(write) { + array_flags->clearDeviceFlag(array_id); } } -void PairsSimulation::copyPropertyToDevice(Property &prop) { - if(!prop_flags->isDeviceFlagSet(prop.getId())) { +void PairsSimulation::copyPropertyToDevice(Property &prop, bool write) { + int prop_id = prop.getId(); + + if(!prop_flags->isDeviceFlagSet(prop_id)) { PAIRS_DEBUG("Copying property %s to device\n", prop.getName().c_str()); pairs::copy_to_device(prop.getHostPointer(), prop.getDevicePointer(), prop.getTotalSize()); + prop_flags->setDeviceFlag(prop_id); + } + + if(write) { + prop_flags->clearHostFlag(prop_id); } } -void PairsSimulation::copyPropertyToHost(Property &prop) { - if(!prop_flags->isHostFlagSet(prop.getId())) { +void PairsSimulation::copyPropertyToHost(Property &prop, bool write) { + int prop_id = prop.getId(); + + if(!prop_flags->isHostFlagSet(prop_id)) { PAIRS_DEBUG("Copying property %s to host\n", prop.getName().c_str()); pairs::copy_to_host(prop.getDevicePointer(), prop.getHostPointer(), prop.getTotalSize()); + prop_flags->setHostFlag(prop_id); + } + + if(write) { + prop_flags->clearDeviceFlag(prop_id); } } -void PairsSimulation::copyContactPropertyToDevice(ContactProperty &contact_prop) { - if(!contact_prop_flags->isDeviceFlagSet(contact_prop.getId())) { +void PairsSimulation::copyContactPropertyToDevice(ContactProperty &contact_prop, bool write) { + int prop_id = contact_prop.getId(); + + if(!contact_prop_flags->isDeviceFlagSet(prop_id)) { PAIRS_DEBUG("Copying contact property %s to device\n", contact_prop.getName().c_str()); pairs::copy_to_device(contact_prop.getHostPointer(), contact_prop.getDevicePointer(), contact_prop.getTotalSize()); + contact_prop_flags->setDeviceFlag(prop_id); + } + + if(write) { + contact_prop_flags->clearHostFlag(prop_id); } } -void PairsSimulation::copyContactPropertyToHost(ContactProperty &contact_prop) { +void PairsSimulation::copyContactPropertyToHost(ContactProperty &contact_prop, bool write) { + int prop_id = contact_prop.getId(); + if(!contact_prop_flags->isHostFlagSet(contact_prop.getId())) { PAIRS_DEBUG("Copying contact property %s to host\n", contact_prop.getName().c_str()); pairs::copy_to_host(contact_prop.getDevicePointer(), contact_prop.getHostPointer(), contact_prop.getTotalSize()); + contact_prop_flags->setHostFlag(prop_id); + } + + if(write) { + contact_prop_flags->clearDeviceFlag(prop_id); } } @@ -172,7 +214,7 @@ void PairsSimulation::communicateSizes(int dim, const int *send_sizes, int *recv auto nsend_id = getArrayByHostPointer(send_sizes).getId(); auto nrecv_id = getArrayByHostPointer(recv_sizes).getId(); - copyArrayToHost(nsend_id); + copyArrayToHost(nsend_id, false); array_flags->setHostFlag(nrecv_id); array_flags->clearDeviceFlag(nrecv_id); this->getDomainPartitioner()->communicateSizes(dim, send_sizes, recv_sizes); @@ -191,11 +233,11 @@ void PairsSimulation::communicateData( auto nsend_id = getArrayByHostPointer(nsend).getId(); auto nrecv_id = getArrayByHostPointer(nrecv).getId(); - copyArrayToHost(send_buf_id); - copyArrayToHost(send_offsets_id); - copyArrayToHost(recv_offsets_id); - copyArrayToHost(nsend_id); - copyArrayToHost(nrecv_id); + copyArrayToHost(send_buf_id, false); + copyArrayToHost(send_offsets_id, false); + copyArrayToHost(recv_offsets_id, false); + copyArrayToHost(nsend_id, false); + copyArrayToHost(nrecv_id, false); array_flags->setHostFlag(recv_buf_id); array_flags->clearDeviceFlag(recv_buf_id); this->getDomainPartitioner()->communicateData(dim, elem_size, send_buf, send_offsets, nsend, recv_buf, recv_offsets, nrecv); diff --git a/runtime/pairs.hpp b/runtime/pairs.hpp index aa6be15..ec26c6a 100644 --- a/runtime/pairs.hpp +++ b/runtime/pairs.hpp @@ -122,45 +122,78 @@ public: void setArrayDeviceFlag(Array &array) { array_flags->setDeviceFlag(array.getId()); } void clearArrayDeviceFlag(array_t id) { clearArrayDeviceFlag(getArray(id)); } void clearArrayDeviceFlag(Array &array) { array_flags->clearDeviceFlag(array.getId()); } - void copyArrayToDevice(array_t id) { copyArrayToDevice(getArray(id)); } - void copyArrayToDevice(Array &array); + void copyArrayToDevice(array_t id, bool write) { copyArrayToDevice(getArray(id), write); } + void copyArrayToDevice(Array &array, bool write); void setArrayHostFlag(array_t id) { setArrayHostFlag(getArray(id)); } void setArrayHostFlag(Array &array) { array_flags->setHostFlag(array.getId()); } void clearArrayHostFlag(array_t id) { clearArrayHostFlag(getArray(id)); } void clearArrayHostFlag(Array &array) { array_flags->clearHostFlag(array.getId()); } - void copyArrayToHost(array_t id) { copyArrayToHost(getArray(id)); } - void copyArrayToHost(Array &array); + void copyArrayToHost(array_t id, bool write) { copyArrayToHost(getArray(id), write); } + void copyArrayToHost(Array &array, bool write); void setPropertyDeviceFlag(property_t id) { setPropertyDeviceFlag(getProperty(id)); } void setPropertyDeviceFlag(Property &prop) { prop_flags->setDeviceFlag(prop.getId()); } void clearPropertyDeviceFlag(property_t id) { clearPropertyDeviceFlag(getProperty(id)); } void clearPropertyDeviceFlag(Property &prop) { prop_flags->clearDeviceFlag(prop.getId()); } - void copyPropertyToDevice(property_t id) { copyPropertyToDevice(getProperty(id)); } - void copyPropertyToDevice(Property &prop); + void copyPropertyToDevice(property_t id, bool write) { copyPropertyToDevice(getProperty(id), write); } + void copyPropertyToDevice(Property &prop, bool write); void setPropertyHostFlag(property_t id) { setPropertyHostFlag(getProperty(id)); } void setPropertyHostFlag(Property &prop) { prop_flags->setHostFlag(prop.getId()); } void clearPropertyHostFlag(property_t id) { clearPropertyHostFlag(getProperty(id)); } void clearPropertyHostFlag(Property &prop) { prop_flags->clearHostFlag(prop.getId()); } - void copyPropertyToHost(property_t id) { copyPropertyToHost(getProperty(id)); } - void copyPropertyToHost(Property &prop); - - void setContactPropertyDeviceFlag(property_t id) { setContactPropertyDeviceFlag(getContactProperty(id)); } - void setContactPropertyDeviceFlag(ContactProperty &prop) { contact_prop_flags->setDeviceFlag(prop.getId()); } - void clearContactPropertyDeviceFlag(property_t id) { clearContactPropertyDeviceFlag(getContactProperty(id)); } - void clearContactPropertyDeviceFlag(ContactProperty &prop) { contact_prop_flags->clearDeviceFlag(prop.getId()); } - void copyContactPropertyToDevice(property_t id) { copyContactPropertyToDevice(getContactProperty(id)); } - void copyContactPropertyToDevice(ContactProperty &prop); - - void setContactPropertyHostFlag(property_t id) { setContactPropertyHostFlag(getContactProperty(id)); } - void setContactPropertyHostFlag(ContactProperty &prop) { contact_prop_flags->setHostFlag(prop.getId()); } - void clearContactPropertyHostFlag(property_t id) { clearContactPropertyHostFlag(getContactProperty(id)); } - void clearContactPropertyHostFlag(ContactProperty &prop) { contact_prop_flags->clearHostFlag(prop.getId()); } - void copyContactPropertyToHost(property_t id) { copyContactPropertyToHost(getContactProperty(id)); } - void copyContactPropertyToHost(ContactProperty &prop); - - void copyFeaturePropertyToDevice(property_t id) { copyFeaturePropertyToDevice(getFeatureProperty(id)); } + void copyPropertyToHost(property_t id, bool write) { copyPropertyToHost(getProperty(id), write); } + void copyPropertyToHost(Property &prop, bool write); + + void setContactPropertyDeviceFlag(property_t id) { + setContactPropertyDeviceFlag(getContactProperty(id)); + } + + void setContactPropertyDeviceFlag(ContactProperty &prop) { + contact_prop_flags->setDeviceFlag(prop.getId()); + } + + void clearContactPropertyDeviceFlag(property_t id) { + clearContactPropertyDeviceFlag(getContactProperty(id)); + } + + void clearContactPropertyDeviceFlag(ContactProperty &prop) { + contact_prop_flags->clearDeviceFlag(prop.getId()); + } + + void copyContactPropertyToDevice(property_t id, bool write) { + copyContactPropertyToDevice(getContactProperty(id), write); + } + + void copyContactPropertyToDevice(ContactProperty &prop, bool write); + + void setContactPropertyHostFlag(property_t id) { + setContactPropertyHostFlag(getContactProperty(id)); + } + + void setContactPropertyHostFlag(ContactProperty &prop) { + contact_prop_flags->setHostFlag(prop.getId()); + } + + void clearContactPropertyHostFlag(property_t id) { + clearContactPropertyHostFlag(getContactProperty(id)); + } + + void clearContactPropertyHostFlag(ContactProperty &prop) { + contact_prop_flags->clearHostFlag(prop.getId()); + } + + void copyContactPropertyToHost(property_t id, bool write) { + copyContactPropertyToHost(getContactProperty(id), write); + } + + void copyContactPropertyToHost(ContactProperty &prop, bool write); + + void copyFeaturePropertyToDevice(property_t id) { + copyFeaturePropertyToDevice(getFeatureProperty(id)); + } + void copyFeaturePropertyToDevice(FeatureProperty &feature_prop); void communicateSizes(int dim, const int *send_sizes, int *recv_sizes); @@ -240,7 +273,7 @@ void PairsSimulation::reallocArray(array_t id, T_ptr **h_ptr, T_ptr **d_ptr, siz *h_ptr = (T_ptr *) new_h_ptr; *d_ptr = (T_ptr *) new_d_ptr; if(array_flags->isDeviceFlagSet(id)) { - copyArrayToDevice(id); + copyArrayToDevice(id, false); } } @@ -310,7 +343,7 @@ void PairsSimulation::reallocProperty(property_t id, T_ptr **h_ptr, T_ptr **d_pt *h_ptr = (T_ptr *) new_h_ptr; *d_ptr = (T_ptr *) new_d_ptr; if(prop_flags->isDeviceFlagSet(id)) { - copyPropertyToDevice(id); + copyPropertyToDevice(id, false); } } @@ -380,7 +413,7 @@ void PairsSimulation::reallocContactProperty(property_t id, T_ptr **h_ptr, T_ptr *h_ptr = (T_ptr *) new_h_ptr; *d_ptr = (T_ptr *) new_d_ptr; if(contact_prop_flags->isDeviceFlagSet(id)) { - copyContactPropertyToDevice(id); + copyContactPropertyToDevice(id, false); } } diff --git a/runtime/thermo.hpp b/runtime/thermo.hpp index 8f00fd2..993e9c8 100644 --- a/runtime/thermo.hpp +++ b/runtime/thermo.hpp @@ -26,8 +26,8 @@ double compute_thermo(PairsSimulation *ps, int nlocal, double xprd, double yprd, //const double e_scale = 0.5; double t = 0.0, p; - ps->copyPropertyToHost(masses); - ps->copyPropertyToHost(velocities); + ps->copyPropertyToHost(masses, false); + ps->copyPropertyToHost(velocities, false); for(int i = 0; i < nlocal; i++) { t += masses(i) * ( velocities(i, 0) * velocities(i, 0) + diff --git a/runtime/vtk.hpp b/runtime/vtk.hpp index 3e69500..527f074 100644 --- a/runtime/vtk.hpp +++ b/runtime/vtk.hpp @@ -29,8 +29,8 @@ void vtk_write_data(PairsSimulation *ps, const char *filename, int start, int en filename_oss << timestep << ".vtk"; std::ofstream out_file(filename_oss.str()); - ps->copyPropertyToHost(masses); - ps->copyPropertyToHost(positions); + ps->copyPropertyToHost(masses, false); + ps->copyPropertyToHost(positions, false); for(int i = start; i < end; i++) { if(flags(i) & FLAGS_INFINITE) { diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 154683a..07f3a17 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -48,10 +48,11 @@ class FetchKernelReferences(Visitor): self.writing = writing_state def visit_Assign(self, ast_node): + self.writing = False + self.visit(ast_node._src) self.writing = True self.visit(ast_node._dest) self.writing = False - self.visit(ast_node._src) def visit_AtomicAdd(self, ast_node): self.writing = True @@ -75,11 +76,25 @@ class FetchKernelReferences(Visitor): self.visit_children(ast_node) self.kernel_stack.pop() - ast_node.add_array_access([a for a in self.kernel_used_array_accesses[kernel_id] if a not in self.kernel_decls[kernel_id]]) - ast_node.add_scalar_op([b for b in self.kernel_used_scalar_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) - ast_node.add_vector_op([b for b in self.kernel_used_vector_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) - ast_node.add_matrix_op([b for b in self.kernel_used_matrix_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) - ast_node.add_quaternion_op([b for b in self.kernel_used_quat_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) + ast_node.add_array_access( + [a for a in self.kernel_used_array_accesses[kernel_id] \ + if a not in self.kernel_decls[kernel_id]]) + + ast_node.add_scalar_op( + [b for b in self.kernel_used_scalar_ops[kernel_id] \ + if b not in self.kernel_decls[kernel_id] and not b.in_place]) + + ast_node.add_vector_op( + [b for b in self.kernel_used_vector_ops[kernel_id] \ + if b not in self.kernel_decls[kernel_id] and not b.in_place]) + + ast_node.add_matrix_op( + [b for b in self.kernel_used_matrix_ops[kernel_id] \ + if b not in self.kernel_decls[kernel_id] and not b.in_place]) + + ast_node.add_quaternion_op( + [b for b in self.kernel_used_quat_ops[kernel_id] \ + if b not in self.kernel_decls[kernel_id] and not b.in_place]) def visit_PropertyAccess(self, ast_node): # Visit property and save current writing state diff --git a/src/pairs/analysis/modules.py b/src/pairs/analysis/modules.py index 4a8ebad..5b97619 100644 --- a/src/pairs/analysis/modules.py +++ b/src/pairs/analysis/modules.py @@ -18,10 +18,11 @@ class FetchModulesReferences(Visitor): self.writing = writing_state def visit_Assign(self, ast_node): + self.writing = False + self.visit(ast_node._src) self.writing = True self.visit(ast_node._dest) self.writing = False - self.visit(ast_node._src) def visit_AtomicAdd(self, ast_node): self.writing = True diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index dba78f3..cf16325 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -8,7 +8,7 @@ from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts from pairs.ir.declaration import Decl from pairs.ir.scalars import ScalarOp -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef from pairs.ir.features import FeatureProperty, FeaturePropertyAccess, RegisterFeatureProperty from pairs.ir.functions import Call from pairs.ir.kernel import KernelLaunch @@ -456,64 +456,28 @@ class CGen: if isinstance(ast_node, CopyArray): array_id = ast_node.array.id() array_name = ast_node.array.name() + ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host" + write = "true" if ast_node.write else "false" + self.print(f"pairs->copyArrayTo{ctx_suffix}({array_id}, {write}); // {array_name}") - if ast_node.context() == Contexts.Device: - self.print(f"pairs->copyArrayToDevice({array_id}); // {array_name}") - else: - self.print(f"pairs->copyArrayToHost({array_id}); // {array_name}") + if isinstance(ast_node, CopyContactProperty): + prop_id = ast_node.contact_prop.id() + prop_name = ast_node.contact_prop.name() + write = "true" if ast_node.write else "false" + ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host" + self.print(f"pairs->copyContactPropertyTo{ctx_suffix}({prop_id}, {write}); // {prop_name}") if isinstance(ast_node, CopyProperty): prop_id = ast_node.prop.id() prop_name = ast_node.prop.name() - - if ast_node.context() == Contexts.Device: - self.print(f"pairs->copyPropertyToDevice({prop_id}); // {prop_name}") - else: - self.print(f"pairs->copyPropertyToHost({prop_id}); // {prop_name}") + write = "true" if ast_node.write else "false" + ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host" + self.print(f"pairs->copyPropertyTo{ctx_suffix}({prop_id}, {write}); // {prop_name}") if isinstance(ast_node, CopyVar): var_name = ast_node.variable.name() - - if ast_node.context() == Contexts.Device: - self.print(f"rv_{var_name}.copyToDevice();") - else: - self.print(f"rv_{var_name}.copyToHost();") - - if isinstance(ast_node, ClearArrayFlag): - array_id = ast_node.array.id() - array_name = ast_node.array.name() - - if ast_node.context() == Contexts.Device: - self.print(f"pairs->clearArrayDeviceFlag({array_id}); // {array_name}") - else: - self.print(f"pairs->clearArrayHostFlag({array_id}); // {array_name}") - - if isinstance(ast_node, ClearPropertyFlag): - prop_id = ast_node.prop.id() - prop_name = ast_node.prop.name() - - if ast_node.context() == Contexts.Device: - self.print(f"pairs->clearPropertyDeviceFlag({prop_id}); // {prop_name}") - else: - self.print(f"pairs->clearPropertyHostFlag({prop_id}); // {prop_name}") - - if isinstance(ast_node, SetArrayFlag): - array_id = ast_node.array.id() - array_name = ast_node.array.name() - - if ast_node.context() == Contexts.Device: - self.print(f"pairs->setArrayDeviceFlag({array_id}); // {array_name}") - else: - self.print(f"pairs->setArrayHostFlag({array_id}); // {array_name}") - - if isinstance(ast_node, SetPropertyFlag): - prop_id = ast_node.prop.id() - prop_name = ast_node.prop.name() - - if ast_node.context() == Contexts.Device: - self.print(f"pairs->setPropertyDeviceFlag({prop_id}); // {prop_name}") - else: - self.print(f"pairs->setPropertyHostFlag({prop_id}); // {prop_name}") + ctx_suffix = "Device" if ast_node.context() == Contexts.Device else "Host" + self.print(f"rv_{var_name}.copyTo{ctx_suffix}();") if isinstance(ast_node, For): iterator = self.generate_expression(ast_node.iterator) diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index cef8e0f..96af505 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -30,10 +30,11 @@ class DeviceStaticRef(ASTNode): class CopyArray(ASTNode): - def __init__(self, sim, array, ctx): + def __init__(self, sim, array, ctx, write): super().__init__(sim) self.array = array self.ctx = ctx + self.write = write self.sim.add_statement(self) def context(self): @@ -44,10 +45,11 @@ class CopyArray(ASTNode): class CopyProperty(ASTNode): - def __init__(self, sim, prop, ctx): + def __init__(self, sim, prop, ctx, write): super().__init__(sim) self.prop = prop self.ctx = ctx + self.write = write self.sim.add_statement(self) def context(self): @@ -57,39 +59,12 @@ class CopyProperty(ASTNode): return [self.prop] -class CopyVar(ASTNode): - def __init__(self, sim, variable, ctx): +class CopyContactProperty(ASTNode): + def __init__(self, sim, prop, ctx, write): super().__init__(sim) - self.variable = variable - self.ctx = ctx - self.sim.add_statement(self) - - def context(self): - return self.ctx - - def children(self): - return [self.variable] - - -class ClearArrayFlag(ASTNode): - def __init__(self, sim, array, ctx): - super().__init__(sim) - self.array = array - self.ctx = ctx - self.sim.add_statement(self) - - def context(self): - return self.ctx - - def children(self): - return [self.array] - - -class ClearPropertyFlag(ASTNode): - def __init__(self, sim, prop, ctx): - super().__init__(sim) - self.prop = prop + self.contact_prop = prop self.ctx = ctx + self.write = write self.sim.add_statement(self) def context(self): @@ -99,24 +74,10 @@ class ClearPropertyFlag(ASTNode): return [self.prop] -class SetArrayFlag(ASTNode): - def __init__(self, sim, array, ctx): - super().__init__(sim) - self.array = array - self.ctx = ctx - self.sim.add_statement(self) - - def context(self): - return self.ctx - - def children(self): - return [self.array] - - -class SetPropertyFlag(ASTNode): - def __init__(self, sim, prop, ctx): +class CopyVar(ASTNode): + def __init__(self, sim, variable, ctx): super().__init__(sim) - self.prop = prop + self.variable = variable self.ctx = ctx self.sim.add_statement(self) @@ -124,4 +85,4 @@ class SetPropertyFlag(ASTNode): return self.ctx def children(self): - return [self.prop] + return [self.variable] diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index a29773d..12ed846 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -85,66 +85,99 @@ class Kernel(ASTNode): def add_array(self, array, write=False): array_list = array if isinstance(array, list) else [array] character = 'w' if write else 'r' + for a in array_list: - assert isinstance(a, Array), "Kernel.add_array(): Element is not of type Array." - self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character + assert isinstance(a, Array), \ + "Kernel.add_array(): Element is not of type Array." + + self._arrays[a] = character if a not in self._arrays else \ + self._arrays[a] + character def add_variable(self, variable, write=False): variable_list = variable if isinstance(variable, list) else [variable] character = 'w' if write else 'r' + for v in variable_list: if not v.temporary(): - assert isinstance(v, Var), "Kernel.add_variable(): Element is not of type Var." - self._variables[v] = character if v not in self._variables else self._variables[v] + character + assert isinstance(v, Var), \ + "Kernel.add_variable(): Element is not of type Var." + + self._variables[v] = character if v not in self._variables else \ + self._variables[v] + character def add_property(self, prop, write=False): prop_list = prop if isinstance(prop, list) else [prop] character = 'w' if write else 'r' + for p in prop_list: - assert isinstance(p, Property), "Kernel.add_property(): Element is not of type Property." - self._properties[p] = character if p not in self._properties else self._properties[p] + character + assert isinstance(p, Property), \ + "Kernel.add_property(): Element is not of type Property." + + self._properties[p] = character if p not in self._properties else \ + self._properties[p] + character def add_contact_property(self, contact_prop, write=False): contact_prop_list = contact_prop if isinstance(contact_prop, list) else [contact_prop] character = 'w' if write else 'r' + for cp in contact_prop_list: - assert isinstance(cp, ContactProperty), "Kernel.add_contact_property(): Element is not of type ContactProperty." - self._contact_properties[cp] = character if cp not in self._contact_properties else self._contact_properties[cp] + character + assert isinstance(cp, ContactProperty), \ + "Kernel.add_contact_property(): Element is not of type ContactProperty." + + self._contact_properties[cp] = character if cp not in self._contact_properties else \ + self._contact_properties[cp] + character def add_feature_property(self, feature_prop): feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop] + for fp in feature_prop_list: - assert isinstance(fp, FeatureProperty), "Kernel.add_feature_property(): Element is not of type FeatureProperty." + assert isinstance(fp, FeatureProperty), \ + "Kernel.add_feature_property(): Element is not of type FeatureProperty." + self._feature_properties[fp] = 'r' def add_array_access(self, array_access): array_access_list = array_access if isinstance(array_access, list) else [array_access] for a in array_access_list: - assert isinstance(a, ArrayAccess), "Kernel.add_array_access(): Element is not of type ArrayAccess." + assert isinstance(a, ArrayAccess), \ + "Kernel.add_array_access(): Element is not of type ArrayAccess." + self._array_accesses.add(a) def add_scalar_op(self, scalar_op): scalar_op_list = scalar_op if isinstance(scalar_op, list) else [scalar_op] + for b in scalar_op_list: - assert isinstance(b, ScalarOp), "Kernel.add_scalar_op(): Element is not of type ScalarOp." + assert isinstance(b, ScalarOp), \ + "Kernel.add_scalar_op(): Element is not of type ScalarOp." + self._scalar_ops.append(b) def add_vector_op(self, vector_op): vector_op_list = vector_op if isinstance(vector_op, list) else [vector_op] + for b in vector_op_list: - assert isinstance(b, VectorOp), "Kernel.add_vector_op(): Element is not of type VectorOp." + assert isinstance(b, VectorOp), \ + "Kernel.add_vector_op(): Element is not of type VectorOp." + self._vector_ops.append(b) def add_matrix_op(self, matrix_op): matrix_op_list = matrix_op if isinstance(matrix_op, list) else [matrix_op] + for b in matrix_op_list: - assert isinstance(b, MatrixOp), "Kernel.add_matrix_op(): Element is not of type MatrixOp." + assert isinstance(b, MatrixOp), \ + "Kernel.add_matrix_op(): Element is not of type MatrixOp." + self._matrix_ops.append(b) def add_quaternion_op(self, quat_op): quat_op_list = quat_op if isinstance(quat_op, list) else [quat_op] + for b in quat_op_list: - assert isinstance(b, QuaternionOp), "Kernel.add_quaternion_op(): Element is not of type QuaternionOp." + assert isinstance(b, QuaternionOp), \ + "Kernel.add_quaternion_op(): Element is not of type QuaternionOp." + self._quat_ops.append(b) def children(self): diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 71bd663..abbc401 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -86,6 +86,13 @@ class Module(ASTNode): def write_properties(self): return {p for p in self._properties if 'w' in self._properties[p]} + def contact_properties_to_synchronize(self): + #return {cp for cp in self._contact_properties if self._contact_properties[cp][0] == 'r'} + return {cp for cp in self._contact_properties} + + def write_contact_properties(self): + return {cp for cp in self._contact_properties if 'w' in self._contact_properties[cp]} + def arrays_to_synchronize(self): #return {a for a in self._arrays if a.sync() and self._arrays[a][0] == 'r'} return {a for a in self._arrays if a.sync()} @@ -96,35 +103,54 @@ class Module(ASTNode): def add_array(self, array, write=False): array_list = array if isinstance(array, list) else [array] character = 'w' if write else 'r' + for a in array_list: - assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!" - self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character + assert isinstance(a, Array), \ + "Module.add_array(): given element is not of type Array!" + + self._arrays[a] = character if a not in self._arrays else \ + self._arrays[a] + character def add_variable(self, variable, write=False): variable_list = variable if isinstance(variable, list) else [variable] character = 'w' if write else 'r' + for v in variable_list: - assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!" - self._variables[v] = character if v not in self._variables else self._variables[v] + character + assert isinstance(v, Var), \ + "Module.add_variable(): given element is not of type Var!" + + self._variables[v] = character if v not in self._variables else \ + self._variables[v] + character def add_property(self, prop, write=False): prop_list = prop if isinstance(prop, list) else [prop] character = 'w' if write else 'r' + for p in prop_list: - assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!" - self._properties[p] = character if p not in self._properties else self._properties[p] + character + assert isinstance(p, Property), \ + "Module.add_property(): given element is not of type Property!" + + self._properties[p] = character if p not in self._properties else \ + self._properties[p] + character def add_contact_property(self, contact_prop, write=False): contact_prop_list = contact_prop if isinstance(contact_prop, list) else [contact_prop] character = 'w' if write else 'r' + for cp in contact_prop_list: - assert isinstance(cp, ContactProperty), "Module.add_contact_property(): given element is not of type ContactProperty!" - self._contact_properties[cp] = character if cp not in self._contact_properties else self._contact_properties[cp] + character + assert isinstance(cp, ContactProperty), \ + "Module.add_contact_property(): given element is not of type ContactProperty!" + + self._contact_properties[cp] = character if cp not in self._contact_properties else \ + self._contact_properties[cp] + character def add_feature_property(self, feature_prop): feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop] + for fp in feature_prop_list: - assert isinstance(fp, FeatureProperty), "Module.add_feature_property(): given element is not of type FeatureProperty!" + assert isinstance(fp, FeatureProperty), \ + "Module.add_feature_property(): given element is not of type FeatureProperty!" + self._feature_properties[fp] = 'r' def add_host_reference(self, elem): diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py index d3fa32d..7ea490b 100644 --- a/src/pairs/sim/arrays.py +++ b/src/pairs/sim/arrays.py @@ -1,6 +1,5 @@ from pairs.ir.block import pairs_inline from pairs.ir.contexts import Contexts -from pairs.ir.device import ClearArrayFlag from pairs.ir.memory import Malloc from pairs.ir.arrays import DeclareStaticArray, RegisterArray from pairs.sim.lowerable import FinalLowerable @@ -17,7 +16,3 @@ class DeclareArrays(FinalLowerable): DeclareStaticArray(self.sim, a) RegisterArray(self.sim, a, a.alloc_size()) - - if not a.sync(): - ClearArrayFlag(self.sim, self.sim.resizes, Contexts.Host) - ClearArrayFlag(self.sim, self.sim.resizes, Contexts.Device) diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index c13c349..5152eef 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -17,10 +17,10 @@ class Comm: self.sim = sim self.dom_part = dom_part self.nsend_all = sim.add_var('nsend_all', Types.Int32) - self.send_capacity = sim.add_var('send_capacity', Types.Int32, 100000) - self.recv_capacity = sim.add_var('recv_capacity', Types.Int32, 100000) - self.elem_capacity = sim.add_var('elem_capacity', Types.Int32, 10) - self.neigh_capacity = sim.add_var('neigh_capacity', Types.Int32, 6) + self.send_capacity = sim.add_var('send_capacity', Types.Int32, 200000) + self.recv_capacity = sim.add_var('recv_capacity', Types.Int32, 200000) + self.elem_capacity = sim.add_var('elem_capacity', Types.Int32, 40) + self.neigh_capacity = sim.add_var('neigh_capacity', Types.Int32, 10) self.nsend = sim.add_array('nsend', [self.neigh_capacity], Types.Int32) self.send_offsets = sim.add_array('send_offsets', [self.neigh_capacity], Types.Int32) self.send_buffer = sim.add_array('send_buffer', [self.send_capacity, self.elem_capacity], Types.Real) diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index c2d02fc..45764a5 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -4,7 +4,7 @@ from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.lit import Lit from pairs.ir.loops import For @@ -34,34 +34,34 @@ class AddDeviceCopies(Mutator): copy_context = Contexts.Device if s.module.run_on_device else Contexts.Host clear_context = Contexts.Host if s.module.run_on_device else Contexts.Device - for a in s.module.arrays_to_synchronize(): - new_stmts += [CopyArray(s.sim, a, copy_context)] + for array in s.module.arrays_to_synchronize(): + write = array in s.module.write_arrays() + new_stmts += [CopyArray(s.sim, array, copy_context, write)] - for p in s.module.properties_to_synchronize(): - new_stmts += [CopyProperty(s.sim, p, copy_context)] + for prop in s.module.properties_to_synchronize(): + write = prop in s.module.write_properties() + new_stmts += [CopyProperty(s.sim, prop, copy_context, write)] - for a in s.module.write_arrays(): - new_stmts += [SetArrayFlag(s.sim, a, copy_context), ClearArrayFlag(s.sim, a, clear_context)] - - for p in s.module.write_properties(): - new_stmts += [SetPropertyFlag(s.sim, p, copy_context), ClearPropertyFlag(s.sim, p, clear_context)] + for contact_prop in s.module.contact_properties_to_synchronize(): + write = prop in s.module.write_contact_properties() + new_stmts += [CopyContactProperty(s.sim, contact_prop, copy_context, write)] if self.module_resizes[s.module] and s.module.run_on_device: - new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device)] + new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Device, False)] if s.module.run_on_device: - for v in s.module.variables_to_synchronize(): - new_stmts += [CopyVar(s.sim, v, Contexts.Device)] + for var in s.module.variables_to_synchronize(): + new_stmts += [CopyVar(s.sim, var, Contexts.Device)] new_stmts.append(s) if isinstance(s, ModuleCall): if s.module.run_on_device: - for v in s.module.variables_to_synchronize(): - new_stmts += [CopyVar(s.sim, v, Contexts.Host)] + for var in s.module.variables_to_synchronize(): + new_stmts += [CopyVar(s.sim, var, Contexts.Host)] if self.module_resizes[s.module]: - new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host)] + new_stmts += [CopyArray(s.sim, s.sim.resizes, Contexts.Host, False)] ast_node.stmts = new_stmts return ast_node @@ -83,7 +83,10 @@ class AddDeviceKernels(Mutator): kernel_name = f"{ast_node.name}_kernel{kernel_id}" kernel = ast_node.sim.find_kernel_by_name(kernel_name) if kernel is None: - kernel_body = Filter(ast_node.sim, ScalarOp.inline(s.iterator < s.max.copy(True)), s.block) + kernel_body = Filter(ast_node.sim, + ScalarOp.inline(s.iterator < s.max.copy(True)), + s.block) + kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator) kernel_id += 1 -- GitLab