diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index d3501b0011a3ae64b0ce2f469c14c1ed707814ed..bc3d1c2c0248e72e6f61f533529717d8688f4b25 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -12,11 +12,12 @@ from pairs.ir.lit import Lit from pairs.ir.loops import For, Iter, ParticleFor, While from pairs.ir.math import Ceil, Sqrt from pairs.ir.memory import Malloc, Realloc +from pairs.ir.module import Module_Call from pairs.ir.properties import Property, PropertyAccess, PropertyList, RegisterProperty, UpdateProperty from pairs.ir.select import Select from pairs.ir.sizeof import Sizeof from pairs.ir.utils import Print -from pairs.ir.variables import Var, VarDecl +from pairs.ir.variables import Var, VarDecl, Deref from pairs.sim.timestep import Timestep from pairs.code_gen.printer import Printer @@ -51,12 +52,46 @@ class CGen: self.print("") self.print("using namespace pairs;") self.print("") - self.print("int main() {") - self.print(" PairsSim *ps = new PairsSim();") - self.generate_statement(ast_node) - self.print("}") + for module in self.sim.modules(): + self.generate_module(module) self.print.end() + def generate_module(self, module): + if module.name == 'main': + self.print("int main() {") + self.print(" PairsSim *ps = new PairsSim();") + self.generate_statement(module.block) + self.print(" return 0;") + self.print("}") + + else: + module_params = "" + for var in module.read_only_variables(): + type_kw = CGen.type2keyword(var.type()) + decl = f"{type_kw} {var.name()}" + module_params += decl if len(module_params) <= 0 else f", {decl}" + + for var in module.write_variables(): + type_kw = CGen.type2keyword(var.type()) + decl = f"{type_kw} *{var.name()}" + module_params += decl if len(module_params) <= 0 else f", {decl}" + + for array in module.arrays(): + type_kw = CGen.type2keyword(array.type()) + decl = f"{type_kw} *{array.name()}" + module_params += decl if len(module_params) <= 0 else f", {decl}" + + for prop in module.properties(): + type_kw = CGen.type2keyword(prop.type()) + decl = f"{type_kw} *{prop.name()}" + module_params += decl if len(module_params) <= 0 else f", {decl}" + + self.print(f"void {module.name}({module_params}) {{") + self.print.add_indent(4) + self.generate_statement(module.block) + self.print.add_indent(-4) + self.print("}") + def generate_statement(self, ast_node, bypass_checking=False): if isinstance(ast_node, ArrayDecl): tkw = CGen.type2keyword(ast_node.array.type()) @@ -70,10 +105,10 @@ class CGen: self.print(f"{dest} = {src};") if isinstance(ast_node, Block): - self.print.add_ind(4) + self.print.add_indent(4) for stmt in ast_node.statements(): self.generate_statement(stmt) - self.print.add_ind(-4) + self.print.add_indent(-4) # TODO: Why there are Decls for other types? if isinstance(ast_node, Decl): @@ -133,9 +168,9 @@ class CGen: self.print(f"pairs::copy_to_device({ast_node.prop.name()})") if isinstance(ast_node, KernelBlock): - self.print.add_ind(-4) + self.print.add_indent(-4) self.generate_statement(ast_node.block) - self.print.add_ind(4) # Workaround for fixing indentation of kernels + self.print.add_indent(4) # Workaround for fixing indentation of kernels if isinstance(ast_node, For): iterator = self.generate_expression(ast_node.iterator) @@ -166,6 +201,26 @@ class CGen: else: self.print(f"{array_name} = ({tkw} *) malloc({size});") + if isinstance(ast_node, Module_Call): + module_params = "" + for var in module.read_only_variables(): + decl = var.name() + module_params += decl if len(module_params) <= 0 else f", {decl}" + + for var in module.write_variables(): + decl = f"&{var.name()}" + module_params += decl if len(module_params) <= 0 else f", {decl}" + + for array in module.arrays(): + decl = array.name() + module_params += decl if len(module_params) <= 0 else f", {decl}" + + for prop in module.properties(): + decl = prop.name() + module_params += decl if len(module_params) <= 0 else f", {decl}" + + self.print(f"{module.name}({module_params});") + if isinstance(ast_node, Print): self.print(f"fprintf(stdout, \"{ast_node.string}\\n\");") self.print(f"fflush(stdout);") @@ -269,6 +324,10 @@ class CGen: expr = self.generate_expression(ast_node.expr) return f"ceil({expr})" + if isinstance(ast_node, Deref): + var = self.generate_expression(ast_node.var) + return f"*{var}" + if isinstance(ast_node, Iter): assert mem is False, "Iterator is not lvalue!" return f"i{ast_node.id()}" diff --git a/src/pairs/code_gen/printer.py b/src/pairs/code_gen/printer.py index 95299469b67fe70a6cb89c308a8a91990ee5cb0a..4d73b7679f7f8c52ff1878339dbc0479cb9db215 100644 --- a/src/pairs/code_gen/printer.py +++ b/src/pairs/code_gen/printer.py @@ -4,7 +4,7 @@ class Printer: self.stream = None self.indent = 0 - def add_ind(self, offset): + def add_indent(self, offset): self.indent += offset def start(self): diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py index 49aeb88718aad1ff486d5209a89e18e4424b3ab0..a8721b21b7e1207e5a82d7a0a2a77f1bb46fa97d 100644 --- a/src/pairs/ir/arrays.py +++ b/src/pairs/ir/arrays.py @@ -160,18 +160,6 @@ class ArrayAccess(ASTTerm): return self.array.type() # return self.array.type() if self.index is None else Type_Array - def scope(self): - if self.index is None: - scope = None - for i in self.indexes: - index_scp = i.scope() - if scope is None or index_scp > scope: - scope = index_scp - - return scope - - return self.index.scope() - def children(self): if self.index is not None: return [self.array, self.index] diff --git a/src/pairs/ir/ast_node.py b/src/pairs/ir/ast_node.py index 94e75fa1b260d4b0d6456d1b1c8c147102398873..71ce1c1a353f0f0aa6df78e2d8a511c674d6eb87 100644 --- a/src/pairs/ir/ast_node.py +++ b/src/pairs/ir/ast_node.py @@ -12,8 +12,5 @@ class ASTNode: def type(self): return Type_Invalid - def scope(self): - return self.sim.global_scope - def children(self): return [] diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py index 95f48dc000aa6b6129b8ba1c420be61bba1d5317..8990ad9850466509a7d7c4318abafcbdc5f1a410 100644 --- a/src/pairs/ir/bin_op.py +++ b/src/pairs/ir/bin_op.py @@ -42,7 +42,6 @@ class BinOp(VectorExpression): self.inlined = False self.generated = False self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op) - self.bin_op_scope = None self.terminals = set() self.decl = Decl(sim, self) @@ -132,14 +131,6 @@ class BinOp(VectorExpression): def add_terminal(self, terminal): self.terminals.add(terminal) - def scope(self): - if self.bin_op_scope is None: - lhs_scp = self.lhs.scope() - rhs_scp = self.rhs.scope() - self.bin_op_scope = lhs_scp if lhs_scp > rhs_scp else rhs_scp - - return self.bin_op_scope - def children(self): return [self.lhs, self.rhs] + list(super().children()) diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py index b9c04b829c14dc76ee5e56a869075333ebab2840..4789b3852603d41b5b002967e84da4d9edb63dd3 100644 --- a/src/pairs/ir/block.py +++ b/src/pairs/ir/block.py @@ -1,4 +1,5 @@ from pairs.ir.ast_node import ASTNode +from pairs.ir.module import Module def pairs_block(func): @@ -16,7 +17,9 @@ def pairs_device_block(func): sim = args[0].sim # self.sim sim.clear_block() func(*args, **kwargs) - return KernelBlock(sim, sim.block) + module = Module(sim, block=KernelBlock(sim, sim.block)) + sim.add_module(module) + return module return inner diff --git a/src/pairs/ir/branches.py b/src/pairs/ir/branches.py index 75bbfe7f8afe166ed7ec525f50b30931504bbfa1..f6c3295d8aa8fc639335d926c1abca9001a3c299 100644 --- a/src/pairs/ir/branches.py +++ b/src/pairs/ir/branches.py @@ -17,14 +17,14 @@ class Branch(ASTNode): def __iter__(self): self.sim.add_statement(self) self.switch = True - self.sim.enter_scope(self) + self.sim.enter(self) yield self.switch - self.sim.leave_scope() + self.sim.leave() self.switch = False - self.sim.enter_scope(self) + self.sim.enter(self) yield self.switch - self.sim.leave_scope() + self.sim.leave() def add_statement(self, stmt): if self.switch: @@ -43,9 +43,9 @@ class Filter(Branch): def __iter__(self): self.sim.add_statement(self) - self.sim.enter_scope(self) + self.sim.enter(self) yield - self.sim.leave_scope() + self.sim.leave() def add_statement(self, stmt): self.block_if.add_statement(stmt) diff --git a/src/pairs/ir/cast.py b/src/pairs/ir/cast.py index 061eb47e88f80381cd0cd7030b210381c4b741d4..fe6674ae8d07ef55b69ab6a9210b4d9ebde18978 100644 --- a/src/pairs/ir/cast.py +++ b/src/pairs/ir/cast.py @@ -20,8 +20,5 @@ class Cast(ASTTerm): def type(self): return self.cast_type - def scope(self): - return self.expr.scope() - def children(self): return [self.expr] diff --git a/src/pairs/ir/functions.py b/src/pairs/ir/functions.py index 8b01b90856327a0423484bb9b5305d10a637de07..0d30f494ef70bc6e6f949481c41a9558d40e5d21 100644 --- a/src/pairs/ir/functions.py +++ b/src/pairs/ir/functions.py @@ -22,6 +22,7 @@ class Call(ASTTerm): def children(self): return self.params + class Call_Int(Call): def __init__(self, sim, func_name, parameters): super().__init__(sim, func_name, parameters, Type_Int) diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py index 9f187393a1c3b7cda0fc761be64a1739cd85e2af..099c5c432f13d4389ece066b953dafc377389981 100644 --- a/src/pairs/ir/loops.py +++ b/src/pairs/ir/loops.py @@ -27,9 +27,6 @@ class Iter(ASTTerm): def type(self): return Type_Int - def scope(self): - return self.loop.block - def __eq__(self, other): if isinstance(other, Iter): return self.iter_id == other.iter_id @@ -59,9 +56,9 @@ class For(ASTNode): def __iter__(self): self.sim.add_statement(self) - self.sim.enter_scope(self) + self.sim.enter(self) yield self.iterator - self.sim.leave_scope() + self.sim.leave() def add_statement(self, stmt): self.block.add_statement(stmt) @@ -90,9 +87,9 @@ class While(ASTNode): def __iter__(self): self.sim.add_statement(self) - self.sim.enter_scope(self) + self.sim.enter(self) yield - self.sim.leave_scope() + self.sim.leave() def add_statement(self, stmt): self.block.add_statement(stmt) diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py index 9ef762f1f75d4071baf651d9b5f288dbfe0447a7..09e41713a16ff157841a3d607ecfc65d4093a676 100644 --- a/src/pairs/ir/math.py +++ b/src/pairs/ir/math.py @@ -13,9 +13,6 @@ class Sqrt(ASTTerm): def type(self): return self.expr.type() - def scope(self): - return self.expr.scope() - def children(self): return [self.expr] @@ -32,8 +29,5 @@ class Ceil(ASTTerm): def type(self): return Type_Int - def scope(self): - return self.expr.scope() - def children(self): return [self.expr] diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py new file mode 100644 index 0000000000000000000000000000000000000000..900e565869b82f35ec1ae6fb1b81fa44fdd12fd0 --- /dev/null +++ b/src/pairs/ir/module.py @@ -0,0 +1,72 @@ +from pairs.ir.arrays import Array +from pairs.ir.ast_node import ASTNode +from pairs.ir.properties import Property +from pairs.ir.variables import Var + + +class Module(ASTNode): + last_module = 0 + + def __init__(self, sim, name=None, block=None): + super().__init__(sim) + self._name = name if name is not None else "module_" + str(Module.last_module) + self._variables = {} + self._arrays = set() + self._properties = set() + self._block = block + sim.add_module(self) + Module.last_module += 1 + + @property + def name(self): + return self._name + + @property + def block(self): + return self._block + + def variables(self): + return self._variables + + def read_only_variables(self): + return [v for v in self._variables if not self._variables[v]] + + def write_variables(self): + return [v for v in self._variables if self._variables[v]] + + def arrays(self): + return self._arrays + + def properties(self): + return self._properties + + def add_array(self, array, write=False): + array_list = array if isinstance(array, list) else [array] + for a in array_list: + assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!" + self._arrays.add(a) + + def add_variable(self, variable, write=False): + variable_list = variable if isinstance(variable, list) else [variable] + for v in variable_list: + assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!" + if v not in self._variables: + self._variables[v] = write + else: + self._variables[v] = self._variables[v] or write + + def add_property(self, prop, write=False): + prop_list = prop if isinstance(prop, list) else [prop] + for p in prop_list: + assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!" + self._properties.add(p) + + def children(self): + return [self._block] + + +class Module_Call(ASTNode): + def __init__(self, sim, module): + assert isinstance(module, Module), "Module_Call(): given parameter is not of type Module!" + super().__init__(sim) + self.module = module diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index cb394cf368f5f905da7199a22ef0d23c575985e0..09c9f739b3c5df202fbd01c26b84e937c2c05b3d 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -84,6 +84,10 @@ class Mutator: ast_node.size = self.mutate(ast_node.size) return ast_node + def mutate_Module(self, ast_node): + ast_node._block = self.mutate(ast_node._block) + return ast_node + def mutate_Realloc(self, ast_node): ast_node.array = self.mutate(ast_node.array) ast_node.size = self.mutate(ast_node.size) diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py index c17c78a59198f94831281f879565e10c6e3dad3b..d23a1547edfbf984c7b6bdf3febff11affbaa316 100644 --- a/src/pairs/ir/properties.py +++ b/src/pairs/ir/properties.py @@ -84,9 +84,6 @@ class Property(ASTNode): def sizes(self): return [self.sim.particle_capacity] if self.prop_type != Type_Vector else [self.sim.ndims(), self.sim.particle_capacity] - def scope(self): - return self.sim.global_scope - def __getitem__(self, expr): return PropertyAccess(self.sim, self, expr) @@ -144,9 +141,6 @@ class PropertyAccess(ASTTerm, VectorExpression): def add_terminal(self, terminal): self.terminals.add(terminal) - def scope(self): - return self.index.scope() - def children(self): return [self.prop, self.index] + list(super().children()) diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py index 102dc2d7a439dd34ff0f1704d7aae4bbe89e1d4f..2990e91a7fbf435a2f98ee51b7294e70d3eb1f9d 100644 --- a/src/pairs/ir/variables.py +++ b/src/pairs/ir/variables.py @@ -67,3 +67,19 @@ class VarDecl(ASTNode): super().__init__(sim) self.var = var self.sim.add_statement(self) + + +class Deref(ASTTerm): + def __init__(self, sim, var): + super().__init__(sim) + self._var = var + + def __str__(self): + return f"Deref<var: {self.var.name()}>" + + @property + def var(self): + return self._var + + def type(self): + return self._var.type() diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py index 9658ee778e7e48f680663d61c53c90b0fd0b4eee..074f416692e97121bf9d9d2bfef62108b80c3765 100644 --- a/src/pairs/sim/interaction.py +++ b/src/pairs/sim/interaction.py @@ -57,9 +57,9 @@ class ParticleInteraction(Lowerable): def __iter__(self): self.sim.add_statement(self) - self.sim.enter_scope(self) + self.sim.enter(self) yield self.i, self.j - self.sim.leave_scope() + self.sim.leave() @pairs_block def lower(self): diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index db46e6a7ed545e76a255ecc4c40b6f3e58d11cfb..70f11b088c04b2db747eed7499fccff47797dc10 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -3,6 +3,7 @@ from pairs.ir.block import Block, KernelBlock from pairs.ir.branches import Filter from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector from pairs.ir.layouts import Layout_AoS +from pairs.ir.module import Module from pairs.ir.properties import Properties from pairs.ir.symbols import Symbol from pairs.ir.variables import Variables @@ -20,12 +21,14 @@ from pairs.sim.timestep import Timestep from pairs.sim.variables import VariablesDecl from pairs.sim.vtk import VTKWrite from pairs.transformations.add_device_copies import add_device_copies +from pairs.transformations.fetch_modules_references import fetch_modules_references from pairs.transformations.prioritize_scalar_ops import prioritize_scalar_ops from pairs.transformations.set_used_bin_ops import set_used_bin_ops from pairs.transformations.simplify import simplify_expressions from pairs.transformations.LICM import move_loop_invariant_code from pairs.transformations.lower import lower_everything from pairs.transformations.merge_adjacent_blocks import merge_adjacent_blocks +from pairs.transformations.replace_modules_by_calls import replace_modules_by_calls from pairs.transformations.replace_symbols import replace_symbols @@ -33,7 +36,6 @@ class Simulation: def __init__(self, code_gen, dims=3, timesteps=100): self.code_gen = code_gen self.code_gen.assign_simulation(self) - self.global_scope = None self.position_prop = None self.properties = Properties(self) self.vars = Variables(self) @@ -52,6 +54,7 @@ class Simulation: self.block = Block(self, []) self.setups = Block(self, []) self.kernels = Block(self, []) + self.module_list = [] self.dims = dims self.ntimesteps = timesteps self.expr_id = 0 @@ -61,6 +64,13 @@ class Simulation: self.nparticles = self.nlocal + self.nghost self.properties.add_capacity(self.particle_capacity) + def add_module(self, module): + assert isinstance(module, Module), "add_module(): Given parameter is not of type Module!" + self.module_list.append(module) + + def modules(self): + return self.module_list + def ndims(self): return self.dims @@ -157,7 +167,7 @@ class Simulation: self.block = Block(self, []) def build_kernel_block_with_statements(self): - self.kernels.add_statement(KernelBlock(self, self.block)) + self.kernels.add_statement(Module(self, block=KernelBlock(self, self.block))) def add_statement(self, stmt): if not self.scope: @@ -175,10 +185,10 @@ class Simulation: for _ in range(0, self.nested_count): self.scope.pop() - def enter_scope(self, scope): + def enter(self, scope): self.scope.append(scope) - def leave_scope(self): + def leave(self): if not self.nest: self.scope.pop() else: @@ -212,8 +222,7 @@ class Simulation: PropertiesAlloc(self), ]) - program = Block.merge_blocks(decls, body) - self.global_scope = program + program = Module(self, name='main', block=Block.merge_blocks(decls, body)) # Transformations lower_everything(program) @@ -222,8 +231,10 @@ class Simulation: prioritize_scalar_ops(program) simplify_expressions(program) move_loop_invariant_code(program) + fetch_modules_references(program) set_used_bin_ops(program) add_device_copies(program) + replace_modules_by_calls(program) # For this part on, all bin ops are generated without usage verification self.check_decl_usage = False diff --git a/src/pairs/transformations/fetch_modules_references.py b/src/pairs/transformations/fetch_modules_references.py new file mode 100644 index 0000000000000000000000000000000000000000..534ddfbd3eedef8b204154d48f109aeebca77121 --- /dev/null +++ b/src/pairs/transformations/fetch_modules_references.py @@ -0,0 +1,62 @@ +from pairs.ir.module import Module +from pairs.ir.mutator import Mutator +from pairs.ir.variables import Deref +from pairs.ir.visitor import Visitor + + +class FetchModulesReferences(Visitor): + def __init__(self, ast): + super().__init__(ast) + self.module_stack = [] + self.writing = False + + def visit_Assign(self, ast_node): + self.writing = True + for c in ast_node.destinations(): + self.visit(c) + + self.writing = False + for c in ast_node.sources(): + self.visit(c) + + def visit_Module(self, ast_node): + self.module_stack.append(ast_node) + self.visit_children(ast_node) + self.module_stack.pop() + + def visit_Array(self, ast_node): + for m in self.module_stack: + m.add_array(ast_node) + + def visit_Property(self, ast_node): + for m in self.module_stack: + m.add_property(ast_node) + + def visit_Var(self, ast_node): + for m in self.module_stack: + m.add_variable(ast_node, self.writing) + + +class AddDereferencesToWriteVariables(Mutator): + def __init__(self, ast): + super().__init__(ast) + self.module_stack = [] + + def mutate_Module(self, ast_node): + self.module_stack.append(ast_node) + ast_node._block = self.mutate(ast_node._block) + self.module_stack.pop() + return ast_node + + def mutate_Var(self, ast_node): + if ast_node in self.module_stack[-1].write_variables(): + return Deref(ast_node.sim, ast_node) + + return ast_node + + +def fetch_modules_references(ast): + fetch_refs = FetchModulesReferences(ast) + fetch_refs.visit() + add_derefs_to_write_vars = AddDereferencesToWriteVariables(ast) + add_derefs_to_write_vars.mutate() diff --git a/src/pairs/transformations/replace_modules_by_calls.py b/src/pairs/transformations/replace_modules_by_calls.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2e6c7a405eb8bc091f3aca1c113c46c15d4c68 --- /dev/null +++ b/src/pairs/transformations/replace_modules_by_calls.py @@ -0,0 +1,15 @@ +from pairs.ir.module import Module_Call +from pairs.ir.mutator import Mutator + + +class ReplaceModulesByCalls(Mutator): + def __init__(self, ast): + super().__init__(ast) + + def mutate_Module(self, ast_node): + return Module_Call(ast_node.sim, ast_node) + + +def replace_modules_by_calls(ast): + replace = ReplaceModulesByCalls(ast) + replace.mutate()