From ae595f96046f17685b915b57e94522c713c258bb Mon Sep 17 00:00:00 2001 From: Behzad Safaei <iwia103h@alex2.nhr.fau.de> Date: Wed, 29 Jan 2025 01:14:54 +0100 Subject: [PATCH] Fix For to Kernel transformation, fix whole-program gen --- examples/dem_sd.py | 1 + src/pairs/analysis/devices.py | 21 +++--- src/pairs/code_gen/cgen.py | 102 +++++++++++++++------------ src/pairs/mapping/funcs.py | 4 ++ src/pairs/transformations/devices.py | 6 +- 5 files changed, 79 insertions(+), 55 deletions(-) diff --git a/examples/dem_sd.py b/examples/dem_sd.py index 15ee699..94bc93f 100644 --- a/examples/dem_sd.py +++ b/examples/dem_sd.py @@ -181,4 +181,5 @@ psim.compute(linear_spring_dashpot, 'collisionTime_SI': collisionTime_SI}) psim.compute(euler, parameters={'dt' : pairs.real()}) +# psim.compute(euler, symbols={'dt' : dt_SI}) psim.generate() diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 1c009af..29e554e 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -15,18 +15,21 @@ class MarkCandidateLoops(Visitor): self.device_module = False def visit_For(self, ast_node): - if self.device_module: - if ast_node.not_kernel: - self.visit(ast_node.block) - ast_node.mark_iter_as_ref_candidate() - else: - if not isinstance(ast_node.min, Lit) or not isinstance(ast_node.max, Lit): - ast_node.mark_as_kernel_candidate() + if self.device_module and not ast_node.not_kernel and (not isinstance(ast_node.min, Lit) or not isinstance(ast_node.max, Lit)): + ast_node.mark_as_kernel_candidate() + else: + ast_node.mark_iter_as_ref_candidate() + self.visit(ast_node.block) + def visit_Module(self, ast_node): - self.device_module = ast_node.run_on_device - self.visit_children(ast_node) + parent_runs_on_device = self.device_module + if ast_node.run_on_device: + self.device_module = True + self.visit_children(ast_node) + self.device_module = parent_runs_on_device + class FetchKernelReferences(Visitor): def __init__(self, ast=None): diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 7649f4f..6795e2d 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -372,11 +372,22 @@ class CGen: self.generate_pairs_object_structure() self.generate_module_headers() + self.print("namespace pairs::internal {") + self.print.add_indent(4) + for kernel in self.sim.kernels(): self.generate_kernel(kernel) for module in self.sim.modules(): - self.generate_module(module) + if module.name!='main': + self.generate_module(module) + + self.print.add_indent(-4) + self.print("}") + + for module in self.sim.modules(): + if module.name=='main': + self.generate_main(module) self.print.end() @@ -589,57 +600,58 @@ class CGen: if feature_prop in module.host_references(): self.print(f"{type_kw} *h_{feature_prop.name()} = pobj->{feature_prop.name()};") - def generate_module(self, module): - if module.name == 'main': - ndims = module.sim.ndims() - nprops = module.sim.properties.nprops() - ncontactprops = module.sim.contact_properties.nprops() - narrays = module.sim.arrays.narrays() - part = DomainPartitioners.c_keyword(module.sim.partitioner()) - - self.generate_full_object_names = True - self.print("int main(int argc, char **argv) {") - self.print(f" PairsRuntime *pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});") - self.print(f" struct PairsObjects *pobj = new PairsObjects();") - - if module.sim._enable_profiler: - self.print(" LIKWID_MARKER_INIT;") - - self.generate_statement(module.block) - - if module.sim._enable_profiler: - self.print(" LIKWID_MARKER_CLOSE;") - - self.print(" pairs::print_timers(pairs_runtime);") - self.print(" pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);") - self.print(" delete pobj;") - self.print(" delete pairs_runtime;") - self.print(" return 0;") - self.print("}") - self.generate_full_object_names = False + def generate_main(self, module): + assert module.name=='main' + + ndims = module.sim.ndims() + nprops = module.sim.properties.nprops() + ncontactprops = module.sim.contact_properties.nprops() + narrays = module.sim.arrays.narrays() + part = DomainPartitioners.c_keyword(module.sim.partitioner()) + + self.generate_full_object_names = True + self.print("int main(int argc, char **argv) {") + self.print(f" PairsRuntime *pairs_runtime = new PairsRuntime({nprops}, {ncontactprops}, {narrays}, {part});") + self.print(f" struct PairsObjects *pobj = new PairsObjects();") + + if module.sim._enable_profiler: + self.print(" LIKWID_MARKER_INIT;") + + self.generate_statement(module.block) + if module.sim._enable_profiler: + self.print(" LIKWID_MARKER_CLOSE;") + + self.print(" pairs::print_timers(pairs_runtime);") + self.print(" pairs::print_stats(pairs_runtime, pobj->nlocal, pobj->nghost);") + self.print(" delete pobj;") + self.print(" delete pairs_runtime;") + self.print(" return 0;") + self.print("}") + self.generate_full_object_names = False + + def generate_module(self, module): + module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}" + for param in module.parameters()) + if not module.user_defined: + module_params = ", " + module_params if module_params else "" + self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{") else: - module_params = ", ".join(f"{Types.c_keyword(self.sim, param.type())} {param.name()}" - for param in module.parameters()) - if not module.user_defined: - module_params = ", " + module_params if module_params else "" - self.print(f"void {module.name}(PairsRuntime *pairs_runtime, struct PairsObjects *pobj{module_params}) {{") - else: - - self.print(f"void {module.name}({module_params}) {{") + + self.print(f"void {module.name}({module_params}) {{") - self.print.add_indent(4) + self.print.add_indent(4) - if self.debug: - self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");") + if self.debug: + self.print(f"PAIRS_DEBUG(\"\\n{module.name}\\n\");") - if not module.user_defined: - self.generate_module_declerations(module) + if not module.user_defined: + self.generate_module_declerations(module) - self.print.add_indent(-4) - self.generate_statement(module.block) - self.print("}") + self.print.add_indent(-4) + self.generate_statement(module.block) + self.print("}") def generate_kernel(self, kernel): kernel_params = "int range_start" diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py index efcb825..6258aa1 100644 --- a/src/pairs/mapping/funcs.py +++ b/src/pairs/mapping/funcs.py @@ -286,6 +286,10 @@ class BuildParticleIR(ast.NodeVisitor): def compute(sim, func, cutoff_radius=None, symbols={}, parameters={}, pre_step=False, skip_first=False): + if sim._generate_whole_program: + assert not parameters, "Compute functions can't take custom parameters when generating whole program." + + src = inspect.getsource(func) tree = ast.parse(src, mode='exec') #print(ast.dump(ast.parse(src, mode='exec'))) diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index dceb1e3..21898c1 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -88,6 +88,7 @@ class AddDeviceKernels(Mutator): super().__init__(ast) self._module_name = None self._kernel_id = 0 + self._device_module = False def create_kernel(self, sim, iterator, rmax, block): kernel_name = f"{self._module_name}_kernel{self._kernel_id}" @@ -101,7 +102,7 @@ class AddDeviceKernels(Mutator): return kernel def mutate_For(self, ast_node): - if ast_node.is_kernel_candidate(): + if ast_node.is_kernel_candidate() and self._device_module: kernel = self.create_kernel(ast_node.sim, ast_node.iterator, ast_node.max, ast_node.block) ast_node = KernelLaunch(ast_node.sim, kernel, ast_node.iterator, ast_node.min, ast_node.max) @@ -111,11 +112,14 @@ class AddDeviceKernels(Mutator): return ast_node def mutate_Module(self, ast_node): + parent_runs_on_device = self._device_module if ast_node.run_on_device: + self._device_module = True self._module_name = ast_node.name self._kernel_id = 0 ast_node._block = self.mutate(ast_node._block) + self._device_module = parent_runs_on_device return ast_node class AddHostReferencesToModules(Mutator): -- GitLab