Skip to content
Snippets Groups Projects
Commit e4f83f25 authored by Rafael Ravedutti Lucio Machado's avatar Rafael Ravedutti Lucio Machado
Browse files

Refactor code with python standards

parent 4d90016d
Branches
Tags
No related merge requests found
......@@ -5,6 +5,7 @@ from ast.lit import is_literal, LitAST
from ast.memory import ReallocAST
from functools import reduce
class Arrays:
def __init__(self, sim):
self.sim = sim
......@@ -19,16 +20,22 @@ class Arrays:
def find(self, a_name):
return [a for a in self.arrays if a.name() == a_name][0]
class ArrayND:
def __init__(self, sim, arr_name, arr_sizes, arr_type):
self.sim = sim
self.arr_name = arr_name
self.arr_sizes = [arr_sizes] if not isinstance(arr_sizes, list) else [LitAST(s) if is_literal(s) else s for s in arr_sizes]
self.arr_sizes = ([arr_sizes] if not isinstance(arr_sizes, list)
else [LitAST(s) if is_literal(s) else s
for s in arr_sizes])
self.arr_type = arr_type
self.arr_ndims = len(self.arr_sizes)
def __str__(self):
return f"ArrayND<name: {self.arr_name}, sizes: {self.arr_sizes}, type: {self.arr_type}>"
return f"""ArrayND<
name: {self.arr_name},
sizes: {self.arr_sizes},
type: {self.arr_type}>"""
def __getitem__(self, expr_ast):
return ArrayAccess(self.sim, self, expr_ast)
......@@ -57,6 +64,7 @@ class ArrayND:
def transform(self, fn):
return fn(self)
class ArrayAccess:
def __init__(self, sim, array, index):
self.sim = sim
......@@ -80,19 +88,23 @@ class ArrayAccess:
return self.sim.capture_statement(AssignAST(self.sim, self, other))
def add(self, other):
return self.sim.capture_statement(AssignAST(self.sim, self, self + other))
return self.sim.capture_statement(
AssignAST(self.sim, self, self + other))
def type(self):
return self.array.type() if len(self.indexes) == self.array.ndims() else Type_Array
return (self.array.type() if len(self.indexes) == self.array.ndims()
else Type_Array)
def generate(self, mem=False):
index = None
sizes = self.array.sizes()
for s in range(0, len(sizes)):
index = self.indexes[s] if index is None else index * sizes[s] + self.indexes[s]
for s in range(0, len(sizes)):
index = (self.indexes[s] if index is None
else index * sizes[s] + self.indexes[s])
index = LitAST(index) if is_literal(index) else index
return self.sim.code_gen.generate_array_access(self.array.generate(), index.generate())
return self.sim.code_gen.generate_array_access(
self.array.generate(), index.generate())
def transform(self, fn):
self.array = self.array.transform(fn)
......
from ast.data_types import Type_Vector
from ast.lit import is_literal, LitAST
class AssignAST:
def __init__(self, sim, dest, src):
self.sim = sim
......@@ -13,7 +14,11 @@ class AssignAST:
for i in range(0, sim.dimensions):
from ast.expr import ExprAST
self.assignments.append((dest[i], src if not isinstance(src, ExprAST) or src.type() != Type_Vector else src[i]))
dsrc = (src if (not isinstance(src, ExprAST) or
src.type() != Type_Vector)
else src[i])
self.assignments.append((dest[i], dsrc))
else:
self.assignments = [(dest, src)]
......@@ -30,5 +35,9 @@ class AssignAST:
self.generated = True
def transform(self, fn):
self.assignments = [(self.assignments[i][0].transform(fn), self.assignments[i][1].transform(fn)) for i in range(0, len(self.assignments))]
self.assignments = [(
self.assignments[i][0].transform(fn),
self.assignments[i][1].transform(fn))
for i in range(0, len(self.assignments))]
return fn(self)
......@@ -31,8 +31,10 @@ class BlockAST:
return fn(self)
def merge_blocks(block1, block2):
assert isinstance(block1, BlockAST), "First block type is not BlockAST!"
assert isinstance(block2, BlockAST), "Second block type is not BlockAST!"
assert isinstance(block1, BlockAST), \
"First block type is not BlockAST!"
assert isinstance(block2, BlockAST), \
"Second block type is not BlockAST!"
return BlockAST(block1.sim, block1.statements() + block2.statements())
def from_list(sim, block_list):
......@@ -40,7 +42,8 @@ class BlockAST:
result_block = BlockAST(sim, [])
for block in block_list:
assert isinstance(block, BlockAST), "Element in list is not BlockAST!"
assert isinstance(block, BlockAST), \
"Element in list is not BlockAST!"
result_block = BlockAST.merge_blocks(result_block, block)
return result_block
from ast.block import BlockAST
from ast.lit import is_literal, LitAST
class BranchAST:
def __init__(self, sim, cond, block_if, block_else):
self.sim = sim
......@@ -9,13 +10,16 @@ class BranchAST:
self.block_else = block_else
def if_stmt(sim, cond, body):
return BranchAST(sim, cond, body if isinstance(body, BlockAST) else BlockAST(sim, body), None)
return BranchAST(sim, cond, (body if isinstance(body, BlockAST)
else BlockAST(sim, body)), None)
def if_else_stmt(sim, cond, body_if, body_else):
return BranchAST(sim, cond,
body_if if isinstance(body_if, BlockAST) else BlockAST(sim, body_if),
body_else if isinstance(body_else, BlockAST) else BlockAST(sim, body_else)
)
return BranchAST(
sim, cond,
(body_if if isinstance(body_if, BlockAST)
else BlockAST(sim, body_if)),
(body_else if isinstance(body_else, BlockAST)
else BlockAST(sim, body_else)))
def generate(self):
self.sim.code_gen.generate_if(self.cond.generate())
......@@ -30,5 +34,6 @@ class BranchAST:
def transform(self, fn):
self.cond = self.cond.transform(fn)
self.block_if = self.block_if.transform(fn)
self.block_else = None if self.block_else is None else self.block_else.transform(fn)
self.block_else = (None if self.block_else is None
else self.block_else.transform(fn))
return fn(self)
from ast.data_types import Type_Int, Type_Float
class CastAST:
def __init__(self, sim, expr, cast_type):
self.sim = sim
......
from ast.assign import AssignAST
from ast.data_types import Type_Int, Type_Float, Type_Bool, Type_Vector
from ast.lit import is_literal, LitAST
from ast.loops import IterAST
from ast.properties import Property
from code_gen.printer import printer
class ExprAST:
def __init__(self, sim, lhs, rhs, op, mem=False):
self.sim = sim
......@@ -63,7 +63,8 @@ class ExprAST:
return ExprAST(self.sim, 1.0, self, '/')
def __getitem__(self, index):
assert self.lhs.type() == Type_Vector, "Cannot use operator [] on specified type!"
assert self.lhs.type() == Type_Vector, \
"Cannot use operator [] on specified type!"
index_ast = index if not is_literal(index) else LitAST(index)
return ExprVecAST(self.sim, self, index_ast)
......@@ -76,7 +77,8 @@ class ExprAST:
def add(self, other):
assert self.mem is True, "Invalid assignment: lvalue expected!"
return self.sim.capture_statement(AssignAST(self.sim, self, self + other))
return self.sim.capture_statement(
AssignAST(self.sim, self, self + other))
def infer_type(lhs, rhs, op):
lhs_type = lhs.type()
......@@ -116,21 +118,29 @@ class ExprAST:
ename = f"e{self.expr_id}"
if self.generated is False:
assert self.expr_type != Type_Vector, "Vector code must be generated through ExprVecAST class!"
t = 'double' if self.expr_type == Type_Float else 'int' if self.expr_type == Type_Int else 'bool'
assert self.expr_type != Type_Vector, \
"Vector code must be generated through ExprVecAST class!"
t = ('double' if self.expr_type == Type_Float
else 'int' if self.expr_type == Type_Int else 'bool')
printer.print(f"const {t} {ename} = {lexpr} {self.op} {rexpr};")
self.generated = True
return ename
def generate_inline(self, mem=False):
lexpr = self.lhs.generate_inline(mem) if isinstance(self.lhs, ExprAST) else self.lhs.generate(mem)
rexpr = self.rhs.generate_inline() if isinstance(self.rhs, ExprAST) else self.rhs.generate()
lexpr = (self.lhs.generate_inline(mem) if isinstance(self.lhs, ExprAST)
else self.lhs.generate(mem))
rexpr = (self.rhs.generate_inline() if isinstance(self.rhs, ExprAST)
else self.rhs.generate())
if self.op == '[]':
return f"{lexpr}[{rexpr}]" if self.mem else f"{lexpr}_{rexpr}"
assert self.expr_type != Type_Vector, "Vector code must be generated through ExprVecAST class!"
assert self.expr_type != Type_Vector, \
"Vector code must be generated through ExprVecAST class!"
return f"{lexpr} {self.op} {rexpr}"
def transform(self, fn):
......@@ -138,16 +148,21 @@ class ExprAST:
self.rhs = self.rhs.transform(fn)
return fn(self)
class ExprVecAST():
def __init__(self, sim, expr, index):
self.sim = sim
self.expr = expr
self.index = index
self.lhs = expr.lhs if not isinstance(expr.lhs, ExprAST) else ExprVecAST(sim, expr.lhs, index)
self.rhs = expr.rhs if not isinstance(expr.rhs, ExprAST) else ExprVecAST(sim, expr.rhs, index)
self.lhs = (expr.lhs if not isinstance(expr.lhs, ExprAST)
else ExprVecAST(sim, expr.lhs, index))
self.rhs = (expr.rhs if not isinstance(expr.rhs, ExprAST)
else ExprVecAST(sim, expr.rhs, index))
def __str__(self):
return f"ExprVecAST<a: {self.lhs}, b: {self.rhs}, op: {self.expr.op}, i: {self.index}>"
return f"""ExprVecAST<
a: {self.lhs}, b: {self.rhs}, op: {self.expr.op},
i: {self.index}>"""
def __sub__(self, other):
return ExprAST(self.sim, self, other, '-')
......@@ -170,11 +185,14 @@ class ExprVecAST():
expr = self.expr.generate()
return f"{expr}[{iexpr}]"
ename = f"e{self.expr.expr_id}[{iexpr}]" if self.expr.mem else f"e{self.expr.expr_id}_{iexpr}"
ename = (f"e{self.expr.expr_id}[{iexpr}]" if self.expr.mem
else f"e{self.expr.expr_id}_{iexpr}")
if self.expr.generated_vector_index(iexpr):
lexpr = self.lhs.generate(mem)
rexpr = self.rhs.generate()
printer.print(f"const double {ename} = {lexpr} {self.expr.op} {rexpr};")
printer.print(
f"const double {ename} = {lexpr} {self.expr.op} {rexpr};")
self.expr.vec_generated.append(iexpr)
return ename
......
from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector
from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool
from ast.data_types import Type_Vector
def is_literal(a):
return isinstance(a, int) or isinstance(a, float) or isinstance(a, bool) or isinstance(a, list)
return (isinstance(a, int) or
isinstance(a, float) or
isinstance(a, bool) or
isinstance(a, list))
class LitAST:
def __init__(self, value):
......
from ast.data_types import Type_Int
from ast.lit import is_literal, LitAST
class IterAST():
def __init__(self, sim):
self.sim = sim
......@@ -40,13 +41,14 @@ class IterAST():
def transform(self, fn):
return fn(self)
class ForAST():
def __init__(self, sim, range_min, range_max, body=None):
self.sim = sim
self.iterator = IterAST(sim)
self.min = LitAST(range_min) if is_literal(range_min) else range_min;
self.max = LitAST(range_max) if is_literal(range_max) else range_max;
self.body = body;
self.min = LitAST(range_min) if is_literal(range_min) else range_min
self.max = LitAST(range_max) if is_literal(range_max) else range_max
self.body = body
def __str__(self):
return f"For<min: {self.min}, max: {self.max}>"
......@@ -70,16 +72,19 @@ class ForAST():
self.body = self.body.transform(fn)
return fn(self)
class ParticleForAST(ForAST):
def __init__(self, sim, body=None):
super().__init__(sim, 0, 0, body)
def generate(self):
it_id = self.iterator.generate()
self.sim.code_gen.generate_for_preamble(it_id, 0, self.sim.nparticles.generate())
self.sim.code_gen.generate_for_preamble(
it_id, 0, self.sim.nparticles.generate())
self.body.generate()
self.sim.code_gen.generate_for_epilogue()
class NeighborForAST(ForAST):
def __init__(self, sim, particle_iter, body=None):
super().__init__(sim, 0, 0, body)
......@@ -87,10 +92,12 @@ class NeighborForAST(ForAST):
def generate(self):
it_id = self.iterator.generate()
self.sim.code_gen.generate_for_preamble(it_id, 0, f"neighbors[{self.particle_iter.generate()}]")
self.sim.code_gen.generate_for_preamble(
it_id, 0, f"neighbors[{self.particle_iter.generate()}]")
self.body.generate()
self.sim.code_gen.generate_for_epilogue()
class WhileAST():
def __init__(self, sim, cond, body=None):
self.sim = sim
......@@ -102,7 +109,8 @@ class WhileAST():
def generate(self):
from ast.expr import ExprAST
cond_gen = self.cond.generate() if not isinstance(self.cond, ExprAST) else self.cond.generate_inline()
cond_gen = (self.cond.generate() if not isinstance(self.cond, ExprAST)
else self.cond.generate_inline())
self.sim.code_gen.generate_while_preamble(cond_gen)
self.body.generate()
self.sim.code_gen.generate_while_epilogue()
......
......@@ -5,7 +5,8 @@ class ReallocAST:
self.size = size
def generate(self, mem=False):
self.sim.code_gen.generate_realloc(self.array.generate(), self.size.generate())
self.sim.code_gen.generate_realloc(
self.array.generate(), self.size.generate())
def transform(self, fn):
self.array = self.array.transform(fn)
......
from ast.data_types import Type_Float, Type_Vector
class Properties:
def __init__(self, sim):
self.sim = sim
......@@ -25,6 +26,7 @@ class Properties:
def find(self, p_name):
return [p for p in self.props if p.name() == p_name][0]
class Property:
def __init__(self, sim, prop_name, prop_type, default_value, volatile):
self.sim = sim
......@@ -53,6 +55,7 @@ class Property:
def transform(self, fn):
return fn(self)
class PropertyDeclAST:
def __init__(self, sim, prop, size):
self.sim = sim
......@@ -62,19 +65,21 @@ class PropertyDeclAST:
return f"PropertyDecl<{self.prop.name}>"
def generate(self, mem=False):
nparticles = self.sim.nparticles
sizes = []
if self.prop.prop_type == Type_Float:
sizes = [self.sim.nparticles.generate()]
sizes = [nparticles.generate()]
elif self.prop.prop_type == Type_Vector:
if self.prop.flattened:
sizes = [(self.sim.nparticles * self.sim.dimensions).generate()]
sizes = [(nparticles * self.sim.dimensions).generate()]
else:
sizes = [self.sim.nparticles.generate(), self.sim.dimensions]
sizes = [nparticles.generate(), self.sim.dimensions]
else:
raise Exception("Invalid property type!")
self.sim.code_gen.generate_array_decl(self.prop.prop_name, self.prop.prop_type, sizes)
self.sim.code_gen.generate_array_decl(
self.prop.prop_name, self.prop.prop_type, sizes)
def transform(self, fn):
return fn(self)
......@@ -3,18 +3,27 @@ from ast.expr import ExprAST, ExprVecAST
from ast.lit import LitAST
from ast.properties import Property
class Transform:
flattened_list = []
def flatten(ast):
if isinstance(ast, ExprVecAST):
if ast.expr.op == '[]' and ast.expr.type() == Type_Vector:
item = [f for f in Transform.flattened_list if f[0] == ast.index and f[1] == ast.expr.rhs]
item = [f for f in Transform.flattened_list
if f[0] == ast.index and f[1] == ast.expr.rhs]
if item:
return item[0][2]
new_expr = ExprAST(ast.expr.sim, ast.expr.lhs, ast.expr.rhs * ast.expr.sim.dimensions + ast.index, '[]', ast.expr.mem)
Transform.flattened_list.append((ast.index, ast.expr.rhs, new_expr))
new_expr = ExprAST(
ast.expr.sim,
ast.expr.lhs,
ast.expr.rhs * ast.expr.sim.dimensions + ast.index,
'[]',
ast.expr.mem)
Transform.flattened_list.append(
(ast.index, ast.expr.rhs, new_expr))
return new_expr
if isinstance(ast, Property):
......
from ast.assign import AssignAST
from ast.expr import ExprAST
class Variables:
def __init__(self, sim):
self.sim = sim
......@@ -15,6 +16,7 @@ class Variables:
def find(self, v_name):
return [v for v in self.vars if v.name() == v_name][0]
class Var:
def __init__(self, sim, var_name, var_type):
self.sim = sim
......@@ -58,7 +60,8 @@ class Var:
return self.sim.capture_statement(AssignAST(self.sim, self, other))
def add(self, other):
return self.sim.capture_statement(AssignAST(self.sim, self, self + other))
return self.sim.capture_statement(
AssignAST(self.sim, self, self + other))
def name(self):
return self.var_name
......
from ast.data_types import Type_Int, Type_Float
from code_gen.printer import printer
class CGen:
def generate_program_preamble():
printer.print("int main() {")
......@@ -15,7 +16,9 @@ class CGen:
printer.add_ind(-4)
def generate_cast(ctype, expr):
t = 'double' if ctype == Type_Float else 'int' if ctype == Type_Int else 'bool'
t = ('double' if ctype == Type_Float
else 'int' if ctype == Type_Int else 'bool')
return f"({t})({expr})"
def generate_if(cond):
......@@ -31,7 +34,9 @@ class CGen:
printer.print(f"{dest} = {src};")
def generate_array_decl(array, a_type, sizes):
t = 'double' if a_type == Type_Float else 'int' if a_type == Type_Int else 'bool'
t = ('double' if a_type == Type_Float
else 'int' if a_type == Type_Int else 'bool')
gen_str = f"{t} {array}"
for s in sizes:
gen_str += f"[{s}]"
......@@ -45,7 +50,8 @@ class CGen:
printer.print(f"{array} = realloc({size});")
def generate_for_preamble(iter_id, rmin, rmax):
printer.print(f"for(int {iter_id} = {rmin}; {iter_id} < {rmax}; {iter_id}++) {{")
printer.print(
f"for(int {iter_id} = {rmin}; {iter_id} < {rmax}; {iter_id}++) {{")
def generate_for_epilogue():
printer.print("}")
......
class Printer:
def __init__(self):
self.indent = 0
self.indent = 0
def add_ind(self, offset):
self.indent += offset
......@@ -8,4 +8,5 @@ class Printer:
def print(self, text):
print(self.indent * ' ' + text)
printer = Printer()
from code_gen.cgen import CGen
from sim.particle_simulation import ParticleSimulation
def simulation(dims=3, timesteps=100):
return ParticleSimulation(CGen, dims, timesteps)
......@@ -11,7 +11,7 @@ psim = pt.simulation()
mass = psim.add_real_property('mass', 1.0)
position = psim.add_vector_property('position')
velocity = psim.add_vector_property('velocity')
force = psim.add_vector_property('force', volatile=True)
force = psim.add_vector_property('force', vol=True)
grid_config = [[0.0, 4.0], [0.0, 4.0], [0.0, 4.0]]
psim.setup_grid(grid_config)
......
from ast.assign import AssignAST
from ast.block import BlockAST
from ast.branches import BranchAST
from ast.cast import CastAST
......@@ -9,19 +8,32 @@ from functools import reduce
from sim.resize import Resize
import math
class CellLists:
def __init__(self, sim, spacing, cutoff_radius):
self.sim = sim
self.spacing = spacing
self.nneighbor_cells = [math.ceil(cutoff_radius / (spacing if not isinstance(spacing, list) else spacing[d])) for d in range(0, sim.dimensions)]
self.nstencil = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(0, sim.dimensions)])
self.ncells = self.sim.add_array('ncells', self.sim.dimensions, Type_Int)
self.ncells_total = self.sim.add_var('ncells_total', Type_Int)
self.nneighbor_cells = [
math.ceil(cutoff_radius / (
spacing if not isinstance(spacing, list)
else spacing[d]))
for d in range(0, sim.dimensions)]
self.nstencil = reduce((lambda x, y: x * y), [
self.nneighbor_cells[d] * 2 + 1 for d in range(0, sim.dimensions)])
self.ncells_all = self.sim.add_var('ncells_all', Type_Int)
self.cell_capacity = self.sim.add_var('cell_capacity', Type_Int)
self.cell_particles = self.sim.add_array('cell_particles', [self.ncells_total, self.cell_capacity], Type_Int)
self.cell_sizes = self.sim.add_array('cell_sizes', self.ncells_total, Type_Int)
self.ncells = self.sim.add_array(
'ncells', self.sim.dimensions, Type_Int)
self.cell_particles = self.sim.add_array(
'cell_particles', [self.ncells_all, self.cell_capacity], Type_Int)
self.cell_sizes = self.sim.add_array(
'cell_sizes', self.ncells_all, Type_Int)
self.stencil = self.sim.add_array('stencil', self.nstencil, Type_Int)
class CellListsBuild:
def __init__(self, sim, cell_lists):
self.sim = sim
......@@ -29,26 +41,34 @@ class CellListsBuild:
def lower(self):
cl = self.cell_lists
cfg = cl.sim.grid_config
positions = self.sim.property('position')
reset_loop = ForAST(self.sim, 0, cl.ncells_total)
reset_loop.set_body(BlockAST(self.sim, [cl.cell_sizes[reset_loop.iter()].set(0)]))
reset_loop = ForAST(self.sim, 0, cl.ncells_all)
reset_loop.set_body(BlockAST(self.sim,
[cl.cell_sizes[reset_loop.iter()].set(0)]))
fill_loop = ParticleForAST(self.sim)
cell_index = [CastAST.int(self.sim, (positions[fill_loop.iter()][d] - cl.sim.grid_config[d][0]) / cl.spacing) for d in range(0, self.sim.dimensions)]
cell_index = [
CastAST.int(self.sim,
(positions[fill_loop.iter()][d] - cfg[d][0]) /
cl.spacing)
for d in range(0, self.sim.dimensions)]
flat_index = None
for d in range(0, self.sim.dimensions):
flat_index = cell_index[d] if flat_index is None else flat_index * cl.ncells[d] + cell_index[d]
flat_index = (cell_index[d] if flat_index is None
else flat_index * cl.ncells[d] + cell_index[d])
cell_size = cl.cell_sizes[flat_index]
resize = Resize(self.sim, cl.cell_capacity, cl.cell_particles, [reset_loop, fill_loop])
resize = Resize(self.sim, cl.cell_capacity, cl.cell_particles,
[reset_loop, fill_loop])
fill_loop.set_body(BlockAST(self.sim, [
BranchAST.if_stmt(self.sim, ExprAST.and_op(flat_index >= 0, flat_index <= cl.ncells_total), [
resize.check(cell_size, [
cl.cell_particles[flat_index][cell_size].set(fill_loop.iter())
]),
cl.cell_sizes[flat_index].set(cell_size + 1)
])
]))
BranchAST.if_stmt(self.sim, ExprAST.and_op(
flat_index >= 0, flat_index <= cl.ncells_all), [
resize.check(cell_size, [
cl.cell_particles[flat_index][cell_size].set(
fill_loop.iter())]),
cl.cell_sizes[flat_index].set(cell_size + 1)])]))
return resize.block()
......@@ -2,6 +2,7 @@ from ast.assign import AssignAST
from ast.block import BlockAST
from ast.loops import ForAST
class ParticleLattice():
def __init__(self, sim, config, spacing, props, positions):
self.sim = sim
......@@ -12,27 +13,31 @@ class ParticleLattice():
self.positions = positions
def lower(self):
dims = self.sim.dimensions
assignments = []
loops = []
index = None
nparticles = 1
nparticles = 1
for i in range(0, self.sim.dimensions):
n = int((self.config[i][1] - self.config[i][0]) / self.spacing[i] - 0.001) + 1
loops.append(ForAST(self.sim, 0, n))
for i in range(0, dims):
dim_cfg = self.config[i]
n = int((dim_cfg[1] - dim_cfg[0]) / self.spacing[i] - 0.001) + 1
loops.append(ForAST(self.sim, 0, n))
if i > 0:
loops[i - 1].set_body(BlockAST(self.sim, [loops[i]]))
index = loops[i].iter() if index is None else index * n + loops[i].iter()
index = (loops[i].iter() if index is None
else index * n + loops[i].iter())
nparticles *= n
for i in range(0, self.sim.dimensions):
for i in range(0, dims):
pos = self.config[i][0] + self.spacing[i] * loops[i].iter()
assignments.append(AssignAST(self.sim, self.positions[index][i], pos))
assignments.append(
AssignAST(self.sim, self.positions[index][i], pos))
particle_props = self.sim.properties.defaults()
for p in self.props:
particle_props[p] = self.props[p]
loops[self.sim.dimensions - 1].set_body(BlockAST(self.sim, assignments))
loops[dims - 1].set_body(BlockAST(self.sim, assignments))
return (BlockAST(self.sim, loops[0]), nparticles)
from ast.arrays import Arrays
from ast.assign import AssignAST
from ast.block import BlockAST
from ast.branches import BranchAST
from ast.data_types import Type_Int, Type_Float, Type_Vector
from ast.expr import ExprAST
from ast.loops import ForAST, ParticleForAST, NeighborForAST
from ast.loops import ParticleForAST, NeighborForAST
from ast.properties import Properties
from ast.transform import Transform
from ast.variables import Variables
......@@ -13,6 +11,7 @@ from sim.lattice import ParticleLattice
from sim.properties import PropertiesDecl, PropertiesResetVolatile
from sim.timestep import Timestep
class ParticleSimulation:
def __init__(self, code_gen, dims=3, timesteps=100):
self.code_gen = code_gen
......@@ -30,11 +29,11 @@ class ParticleSimulation:
self.expr_id = 0
self.iter_id = 0
def add_real_property(self, prop_name, value=0.0, volatile=False):
return self.properties.add(prop_name, Type_Float, value, volatile)
def add_real_property(self, prop_name, value=0.0, vol=False):
return self.properties.add(prop_name, Type_Float, value, vol)
def add_vector_property(self, prop_name, value=[0.0, 0.0, 0.0], volatile=False):
return self.properties.add(prop_name, Type_Vector, value, volatile)
def add_vector_property(self, prop_name, value=[0.0, 0.0, 0.0], vol=False):
return self.properties.add(prop_name, Type_Vector, value, vol)
def property(self, prop_name):
return self.properties.find(prop_name)
......@@ -64,7 +63,8 @@ class ParticleSimulation:
def create_particle_lattice(self, config, spacing, props={}):
positions = self.property('position')
block, nparticles = ParticleLattice(self, config, spacing, props, positions).lower()
block, nparticles = ParticleLattice(
self, config, spacing, props, positions).lower()
self.setup_blocks.append(block)
self.nparticles += nparticles
......@@ -74,12 +74,14 @@ class ParticleSimulation:
i.set_body(BlockAST(self, [j]))
if cutoff_radius is not None and position is not None:
delta = position[i.iter()] - position[j.iter()]
rsq = delta[0] * delta[0] + delta[1] * delta[1] + delta[2] * delta[2]
dp = position[i.iter()] - position[j.iter()]
rsq = dp[0] * dp[0] + dp[1] * dp[1] + dp[2] * dp[2]
self.start_capture()
yield i.iter(), j.iter(), delta, rsq
yield i.iter(), j.iter(), dp, rsq
self.stop_capture()
j.set_body(BlockAST(self, [BranchAST(self, rsq < cutoff_radius, BlockAST(self, self.capture_buffer.copy()), None)]))
j.set_body(BlockAST(self, [
BranchAST(self, rsq < cutoff_radius,
BlockAST(self, self.capture_buffer.copy()), None)]))
else:
yield i.iter(), j.iter()
......@@ -117,7 +119,9 @@ class ParticleSimulation:
program = BlockAST.merge_blocks(
PropertiesDecl(self).lower(),
BlockAST.merge_blocks(BlockAST.from_list(self, self.setup_blocks), timestep_loop.as_block()))
BlockAST.merge_blocks(
BlockAST.from_list(self, self.setup_blocks),
timestep_loop.as_block()))
program.transform(Transform.flatten)
program.transform(Transform.simplify)
......
......@@ -2,6 +2,7 @@ from ast.block import BlockAST
from ast.properties import PropertyDeclAST
from ast.loops import ParticleForAST
class PropertiesDecl:
def __init__(self, sim):
self.sim = sim
......@@ -13,6 +14,7 @@ class PropertiesDecl:
return BlockAST(self.sim, decls)
class PropertiesResetVolatile:
def __init__(self, sim):
self.sim = sim
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment