diff --git a/ast/loops.py b/ast/loops.py index 915a3365a20e5940d4086f4e160b3e0ced629a20..d7ab372d6b63431ba01490da06b639ca9c2305ae 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -21,11 +21,17 @@ class Iter(): return Type_Int def is_mutable(self): - return False + # FIXME: This should be set to False, but currently the flattening transformation is reusing + # expressions that are not alive anymore (used first within if, then outside it), causing + # the generated code to be uncompilable + return True def scope(self): return self.loop.block + def __add__(self, other): + return Expr(self.sim, self, other, '+') + def __sub__(self, other): return Expr(self.sim, self, other, '-') diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index 32f3666f471f876e22c3166f0c260c9eee31d3d0..da346ddbcdc8ba738ba5800b21f71bcecd46b785 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -175,22 +175,17 @@ class ParticleSimulation: def generate(self): timestep = Timestep(self, self.ntimesteps, [ (EnforcePBC(self.pbc).lower(), 20), - (SetupPBC(self.pbc).lower(), 20), - UpdatePBC(self.pbc).lower(), + (SetupPBC(self.pbc).lower(), UpdatePBC(self.pbc).lower(), 20), (CellListsBuild(self.cell_lists).lower(), 20), PropertiesResetVolatile(self).lower(), self.kernels.lower() ]) - timestep.add(Block(self, VTKWrite(self, self.vtk_file, timestep.timestep()))) + timestep.add(Block(self, VTKWrite(self, self.vtk_file, timestep.timestep() + 1))) body = Block.from_list(self, [ CellListsStencilBuild(self.cell_lists).lower(), self.setups.lower(), - EnforcePBC(self.pbc).lower(), - SetupPBC(self.pbc).lower(), - UpdatePBC(self.pbc).lower(), - CellListsBuild(self.cell_lists).lower(), Block(self, VTKWrite(self, self.vtk_file, 0)), timestep.as_block() ]) diff --git a/sim/pbc.py b/sim/pbc.py index bdc709ef5ba5f46878eb347fbdf9ecc31bad0b52..528cc0f60283f3a69ea4f813f99779bb1371d533 100644 --- a/sim/pbc.py +++ b/sim/pbc.py @@ -95,20 +95,13 @@ class SetupPBC: if capacity_exceeded: resize.set(Select(sim, resize > npbc, resize + 1, npbc)) else: + pbc_map[npbc].set(i) pbc_mult[npbc][d].set(1) + positions[nlocal + npbc][d].set(positions[i][d] + grid.length(d)) - for is_local in Branch(sim, i < nlocal): - # TODO: VecFilter.others generator? - if is_local: - pbc_map[npbc].set(i) - else: - pbc_map[npbc].set(pbc_map[i - nlocal]) - - for d_ in [x for x in range(0, ndims) if x != d]: - if is_local: - pbc_mult[npbc][d_].set(0) - else: - pbc_mult[npbc][d_].set(pbc_mult[i - nlocal][d_]) + for d_ in [x for x in range(0, ndims) if x != d]: + pbc_mult[npbc][d_].set(0) + positions[nlocal + npbc][d_].set(positions[i][d_]) npbc.add(1) @@ -117,18 +110,13 @@ class SetupPBC: if capacity_exceeded: resize.set(Select(sim, resize > npbc, resize + 1, npbc)) else: + pbc_map[npbc].set(i) pbc_mult[npbc][d].set(-1) - for is_local in Branch(sim, i < nlocal): - if is_local: - pbc_map[npbc].set(i) - else: - pbc_map[npbc].set(pbc_map[i - nlocal]) - - for d_ in [x for x in range(0, ndims) if x != d]: - if is_local: - pbc_mult[npbc][d_].set(0) - else: - pbc_mult[npbc][d_].set(pbc_mult[i - nlocal][d_]) + positions[nlocal + npbc][d].set(positions[i][d] - grid.length(d)) + + for d_ in [x for x in range(0, ndims) if x != d]: + pbc_mult[npbc][d_].set(0) + positions[nlocal + npbc][d_].set(positions[i][d_]) npbc.add(1) diff --git a/sim/timestep.py b/sim/timestep.py index 3599933bbd2f34a2217418f0f05ebc56dd31b00a..cf6b4dfaf8ec3a33de0189a50af2ec26465a9c74 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -8,29 +8,34 @@ class Timestep: def __init__(self, sim, nsteps, item_list=None): self.sim = sim self.block = Block(sim, []) - self.timestep_loop = For(sim, 1, nsteps + 1, self.block) + self.timestep_loop = For(sim, 0, nsteps + 1, self.block) if item_list is not None: for item in item_list: if isinstance(item, tuple): - self.add(item[0], item[1]) + if len(item) >= 3: + self.add(item[0], item[2], item[1]) + else: + self.add(item[0], item[1]) else: self.add(item) def timestep(self): return self.timestep_loop.iter() - def add(self, item, exec_every=0): - assert exec_every >= 0, \ - "exec_every parameter must be higher or equal than zero!" - + def add(self, item, exec_every=0, item_else=None): + assert exec_every >= 0, "exec_every parameter must be higher or equal than zero!" stmts = item if not isinstance(item, Block) else item.statements() + stmts_else = None ts = self.timestep_loop.iter() + + if item_else is not None: + stmts_else = item_else if not isinstance(item_else, Block) else item_else.statements() + if exec_every > 0: self.block.add_statement( - Branch(self.sim, - Expr.cmp(ts % exec_every, 0), - True, Block(self.sim, stmts), None)) + Branch(self.sim, Expr.cmp(ts % exec_every, 0), True if stmts_else is None else False, + Block(self.sim, stmts), None if stmts_else is None else Block(self.sim, stmts_else))) else: self.block.add_statement(stmts)