diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index e5bb5aee2bb5f032f66b0377ae65b1bbf158994c..7b05e17aca1ec49d134aed4f2b78a615ec4846e8 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -236,11 +236,11 @@ class CGen: module_params += decl if len(module_params) <= 0 else f", {decl}" for array in module.arrays(): - decl = array.name() + decl = f"d_{array.name()}" if module.run_on_device else array.name() module_params += decl if len(module_params) <= 0 else f", {decl}" for prop in module.properties(): - decl = prop.name() + decl = f"d_{prop.name()}" if module.run_on_device else prop.name() module_params += decl if len(module_params) <= 0 else f", {decl}" self.print(f"{module.name}({module_params});") diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 2d656560778f99d1a3e1d627883e9f4496dcfc82..e0e60194b876cfd3f8b92ec6428e316c52f4cfd7 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -2,7 +2,7 @@ from pairs.ir.ast_node import ASTNode from pairs.ir.module import Module -def pairs_block(func): +def pairs_inline(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim sim.init_block() @@ -12,6 +12,21 @@ def pairs_block(func): return inner +def pairs_host_block(func): + def inner(*args, **kwargs): + sim = args[0].sim # self.sim + sim.init_block() + func(*args, **kwargs) + return Module(sim, + name=sim._module_name, + block=Block(sim, sim._block), + resizes_to_check=sim._resizes_to_check, + check_properties_resize=sim._check_properties_resize, + run_on_device=False) + + return inner + + def pairs_device_block(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim diff --git a/src/pairs/sim/arrays.py b/src/pairs/sim/arrays.py index 5adcaecccb9d2eb69ff8d24cf20a3e1ed2a7c4d3..39468c2e5aef993b27684d3a2cec329ba3de058b 100644 --- a/src/pairs/sim/arrays.py +++ b/src/pairs/sim/arrays.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.memory import Malloc from pairs.ir.arrays import ArrayDecl from pairs.sim.lowerable import Lowerable @@ -8,7 +8,7 @@ class ArraysDecl(Lowerable): def __init__(self, sim): super().__init__(sim) - @pairs_block + @pairs_inline def lower(self): for a in self.sim.arrays.all(): if a.is_static(): diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py index b5e5d180667261d7a389e30491c9055646e73ca2..01fd6cea091345ebd6e6a51497609fb505b0aadc 100644 --- a/src/pairs/sim/cell_lists.py +++ b/src/pairs/sim/cell_lists.py @@ -1,7 +1,7 @@ from functools import reduce import math from pairs.ir.bin_op import BinOp -from pairs.ir.block import pairs_device_block +from pairs.ir.block import pairs_device_block, pairs_host_block from pairs.ir.branches import Branch, Filter from pairs.ir.cast import Cast from pairs.ir.loops import For, ParticleFor @@ -36,7 +36,7 @@ class CellListsStencilBuild(Lowerable): super().__init__(sim) self.cell_lists = cell_lists - @pairs_device_block + @pairs_host_block def lower(self): sim = self.sim cl = self.cell_lists diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py index 123d701bbc69ece69d12aabae92a674fdcd98376..f597b942b1811f3f42ae6b31515f0aac8282a027 100644 --- a/src/pairs/sim/interaction.py +++ b/src/pairs/sim/interaction.py @@ -1,5 +1,5 @@ from pairs.ir.bin_op import BinOp -from pairs.ir.block import Block, pairs_block +from pairs.ir.block import Block, pairs_device_block from pairs.ir.branches import Branch, Filter from pairs.ir.loops import For, ParticleFor from pairs.ir.types import Types @@ -60,7 +60,7 @@ class ParticleInteraction(Lowerable): yield self.i, self.j self.sim.leave() - @pairs_block + @pairs_device_block def lower(self): if self.nbody == 2: position = self.sim.position() diff --git a/src/pairs/sim/lattice.py b/src/pairs/sim/lattice.py index bc3e657ba5762f9c93ed9c14ae1401bf750ab2c0..9ab1f9b0549c0976211395c74a3e63a05edc545b 100644 --- a/src/pairs/sim/lattice.py +++ b/src/pairs/sim/lattice.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.loops import For from pairs.ir.types import Types from pairs.sim.lowerable import Lowerable @@ -12,7 +12,7 @@ class ParticleLattice(Lowerable): self.props = props self.positions = positions - @pairs_block + @pairs_inline def lower(self): index = None loop_indexes = [] diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py index f1b457123743d5cadca8a940e763a07f4cedbbe5..bb73102d96f2df534c798d3b69a1f8dcd11b0c82 100644 --- a/src/pairs/sim/pbc.py +++ b/src/pairs/sim/pbc.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_device_block +from pairs.ir.block import pairs_device_block, pairs_host_block from pairs.ir.branches import Branch, Filter from pairs.ir.loops import For, ParticleFor from pairs.ir.utils import Print @@ -72,7 +72,7 @@ class SetupPBC(Lowerable): super().__init__(sim) self.pbc = pbc - @pairs_device_block + @pairs_host_block def lower(self): sim = self.sim ndims = sim.ndims() diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py index 1818511e7471e45f74fba9d2efed9da364791190..7fce292dc525111f1c974cbd9534ceea0692994c 100644 --- a/src/pairs/sim/properties.py +++ b/src/pairs/sim/properties.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block, pairs_device_block +from pairs.ir.block import pairs_device_block, pairs_inline from pairs.ir.loops import ParticleFor from pairs.ir.memory import Malloc, Realloc from pairs.ir.properties import RegisterProperty, UpdateProperty @@ -14,7 +14,7 @@ class PropertiesAlloc(Lowerable): self.sim = sim self.realloc = realloc - @pairs_block + @pairs_inline def lower(self): capacity = sum(self.sim.properties.capacities) for p in self.sim.properties.all(): diff --git a/src/pairs/sim/read_from_file.py b/src/pairs/sim/read_from_file.py index d20a485ca65581f14909a7d1d6f995df70a77f1d..7f37b5fcf4263e53cc713793d1b6ec516e4032b8 100644 --- a/src/pairs/sim/read_from_file.py +++ b/src/pairs/sim/read_from_file.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.functions import Call_Int from pairs.ir.properties import PropertyList from pairs.ir.types import Types @@ -14,7 +14,7 @@ class ReadFromFile(Lowerable): self.grid = MutableGrid(sim, sim.ndims()) self.grid_buffer = self.sim.add_static_array("grid_buffer", [self.sim.ndims() * 2], Types.Double) - @pairs_block + @pairs_inline def lower(self): self.sim.nlocal.set(Call_Int(self.sim, "pairs::read_particle_data", [self.filename, self.grid_buffer, self.props, self.props.length()])) diff --git a/src/pairs/sim/variables.py b/src/pairs/sim/variables.py index f7b512fe96ae50d9854be229c3ba1d202051db9c..18847b3201c86b57bdd754585bb2e8ab0a698db2 100644 --- a/src/pairs/sim/variables.py +++ b/src/pairs/sim/variables.py @@ -1,4 +1,4 @@ -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.variables import VarDecl from pairs.sim.lowerable import Lowerable @@ -7,7 +7,7 @@ class VariablesDecl(Lowerable): def __init__(self, sim): super().__init__(sim) - @pairs_block + @pairs_inline def lower(self): for v in self.sim.vars.all(): VarDecl(self.sim, v) diff --git a/src/pairs/sim/vtk.py b/src/pairs/sim/vtk.py index ad090133a90ca95429fde2e8392de7d47bb9c485..dc19a8ed88576c9d2df799322fb8c6ef13647484 100644 --- a/src/pairs/sim/vtk.py +++ b/src/pairs/sim/vtk.py @@ -1,5 +1,5 @@ from pairs.ir.ast_node import ASTNode -from pairs.ir.block import pairs_block +from pairs.ir.block import pairs_inline from pairs.ir.functions import Call_Void from pairs.ir.lit import Lit from pairs.sim.lowerable import Lowerable @@ -11,7 +11,7 @@ class VTKWrite(Lowerable): self.filename = filename self.timestep = Lit.cvt(sim, timestep) - @pairs_block + @pairs_inline def lower(self): nlocal = self.sim.nlocal npbc = self.sim.pbc.npbc