From 216461eb031becf6cd4fe8f75b57a1257a3abea2 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Fri, 18 Feb 2022 01:41:21 +0100 Subject: [PATCH] Adjust arrays and properties references according to module context Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- src/pairs/code_gen/cgen.py | 4 ++-- src/pairs/ir/block.py | 17 ++++++++++++++++- src/pairs/sim/arrays.py | 4 ++-- src/pairs/sim/cell_lists.py | 4 ++-- src/pairs/sim/interaction.py | 4 ++-- src/pairs/sim/lattice.py | 4 ++-- src/pairs/sim/pbc.py | 4 ++-- src/pairs/sim/properties.py | 4 ++-- src/pairs/sim/read_from_file.py | 4 ++-- src/pairs/sim/variables.py | 4 ++-- src/pairs/sim/vtk.py | 4 ++-- 11 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index e5bb5ae..7b05e17 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -236,11 +236,11 @@ class CGen: module_params += decl if len(module_params) <= 0 else f", {decl}" for array in module.arrays(): - decl = array.name() + decl = f"d_{array.name()}" if module.run_on_device else array.name() module_params += decl if len(module_params) <= 0 else f", {decl}" for prop in module.properties(): - decl = prop.name() + decl = f"d_{prop.name()}" if module.run_on_device else prop.name() module_params += decl if len(module_params) <= 0 else f", {decl}" self.print(f"{module.name}({module_params});") diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 2d65656..e0e6019 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -2,7 +2,7 @@ from pairs.ir.ast_node import ASTNode from pairs.ir.module import Module -def pairs_block(func): +def pairs_inline(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim sim.init_block() @@ -12,6 +12,21 @@ def pairs_block(func): return inner +def pairs_host_block(func): + def inner(*args, **kwargs): + sim = args[0].sim # self.sim + sim.init_block() + func(*args, **kwargs) + return Module(sim, + name=sim._module_name, + block=Block(sim, sim._block), + resizes_to_check=sim._resizes_to_check, + check_properties_resize=sim._check_properties_resize, + run_on_device=False) + + return inner + + def pairs_device_block(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py index 5adcaec..39468c2 100644 --- a/src/pairs/sim/arrays.py +++ b/src/pairs/sim/arrays.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.memory import Malloc from pairs.ir.arrays import ArrayDecl from pairs.sim.lowerable import Lowerable @@ -8,7 +8,7 @@ class ArraysDecl(Lowerable): def __init__(self, sim): super().__init__(sim) - @pairs_block + @pairs_inline def lower(self): for a in self.sim.arrays.all(): if a.is_static(): diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py index b5e5d18..01fd6ce 100644 --- a/src/pairs/sim/cell_lists.py +++ b/src/pairs/sim/cell_lists.py @@ -1,7 +1,7 @@ from functools import reduce import math from pairs.ir.bin_op import BinOp -from pairs.ir.block import pairs_device_block +from pairs.ir.block import pairs_device_block, pairs_host_block from pairs.ir.branches import Branch, Filter from pairs.ir.cast import Cast from pairs.ir.loops import For, ParticleFor @@ -36,7 +36,7 @@ class CellListsStencilBuild(Lowerable): super().__init__(sim) self.cell_lists = cell_lists - @pairs_device_block + @pairs_host_block def lower(self): sim = self.sim cl = self.cell_lists diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py index 123d701..f597b94 100644 --- a/src/pairs/sim/interaction.py +++ b/src/pairs/sim/interaction.py @@ -1,5 +1,5 @@ from pairs.ir.bin_op import BinOp -from pairs.ir.block import Block, pairs_block +from pairs.ir.block import Block, pairs_device_block from pairs.ir.branches import Branch, Filter from pairs.ir.loops import For, ParticleFor from pairs.ir.types import Types @@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable): yield self.i, self.j self.sim.leave() - @pairs_block + @pairs_device_block def lower(self): if self.nbody == 2: position = self.sim.position() diff --git a/src/pairs/sim/lattice.py b/src/pairs/sim/lattice.py index bc3e657..9ab1f9b 100644 --- a/src/pairs/sim/lattice.py +++ b/src/pairs/sim/lattice.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.loops import For from pairs.ir.types import Types from pairs.sim.lowerable import Lowerable @@ -12,7 +12,7 @@ class ParticleLattice(Lowerable): self.props = props self.positions = positions - @pairs_block + @pairs_inline def lower(self): index = None loop_indexes = [] diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py index f1b4571..bb73102 100644 --- a/src/pairs/sim/pbc.py +++ b/src/pairs/sim/pbc.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_device_block +from pairs.ir.block import pairs_device_block, pairs_host_block from pairs.ir.branches import Branch, Filter from pairs.ir.loops import For, ParticleFor from pairs.ir.utils import Print @@ -72,7 +72,7 @@ class SetupPBC(Lowerable): super().__init__(sim) self.pbc = pbc - @pairs_device_block + @pairs_host_block def lower(self): sim = self.sim ndims = sim.ndims() diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py index 1818511..7fce292 100644 --- a/src/pairs/sim/properties.py +++ b/src/pairs/sim/properties.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block, pairs_device_block +from pairs.ir.block import pairs_device_block, pairs_inline from pairs.ir.loops import ParticleFor from pairs.ir.memory import Malloc, Realloc from pairs.ir.properties import RegisterProperty, UpdateProperty @@ -14,7 +14,7 @@ class PropertiesAlloc(Lowerable): self.sim = sim self.realloc = realloc - @pairs_block + @pairs_inline def lower(self): capacity = sum(self.sim.properties.capacities) for p in self.sim.properties.all(): diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py index d20a485..7f37b5f 100644 --- a/src/pairs/sim/read_from_file.py +++ b/src/pairs/sim/read_from_file.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.functions import Call_Int from pairs.ir.properties import PropertyList from pairs.ir.types import Types @@ -14,7 +14,7 @@ class ReadFromFile(Lowerable): self.grid = MutableGrid(sim, sim.ndims()) self.grid_buffer = self.sim.add_static_array("grid_buffer", [self.sim.ndims() * 2], Types.Double) - @pairs_block + @pairs_inline def lower(self): self.sim.nlocal.set(Call_Int(self.sim, "pairs::read_particle_data", [self.filename, self.grid_buffer, self.props, self.props.length()])) diff --git a/src/pairs/sim/variables.py b/src/pairs/sim/variables.py index f7b512f..18847b3 100644 --- a/src/pairs/sim/variables.py +++ b/src/pairs/sim/variables.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.variables import VarDecl from pairs.sim.lowerable import Lowerable @@ -7,7 +7,7 @@ class VariablesDecl(Lowerable): def __init__(self, sim): super().__init__(sim) - @pairs_block + @pairs_inline def lower(self): for v in self.sim.vars.all(): VarDecl(self.sim, v) diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py index ad09013..dc19a8e 100644 --- a/src/pairs/sim/vtk.py +++ b/src/pairs/sim/vtk.py @@ -1,5 +1,5 @@ from pairs.ir.ast_node import ASTNode -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.functions import Call_Void from pairs.ir.lit import Lit from pairs.sim.lowerable import Lowerable @@ -11,7 +11,7 @@ class VTKWrite(Lowerable): self.filename = filename self.timestep = Lit.cvt(sim, timestep) - @pairs_block + @pairs_inline def lower(self): nlocal = self.sim.nlocal npbc = self.sim.pbc.npbc -- GitLab