diff --git a/code_gen/cgen.py b/code_gen/cgen.py index 0e6fa3ba0f757f45937b4e7686c4412d841c3d33..93ea927c2a9aff76d6a56663fc2e761c420a6a9e 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -3,7 +3,7 @@ from ir.arrays import Array, ArrayAccess, ArrayDecl from ir.block import Block from ir.branches import Branch from ir.cast import Cast -from ir.bin_op import BinOp, BinOpDef +from ir.bin_op import BinOp, Decl, VectorAccess from ir.data_types import Type_Int, Type_Float, Type_String, Type_Vector from ir.functions import Call from ir.layouts import Layout_AoS, Layout_SoA, Layout_Invalid @@ -11,7 +11,7 @@ from ir.lit import Lit from ir.loops import For, Iter, ParticleFor, While from ir.math import Ceil, Sqrt from ir.memory import Malloc, Realloc -from ir.properties import Property, PropertyList, RegisterProperty, UpdateProperty +from ir.properties import Property, PropertyAccess, PropertyList, RegisterProperty, UpdateProperty from ir.select import Select from ir.sizeof import Sizeof from ir.utils import Print @@ -56,7 +56,7 @@ class CGen: self.print("}") self.print.end() - def generate_statement(self, ast_node): + def generate_statement(self, ast_node, bypass_checking=False): if isinstance(ast_node, ArrayDecl): tkw = CGen.type2keyword(ast_node.array.type()) size = self.generate_expression(BinOp.inline(ast_node.array.alloc_size())) @@ -76,29 +76,44 @@ class CGen: self.print.add_ind(-4) - if isinstance(ast_node, BinOpDef): - bin_op = ast_node.bin_op - - if not isinstance(bin_op, BinOp) or not ast_node.used: - return None - - if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False: - if bin_op.kind() == BinOp.Kind_Scalar: - lhs = self.generate_expression(bin_op.lhs, bin_op.mem) - rhs = self.generate_expression(bin_op.rhs) - tkw = CGen.type2keyword(bin_op.type()) - self.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};") - - elif bin_op.kind() == BinOp.Kind_Vector: - for i in bin_op.vector_indexes: - lhs = self.generate_expression(bin_op.lhs, bin_op.mem, index=i) - rhs = self.generate_expression(bin_op.rhs, index=i) - self.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};") + # TODO: Why there are Decls for other types? + if isinstance(ast_node, Decl): + if isinstance(ast_node.elem, BinOp): + bin_op = ast_node.elem + if not bypass_checking and (not isinstance(bin_op, BinOp) or not ast_node.used): + return None + + if bin_op.inlined is False and bin_op.generated is False: + if bin_op.is_vector_kind(): + for i in bin_op.indexes(): + lhs = self.generate_expression(bin_op.lhs, bin_op.mem, index=i) + rhs = self.generate_expression(bin_op.rhs, index=i) + self.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};") + + else: + lhs = self.generate_expression(bin_op.lhs, bin_op.mem) + rhs = self.generate_expression(bin_op.rhs) + tkw = CGen.type2keyword(bin_op.type()) + self.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};") + + ast_node.elem.generated = True + + if isinstance(ast_node.elem, PropertyAccess): + prop_access = ast_node.elem + prop_name = prop_access.prop.name() + acc_ref = f"p{prop_access.id()}" + + if prop_access.is_vector_kind(): + for i in prop_access.indexes(): + i_expr = self.generate_expression(prop_access.get_index_expression(i)) + self.print(f"const double {acc_ref}_{i} = {prop_name}[{i_expr}];") else: - raise Exception("Invalid BinOp kind!") + tkw = CGen.type2keyword(prop_access.type()) + index_g = self.generate_expression(prop_access.index) + self.print(f"const {tkw} {acc_ref} = {prop_name}[{index_g}];") - bin_op.generated = True + ast_node.elem.generated = True if isinstance(ast_node, Branch): cond = self.generate_expression(ast_node.cond) @@ -200,41 +215,34 @@ class CGen: return ast_node.name() if isinstance(ast_node, ArrayAccess): - index = self.generate_expression(ast_node.index) array_name = ast_node.array.name() + acc_index = self.generate_expression(ast_node.index) if mem: - return f"{array_name}[{index}]" + return f"{array_name}[{acc_index}]" acc_ref = f"a{ast_node.id()}" - if ast_node.generated is False: + if not ast_node.generated: tkw = CGen.type2keyword(ast_node.type()) - self.print(f"const {tkw} {acc_ref} = {array_name}[{index}];") + self.print(f"const {tkw} {acc_ref} = {array_name}[{acc_index}];") ast_node.generated = True return acc_ref if isinstance(ast_node, BinOp): - if isinstance(ast_node.lhs, BinOp) and ast_node.lhs.kind() == BinOp.Kind_Vector and ast_node.operator() == '[]': - return self.generate_expression(ast_node.lhs, ast_node.mem, self.generate_expression(ast_node.rhs)) - lhs = self.generate_expression(ast_node.lhs, mem, index) rhs = self.generate_expression(ast_node.rhs, index=index) - if ast_node.operator() == '[]': - idx = self.generate_expression(ast_node.mapped_vector_index(index)) if ast_node.is_vector_property_access() else rhs - return f"{lhs}[{idx}]" if ast_node.mem else f"{lhs}_{idx}" - if ast_node.inlined is True: assert ast_node.type() != Type_Vector, "Vector operations cannot be inlined!" return f"({lhs} {ast_node.operator()} {rhs})" # Some expressions can be defined on-the-fly during transformations, hence they do not have - # a definition statement in the tree, so we generate them right before use + # a declaration statement in the tree, so we generate them right before use if not ast_node.generated: - self.generate_statement(ast_node.definition()) + self.generate_statement(ast_node.declaration(), bypass_checking=True) - if ast_node.kind() == BinOp.Kind_Vector: + if ast_node.is_vector_kind(): assert index is not None, "Index must be set for vector reference!" return f"e{ast_node.id()}[{index}]" if ast_node.mem else f"e{ast_node.id()}_{index}" @@ -268,6 +276,24 @@ class CGen: if isinstance(ast_node, Property): return ast_node.name() + if isinstance(ast_node, PropertyAccess): + assert not ast_node.is_vector_kind() or index is not None, "Index must be set for vector property access!" + prop_name = ast_node.prop.name() + + if mem: + index_expr = ast_node.index if not ast_node.is_vector_kind() else ast_node.get_index_expression(index) + index_g = self.generate_expression(index_expr) + return f"{prop_name}[{index_g}]" + + if not ast_node.generated: + self.generate_statement(ast_node.declaration(), bypass_checking=True) + + acc_ref = f"p{ast_node.id()}" + if ast_node.is_vector_kind(): + acc_ref += f"_{index}" + + return acc_ref + if isinstance(ast_node, PropertyList): tid = CGen.temp_id list_ref = f"prop_list_{tid}" @@ -295,3 +321,6 @@ class CGen: if isinstance(ast_node, Var): return ast_node.name() + + if isinstance(ast_node, VectorAccess): + return self.generate_expression(ast_node.expr, mem, self.generate_expression(ast_node.index)) diff --git a/graph/graphviz.py b/graph/graphviz.py index 30c0df916ab1495300d8a5ba7965e68ab67f4b9e..e078b421b5875b6fbcf8dc924aaf478332785c4e 100644 --- a/graph/graphviz.py +++ b/graph/graphviz.py @@ -1,5 +1,5 @@ from ir.arrays import Array -from ir.bin_op import BinOp, BinOpDef +from ir.bin_op import BinOp, Decl from ir.lit import Lit from ir.loops import Iter from ir.properties import Property @@ -17,12 +17,12 @@ class ASTGraph: def render(self): def generate_edges_for_node(ast_node, graph, generated): node_id = id(ast_node) - if not isinstance(ast_node, BinOpDef) and node_id not in generated: + if not isinstance(ast_node, Decl) and node_id not in generated: node_ref = f"n{id(ast_node)}" generated.append(node_id) graph.node(node_ref, label=ASTGraph.get_node_label(ast_node)) for child in ast_node.children(): - if not isinstance(child, BinOpDef): + if not isinstance(child, Decl): child_ref = f"n{id(child)}" graph.node(child_ref, label=ASTGraph.get_node_label(child)) graph.edge(node_ref, child_ref) diff --git a/ir/assign.py b/ir/assign.py index 997cd4c8bb40546b68708ec517737f52efeee3e1..4465699badf5f75b9916bf2953aac56ef2eff0e4 100644 --- a/ir/assign.py +++ b/ir/assign.py @@ -1,6 +1,7 @@ from ir.ast_node import ASTNode from ir.data_types import Type_Vector from ir.lit import as_lit_ast +from ir.vector_expr import VectorExpression from functools import reduce @@ -14,8 +15,7 @@ class Assign(ASTNode): self.assignments = [] for i in range(0, sim.ndims()): - from ir.bin_op import BinOp - dim_src = src if not isinstance(src, BinOp) or src.type() != Type_Vector else src[i] + dim_src = src if not isinstance(src, VectorExpression) or src.type() != Type_Vector else src[i] self.assignments.append((dest[i], dim_src)) else: self.assignments = [(dest, src)] diff --git a/ir/bin_op.py b/ir/bin_op.py index 34081bb35293f95282b91a10c5287790e4fbb59f..809510cd8c244b493668aa38fd512ecd64b4be8d 100644 --- a/ir/bin_op.py +++ b/ir/bin_op.py @@ -2,28 +2,24 @@ from ir.ast_node import ASTNode from ir.assign import Assign from ir.data_types import Type_Float, Type_Bool, Type_Vector from ir.lit import as_lit_ast -from ir.properties import Property +from ir.vector_expr import VectorExpression -class BinOpDef(ASTNode): - def __init__(self, bin_op): - super().__init__(bin_op.sim) - self.bin_op = bin_op - self.bin_op.sim.add_statement(self) - self.used = not bin_op.sim.check_bin_ops_usage +class Decl(ASTNode): + def __init__(self, sim, elem): + super().__init__(sim) + self.elem = elem + self.used = not sim.check_decl_usage + sim.add_statement(self) def __str__(self): - return f"BinOpDef<bin_op: self.bin_op>" + return f"Decl<elem: self.elem>" def children(self): - return [self.bin_op] - + return [self.elem] -class BinOp(ASTNode): - # BinOp kinds - Kind_Scalar = 0 - Kind_Vector = 1 +class BinOp(VectorExpression): last_bin_op = 0 def new_id(): @@ -48,9 +44,7 @@ class BinOp(ASTNode): self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op) self.bin_op_scope = None self.terminals = set() - self._vector_indexes = set() - self.vector_index_mapping = {} - self.bin_op_def = BinOpDef(self) + self.decl = Decl(sim, self) def reassign(self, lhs, rhs, op): assert self.generated is False, "Error on reassign: BinOp {} already generated!".format(self.bin_op_id) @@ -65,9 +59,7 @@ class BinOp(ASTNode): return f"BinOp<a: {a}, b: {b}, op: {self.op}>" def match(self, bin_op): - return self.lhs == bin_op.lhs and \ - self.rhs == bin_op.rhs and \ - self.op == bin_op.operator() + return self.lhs == bin_op.lhs and self.rhs == bin_op.rhs and self.op == bin_op.operator() def x(self): return self.__getitem__(0) @@ -78,40 +70,6 @@ class BinOp(ASTNode): def z(self): return self.__getitem__(2) - def map_vector_index(self, index, expr): - self.vector_index_mapping[index] = expr - - def mapped_vector_index(self, index): - mapping = self.vector_index_mapping - return mapping[index] if index in mapping else as_lit_ast(self.sim, index) - - def mapped_expressions(self): - return self.vector_index_mapping.values() - - @property - def vector_indexes(self): - return self._vector_indexes - - def propagate_vector_access(self, index): - self.vector_indexes.add(index) - - if isinstance(self.lhs, BinOp) and self.lhs.kind() == BinOp.Kind_Vector: - self.lhs.propagate_vector_access(index) - - if isinstance(self.rhs, BinOp) and self.rhs.kind() == BinOp.Kind_Vector: - self.rhs.propagate_vector_access(index) - - def __getitem__(self, index): - assert self.type() == Type_Vector, "Cannot use operator [] on specified type!" - self.propagate_vector_access(index) - return BinOp(self.sim, self, as_lit_ast(self.sim, index), '[]', self.mem) - - def is_property_access(self): - return isinstance(self.lhs, Property) and self.operator() == '[]' - - def is_vector_property_access(self): - return self.is_property_access() and self.type() == Type_Vector - def set(self, other): assert self.mem is True, "Invalid assignment: lvalue expected!" return self.sim.add_statement(Assign(self.sim, self, other)) @@ -132,9 +90,6 @@ class BinOp(ASTNode): return Type_Bool if op == '[]': - if isinstance(lhs, Property): - return lhs_type - if lhs_type == Type_Vector: return Type_Float @@ -168,15 +123,12 @@ class BinOp(ASTNode): def type(self): return self.bin_op_type - def definition(self): - return self.bin_op_def + def declaration(self): + return self.decl def operator(self): return self.op - def kind(self): - return BinOp.Kind_Vector if self.type() == Type_Vector else BinOp.Kind_Scalar - def add_terminal(self, terminal): self.terminals.add(terminal) @@ -189,7 +141,11 @@ class BinOp(ASTNode): return self.bin_op_scope def children(self): - return [self.lhs, self.rhs] + return [self.lhs, self.rhs] + list(super().children()) + + def __getitem__(self, index): + super().__getitem__(index) + return VectorAccess(self.sim, self, as_lit_ast(self.sim, index)) def __add__(self, other): return BinOp(self.sim, self, other, '+') @@ -291,3 +247,23 @@ class ASTTerm(ASTNode): def __mod__(self, other): return BinOp(self.sim, self, other, '%') + + +class VectorAccess(ASTTerm): + def __init__(self, sim, expr, index): + super().__init__(sim) + self.expr = expr + self.index = index + + def type(self): + return Type_Float + + def set(self, other): + return self.sim.add_statement(Assign(self.sim, self, other)) + + def add(self, other): + return self.sim.add_statement(Assign(self.sim, self, self + other)) + + def sub(self, other): + return self.sim.add_statement(Assign(self.sim, self, self - other)) + diff --git a/ir/block.py b/ir/block.py index 89ffcd8b4d72bfd99aed60c346afaf4a65d8ab77..b136db8b0c3afb826cf92d3c6824428f94cb97db 100644 --- a/ir/block.py +++ b/ir/block.py @@ -6,6 +6,8 @@ class Block(ASTNode): super().__init__(sim) self.level = 0 self.variants = set() + self.props_accessed = {} + self.props_to_sync = set() if isinstance(stmts, Block): self.stmts = stmts.statements() diff --git a/ir/mutator.py b/ir/mutator.py index b856fbf103b558993f5a7e23169f0e35c467ba09..942173689eb2fe8af502f496bb22379741048025 100644 --- a/ir/mutator.py +++ b/ir/mutator.py @@ -33,11 +33,7 @@ class Mutator: def mutate_BinOp(self, ast_node): ast_node.lhs = self.mutate(ast_node.lhs) ast_node.rhs = self.mutate(ast_node.rhs) - ast_node.vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.vector_index_mapping.items()} - return ast_node - - def mutate_BinOpDef(self, ast_node): - ast_node.bin_op = self.mutate(ast_node.bin_op) + ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()} return ast_node def mutate_Block(self, ast_node): @@ -50,6 +46,10 @@ class Mutator: ast_node.block_else = None if ast_node.block_else is None else self.mutate(ast_node.block_else) return ast_node + def mutate_Decl(self, ast_node): + ast_node.elem = self.mutate(ast_node.elem) + return ast_node + def mutate_Filter(self, ast_node): return self.mutate_Branch(ast_node) @@ -65,6 +65,12 @@ class Mutator: def mutate_ParticleFor(self, ast_node): return self.mutate_For(ast_node) + def mutate_PropertyAccess(self, ast_node): + ast_node.prop = self.mutate(ast_node.prop) + ast_node.index = self.mutate(ast_node.index) + ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()} + return ast_node + def mutate_Malloc(self, ast_node): ast_node.array = self.mutate(ast_node.array) ast_node.size = self.mutate(ast_node.size) @@ -89,6 +95,10 @@ class Mutator: ast_node.block = self.mutate(ast_node.block) return ast_node + def mutate_VectorAccess(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + def mutate_While(self, ast_node): ast_node.cond = self.mutate(ast_node.cond) ast_node.block = self.mutate(ast_node.block) diff --git a/ir/properties.py b/ir/properties.py index 1ff71b909fb4a333b445b83ee678160d9b26b855..2d5be9225b79cca1edd6e835b099d48aa3d6106f 100644 --- a/ir/properties.py +++ b/ir/properties.py @@ -1,6 +1,10 @@ from ir.ast_node import ASTNode +from ir.assign import Assign +from ir.bin_op import BinOp, Decl, ASTTerm, VectorAccess +from ir.data_types import Type_Vector from ir.layouts import Layout_AoS from ir.lit import as_lit_ast +from ir.vector_expr import VectorExpression class Properties: @@ -74,12 +78,81 @@ class Property(ASTNode): def default(self): return self.default_value + def ndims(self): + return 1 if self.prop_type != Type_Vector else 2 + + def sizes(self): + return [self.sim.particle_capacity] if self.prop_type != Type_Vector else [self.sim.ndims(), self.sim.particle_capacity] + def scope(self): return self.sim.global_scope def __getitem__(self, expr): - from ir.bin_op import BinOp - return BinOp(self.sim, self, expr, '[]', True) + return PropertyAccess(self.sim, self, expr) + + +class PropertyAccess(ASTTerm, VectorExpression): + last_prop_acc = 0 + + def new_id(): + PropertyAccess.last_prop_acc += 1 + return PropertyAccess.last_prop_acc - 1 + + def __init__(self, sim, prop, index): + super().__init__(sim) + self.acc_id = PropertyAccess.new_id() + self.prop = prop + self.index = as_lit_ast(sim, index) + self.generated = False + self.terminals = set() + self.decl = Decl(sim, self) + + def __str__(self): + return f"PropertyAccess<prop: {self.prop}, index: {self.index}>" + + def vector_index(self, v_index): + sizes = self.prop.sizes() + layout = self.prop.layout() + index = self.index * sizes[0] + v_index if layout == Layout_AoS else \ + v_index * sizes[1] + self.index if layout == Layout_SoA else \ + None + + assert index is not None, "Invalid data layout" + return index + + def propagate_through(self): + return [] + + def set(self, other): + return self.sim.add_statement(Assign(self.sim, self, other)) + + def add(self, other): + return self.sim.add_statement(Assign(self.sim, self, self + other)) + + def sub(self, other): + return self.sim.add_statement(Assign(self.sim, self, self - other)) + + def id(self): + return self.acc_id + + def type(self): + return self.prop.type() + + def declaration(self): + return self.decl + + def add_terminal(self, terminal): + self.terminals.add(terminal) + + def scope(self): + return self.index.scope() + + def children(self): + return [self.prop, self.index] + list(super().children()) + + def __getitem__(self, index): + super().__getitem__(index) + return VectorAccess(self.sim, self, as_lit_ast(self.sim, index)) class PropertyList(ASTNode): diff --git a/ir/vector_expr.py b/ir/vector_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3bb0beb9dfd9c7258ad280af34aea9fa7ec6fa --- /dev/null +++ b/ir/vector_expr.py @@ -0,0 +1,49 @@ +from ir.ast_node import ASTNode +from ir.data_types import Type_Vector +from ir.lit import Lit + + +class VectorExpression(ASTNode): + def __init__(self, sim): + super().__init__(sim) + self.vector_indexes = set() + self.expressions = {} + + def vector_expressions(self): + return self.expressions.values() + + def indexes(self): + yield from self.vector_indexes + + def get_index_expression(self, index): + index_value = index.value if isinstance(index, Lit) else index + return self.expressions[index_value] if index_value in self.expressions else None + + def vector_index(self, v_index): + return None + + def propagate_vector_index(self, index): + self.vector_indexes.add(index) + index_expr = self.vector_index(index) + + if index_expr is not None: + self.expressions[index] = index_expr + + for p in self.propagate_through(): + if isinstance(p, VectorExpression) and p.is_vector_kind(): + p.propagate_vector_index(index) + + def is_vector_kind(self): + return self.type() == Type_Vector + + # Default is to propagate through children, but this can be overridden + def propagate_through(self): + return self.children() + + def children(self): + return self.vector_expressions() + + def __getitem__(self, index): + assert self.type() == Type_Vector, "Cannot use operator [] on specified type!" + self.propagate_vector_index(index) + return self diff --git a/log/logbook_RR.md b/log/logbook_RR.md index 06ab43ea9c46d835f2db3223d400715ae4e5210e..86432e3eab143a347677293ae9a149334680b143 100644 --- a/log/logbook_RR.md +++ b/log/logbook_RR.md @@ -19,10 +19,8 @@ The Agenda section is a scratchpad area for planning and Todo list ------------------------------------------------------------------------------> # Agenda -* Implement neighbor lists * Use variables to accumulate forces (reduction) * Express simulation specific kernels (cell lists, PBC, neighbor lists, comm) in a cleaner way with new syntax -* Runtime functions (VTK printing, read and write to files) * MPI support * GPU support * LLVM support diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index ef2a6e418bca63d60a0ed64608f1fe6d01354ba2..8232700153c848975c7cb62d315a9b7493e7fb2c 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -20,7 +20,6 @@ from sim.setup_wrapper import SetupWrapper from sim.timestep import Timestep from sim.variables import VariablesDecl from sim.vtk import VTKWrite -from transformations.flatten import flatten_property_accesses from transformations.prioritize_scalar_ops import prioritaze_scalar_ops from transformations.set_used_bin_ops import set_used_bin_ops from transformations.simplify import simplify_expressions @@ -45,7 +44,7 @@ class ParticleSimulation: self.scope = [] self.nested_count = 0 self.nest = False - self.check_bin_ops_usage = True + self.check_decl_usage = True self.block = Block(self, []) self.setups = SetupWrapper() self.kernels = KernelWrapper() @@ -212,13 +211,12 @@ class ParticleSimulation: # Transformations prioritaze_scalar_ops(program) - flatten_property_accesses(program) simplify_expressions(program) - move_loop_invariant_code(program) + #move_loop_invariant_code(program) set_used_bin_ops(program) # For this part on, all bin ops are generated without usage verification - self.check_bin_ops_usage = False + self.check_decl_usage = False ASTGraph(self.kernels.lower(), "kernels").render() self.code_gen.generate_program(program) diff --git a/sim/pbc.py b/sim/pbc.py index 1f066be079ec89809b1ec3c8f4eb370823f6506c..dd23bca5212d74a41096f12ea8720021d07fa7c4 100644 --- a/sim/pbc.py +++ b/sim/pbc.py @@ -89,13 +89,14 @@ class SetupPBC: npbc.set(0) for d in range(0, ndims): for i in For(sim, 0, nlocal + npbc): - last_id = nlocal + npbc + #last_id = nlocal + npbc # TODO: VecFilter? for _ in Filter(sim, positions[i][d] < grid.min(d) + cutneigh): for capacity_exceeded in Branch(sim, npbc >= pbc_capacity): if capacity_exceeded: resize.set(Select(sim, resize > npbc, resize + 1, npbc)) else: + last_id = nlocal + npbc pbc_map[npbc].set(i) pbc_mult[npbc][d].set(1) positions[last_id][d].set(positions[i][d] + grid.length(d)) @@ -111,6 +112,7 @@ class SetupPBC: if capacity_exceeded: resize.set(Select(sim, resize > npbc, resize + 1, npbc)) else: + last_id = nlocal + npbc pbc_map[npbc].set(i) pbc_mult[npbc][d].set(-1) positions[last_id][d].set(positions[i][d] - grid.length(d)) diff --git a/transformations/LICM.py b/transformations/LICM.py index 5b42d8c072ff70397dd303432bcc877eab3f71ab..4a655b79689f961653b3aed348b6513adda5637a 100644 --- a/transformations/LICM.py +++ b/transformations/LICM.py @@ -1,6 +1,7 @@ from ir.bin_op import BinOp from ir.loops import For, While from ir.mutator import Mutator +from ir.properties import PropertyAccess from ir.visitor import Visitor @@ -39,12 +40,7 @@ class SetBlockVariants(Mutator): def mutate_BinOp(self, ast_node): ast_node.lhs = self.mutate(ast_node.lhs) - - # For property accesses, we only want to include the property name, and not - # the index that is also present in the expression - if not ast_node.is_property_access(): - ast_node.rhs = self.mutate(ast_node.rhs) - + ast_node.rhs = self.mutate(ast_node.rhs) return ast_node def mutate_ArrayAccess(self, ast_node): @@ -66,6 +62,12 @@ class SetBlockVariants(Mutator): def mutate_Property(self, ast_node): return self.push_variant(ast_node) + def mutate_PropertyAccess(self, ast_node): + # For property accesses, we only want to include the property name, and not + # the index that is also present in the access node + ast_node.prop = self.mutate(ast_node.prop) + return ast_node + def mutate_Var(self, ast_node): return self.push_variant(ast_node) @@ -91,10 +93,10 @@ class SetParentBlock(Visitor): def visit_Assign(self, ast_node): self.set_parent_block(ast_node) - def visit_BinOpDef(self, ast_node): + def visit_Branch(self, ast_node): self.set_parent_block(ast_node) - def visit_Branch(self, ast_node): + def visit_Decl(self, ast_node): self.set_parent_block(ast_node) def visit_Filter(self, ast_node): @@ -124,16 +126,21 @@ class SetParentBlock(Visitor): class SetBinOpTerminals(Visitor): def __init__(self, ast): super().__init__(ast) - self.bin_ops = [] + self.elems = [] def push_terminal(self, ast_node): - for bin_op in self.bin_ops: - bin_op.add_terminal(ast_node.name()) + for e in self.elems: + e.add_terminal(ast_node.name()) def visit_BinOp(self, ast_node): - self.bin_ops.append(ast_node) + self.elems.append(ast_node) self.visit_children(ast_node) - self.bin_ops.pop() + self.elems.pop() + + def visit_PropertyAccess(self, ast_node): + self.elems.append(ast_node) + self.visit_children(ast_node) + self.elems.pop() def visit_Array(self, ast_node): self.push_terminal(ast_node) @@ -148,6 +155,7 @@ class SetBinOpTerminals(Visitor): def visit_Property(self, ast_node): self.push_terminal(ast_node) + def visit_Var(self, ast_node): self.push_terminal(ast_node) @@ -174,11 +182,11 @@ class LICM(Mutator): self.loops.pop() return ast_node - def mutate_BinOpDef(self, ast_node): - if self.loops and isinstance(ast_node.bin_op, BinOp): + def mutate_Decl(self, ast_node): + if self.loops: last_loop = self.loops[-1] #print(f"variants = {last_loop.block.variants}, terminals = {ast_node.bin_op.terminals}") - if not last_loop.block.variants.intersection(ast_node.bin_op.terminals): + if isinstance(ast_node.elem, (BinOp, PropertyAccess)) and not last_loop.block.variants.intersection(ast_node.elem.terminals): #print(f'lifting {ast_node.bin_op.id()}') self.lifts[id(last_loop)].append(ast_node) return None diff --git a/transformations/flatten.py b/transformations/flatten.py deleted file mode 100644 index 7b0ac8e739d394ee0d782bef84ae7db3ffecc10a..0000000000000000000000000000000000000000 --- a/transformations/flatten.py +++ /dev/null @@ -1,35 +0,0 @@ -from ir.layouts import Layout_AoS, Layout_SoA -from ir.mutator import Mutator - - -class FlattenPropertyAccesses(Mutator): - def __init__(self, ast): - super().__init__(ast) - - def mutate_BinOp(self, ast_node): - ast_node.lhs = self.mutate(ast_node.lhs) - ast_node.rhs = self.mutate(ast_node.rhs) - - if ast_node.is_vector_property_access(): - layout = ast_node.lhs.layout() - - for i in ast_node.vector_indexes: - flat_index = None - - if layout == Layout_AoS: - flat_index = ast_node.rhs * ast_node.sim.ndims() + i - - elif layout == Layout_SoA: - flat_index = i * ast_node.sim.particle_capacity + ast_node.rhs - - else: - raise Exception("Invalid property layout!") - - ast_node.map_vector_index(i, flat_index) - - return ast_node - - -def flatten_property_accesses(ast_node): - flatten = FlattenPropertyAccesses(ast_node) - flatten.mutate() diff --git a/transformations/set_used_bin_ops.py b/transformations/set_used_bin_ops.py index d26d1bc7a11633e876238fd48e1cb69ca2c56502..404cb65d11c6365685499380f459fef0d87cc4f2 100644 --- a/transformations/set_used_bin_ops.py +++ b/transformations/set_used_bin_ops.py @@ -1,4 +1,3 @@ -from ir.bin_op import BinOp from ir.visitor import Visitor @@ -7,16 +6,16 @@ class SetUsedBinOps(Visitor): super().__init__(ast) self.bin_ops = [] - def visit_BinOpDef(self, ast_node): - pass - def visit_BinOp(self, ast_node): - ast_node.bin_op_def.used = True + ast_node.decl.used = True self.visit_children(ast_node) - # TODO: These expressions could be automatically included in visitor traversal - for vidxs in ast_node.mapped_expressions(): - self.visit(vidxs) + def visit_Decl(self, ast_node): + pass + + def visit_PropertyAccess(self, ast_node): + ast_node.decl.used = True + self.visit_children(ast_node) def set_used_bin_ops(ast): set_used_binops = SetUsedBinOps(ast) diff --git a/transformations/simplify.py b/transformations/simplify.py index 83d6cac1d216553e149b84ded7239486a38749da..69803d816b6758ca26bd01f0fa1d7012df69b819 100644 --- a/transformations/simplify.py +++ b/transformations/simplify.py @@ -11,7 +11,7 @@ class SimplifyExpressions(Mutator): sim = ast_node.lhs.sim ast_node.lhs = self.mutate(ast_node.lhs) ast_node.rhs = self.mutate(ast_node.rhs) - ast_node.vector_index_mapping = {i: self.mutate(e) for i, e in ast_node.vector_index_mapping.items()} + ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()} if ast_node.op in ['+', '-'] and ast_node.rhs == 0: return ast_node.lhs