diff --git a/src/pairs/coupling/parse_cpp.py b/src/pairs/coupling/parse_cpp.py index 1659115392a89f6b5c12306f5442bed7be3fb24f..cafe9d306dd7ae8fa9cb01eee4dac658567e8ada 100644 --- a/src/pairs/coupling/parse_cpp.py +++ b/src/pairs/coupling/parse_cpp.py @@ -151,7 +151,7 @@ def map_kernel_to_simulation(sim, node): contactNormal = sim.add_var('contactNormal', Type_Vector) penetrationDepth = sim.add_var('penetrationDepth', Type_Float) - self.clear_block() + self.init_block() pairs = ParticleInteraction(sim, 2) for i, j in pairs: return map_method_tree(sim, node, { diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py index 8990ad9850466509a7d7c4318abafcbdc5f1a410..cad9d522cd2f518fb2a226bea77e7bfdeebddc6b 100644 --- a/src/pairs/ir/bin_op.py +++ b/src/pairs/ir/bin_op.py @@ -227,6 +227,9 @@ class ASTTerm(ASTNode): def and_op(self, other): return BinOp(self.sim, self, other, '&&') + def or_op(self, other): + return BinOp(self.sim, self, other, '||') + def cmp(lhs, rhs): return BinOp(lhs.sim, lhs, rhs, '==') diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index 38092da8671f581584218f4b7d506ef3b3f4acf8..212c0a4ca920fab568a4431fd9887c578784403a 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -5,9 +5,9 @@ from pairs.ir.module import Module def pairs_block(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim - sim.clear_block() + sim.init_block() func(*args, **kwargs) - return sim.block + return sim._block return inner @@ -15,9 +15,9 @@ def pairs_block(func): def pairs_device_block(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim - sim.clear_block() + sim.init_block() func(*args, **kwargs) - return Module(sim, block=KernelBlock(sim, sim.block)) + return Module(sim, block=KernelBlock(sim, sim._block)) return inner @@ -26,9 +26,9 @@ def pairs_device_block(func): def pairs_host_block(func): def inner(*args, **kwargs): sim = args[0].sim # self.sim - sim.clear_block() + sim.init_block() func(*args, **kwargs) - return KernelBlock(sim, sim.block, run_on_host=True) + return KernelBlock(sim, sim._block, run_on_host=True) return inner diff --git a/src/pairs/ir/branches.py b/src/pairs/ir/branches.py index f6c3295d8aa8fc639335d926c1abca9001a3c299..aa41913921cc90fd5122b0dbffdd4d7cd7c2b82f 100644 --- a/src/pairs/ir/branches.py +++ b/src/pairs/ir/branches.py @@ -9,10 +9,7 @@ class Branch(ASTNode): self.cond = as_lit_ast(sim, cond) self.switch = True self.block_if = Block(sim, []) if blk_if is None else blk_if - self.block_else = \ - None if one_way \ - else Block(sim, []) if blk_else is None \ - else blk_else + self.block_else = None if one_way else Block(sim, []) if blk_else is None else blk_else def __iter__(self): self.sim.add_statement(self) @@ -33,8 +30,7 @@ class Branch(ASTNode): self.block_else.add_statement(stmt) def children(self): - return [self.cond, self.block_if] + \ - ([] if self.block_else is None else [self.block_else]) + return [self.cond, self.block_if] + ([] if self.block_else is None else [self.block_else]) class Filter(Branch): diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py index 752152a858ca609fdd3d5bca9a183c0422028d6e..70aac81683a1c3258ccc02ffa4e969a589fd5dc9 100644 --- a/src/pairs/ir/module.py +++ b/src/pairs/ir/module.py @@ -7,13 +7,15 @@ from pairs.ir.variables import Var class Module(ASTNode): last_module = 0 - def __init__(self, sim, name=None, block=None): + def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False): super().__init__(sim) self._name = name if name is not None else "module_" + str(Module.last_module) self._variables = {} self._arrays = set() self._properties = set() self._block = block + self._resizes_to_check = resizes_to_check + self._check_properties_resize = check_properties_resize sim.add_module(self) Module.last_module += 1 diff --git a/src/pairs/ir/transform.py b/src/pairs/ir/transform.py deleted file mode 100644 index 4705552594e6ddf4276f84bad8f67a9ad46a16a2..0000000000000000000000000000000000000000 --- a/src/pairs/ir/transform.py +++ /dev/null @@ -1,129 +0,0 @@ -from pairs.ir.arrays import ArrayAccess -from pairs.ir.bin_op import BinOp -from pairs.ir.data_types import Type_Int, Type_Vector -from pairs.ir.layouts import Layout_AoS, Layout_SoA -from pairs.ir.lit import Lit -from pairs.ir.loops import Iter -from pairs.ir.properties import Property - - -class Transform: - reuse_expressions = {} - - def apply(ast, fn): - ast.transform(fn) - Transform.reuse_expressions = {} - - def flatten(ast): - if isinstance(ast, BinOp): - if ast.is_vector_property_access(): - layout = ast.lhs.layout() - - for i in ast.vector_indexes(): - flat_index = None - - if layout == Layout_AoS: - flat_index = ast.rhs * ast.sim.ndims() + i - - elif layout == Layout_SoA: - flat_index = i * ast.sim.particle_capacity + ast.rhs - - else: - raise Exception("Invalid property layout!") - - ast.map_vector_index(i, flat_index) - - return ast - - def simplify(ast): - if isinstance(ast, BinOp): - sim = ast.lhs.sim - - if ast.op in ['+', '-'] and ast.rhs == 0: - return ast.lhs - - if ast.op in ['+'] and ast.lhs == 0: - return ast.rhs - - if ast.op in ['*', '/'] and ast.rhs == 1: - return ast.lhs - - if ast.op == '*' and ast.lhs == 1: - return ast.rhs - - if ast.op == '*' and ast.lhs == 0: - return Lit(sim, 0 if ast.type() == Type_Int else 0.0) - - return ast - - def reuse_index_expressions(ast): - if isinstance(ast, BinOp): - iter_id = None - - if isinstance(ast.lhs, Iter): - iter_id = ast.lhs.iter_id - - if isinstance(ast.rhs, Iter): - iter_id = ast.rhs.iter_id - - if iter_id is not None: - if iter_id in Transform.reuse_expressions: - item = [e for e in Transform.reuse_expressions[iter_id] - if ast.match(e)] - if item: - return item[0] - - else: - Transform.reuse_expressions[iter_id] = [] - - Transform.reuse_expressions[iter_id].append(ast) - - return ast - - def reuse_expr_expressions(ast): - if isinstance(ast, BinOp): - expr_id = None - - if isinstance(ast.lhs, BinOp): - expr_id = ast.lhs.expr_id - - if isinstance(ast.rhs, BinOp): - expr_id = ast.rhs.expr_id - - if expr_id is not None: - if expr_id in Transform.reuse_expressions: - item = [e for e in Transform.reuse_expressions[expr_id] - if ast.match(e)] - if item: - return item[0] - - else: - Transform.reuse_expressions[expr_id] = [] - - Transform.reuse_expressions[expr_id].append(ast) - - return ast - - def reuse_array_access_expressions(ast): - if isinstance(ast, BinOp): - acc_id = None - - if isinstance(ast.lhs, ArrayAccess): - acc_id = ast.lhs.acc_id - - if isinstance(ast.rhs, ArrayAccess): - acc_id = ast.rhs.acc_id - - if acc_id is not None: - if acc_id in Transform.reuse_expressions: - item = [e for e in Transform.reuse_expressions[acc_id] - if ast.match(e)] - if item: - return item[0] - - else: - Transform.reuse_expressions[acc_id] = [] - - Transform.reuse_expressions[acc_id].append(ast) - - return ast diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py index 57bb43c9f6f74df26be4c64fcbe480565ac0f961..164a3de70e6a52d300e2bfe5c7fd3e72f56d8c82 100644 --- a/src/pairs/mapping/funcs.py +++ b/src/pairs/mapping/funcs.py @@ -136,7 +136,7 @@ def compute(sim, func, cutoff_radius=None, symbols={}): ir = BuildParticleIR(sim, symbols) assert nparams > 0, "Number of parameters from compute functions must be higher than zero!" - sim.clear_block() + sim.init_block() if nparams == 1: for i in ParticleFor(sim): ir.add_symbols({params[0]: i}) diff --git a/src/pairs/sim/cell_lists.py b/src/pairs/sim/cell_lists.py index df6f2dc4fcfd2218396b18cddb4d0b975efa1864..82a9f7e6658daf5e7749469318f620877468cb8a 100644 --- a/src/pairs/sim/cell_lists.py +++ b/src/pairs/sim/cell_lists.py @@ -44,15 +44,14 @@ class CellListsStencilBuild(Lowerable): index = None nall = 1 + sim.module_name("build_cell_lists_stencil") + sim.check_resize(cl.ncells_capacity, cl.ncells) + for d in range(sim.ndims()): cl.dim_ncells[d].set(Ceil(sim, (grid.max(d) - grid.min(d)) / cl.spacing[d]) + 2) nall *= cl.dim_ncells[d] cl.ncells.set(nall) - for resize in Resize(sim, cl.ncells_capacity): - for _ in Filter(sim, cl.ncells >= cl.ncells_capacity): - resize.set(cl.ncells) - for _ in sim.nest_mode(): cl.nstencil.set(0) for d in range(sim.ndims()): @@ -75,28 +74,24 @@ class CellListsBuild(Lowerable): cl = self.cell_lists grid = sim.grid positions = sim.position() - - for resize in Resize(sim, cl.cell_capacity): - for c in For(sim, 0, cl.ncells): - cl.cell_sizes[c].set(0) - - for i in ParticleFor(sim, local_only=False): - cell_index = [ - Cast.int(sim, (positions[i][d] - grid.min(d)) / cl.spacing[d]) - for d in range(0, sim.ndims())] - - flat_idx = None - for d in range(0, sim.ndims()): - flat_idx = (cell_index[d] if flat_idx is None - else flat_idx * cl.dim_ncells[d] + cell_index[d]) - - cell_size = cl.cell_sizes[flat_idx] - for _ in Filter(sim, BinOp.and_op(flat_idx >= 0, flat_idx <= cl.ncells)): - for cond in Branch(sim, cell_size >= cl.cell_capacity): - if cond: - resize.set(cell_size) - else: - cl.cell_particles[flat_idx][cell_size].set(i) - cl.particle_cell[i].set(flat_idx) - - cl.cell_sizes[flat_idx].set(cell_size + 1) + sim.module_name("build_cell_lists") + sim.check_resize(cl.cell_capacity, cl.cell_sizes) + + for c in For(sim, 0, cl.ncells): + cl.cell_sizes[c].set(0) + + for i in ParticleFor(sim, local_only=False): + cell_index = [ + Cast.int(sim, (positions[i][d] - grid.min(d)) / cl.spacing[d]) + for d in range(0, sim.ndims())] + + flat_idx = None + for d in range(0, sim.ndims()): + flat_idx = (cell_index[d] if flat_idx is None + else flat_idx * cl.dim_ncells[d] + cell_index[d]) + + cell_size = cl.cell_sizes[flat_idx] + for _ in Filter(sim, BinOp.and_op(flat_idx >= 0, flat_idx <= cl.ncells)): + cl.particle_cell[i].set(flat_idx) + cl.cell_particles[flat_idx][cell_size].set(i) + cl.cell_sizes[flat_idx].set(cell_size + 1) diff --git a/src/pairs/sim/neighbor_lists.py b/src/pairs/sim/neighbor_lists.py index d55a1122786a53ef9f9b907cb0144dcb3d3cc3e1..54bcc2e0fbf6a0c6a8c00c8a838a3f065da1e4ac 100644 --- a/src/pairs/sim/neighbor_lists.py +++ b/src/pairs/sim/neighbor_lists.py @@ -29,17 +29,13 @@ class NeighborListsBuild(Lowerable): cell_lists = neighbor_lists.cell_lists cutoff_radius = cell_lists.cutoff_radius position = sim.position() + sim.module_name("neighbor_lists_build") + sim.check_resize(neighbor_lists.capacity, neighbor_lists.numneighs) - for resize in Resize(sim, neighbor_lists.capacity): - for i in ParticleFor(sim): - neighbor_lists.numneighs[i].set(0) + for i in ParticleFor(sim): + neighbor_lists.numneighs[i].set(0) - pairs = ParticleInteraction(sim, 2, cutoff_radius, bypass_neighbor_lists=True) - for i, j in pairs: - numneighs = neighbor_lists.numneighs[i] - for cond in Branch(sim, numneighs >= neighbor_lists.capacity): - if cond: - resize.set(numneighs) - else: - neighbor_lists.neighborlists[i][numneighs].set(j) - neighbor_lists.numneighs[i].set(numneighs + 1) + for i, j in ParticleInteraction(sim, 2, cutoff_radius, bypass_neighbor_lists=True): + numneighs = neighbor_lists.numneighs[i] + neighbor_lists.neighborlists[i][numneighs].set(j) + neighbor_lists.numneighs[i].set(numneighs + 1) diff --git a/src/pairs/sim/pbc.py b/src/pairs/sim/pbc.py index 87b8aadb34007759fe3465691e7f44daa97d19a1..90e76d77a11909024e8744078d3e3161a51f4500 100644 --- a/src/pairs/sim/pbc.py +++ b/src/pairs/sim/pbc.py @@ -35,6 +35,7 @@ class UpdatePBC(Lowerable): pbc_mult = self.pbc.pbc_mult positions = self.pbc.sim.position() nlocal = self.pbc.sim.nlocal + sim.module_name("update_pbc") for i in For(sim, 0, npbc): # TODO: allow syntax: @@ -54,6 +55,7 @@ class EnforcePBC(Lowerable): ndims = sim.ndims() grid = self.pbc.grid positions = sim.position() + sim.module_name("enforce_pbc") for i in ParticleFor(sim): # TODO: VecFilter? @@ -82,39 +84,32 @@ class SetupPBC(Lowerable): pbc_mult = self.pbc.pbc_mult positions = self.pbc.sim.position() nlocal = self.pbc.sim.nlocal - - for resize in Resize(sim, pbc_capacity): - npbc.set(0) - for d in range(0, ndims): - for i in For(sim, 0, nlocal + npbc): - last_id = nlocal + npbc - # TODO: VecFilter? - for _ in Filter(sim, positions[i][d] < grid.min(d) + cutneigh): - for capacity_exceeded in Branch(sim, npbc >= pbc_capacity): - if capacity_exceeded: - resize.set(Select(sim, resize > npbc, resize + 1, npbc)) - else: - pbc_map[npbc].set(i) - pbc_mult[npbc][d].set(1) - positions[last_id][d].set(positions[i][d] + grid.length(d)) - - for d_ in [x for x in range(0, ndims) if x != d]: - pbc_mult[npbc][d_].set(0) - positions[last_id][d_].set(positions[i][d_]) - - npbc.add(1) - - for _ in Filter(sim, positions[i][d] > grid.max(d) - cutneigh): - for capacity_exceeded in Branch(sim, npbc >= pbc_capacity): - if capacity_exceeded: - resize.set(Select(sim, resize > npbc, resize + 1, npbc)) - else: - pbc_map[npbc].set(i) - pbc_mult[npbc][d].set(-1) - positions[last_id][d].set(positions[i][d] - grid.length(d)) - - for d_ in [x for x in range(0, ndims) if x != d]: - pbc_mult[npbc][d_].set(0) - positions[last_id][d_].set(positions[i][d_]) - - npbc.add(1) + sim.module_name("setup_pbc") + sim.check_resize(pbc_capacity, npbc) + + npbc.set(0) + for d in range(0, ndims): + for i in For(sim, 0, nlocal + npbc): + last_id = nlocal + npbc + # TODO: VecFilter? + for _ in Filter(sim, positions[i][d] < grid.min(d) + cutneigh): + pbc_map[npbc].set(i) + pbc_mult[npbc][d].set(1) + positions[last_id][d].set(positions[i][d] + grid.length(d)) + + for d_ in [x for x in range(0, ndims) if x != d]: + pbc_mult[npbc][d_].set(0) + positions[last_id][d_].set(positions[i][d_]) + + npbc.add(1) + + for _ in Filter(sim, positions[i][d] > grid.max(d) - cutneigh): + pbc_map[npbc].set(i) + pbc_mult[npbc][d].set(-1) + positions[last_id][d].set(positions[i][d] - grid.length(d)) + + for d_ in [x for x in range(0, ndims) if x != d]: + pbc_mult[npbc][d_].set(0) + positions[last_id][d_].set(positions[i][d_]) + + npbc.add(1) diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 734a7b181b3181203a1941f73ad64f254233bcf4..8d6717b77e8843df65d947720527fc3b691d4ba5 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -21,28 +21,28 @@ from pairs.sim.timestep import Timestep from pairs.sim.variables import VariablesDecl from pairs.sim.vtk import VTKWrite from pairs.transformations.add_device_copies import add_device_copies -from pairs.transformations.fetch_modules_references import fetch_modules_references from pairs.transformations.prioritize_scalar_ops import prioritize_scalar_ops from pairs.transformations.set_used_bin_ops import set_used_bin_ops from pairs.transformations.simplify import simplify_expressions from pairs.transformations.LICM import move_loop_invariant_code from pairs.transformations.lower import lower_everything from pairs.transformations.merge_adjacent_blocks import merge_adjacent_blocks -from pairs.transformations.replace_modules_by_calls import replace_modules_by_calls +from pairs.transformations.modules import modularize from pairs.transformations.replace_symbols import replace_symbols class Simulation: - def __init__(self, code_gen, dims=3, timesteps=100): + def __init__(self, code_gen, dims=3, timesteps=100, particle_capacity=10000): self.code_gen = code_gen self.code_gen.assign_simulation(self) self.position_prop = None self.properties = Properties(self) self.vars = Variables(self) self.arrays = Arrays(self) - self.particle_capacity = self.add_var('particle_capacity', Type_Int, 10000) + self.particle_capacity = self.add_var('particle_capacity', Type_Int, particle_capacity) self.nlocal = self.add_var('nlocal', Type_Int) self.nghost = self.add_var('nghost', Type_Int) + self.resizes = self.add_array('resizes', 3, Type_Int) self.grid = None self.cell_lists = None self.neighbor_lists = None @@ -51,15 +51,17 @@ class Simulation: self.nested_count = 0 self.nest = False self.check_decl_usage = True - self.block = Block(self, []) + self._block = Block(self, []) self.setups = Block(self, []) self.kernels = Block(self, []) self.module_list = [] + self._check_properties_resize = False + self._resizes_to_check = {} + self._module_name = None self.dims = dims self.ntimesteps = timesteps self.expr_id = 0 self.iter_id = 0 - self.temp_id = 1000 self.vtk_file = None self.nparticles = self.nlocal + self.nghost self.properties.add_capacity(self.particle_capacity) @@ -108,14 +110,6 @@ class Simulation: assert self.var(var_name) is None, f"Variable already defined: {var_name}" return self.vars.add(var_name, var_type, init_value) - def add_or_reuse_var(self, var_name, var_type, init_value=0): - existing_var = self.var(var_name) - if existing_var is not None: - assert existing_var.type() == var_type, f"Cannot reuse variable {var_name}: types differ!" - return existing_var - - return self.vars.add(var_name, var_type, init_value) - def add_symbol(self, sym_type): return Symbol(self, sym_type) @@ -157,15 +151,36 @@ class Simulation: def compute(self, func, cutoff_radius=None, symbols={}): return compute(self, func, cutoff_radius, symbols) - def clear_block(self): - self.block = Block(self, []) + def init_block(self): + self._block = Block(self, []) + self._check_properties_resize = False + self._resizes_to_check = {} + self._module_name = None + + def module_name(self, name): + self._module_name = name + + def check_properties_resize(self): + self._check_properties_resize = True + + def check_resize(self, capacity, size): + size_array = [size] if not isinstance(size, list) else size + if capacity not in self._resizes_to_check: + self._resizes_to_check[capacity] = size_array + else: + self._resizes_to_check[capacity] += size_array def build_kernel_block_with_statements(self): - self.kernels.add_statement(Module(self, block=KernelBlock(self, self.block))) + self.kernels.add_statement( + Module(self, + name=self._module_name, + block=KernelBlock(self, self._block), + resizes_to_check=self._resizes_to_check, + check_properties_resize=self._check_properties_resize)) def add_statement(self, stmt): if not self.scope: - self.block.add_statement(stmt) + self._block.add_statement(stmt) else: self.scope[-1].add_statement(stmt) @@ -192,6 +207,7 @@ class Simulation: self.vtk_file = filename def generate(self): + # For timestep in Timestep(self): timestep = Timestep(self, self.ntimesteps, [ (EnforcePBC(self, self.pbc), 20), (SetupPBC(self, self.pbc), UpdatePBC(self, self.pbc), 20), @@ -225,10 +241,9 @@ class Simulation: prioritize_scalar_ops(program) simplify_expressions(program) move_loop_invariant_code(program) - fetch_modules_references(program) set_used_bin_ops(program) + modularize(program) add_device_copies(program) - replace_modules_by_calls(program) # For this part on, all bin ops are generated without usage verification self.check_decl_usage = False diff --git a/src/pairs/transformations/fetch_modules_references.py b/src/pairs/transformations/fetch_modules_references.py deleted file mode 100644 index d9fa43d92fdaaf7a714ce2c47b323ed9dffbd973..0000000000000000000000000000000000000000 --- a/src/pairs/transformations/fetch_modules_references.py +++ /dev/null @@ -1,63 +0,0 @@ -from pairs.ir.module import Module -from pairs.ir.mutator import Mutator -from pairs.ir.variables import Deref -from pairs.ir.visitor import Visitor - - -class FetchModulesReferences(Visitor): - def __init__(self, ast): - super().__init__(ast) - self.module_stack = [] - self.writing = False - - def visit_Assign(self, ast_node): - self.writing = True - for c in ast_node.destinations(): - self.visit(c) - - self.writing = False - for c in ast_node.sources(): - self.visit(c) - - def visit_Module(self, ast_node): - self.module_stack.append(ast_node) - self.visit_children(ast_node) - self.module_stack.pop() - - def visit_Array(self, ast_node): - for m in self.module_stack: - m.add_array(ast_node) - - def visit_Property(self, ast_node): - for m in self.module_stack: - m.add_property(ast_node) - - def visit_Var(self, ast_node): - for m in self.module_stack: - m.add_variable(ast_node, self.writing) - - -class AddDereferencesToWriteVariables(Mutator): - def __init__(self, ast): - super().__init__(ast) - self.module_stack = [] - - def mutate_Module(self, ast_node): - self.module_stack.append(ast_node) - ast_node._block = self.mutate(ast_node._block) - self.module_stack.pop() - return ast_node - - def mutate_Var(self, ast_node): - parent_module = self.module_stack[-1] - if parent_module.name != 'main' and ast_node in parent_module.write_variables(): - return Deref(ast_node.sim, ast_node) - - return ast_node - - -def fetch_modules_references(ast): - fetch_refs = FetchModulesReferences(ast) - fetch_refs.visit() - add_derefs_to_write_vars = AddDereferencesToWriteVariables(ast) - add_derefs_to_write_vars.mutate() diff --git a/src/pairs/transformations/modules.py b/src/pairs/transformations/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..8a34cec54af427cced9c68be3d78b532ffc4509e --- /dev/null +++ b/src/pairs/transformations/modules.py @@ -0,0 +1,190 @@ +from pairs.ir.bin_op import BinOp +from pairs.ir.branches import Branch +from pairs.ir.module import Module, Module_Call +from pairs.ir.mutator import Mutator +from pairs.ir.variables import Deref +from pairs.ir.visitor import Visitor + + +class FetchModulesReferences(Visitor): + def __init__(self, ast): + super().__init__(ast) + self.module_stack = [] + self.writing = False + + def visit_Assign(self, ast_node): + self.writing = True + for c in ast_node.destinations(): + self.visit(c) + + self.writing = False + for c in ast_node.sources(): + self.visit(c) + + def visit_Module(self, ast_node): + self.module_stack.append(ast_node) + self.visit_children(ast_node) + self.module_stack.pop() + + def visit_Array(self, ast_node): + for m in self.module_stack: + m.add_array(ast_node) + + def visit_Property(self, ast_node): + for m in self.module_stack: + m.add_property(ast_node) + + def visit_Var(self, ast_node): + for m in self.module_stack: + m.add_variable(ast_node, self.writing) + + +class AddDereferencesToWriteVariables(Mutator): + def __init__(self, ast): + super().__init__(ast) + self.module_stack = [] + + def mutate_Module(self, ast_node): + self.module_stack.append(ast_node) + ast_node._block = self.mutate(ast_node._block) + self.module_stack.pop() + return ast_node + + def mutate_Var(self, ast_node): + parent_module = self.module_stack[-1] + if parent_module.name != 'main' and ast_node in parent_module.write_variables(): + return Deref(ast_node.sim, ast_node) + + return ast_node + + +class AddResizeLogic(Mutator): + def __init__(self, ast): + super().__init__(ast) + self.block_stack = [] + self.module_stack = [] + self.module_resizes = {} + self.resizes_to_check = {} + self.check_properties_resize = False + self.match_capacity = None + self.update = {} + self.resize_buffers = {} + self.nresize_buffers = 0 + + def mutate_Array(self, ast_node): + for capacity, size in self.resizes_to_check.items(): + if size == ast_node.name(): + self.match_capacity = capacity + + return ast_node + + def mutate_Assignment(self, ast_node): + for dest, src in ast_node.assignments.items(): + if isinstance(dest, ArrayAccess): + self.match_capacity = None + ast_node.indexes = [self.mutate(i) for i in ast_node.indexes] + if ast_node.index is not None: + ast_node.index = self.mutate(ast_node.index) + + # Resize var is used in index, this statement should be checked for safety + if self.match_capacity is not None: + size = self.resizes_to_check[match_capacity] + check_value = self.update[size] if size in self.update else size + resize_id = self.resize_buffers[match_capacity] + return Branch(ast_node.sim, check_value < match_capacity, + Block(ast_node.sim, ast_node), + Block(ast_node.sim, ast_node.resizes[resize_id].set(check_value))) + + # Size is changed here, assigned value must be used for further checkings + for capacity, size in self.resizes_to_check.items(): + if size == dest.array.name(): + self.update[size] = src + + if isinstance(dest, Var): + # Size is changed here, assigned value must be used for further checkings + for capacity, size in self.resizes_to_check.items(): + if size == dest.name(): + self.update[size] = src + + return ast_node + + def mutate_Block(self, ast_node): + self.block_stack.append(ast_node) + ast_node.stmts = [self.mutate(s) for s in ast_node.stmts] + self.block_stack.pop() + return ast_node + + def mutate_Module(self, ast_node): + # Save current state + saved_resizes_to_check = self.resizes_to_check + saved_check_properties_resize = self.check_properties_resize + saved_update = self.update + saved_resize_buffers = self.resize_buffers + saved_nresize_buffers = self.nresize_buffers + + # Update state and keep traversing tree + self.module_resizes[ast_node] = [] + self.module_stack.append(ast_node) + for capacity in ast_node._resizes_to_check.keys(): + self.module_resizes[ast_node].append(self.nresize_buffers) + self.resize_buffers[capacity] = self.nresize_buffers + self.nresize_buffers += 1 + + self.resizes_to_check = ast_node._resizes_to_check + self.check_properties_resize = ast_node._check_properties_resize + self.update = {} + ast_node._block = self.mutate(ast_node._block) + self.module_stack.pop() + + # Restore saved state + self.resizes_to_check = saved_resizes_to_check + self.check_properties_resize = saved_check_properties_resize + self.update = saved_update + self.resize_buffers = saved_resize_buffers + self.nresize_buffers = saved_nresize_buffers + return ast_node + + def mutate_Var(self, ast_node): + for capacity, size in self.resizes_to_check.items(): + if size == ast_node.name(): + self.match_capacity = capacity + + return ast_node + + +class ReplaceModulesByCalls(Mutator): + def __init__(self, ast, module_resizes): + super().__init__(ast) + self.module_resizes = module_resizes + + def mutate_Module(self, ast_node): + ast_node._block = self.mutate(ast_node._block) + if ast_node.name == 'main': + return ast_node + + call = Module_Call(ast_node.sim, ast_node) + if self.module_resizes[ast_node]: + init_stmts = [] + reset_stmts = [] + branch_cond = None + + for r in self.module_resizes[ast_node]: + init_stmts.append(Assign(ast_node.resizes[r], 1)) + reset_stmts.append(Assign(ast_node.resizes[r], 0)) + cond = ast_node.resizes[r] > 0 + branch_cond = cond if branch_cond is None else BinOp.or_op(cond, branch_cond) + + return Block(ast_node.sim, init_stmts + Filter(ast_node.sim, branch_cond, reset_stmts + [call])) + + return call + + +def modularize(ast): + add_resize_logic = AddResizeLogic(ast) + add_resize_logic.mutate() + fetch_refs = FetchModulesReferences(ast) + fetch_refs.visit() + add_derefs_to_write_vars = AddDereferencesToWriteVariables(ast) + add_derefs_to_write_vars.mutate() + replace = ReplaceModulesByCalls(ast, add_resize_logic.module_resizes) + replace.mutate() diff --git a/src/pairs/transformations/replace_modules_by_calls.py b/src/pairs/transformations/replace_modules_by_calls.py deleted file mode 100644 index 1e81694b936fb9783ca86903cf890b9f420a87ae..0000000000000000000000000000000000000000 --- a/src/pairs/transformations/replace_modules_by_calls.py +++ /dev/null @@ -1,16 +0,0 @@ -from pairs.ir.module import Module_Call -from pairs.ir.mutator import Mutator - - -class ReplaceModulesByCalls(Mutator): - def __init__(self, ast): - super().__init__(ast) - - def mutate_Module(self, ast_node): - ast_node._block = self.mutate(ast_node._block) - return Module_Call(ast_node.sim, ast_node) if ast_node.name != 'main' else ast_node - - -def replace_modules_by_calls(ast): - replace = ReplaceModulesByCalls(ast) - replace.mutate()