Skip to content
Snippets Groups Projects
Commit 09bdc22b authored by Rafael Ravedutti Lucio Machado's avatar Rafael Ravedutti Lucio Machado
Browse files

Add cell lists build

parent 539feefd
No related branches found
No related tags found
No related merge requests found
class Array:
def __init__(self, sim, arr_name, arr_size, arr_type):
self.sim = sim
self.arr_name = arr_name
self.arr_size = arr_size
self.arr_type = arr_type
def __str__(self):
return f"Array<name: {self.arr_name}, size: {self.arr_size}, type: {self.arr_type}>"
def name(self):
return self.arr_name
def size(self):
return self.arr_size
def type(self):
return self.arr_type
def __getitem__(self, expr_ast):
from expr import ExprAST
return ExprAST(self.sim, self, expr_ast, '[]', True)
def generate(self, mem=False):
return self.arr_name
def transform(self, fn):
return fn(self)
...@@ -19,7 +19,7 @@ class AssignAST: ...@@ -19,7 +19,7 @@ class AssignAST:
self.assignments = [(dest, src)] self.assignments = [(dest, src)]
def __str__(self): def __str__(self):
return f"Assign<a: {dest}, b: {src}>" return f"Assign<{self.assignments}>"
def generate(self): def generate(self):
if self.generated is False: if self.generated is False:
......
...@@ -2,7 +2,19 @@ from printer import printer ...@@ -2,7 +2,19 @@ from printer import printer
class BlockAST: class BlockAST:
def __init__(self, stmts): def __init__(self, stmts):
self.stmts = stmts if isinstance(stmts, BlockAST):
self.stmts = stmts.statements()
else:
self.stmts = stmts
def add_statement(self, stmt):
if isinstance(stmt, list):
self.stmts = self.stmts + stmt
else:
self.stmts.append(stmt)
def statements(self):
return self.stmts
def generate(self): def generate(self):
printer.add_ind(4) printer.add_ind(4)
...@@ -15,3 +27,8 @@ class BlockAST: ...@@ -15,3 +27,8 @@ class BlockAST:
self.stmts[i] = self.stmts[i].transform(fn) self.stmts[i] = self.stmts[i].transform(fn)
return fn(self) return fn(self)
def merge_blocks(block1, block2):
assert isinstance(block1, BlockAST), "First block type is not BlockAST!"
assert isinstance(block2, BlockAST), "Second block type is not BlockAST!"
return BlockAST(block1.statements() + block2.statements())
from block import BlockAST
from lit import is_literal, LitAST
from printer import printer from printer import printer
class BranchAST: class BranchAST:
def __init__(self, cond, block_if, block_else): def __init__(self, cond, block_if, block_else):
self.cond = cond self.cond = LitAST(cond) if is_literal(cond) else cond
self.block_if = block_if self.block_if = block_if
self.block_else = block_else self.block_else = block_else
def if_stmt(cond, body):
return BranchAST(cond, body if isinstance(body, BlockAST) else BlockAST(body), None)
def if_else_stmt(cond, body_if, body_else):
return BranchAST(cond,
body_if if isinstance(body_if, BlockAST) else BlockAST(body_if),
body_else if isinstance(body_else, BlockAST) else BlockAST(body_else)
)
def generate(self): def generate(self):
cvname = self.cond.generate() cvname = self.cond.generate()
printer.print(f"if({cvname}) {{") printer.print(f"if({cvname}) {{")
......
from assign import AssignAST
from block import BlockAST
from branches import BranchAST
from data_types import Type_Int
from loops import ForAST, ParticleForAST
class CellLists:
def __init__(self, sim, spacing):
self.sim = sim
self.spacing = spacing
self.ncells = self.sim.add_array('ncells', self.sim.dimensions, Type_Int)
self.ncells_total = self.sim.add_var('ncells_total', Type_Int)
self.cell_capacity = self.sim.add_var('cell_capacity', Type_Int)
self.cell_particles = self.sim.add_array('cell_particles', self.ncells_total * self.cell_capacity, Type_Int)
self.cell_sizes = self.sim.add_array('cell_sizes', self.ncells_total, Type_Int)
def build(self):
positions = self.sim.property('position')
reset_loop = ForAST(self.sim, 0, self.ncells_total)
reset_loop.set_body(BlockAST([self.cell_sizes[reset_loop.iter()].set(0)]))
fill_loop = ParticleForAST(self.sim)
cell_index = [(positions[fill_loop.iter()][d] - self.sim.grid_config[d][0]) / self.spacing for d in range(0, self.sim.dimensions)]
flat_index = None
for d in range(0, self.sim.dimensions):
flat_index = cell_index[d] if flat_index is None else flat_index * self.ncells[d] + cell_index[d]
fill_loop.set_body(BlockAST([
BranchAST.if_stmt(flat_index >= 0 and flat_index <= self.ncells_total, [
self.cell_particles[flat_index * self.cell_capacity + self.cell_sizes[flat_index]].set(fill_loop.iter()),
self.cell_sizes[flat_index].set(self.cell_sizes[flat_index] + 1)
])
]))
return BlockAST([reset_loop, fill_loop])
Type_Invalid = -1 Type_Invalid = -1
Type_Int = 0 Type_Int = 0
Type_Float = 1 Type_Float = 1
Type_Vector = 2 Type_Bool = 2
Type_Vector = 3
...@@ -6,7 +6,8 @@ from printer import printer ...@@ -6,7 +6,8 @@ from printer import printer
from properties import Property from properties import Property
def is_expr(e): def is_expr(e):
return isinstance(e, ExprAST) or isinstance(e, IterAST) or isinstance(e, LitAST) from variables import Var
return isinstance(e, ExprAST) or isinstance(e, IterAST) or isinstance(e, LitAST) or isinstance(e, VarAST)
class ExprAST: class ExprAST:
def __init__(self, sim, lhs, rhs, op, mem=False): def __init__(self, sim, lhs, rhs, op, mem=False):
...@@ -47,6 +48,18 @@ class ExprAST: ...@@ -47,6 +48,18 @@ class ExprAST:
def __lt__(self, other): def __lt__(self, other):
return ExprAST(self.sim, self, other, '<') return ExprAST(self.sim, self, other, '<')
def __le__(self, other):
return ExprAST(self.sim, self, other, '<=')
def __gt__(self, other):
return ExprAST(self.sim, self, other, '>')
def __ge__(self, other):
return ExprAST(self.sim, self, other, '>=')
def cmp(lhs, rhs):
return ExprAST(lhs.sim, lhs, rhs, '==')
def inv(self): def inv(self):
return ExprAST(self.sim, 1.0, self, '/') return ExprAST(self.sim, 1.0, self, '/')
...@@ -60,11 +73,11 @@ class ExprAST: ...@@ -60,11 +73,11 @@ class ExprAST:
def set(self, other): def set(self, other):
assert self.mem is True, "Invalid assignment: lvalue expected!" assert self.mem is True, "Invalid assignment: lvalue expected!"
self.sim.produced_stmts.append(AssignAST(self.sim, self, other)) return self.sim.capture_statement(AssignAST(self.sim, self, other))
def add(self, other): def add(self, other):
assert self.mem is True, "Invalid assignment: lvalue expected!" assert self.mem is True, "Invalid assignment: lvalue expected!"
self.sim.produced_stmts.append(AssignAST(self.sim, self, self + other)) return self.sim.capture_statement(AssignAST(self.sim, self, self + other))
def infer_type(lhs, rhs, op): def infer_type(lhs, rhs, op):
lhs_type = lhs.type() lhs_type = lhs.type()
...@@ -124,6 +137,9 @@ class ExprVecAST(): ...@@ -124,6 +137,9 @@ class ExprVecAST():
def __str__(self): def __str__(self):
return f"ExprVecAST<a: {self.lhs}, b: {self.rhs}, op: {self.expr.op}, i: {self.index}>" return f"ExprVecAST<a: {self.lhs}, b: {self.rhs}, op: {self.expr.op}, i: {self.index}>"
def __sub__(self, other):
return ExprAST(self.sim, self, other, '-')
def __mul__(self, other): def __mul__(self, other):
return ExprAST(self.sim, self, other, '*') return ExprAST(self.sim, self, other, '*')
......
from data_types import Type_Invalid, Type_Int, Type_Float, Type_Vector from data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector
def is_literal(a): def is_literal(a):
return isinstance(a, int) or isinstance(a, float) or isinstance(a, list) return isinstance(a, int) or isinstance(a, float) or isinstance(a, bool) or isinstance(a, list)
class LitAST: class LitAST:
def __init__(self, value): def __init__(self, value):
...@@ -14,6 +14,9 @@ class LitAST: ...@@ -14,6 +14,9 @@ class LitAST:
if isinstance(value, float): if isinstance(value, float):
self.lit_type = Type_Float self.lit_type = Type_Float
if isinstance(value, bool):
self.lit_type = Type_Bool
if isinstance(value, list): if isinstance(value, list):
self.lit_type = Type_Vector self.lit_type = Type_Vector
......
from data_types import Type_Int from data_types import Type_Int
from lit import is_literal, LitAST
from printer import printer from printer import printer
class IterAST(): class IterAST():
...@@ -26,6 +27,10 @@ class IterAST(): ...@@ -26,6 +27,10 @@ class IterAST():
def __req__(self, other): def __req__(self, other):
return self.__cmp__(other) return self.__cmp__(other)
def __mod__(self, other):
from expr import ExprAST
return ExprAST(self.sim, self, other, '%')
def __str__(self): def __str__(self):
return f"Iter<{self.iter_id}>" return f"Iter<{self.iter_id}>"
...@@ -39,8 +44,8 @@ class IterAST(): ...@@ -39,8 +44,8 @@ class IterAST():
class ForAST(): class ForAST():
def __init__(self, sim, range_min, range_max, body=None): def __init__(self, sim, range_min, range_max, body=None):
self.iterator = IterAST(sim) self.iterator = IterAST(sim)
self.min = range_min; self.min = LitAST(range_min) if is_literal(range_min) else range_min;
self.max = range_max; self.max = LitAST(range_max) if is_literal(range_max) else range_max;
self.body = body; self.body = body;
def iter(self): def iter(self):
...@@ -51,7 +56,9 @@ class ForAST(): ...@@ -51,7 +56,9 @@ class ForAST():
def generate(self): def generate(self):
it_id = self.iterator.generate() it_id = self.iterator.generate()
printer.print(f"for(int {it_id} = {self.min}; {it_id} < {self.max}; {it_id}++) {{") rmin = self.min.generate()
rmax = self.max.generate()
printer.print(f"for(int {it_id} = {rmin}; {it_id} < {rmax}; {it_id}++) {{")
self.body.generate(); self.body.generate();
printer.print("}") printer.print("}")
......
...@@ -16,7 +16,6 @@ force = psim.add_vector_property('force', volatile=True) ...@@ -16,7 +16,6 @@ force = psim.add_vector_property('force', volatile=True)
grid_config = [[0.0, 4.0], [0.0, 4.0], [0.0, 4.0]] grid_config = [[0.0, 4.0], [0.0, 4.0], [0.0, 4.0]]
psim.setup_grid(grid_config) psim.setup_grid(grid_config)
psim.create_particle_lattice(grid_config, spacing=[1.0, 1.0, 1.0]) psim.create_particle_lattice(grid_config, spacing=[1.0, 1.0, 1.0])
psim.setup_cell_lists(cutoff_radius + skin)
for i, j, delta, rsq in psim.particle_pairs(cutoff_radius, position): for i, j, delta, rsq in psim.particle_pairs(cutoff_radius, position):
sr2 = 1.0 / rsq sr2 = 1.0 / rsq
......
from arrays import Array
from assign import AssignAST from assign import AssignAST
from block import BlockAST from block import BlockAST
from branches import BranchAST from branches import BranchAST
from cell_lists import CellLists
from data_types import Type_Int, Type_Float, Type_Vector from data_types import Type_Int, Type_Float, Type_Vector
from expr import ExprAST from expr import ExprAST
from loops import ForAST, ParticleForAST, NeighborForAST from loops import ForAST, ParticleForAST, NeighborForAST
from properties import Property from properties import Property
from printer import printer from printer import printer
from timestep import Timestep
from transform import Transform from transform import Transform
from variables import Var
class ParticleSimulation: class ParticleSimulation:
def __init__(self, dims=3, timesteps=100): def __init__(self, dims=3, timesteps=100):
self.properties = [] self.properties = []
self.vars = []
self.arrays = []
self.defaults = {} self.defaults = {}
self.setup = [] self.setup = []
self.grid_config = [] self.grid_config = []
self.setup_stmts = [] self.setup_stmts = []
self.timestep_stmts = [] self.captured_stmts = []
self.produced_stmts = [] self.capture_buffer = []
self.capture = False
self.dimensions = dims self.dimensions = dims
self.ntimesteps = timesteps self.ntimesteps = timesteps
self.expr_id = 0 self.expr_id = 0
...@@ -35,6 +42,25 @@ class ParticleSimulation: ...@@ -35,6 +42,25 @@ class ParticleSimulation:
def add_vector_property(self, prop_name, value=[0.0, 0.0, 0.0], volatile=False): def add_vector_property(self, prop_name, value=[0.0, 0.0, 0.0], volatile=False):
return self.add_property(prop_name, Type_Vector, value, volatile) return self.add_property(prop_name, Type_Vector, value, volatile)
def property(self, prop_name):
return [p for p in self.properties if p.name() == prop_name][0]
def add_array(self, array_name, array_size, array_type):
arr = Array(self, array_name, array_size, array_type)
self.arrays.append(arr)
return arr
def array(self, array_name):
return [a for a in self.arrays if a.name() == array_name][0]
def add_var(self, var_name, var_type):
var = Var(self, var_name, var_type)
self.vars.append(var)
return var
def var(self, var_name):
return [v for v in self.vars if v.name() == var_name][0]
def new_expr(self): def new_expr(self):
self.expr_id += 1 self.expr_id += 1
return self.expr_id - 1 return self.expr_id - 1
...@@ -47,7 +73,7 @@ class ParticleSimulation: ...@@ -47,7 +73,7 @@ class ParticleSimulation:
self.grid_config = config self.grid_config = config
def create_particle_lattice(self, config, spacing, props={}): def create_particle_lattice(self, config, spacing, props={}):
positions = [p for p in self.properties if p.name() == 'position'][0] positions = self.property('position')
assignments = [] assignments = []
loops = [] loops = []
index = None index = None
...@@ -74,16 +100,6 @@ class ParticleSimulation: ...@@ -74,16 +100,6 @@ class ParticleSimulation:
self.setup_stmts.append(loops[0]) self.setup_stmts.append(loops[0])
self.nparticles += nparticles self.nparticles += nparticles
def setup_cell_lists(self, cutoff_radius):
ncells = [
(self.grid_config[0][1] - self.grid_config[0][0]) / cutoff_radius,
(self.grid_config[1][1] - self.grid_config[1][0]) / cutoff_radius,
(self.grid_config[2][1] - self.grid_config[2][0]) / cutoff_radius
]
def set_timesteps(self, ts):
self.ntimesteps = ts
def particle_pairs(self, cutoff_radius=None, position=None): def particle_pairs(self, cutoff_radius=None, position=None):
i = ParticleForAST(self) i = ParticleForAST(self)
j = NeighborForAST(self, i.iter()) j = NeighborForAST(self, i.iter())
...@@ -92,22 +108,37 @@ class ParticleSimulation: ...@@ -92,22 +108,37 @@ class ParticleSimulation:
if cutoff_radius is not None and position is not None: if cutoff_radius is not None and position is not None:
delta = position[i.iter()] - position[j.iter()] delta = position[i.iter()] - position[j.iter()]
rsq = delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2] rsq = delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2]
self.start_capture()
yield i.iter(), j.iter(), delta, rsq yield i.iter(), j.iter(), delta, rsq
j.set_body(BlockAST([BranchAST(rsq < cutoff_radius, BlockAST(self.produced_stmts.copy()), None)])) self.stop_capture()
j.set_body(BlockAST([BranchAST(rsq < cutoff_radius, BlockAST(self.capture_buffer.copy()), None)]))
else: else:
yield i.iter(), j.iter() yield i.iter(), j.iter()
j.set_body(BlockAST(self.produced_stmts.copy())) j.set_body(BlockAST(self.capture_buffer.copy()))
self.timestep_stmts.append(i) self.captured_stmts.append(i)
self.produced_stmts = []
def particles(self): def particles(self):
i = ParticleForAST(self) i = ParticleForAST(self)
self.start_capture()
yield i.iter() yield i.iter()
i.set_body(BlockAST(self.produced_stmts.copy())) self.stop_capture()
self.timestep_stmts.append(i) i.set_body(BlockAST(self.capture_buffer.copy()))
self.produced_stmts = [] self.captured_stmts.append(i)
def start_capture(self):
self.capture_buffer = []
self.capture = True
def stop_capture(self):
self.capture = False
def capture_statement(self, stmt):
if self.capture is True:
self.capture_buffer.append(stmt)
return stmt
def generate_properties_decl(self): def generate_properties_decl(self):
for p in self.properties: for p in self.properties:
...@@ -125,16 +156,16 @@ class ParticleSimulation: ...@@ -125,16 +156,16 @@ class ParticleSimulation:
printer.print("int main() {") printer.print("int main() {")
printer.print(f" const int nparticles = {self.nparticles};") printer.print(f" const int nparticles = {self.nparticles};")
setup_block = BlockAST(self.setup_stmts) setup_block = BlockAST(self.setup_stmts)
setup_block.transform(Transform.flatten)
setup_block.transform(Transform.simplify)
reset_loop = ParticleForAST(self) reset_loop = ParticleForAST(self)
reset_loop.set_body(BlockAST([AssignAST(self, p[reset_loop.iter()], 0.0) for p in self.properties if p.volatile is True])) reset_loop.set_body(BlockAST([AssignAST(self, p[reset_loop.iter()], 0.0) for p in self.properties if p.volatile is True]))
self.timestep_stmts.insert(0, reset_loop) cell_lists = CellLists(self, 2.8)
timestep_block = BlockAST([ForAST(self, 0, self.ntimesteps, BlockAST(self.timestep_stmts))]) timestep_loop = Timestep(self, self.ntimesteps)
timestep_block.transform(Transform.flatten) timestep_loop.add(cell_lists.build(), 20)
timestep_block.transform(Transform.simplify) timestep_loop.add(reset_loop)
timestep_loop.add(self.captured_stmts)
program = BlockAST.merge_blocks(setup_block, timestep_loop.as_block())
program.transform(Transform.flatten)
program.transform(Transform.simplify)
self.generate_properties_decl() self.generate_properties_decl()
setup_block.generate() program.generate()
timestep_block.generate()
printer.print("}") printer.print("}")
from block import BlockAST
from expr import ExprAST
from branches import BranchAST
from loops import ForAST
class Timestep:
def __init__(self, sim, nsteps):
self.sim = sim
self.block = BlockAST([])
self.timestep_loop = ForAST(sim, 0, nsteps, self.block)
def add(self, item, exec_every=0):
assert exec_every >= 0, "Timestep frequency parameter must be higher or equal than zero!"
statements = item if not isinstance(item, BlockAST) else item.statements()
if exec_every > 0:
self.block.add_statement(BranchAST.if_stmt(ExprAST.cmp(self.timestep_loop.iter() % exec_every, 0), statements))
else:
self.block.add_statement(statements)
def as_block(self):
return BlockAST([self.timestep_loop])
def generate(self):
self.block.generate()
def transform(self, fn):
self.block = self.block.transform(fn)
from expr import ExprAST
class Var:
def __init__(self, sim, var_name, var_type):
self.sim = sim
self.var_name = var_name
self.var_type = var_type
def __str__(self):
return f"Var<name: {self.var_name}, type: {self.var_type}>"
def __add__(self, other):
return ExprAST(self.sim, self, other, '+')
def __radd__(self, other):
return ExprAST(self.sim, other, self, '+')
def __sub__(self, other):
return ExprAST(self.sim, self, other, '-')
def __mul__(self, other):
return ExprAST(self.sim, self, other, '*')
def __rmul__(self, other):
return ExprAST(self.sim, other, self, '*')
def __truediv__(self, other):
return ExprAST(self.sim, self, other, '/')
def __rtruediv__(self, other):
return ExprAST(self.sim, other, self, '/')
def __lt__(self, other):
return ExprAST(self.sim, self, other, '<')
def inv(self):
return ExprAST(self.sim, 1.0, self, '/')
def name(self):
return self.var_name
def type(self):
return self.var_type
def generate(self, mem=False):
return self.var_name
def transform(self, fn):
return fn(self)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment