diff --git a/runtime/domain/block_forest.hpp b/runtime/domain/block_forest.hpp index 1c7bf537901f0681b2ddab336f39dafda1abd61a..d5d9d5467faef6aba3736d97a3fceb5149040d3b 100644 --- a/runtime/domain/block_forest.hpp +++ b/runtime/domain/block_forest.hpp @@ -42,7 +42,7 @@ private: real_t *subdom; const bool globalPBC[3]; int world_size, rank, nranks, total_aabbs; - bool balance_workload; + bool balance_workload = false; public: BlockForest( diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 9fdddc2e53a8a50a6ad50338fff9857d84b30814..23a3beda1a391d0d987a5e588fbdd5bf3fca033a 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -18,6 +18,7 @@ class MarkCandidateLoops(Visitor): if self.device_module: if ast_node.not_kernel: self.visit(ast_node.block) + ast_node.mark_iter_as_ref_candidate() else: if not isinstance(ast_node.min, Lit) or not isinstance(ast_node.max, Lit): ast_node.mark_as_kernel_candidate() @@ -195,3 +196,8 @@ class FetchKernelReferences(Visitor): # Variables only have a device version when changed within kernels if self.writing: ast_node.device_flag = True + + def visit_Iter(self, ast_node): + for k in self.kernel_stack: + if ast_node.is_ref_candidate(): + k.add_iter(ast_node, self.writing) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index a7166be8709c99e30455318baad083d7d65e7876..c9e2c2ef57be2037d197834373a3d09420777520 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -626,6 +626,11 @@ class CGen: decl = f"{type_kw} *{var.name()}" kernel_params += f", {decl}" + for it in kernel.iters(): + type_kw = Types.c_keyword(self.sim, it.type()) + decl = f"{type_kw} {it.name()}" + kernel_params += f", {decl}" + for array in kernel.arrays(): if array.is_static(): continue @@ -982,6 +987,9 @@ class CGen: for var in kernel.write_variables(): kernel_params += f", {var.name()}" + for it in kernel.iters(): + kernel_params += f", {it.name()}" + for array in kernel.arrays(): if array.is_static(): continue diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index 04def29cf9153a8ec7f5048f79579bd062186dd0..5faaee406c9e140f734fbdf3d43e0904063daee7 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -9,6 +9,7 @@ from pairs.ir.properties import Property, ContactProperty from pairs.ir.quaternions import QuaternionOp from pairs.ir.variables import Var from pairs.ir.vectors import VectorOp +from pairs.ir.loops import Iter class Kernel(ASTNode): @@ -19,6 +20,7 @@ class Kernel(ASTNode): self._id = Kernel.last_kernel self._name = name if name is not None else "kernel" + str(Kernel.last_kernel) self._variables = {} + self._iters = {} self._arrays = {} self._properties = {} self._contact_properties = {} @@ -50,6 +52,9 @@ class Kernel(ASTNode): def variables(self): return self._variables + def iters(self): + return self._iters + def read_only_variables(self): return [var for var in self._variables if self._variables[var] == Actions.ReadOnly] @@ -99,6 +104,17 @@ class Kernel(ASTNode): action = Actions.NoAction if var not in self._variables else self._variables[var] self._variables[var] = Actions.update_rule(action, new_op) + + def add_iter(self, iter, write=False): + iter_list = iter if isinstance(iter, list) else [iter] + new_op = 'w' if write else 'r' + + for it in iter_list: + assert isinstance(it, Iter), \ + "Kernel.add_iter(): Element is not of type Iter." + + action = Actions.NoAction if it not in self._iters else self._iters[it] + self._iters[it] = Actions.update_rule(action, new_op) def add_property(self, prop, write=False): prop_list = prop if isinstance(prop, list) else [prop] diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 228bd40e9a69edd70ef36b33dacf82a8aba34fb0..8842818627c7da514ec9a24ae85c84a5f08cd747 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -18,6 +18,7 @@ class Iter(ASTTerm): super().__init__(sim, ScalarOp) self.loop = loop self.iter_id = Iter.new_id() + self._ref_candidate = False def id(self): return self.iter_id @@ -27,7 +28,16 @@ class Iter(ASTTerm): def type(self): return Types.Int32 - + + def mark_as_ref_candidate(self): + self._ref_candidate = True + + def is_ref_candidate(self): + return self._ref_candidate + + def __hash__(self): + return hash(self.iter_id) + def __eq__(self, other): return isinstance(other, Iter) and self.iter_id == other.iter_id @@ -64,6 +74,9 @@ class For(ASTNode): def mark_as_kernel_candidate(self): self._kernel_candidate = True + def mark_iter_as_ref_candidate(self): + self.iterator.mark_as_ref_candidate() + def is_kernel_candidate(self): return self._kernel_candidate