From dcccfc94a350d92a4f7a342f6dafd12a7e408bca Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Thu, 24 Jun 2021 02:24:15 +0200 Subject: [PATCH] Add neighbor lists and consider ArrayND in LICM Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- ir/loops.py | 28 ++++++++++++++--------- new_syntax.py | 1 + particle.py | 1 + sim/cell_lists.py | 1 + sim/neighbor_lists.py | 46 ++++++++++++++++++++++++++++++++++++++ sim/particle_simulation.py | 9 +++++++- transformations/LICM.py | 14 +++++++++++- 7 files changed, 87 insertions(+), 13 deletions(-) create mode 100644 sim/neighbor_lists.py diff --git a/ir/loops.py b/ir/loops.py index b4ce423..f035afa 100644 --- a/ir/loops.py +++ b/ir/loops.py @@ -102,22 +102,28 @@ class While(ASTNode): class NeighborFor(): - def __init__(self, sim, particle, cell_lists): + def __init__(self, sim, particle, cell_lists, neighbor_lists=None): self.sim = sim self.particle = particle self.cell_lists = cell_lists + self.neighbor_lists = neighbor_lists def __str__(self): return f"NeighborFor<particle: {self.particle}>" def __iter__(self): - cl = self.cell_lists - for s in For(self.sim, 0, cl.nstencil): - neigh_cell = cl.particle_cell[self.particle] + cl.stencil[s] - for _ in Filter(self.sim, - BinOp.and_op(neigh_cell >= 0, - neigh_cell <= cl.ncells_all)): - for nc in For(self.sim, 0, cl.cell_sizes[neigh_cell]): - it = cl.cell_particles[neigh_cell][nc] - for _ in Filter(self.sim, BinOp.neq(it, self.particle)): - yield it + if self.neighbor_lists is None: + cl = self.cell_lists + for s in For(self.sim, 0, cl.nstencil): + neigh_cell = cl.particle_cell[self.particle] + cl.stencil[s] + for _ in Filter(self.sim, + BinOp.and_op(neigh_cell >= 0, + neigh_cell <= cl.ncells_all)): + for nc in For(self.sim, 0, cl.cell_sizes[neigh_cell]): + it = cl.cell_particles[neigh_cell][nc] + for _ in Filter(self.sim, BinOp.neq(it, self.particle)): + yield it + else: + neighbor_lists = self.neighbor_lists + for k in For(self.sim, 0, neighbor_lists.numneighs[self.particle]): + yield neighbor_lists.neighborlists[self.particle][k] diff --git a/new_syntax.py b/new_syntax.py index 28a1e06..bdbeb0d 100644 --- a/new_syntax.py +++ b/new_syntax.py @@ -184,6 +184,7 @@ psim.add_vector_property('velocity') psim.add_vector_property('force', vol=True) psim.from_file("data/minimd_setup_4x4x4.input", ['mass', 'position', 'velocity']) psim.create_cell_lists(2.8, 2.8) +psim.create_neighbor_lists() psim.periodic(2.8) psim.vtk_output("output/test") diff --git a/particle.py b/particle.py index 312a132..8680d2b 100644 --- a/particle.py +++ b/particle.py @@ -30,6 +30,7 @@ force = psim.add_vector_property('force', vol=True) #psim.create_particle_lattice(grid, spacing=[0.82323, 0.82323, 0.82323]) psim.from_file("data/minimd_setup_4x4x4.input", ['mass', 'position', 'velocity']) psim.create_cell_lists(2.8, 2.8) +psim.create_neighbor_lists() psim.periodic(2.8) psim.vtk_output("output/test") diff --git a/sim/cell_lists.py b/sim/cell_lists.py index 923c4de..4c1400a 100644 --- a/sim/cell_lists.py +++ b/sim/cell_lists.py @@ -14,6 +14,7 @@ class CellLists: self.sim = sim self.grid = grid self.spacing = spacing + self.cutoff_radius = cutoff_radius self.nneighbor_cells = [ math.ceil(cutoff_radius / ( diff --git a/sim/neighbor_lists.py b/sim/neighbor_lists.py new file mode 100644 index 0000000..f2fa90b --- /dev/null +++ b/sim/neighbor_lists.py @@ -0,0 +1,46 @@ +from ir.branches import Branch, Filter +from ir.data_types import Type_Int +from ir.loops import For, ParticleFor, NeighborFor +from ir.utils import Print +from sim.resize import Resize + + +class NeighborLists: + def __init__(self, cell_lists): + self.sim = cell_lists.sim + self.cell_lists = cell_lists + self.capacity = self.sim.add_var('neighborlist_capacity', Type_Int, 32) + self.neighborlists = self.sim.add_array('neighborlists', [self.sim.particle_capacity, self.capacity], Type_Int) + self.numneighs = self.sim.add_array('numneighs', self.sim.particle_capacity, Type_Int) + + +class NeighborListsBuild: + def __init__(self, neighbor_lists): + self.neighbor_lists = neighbor_lists + + def lower(self): + neighbor_lists = self.neighbor_lists + sim = neighbor_lists.sim + cell_lists = neighbor_lists.cell_lists + cutoff_radius = cell_lists.cutoff_radius + position = sim.property('position') + + sim.clear_block() + sim.add_statement(Print(sim, "NeighborListsBuild")) + for resize in Resize(sim, neighbor_lists.capacity): + for i in ParticleFor(sim): + neighbor_lists.numneighs[i].set(0) + for j in NeighborFor(sim, i, cell_lists): + # TODO: find a way to not repeat this (already present in particle_pairs) + dp = position[i] - position[j] + rsq = dp.x() * dp.x() + dp.y() * dp.y() + dp.z() * dp.z() + for _ in Filter(sim, rsq < cutoff_radius): + numneighs = neighbor_lists.numneighs[i] + for cond in Branch(sim, numneighs >= neighbor_lists.capacity): + if cond: + resize.set(numneighs) + else: + neighbor_lists.neighborlists[i][numneighs].set(j) + neighbor_lists.numneighs[i].set(numneighs + 1) + + return sim.block diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index 20efef8..5b82cd4 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -12,6 +12,7 @@ from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild from sim.grid import Grid2D, Grid3D from sim.kernel_wrapper import KernelWrapper from sim.lattice import ParticleLattice +from sim.neighbor_lists import NeighborLists, NeighborListsBuild from sim.pbc import PBC, UpdatePBC, EnforcePBC, SetupPBC from sim.properties import PropertiesAlloc, PropertiesResetVolatile from sim.read_from_file import ReadFromFile @@ -38,6 +39,7 @@ class ParticleSimulation: self.nghost = self.add_var('nghost', Type_Int) self.grid = None self.cell_lists = None + self.neighbor_lists = None self.pbc = None self.scope = [] self.nested_count = 0 @@ -114,6 +116,10 @@ class ParticleSimulation: self.cell_lists = CellLists(self, self.grid, spacing, cutoff_radius) return self.cell_lists + def create_neighbor_lists(self): + self.neighbor_lists = NeighborLists(self.cell_lists) + return self.neighbor_lists + def periodic(self, cutneigh, flags=[1, 1, 1]): self.pbc = PBC(self, self.grid, cutneigh, flags) self.properties.add_capacity(self.pbc.pbc_capacity) @@ -122,7 +128,7 @@ class ParticleSimulation: def particle_pairs(self, cutoff_radius=None, position=None): self.clear_block() for i in ParticleFor(self): - for j in NeighborFor(self, i, self.cell_lists): + for j in NeighborFor(self, i, self.cell_lists, self.neighbor_lists): if cutoff_radius is not None and position is not None: dp = position[i] - position[j] rsq = dp.x() * dp.x() + dp.y() * dp.y() + dp.z() * dp.z() @@ -177,6 +183,7 @@ class ParticleSimulation: (EnforcePBC(self.pbc).lower(), 20), (SetupPBC(self.pbc).lower(), UpdatePBC(self.pbc).lower(), 20), (CellListsBuild(self.cell_lists).lower(), 20), + (NeighborListsBuild(self.neighbor_lists).lower(), 20), PropertiesResetVolatile(self).lower(), self.kernels.lower() ]) diff --git a/transformations/LICM.py b/transformations/LICM.py index 74f9969..5b42d8c 100644 --- a/transformations/LICM.py +++ b/transformations/LICM.py @@ -38,9 +38,13 @@ class SetBlockVariants(Mutator): return ast_node def mutate_BinOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + # For property accesses, we only want to include the property name, and not # the index that is also present in the expression - ast_node.lhs = self.mutate(ast_node.lhs) + if not ast_node.is_property_access(): + ast_node.rhs = self.mutate(ast_node.rhs) + return ast_node def mutate_ArrayAccess(self, ast_node): @@ -52,6 +56,10 @@ class SetBlockVariants(Mutator): def mutate_Array(self, ast_node): return self.push_variant(ast_node) + # TODO: Array should be enough + def mutate_ArrayND(self, ast_node): + return self.push_variant(ast_node) + def mutate_Iter(self, ast_node): return self.push_variant(ast_node) @@ -130,6 +138,10 @@ class SetBinOpTerminals(Visitor): def visit_Array(self, ast_node): self.push_terminal(ast_node) + # TODO: Array should be enough + def visit_ArrayND(self, ast_node): + self.push_terminal(ast_node) + def visit_Iter(self, ast_node): self.push_terminal(ast_node) -- GitLab