diff --git a/ast/arrays.py b/ast/arrays.py index 6908bbe8b0e340bc4d019e00dc755bd3dca77476..c8019352fa1b7d6e5f1be5f3027373dcd91f461a 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -1,6 +1,6 @@ from ast.assign import Assign from ast.data_types import Type_Array -from ast.expr import Expr +from ast.expr import BinOp from ast.layouts import Layout_AoS, Layout_SoA from ast.lit import as_lit_ast from ast.memory import Realloc @@ -80,7 +80,7 @@ class Array: def children(self): return [] - def generate(self, mem=False): + def generate(self, mem=False, index=None): return self.arr_name def transform(self, fn): @@ -134,13 +134,13 @@ class ArrayAccess: return f"ArrayAccess<array: {self.array}, indexes: {self.indexes}>" def __add__(self, other): - return Expr(self.sim, self, other, '+') + return BinOp(self.sim, self, other, '+') def __mul__(self, other): - return Expr(self.sim, self, other, '*') + return BinOp(self.sim, self, other, '*') def __rmul__(self, other): - return Expr(self.sim, other, self, '*') + return BinOp(self.sim, other, self, '*') def __getitem__(self, index): assert self.index is None, "Number of indexes higher than array dimension!" @@ -196,7 +196,7 @@ class ArrayAccess: def children(self): return [self.array] + self.indexes - def generate(self, mem=False): + def generate(self, mem=False, index=None): agen = self.array.generate() igen = self.index.generate() if mem is False and self.generated is False: @@ -220,10 +220,8 @@ class ArrayDecl: def children(self): return [] - def generate(self, mem=False): - self.sim.code_gen.generate_array_decl( - self.array.name(), self.array.type(), - self.array.alloc_size().generate_inline(recursive=True)) + def generate(self, mem=False, index=None): + self.sim.code_gen.generate_array_decl(self.array.name(), self.array.type(), BinOp.inline(self.array.alloc_size()).generate()) def transform(self, fn): return fn(self) diff --git a/ast/assign.py b/ast/assign.py index 0084a1ec6610450ed4cf6319ca22bf23f43a1ee0..98b78b16c03da1a8884d3ed4f10c4da97609ba0a 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -15,8 +15,8 @@ class Assign: self.assignments = [] for i in range(0, sim.dimensions): - from ast.expr import Expr - dsrc = (src if (not isinstance(src, Expr) or + from ast.expr import BinOp + dsrc = (src if (not isinstance(src, BinOp) or src.type() != Type_Vector) else src[i]) diff --git a/ast/cast.py b/ast/cast.py index d219d961b21908f1350beeeaf8609276bf95659f..aa101806675efb33e99c8daffd3fbcd31a0ce858 100644 --- a/ast/cast.py +++ b/ast/cast.py @@ -28,7 +28,7 @@ class Cast: def children(self): return [self.expr] - def generate(self, mem=False): + def generate(self, mem=False, index=None): return self.sim.code_gen.generate_cast(self.cast_type, self.expr.generate()) def transform(self, fn): diff --git a/ast/expr.py b/ast/expr.py index c7ec9f4ba3a429387d9a2624c0fd1053aaa978d0..723759f9c2bd65fbfcf09b43f9b0ac642d39a2e2 100644 --- a/ast/expr.py +++ b/ast/expr.py @@ -4,78 +4,130 @@ from ast.lit import as_lit_ast from ast.properties import Property -class Expr: - last_expr = 0 +class BinOpDef: + def __init__(self, bin_op): + self.bin_op = bin_op + self.bin_op.sim.add_statement(self) + + def __str__(self): + return f"BinOpDef<bin_op: self.bin_op>" + + def children(self): + return [self.bin_op] + + def generate(self, mem=False): + bin_op = self.bin_op + + if not isinstance(bin_op, BinOp): + return None + + if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False: + if bin_op.kind() == BinOp.Kind_Scalar: + lhs = bin_op.lhs.generate(bin_op.mem) + rhs = bin_op.rhs.generate() + bin_op.sim.code_gen.generate_expr(bin_op.id(), bin_op.type(), lhs, rhs, bin_op.op) + + elif bin_op.kind() == BinOp.Kind_Vector: + for i in bin_op.vector_indexes(): + lhs = bin_op.lhs.generate(bin_op.mem, index=i) + rhs = bin_op.rhs.generate(index=i) + bin_op.sim.code_gen.generate_vec_expr(bin_op.id(), i, lhs, rhs, bin_op.operator(), bin_op.mem) + + else: + raise Exception("Invalid BinOp kind!") + + bin_op.generated = True + + def transform(self, fn): + self.bin_op = self.bin_op.transform(fn) + return fn(self) + + +class BinOp: + # BinOp kinds + Kind_Scalar = 0 + Kind_Vector = 1 + + last_bin_op = 0 def new_id(): - Expr.last_expr += 1 - return Expr.last_expr - 1 + BinOp.last_bin_op += 1 + return BinOp.last_bin_op - 1 + + def inline(op): + if not isinstance(op, BinOp): + return op + + return op.inline_rec() def __init__(self, sim, lhs, rhs, op, mem=False): self.sim = sim - self.expr_id = Expr.new_id() + self.bin_op_id = BinOp.new_id() self.lhs = as_lit_ast(sim, lhs) self.rhs = as_lit_ast(sim, rhs) self.op = op self.mem = mem - self.expr_type = Expr.infer_type(self.lhs, self.rhs, self.op) - self.expr_scope = None - self.vec_generated = [] self.mutable = self.lhs.is_mutable() or self.rhs.is_mutable() # Value can change accross references + self.inlined = False self.generated = False + self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op) + self.bin_op_scope = None + self.bin_op_vector_indexes = set() + self.bin_op_vector_index_mapping = {} + self.bin_op_def = BinOpDef(self) def __str__(self): - return f"Expr<a: {self.lhs}, b: {self.rhs}, op: {self.op}>" + return f"BinOp<a: {self.lhs.id()}, b: {self.rhs.id()}, op: {self.op}>" def __add__(self, other): - return Expr(self.sim, self, other, '+') + return BinOp(self.sim, self, other, '+') def __radd__(self, other): - return Expr(self.sim, other, self, '+') + return BinOp(self.sim, other, self, '+') def __sub__(self, other): - return Expr(self.sim, self, other, '-') + return BinOp(self.sim, self, other, '-') def __mul__(self, other): - return Expr(self.sim, self, other, '*') + return BinOp(self.sim, self, other, '*') def __rmul__(self, other): - return Expr(self.sim, other, self, '*') + return BinOp(self.sim, other, self, '*') def __truediv__(self, other): - return Expr(self.sim, self, other, '/') + return BinOp(self.sim, self, other, '/') def __rtruediv__(self, other): - return Expr(self.sim, other, self, '/') + return BinOp(self.sim, other, self, '/') def __lt__(self, other): - return Expr(self.sim, self, other, '<') + return BinOp(self.sim, self, other, '<') def __le__(self, other): - return Expr(self.sim, self, other, '<=') + return BinOp(self.sim, self, other, '<=') def __gt__(self, other): - return Expr(self.sim, self, other, '>') + return BinOp(self.sim, self, other, '>') def __ge__(self, other): - return Expr(self.sim, self, other, '>=') + return BinOp(self.sim, self, other, '>=') def and_op(self, other): - return Expr(self.sim, self, other, '&&') + return BinOp(self.sim, self, other, '&&') def cmp(lhs, rhs): - return Expr(lhs.sim, lhs, rhs, '==') + return BinOp(lhs.sim, lhs, rhs, '==') def neq(lhs, rhs): - return Expr(lhs.sim, lhs, rhs, '!=') + return BinOp(lhs.sim, lhs, rhs, '!=') def inv(self): - return Expr(self.sim, 1.0, self, '/') + return BinOp(self.sim, 1.0, self, '/') - def match(self, expr): - return self.lhs == expr.lhs and \ - self.rhs == expr.rhs and \ - self.op == expr.op + def match(self, bin_op): + return self.lhs == bin_op.lhs and \ + self.rhs == bin_op.rhs and \ + self.op == bin_op.operator() def x(self): return self.__getitem__(0) @@ -86,12 +138,35 @@ class Expr: def z(self): return self.__getitem__(2) + def map_vector_index(self, index, expr): + self.bin_op_vector_index_mapping[index] = expr + + def mapped_vector_index(self, index): + mapping = self.bin_op_vector_index_mapping + return mapping[index] if index in mapping else as_lit_ast(self.sim, index) + + def vector_indexes(self): + return self.bin_op_vector_indexes + + def propagate_vector_access(self, index): + self.bin_op_vector_indexes.add(index) + + if isinstance(self.lhs, BinOp) and self.lhs.kind() == BinOp.Kind_Vector: + self.lhs.propagate_vector_access(index) + + if isinstance(self.rhs, BinOp) and self.rhs.kind() == BinOp.Kind_Vector: + self.rhs.propagate_vector_access(index) + def __getitem__(self, index): - assert self.lhs.type() == Type_Vector, "Cannot use operator [] on specified type!" - return ExprVec(self.sim, self, as_lit_ast(self.sim, index)) + assert self.type() == Type_Vector, "Cannot use operator [] on specified type!" + self.propagate_vector_access(index) + return BinOp(self.sim, self, as_lit_ast(self.sim, index), '[]', self.mem) + + def is_property_access(self): + return isinstance(self.lhs, Property) and self.operator() == '[]' - def generated_vector_index(self, index): - return not [i for i in self.vec_generated if i == index] + def is_vector_property_access(self): + return self.is_property_access() and self.type() == Type_Vector def set(self, other): assert self.mem is True, "Invalid assignment: lvalue expected!" @@ -101,6 +176,10 @@ class Expr: assert self.mem is True, "Invalid assignment: lvalue expected!" return self.sim.add_statement(Assign(self.sim, self, self + other)) + def sub(self, other): + assert self.mem is True, "Invalid assignment: lvalue expected!" + return self.sim.add_statement(Assign(self.sim, self, self - other)) + def infer_type(lhs, rhs, op): lhs_type = lhs.type() rhs_type = rhs.type() @@ -126,153 +205,73 @@ class Expr: if lhs_type == Type_Float or rhs_type == Type_Float: return Type_Float - print(f"{lhs} ({lhs_type}) -- {rhs} ({rhs_type})\n") return None - def type(self): - return self.expr_type + def inline_rec(self): + self.inlined = True - def is_mutable(self): - return self.mutable + if isinstance(self.lhs, BinOp): + self.lhs.inline_rec() - def scope(self): - if self.expr_scope is None: - lhs_scp = self.lhs.scope() - rhs_scp = self.rhs.scope() - self.expr_scope = lhs_scp if lhs_scp > rhs_scp else rhs_scp - - return self.expr_scope + if isinstance(self.rhs, BinOp): + self.rhs.inline_rec() - def children(self): - return [self.lhs, self.rhs] + return self - def generate(self, mem=False): - lhs_expr = self.lhs.generate(mem) - rhs_expr = self.rhs.generate() - if self.op == '[]': - return self.sim.code_gen.generate_expr_access(lhs_expr, rhs_expr, self.mem) + def id(self): + return self.bin_op_id - if self.generated is False: - assert self.expr_type != Type_Vector, \ - "Vector code must be generated through ExprVec class!" - - self.sim.code_gen.generate_expr(self.expr_id, self.expr_type, lhs_expr, rhs_expr, self.op) - self.generated = True - - return self.sim.code_gen.generate_expr_ref(self.expr_id) - - def generate_inline(self, mem=False, recursive=False): - inline_lhs_expr = recursive and isinstance(self.lhs, Expr) - inline_rhs_expr = recursive and isinstance(self.rhs, Expr) - lhs_expr = \ - self.lhs.generate_inline(recursive, mem) if inline_lhs_expr \ - else self.lhs.generate(mem) - rhs_expr = \ - self.rhs.generate_inline(recursive) if inline_rhs_expr \ - else self.rhs.generate() - - if self.op == '[]': - return self.sim.code_gen.generate_expr_access(lhs_expr, rhs_expr, self.mem) - - assert self.expr_type != Type_Vector, \ - "Vector code must be generated through ExprVec class!" - return self.sim.code_gen.generate_inline_expr(lhs_expr, rhs_expr, self.op) - - def transform(self, fn): - self.lhs = self.lhs.transform(fn) - self.rhs = self.rhs.transform(fn) - return fn(self) - - -class ExprVec(): - def __init__(self, sim, expr, index): - self.sim = sim - self.expr = expr - self.index = index - self.expr_scope = None - self.mutable = self.expr.is_mutable() or self.index.is_mutable() - self.lhs = (expr.lhs if not isinstance(expr.lhs, Expr) - else ExprVec(sim, expr.lhs, index)) - self.rhs = (expr.rhs if not isinstance(expr.rhs, Expr) - else ExprVec(sim, expr.rhs, index)) - - def __str__(self): - return (f"ExprVec<a: {self.lhs}, b: {self.rhs}, " + - f"op: {self.expr.op} i: {self.index}>") - - def __add__(self, other): - return Expr(self.sim, self, other, '+') - - def __radd__(self, other): - return Expr(self.sim, other, self, '+') - - def __sub__(self, other): - return Expr(self.sim, self, other, '-') - - def __mul__(self, other): - return Expr(self.sim, self, other, '*') - - def __lt__(self, other): - return Expr(self.sim, self, other, '<') - - def __le__(self, other): - return Expr(self.sim, self, other, '<=') - - def __gt__(self, other): - return Expr(self.sim, self, other, '>') - - def __ge__(self, other): - return Expr(self.sim, self, other, '>=') - - def set(self, other): - return self.sim.add_statement(Assign(self.sim, self, other)) - - def add(self, other): - return self.sim.add_statement(Assign(self.sim, self, self + other)) + def type(self): + return self.bin_op_type - def sub(self, other): - return self.sim.add_statement(Assign(self.sim, self, self - other)) + def operator(self): + return self.op - def type(self): - return Type_Float + def kind(self): + return BinOp.Kind_Vector if self.type() == Type_Vector else BinOp.Kind_Scalar def is_mutable(self): return self.mutable def scope(self): - if self.expr_scope is None: - expr_scp = self.expr.scope() - index_scp = self.index.scope() - self.expr_scope = expr_scp if expr_scp > index_scp else index_scp + if self.bin_op_scope is None: + lhs_scp = self.lhs.scope() + rhs_scp = self.rhs.scope() + self.bin_op_scope = lhs_scp if lhs_scp > rhs_scp else rhs_scp - return self.expr_scope + return self.bin_op_scope def children(self): - return [self.lhs, self.rhs, self.index] + return [self.lhs, self.rhs] - def generate(self, mem=False): - if self.expr.type() != Type_Vector: - return self.expr.generate() + def generate(self, mem=False, index=None): + if isinstance(self.lhs, BinOp) and self.lhs.kind() == BinOp.Kind_Vector and self.op == '[]': + return self.lhs.generate(self.mem, self.rhs.generate()) + + lhs = self.lhs.generate(mem, index) + rhs = self.rhs.generate(index=index) - index_expr = self.index.generate() - if self.expr.op == '[]': - return self.sim.code_gen.generate_expr_access(self.expr.generate(), index_expr, True) + if self.op == '[]': + idx = self.mapped_vector_index(index).generate() if self.is_vector_property_access() else rhs + return self.sim.code_gen.generate_expr_access(lhs, idx, self.mem) - if self.expr.generated_vector_index(index_expr): - self.sim.code_gen.generate_vec_expr( - self.expr.expr_id, - index_expr, - self.lhs.generate(mem), - self.rhs.generate(), - self.expr.op, - self.expr.mem) + if self.inlined is True: + assert self.bin_op_type != Type_Vector, "Vector operations cannot be inlined!" + return self.sim.code_gen.generate_inline_expr(lhs, rhs, self.op) - self.expr.vec_generated.append(index_expr) + # Some expressions can be defined on-the-fly during transformations, hence they do not have + # a definition statement, so we generate them right before usage + if not self.generated: + self.bin_op_def.generate() - return self.sim.code_gen.generate_vec_expr_ref(self.expr.expr_id, index_expr, self.expr.mem) + if self.kind() == BinOp.Kind_Vector: + assert index is not None, "Index must be set for vector reference!" + return self.sim.code_gen.generate_vec_expr_ref(self.id(), index, self.mem) + + return self.sim.code_gen.generate_expr_ref(self.id()) def transform(self, fn): self.lhs = self.lhs.transform(fn) self.rhs = self.rhs.transform(fn) - self.index = self.index.transform(fn) return fn(self) + diff --git a/ast/lit.py b/ast/lit.py index ff09f2397178eeceb8a0dbc8cf9ef96f2c6d5f38..58b25a7c2a0dbdbab12bf8f00eab394f81f45218 100644 --- a/ast/lit.py +++ b/ast/lit.py @@ -57,11 +57,7 @@ class Lit: def children(self): return [] - def generate(self, mem=False): - assert mem is False, "Literal is not lvalue!" - return self.value - - def generate_inline(self, mem=False, recursive=False): + def generate(self, mem=False, index=None): assert mem is False, "Literal is not lvalue!" return self.value diff --git a/ast/loops.py b/ast/loops.py index d7ab372d6b63431ba01490da06b639ca9c2305ae..bd81620473b8ea31648c32f5de0ec04906261633 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -1,7 +1,7 @@ from ast.block import Block from ast.branches import Filter from ast.data_types import Type_Int -from ast.expr import Expr +from ast.expr import BinOp from ast.lit import as_lit_ast @@ -30,18 +30,18 @@ class Iter(): return self.loop.block def __add__(self, other): - return Expr(self.sim, self, other, '+') + return BinOp(self.sim, self, other, '+') def __sub__(self, other): - return Expr(self.sim, self, other, '-') + return BinOp(self.sim, self, other, '-') def __mul__(self, other): - from ast.expr import Expr - return Expr(self.sim, self, other, '*') + from ast.expr import BinOp + return BinOp(self.sim, self, other, '*') def __rmul__(self, other): - from ast.expr import Expr - return Expr(self.sim, other, self, '*') + from ast.expr import BinOp + return BinOp(self.sim, other, self, '*') def __eq__(self, other): if isinstance(other, Iter): @@ -53,8 +53,8 @@ class Iter(): return self.__cmp__(other) def __mod__(self, other): - from ast.expr import Expr - return Expr(self.sim, self, other, '%') + from ast.expr import BinOp + return BinOp(self.sim, self, other, '%') def __str__(self): return f"Iter<{self.iter_id}>" @@ -62,7 +62,7 @@ class Iter(): def children(self): return [] - def generate(self, mem=False): + def generate(self, mem=False, index=None): assert mem is False, "Iterator is not lvalue!" return f"i{self.iter_id}" @@ -128,9 +128,10 @@ class ParticleFor(For): class While(): def __init__(self, sim, cond, block=None): + from ast.expr import BinOp self.sim = sim self.parent_block = None - self.cond = cond + self.cond = BinOp.inline(cond) self.block = Block(sim, []) if block is None else block def __str__(self): @@ -149,10 +150,7 @@ class While(): return [self.cond, self.block] def generate(self): - from ast.expr import Expr - cond_gen = (self.cond.generate() if not isinstance(self.cond, Expr) - else self.cond.generate_inline()) - self.sim.code_gen.generate_while_preamble(cond_gen) + self.sim.code_gen.generate_while_preamble(self.cond.generate()) self.block.generate() self.sim.code_gen.generate_while_epilogue() @@ -177,9 +175,9 @@ class NeighborFor(): for s in For(self.sim, 0, cl.nstencil): neigh_cell = cl.particle_cell[self.particle] + cl.stencil[s] for _ in Filter(self.sim, - Expr.and_op(neigh_cell >= 0, - neigh_cell <= cl.ncells_all)): + BinOp.and_op(neigh_cell >= 0, + neigh_cell <= cl.ncells_all)): for nc in For(self.sim, 0, cl.cell_sizes[neigh_cell]): it = cl.cell_particles[neigh_cell][nc] - for _ in Filter(self.sim, Expr.neq(it, self.particle)): + for _ in Filter(self.sim, BinOp.neq(it, self.particle)): yield it diff --git a/ast/math.py b/ast/math.py index bee39951a0894197903867afb999b4b86af66c32..656543c97c4b200eef0568f2014a43087f33495a 100644 --- a/ast/math.py +++ b/ast/math.py @@ -21,7 +21,7 @@ class Sqrt: def children(self): return [self.expr] - def generate(self, mem=False): + def generate(self, mem=False, index=None): return self.sim.code_gen.generate_sqrt(self.expr.generate()) def transform(self, fn): diff --git a/ast/memory.py b/ast/memory.py index e32ecc1ba3e4a83b8b181a3da780acde77f6f463..205bd445033f83276ecff3c87fa459424be1aaa9 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -1,3 +1,4 @@ +from ast.expr import BinOp from ast.sizeof import Sizeof from functools import reduce import operator @@ -11,15 +12,14 @@ class Malloc: self.array_type = a_type self.decl = decl self.prim_size = Sizeof(sim, a_type) - self.size = self.prim_size * (reduce(operator.mul, sizes) if isinstance(sizes, list) else sizes) + self.size = BinOp.inline(self.prim_size * (reduce(operator.mul, sizes) if isinstance(sizes, list) else sizes)) self.sim.add_statement(self) def children(self): return [self.array, self.size] - def generate(self, mem=False): - self.sim.code_gen.generate_malloc( - self.array.generate(), self.array_type, self.size.generate_inline(recursive=True), self.decl) + def generate(self, mem=False, index=None): + self.sim.code_gen.generate_malloc(self.array.generate(), self.array_type, self.size.generate(), self.decl) def transform(self, fn): self.array = self.array.transform(fn) @@ -34,15 +34,14 @@ class Realloc: self.array = array self.array_type = a_type self.prim_size = Sizeof(sim, a_type) - self.size = self.prim_size * size + self.size = BinOp.inline(self.prim_size * size) self.sim.add_statement(self) def children(self): return [self.array, self.size] - def generate(self, mem=False): - self.sim.code_gen.generate_realloc( - self.array.generate(), self.array_type, self.size.generate_inline(recursive=True)) + def generate(self, mem=False, index=None): + self.sim.code_gen.generate_realloc(self.array.generate(), self.array_type, self.size.generate()) def transform(self, fn): self.array = self.array.transform(fn) diff --git a/ast/properties.py b/ast/properties.py index 3e261beb111a66f93cc82226472953e69fb01137..0c6b2a44bb600e1b8ba22d3345a3234c21bd5710 100644 --- a/ast/properties.py +++ b/ast/properties.py @@ -71,13 +71,13 @@ class Property: return self.sim.global_scope def __getitem__(self, expr_ast): - from ast.expr import Expr - return Expr(self.sim, self, expr_ast, '[]', True) + from ast.expr import BinOp + return BinOp(self.sim, self, expr_ast, '[]', True) def children(self): return [] - def generate(self, mem=False): + def generate(self, mem=False, index=None): return self.prop_name def transform(self, fn): diff --git a/ast/select.py b/ast/select.py index 0dae86bc0ee96e36c31ba216dda3a568187ab41c..8ba489c307459b8eb5a7696f875d34597d967934 100644 --- a/ast/select.py +++ b/ast/select.py @@ -1,4 +1,4 @@ -from ast.expr import Expr +from ast.expr import BinOp from ast.lit import as_lit_ast @@ -6,22 +6,14 @@ class Select: def __init__(self, sim, cond, expr_if, expr_else): self.sim = sim self.cond = as_lit_ast(sim, cond) - self.expr_if = as_lit_ast(sim, expr_if) - self.expr_else = as_lit_ast(sim, expr_else) + self.expr_if = BinOp.inline(as_lit_ast(sim, expr_if)) + self.expr_else = BinOp.inline(as_lit_ast(sim, expr_else)) def children(self): return [self.cond, self.expr_if, self.expr_else] def generate(self): - cond_gen = self.cond.generate() - expr_if_gen = \ - self.expr_if.generate_inline() if isinstance(self.expr_if, Expr) \ - else self.expr_if.generate() - expr_else_gen = \ - self.expr_else.generate_inline() if isinstance(self.expr_else, Expr) \ - else self.expr_else.generate() - - return self.sim.code_gen.generate_select(cond_gen, expr_if_gen, expr_else_gen) + return self.sim.code_gen.generate_select(self.cond.generate(), self.expr_if.generate(), self.expr_else.generate()) def transform(self, fn): self.cond = self.cond.transform(fn) diff --git a/ast/sizeof.py b/ast/sizeof.py index df171a81a50fb615f02d74ee4da6314b8eb670da..ee2598771d10a9a8efb00f8ba2f4b1e33441240f 100644 --- a/ast/sizeof.py +++ b/ast/sizeof.py @@ -1,5 +1,5 @@ from ast.data_types import Type_Int -from ast.expr import Expr +from ast.expr import BinOp class Sizeof: @@ -8,7 +8,7 @@ class Sizeof: self.data_type = data_type def __mul__(self, other): - return Expr(self.sim, self, other, '*') + return BinOp(self.sim, self, other, '*') def type(self): return Type_Int @@ -22,7 +22,7 @@ class Sizeof: def children(self): return [] - def generate(self, mem=False): + def generate(self, mem=False, index=None): return self.sim.code_gen.generate_sizeof(self.data_type) def transform(self, fn): diff --git a/ast/transform.py b/ast/transform.py index 1dad678fe46eee9210f9c9ac3c508814aadbfd50..612b0cfe44508f4a60c74ce2660fc94ce248c5aa 100644 --- a/ast/transform.py +++ b/ast/transform.py @@ -1,6 +1,6 @@ from ast.arrays import ArrayAccess from ast.data_types import Type_Int, Type_Vector -from ast.expr import Expr, ExprVec +from ast.expr import BinOp from ast.layouts import Layout_AoS, Layout_SoA from ast.lit import Lit from ast.loops import Iter @@ -8,40 +8,30 @@ from ast.properties import Property class Transform: - flattened_list = [] reuse_expressions = {} def apply(ast, fn): ast.transform(fn) - Transform.flattened_list = [] Transform.reuse_expressions = {} def flatten(ast): - if isinstance(ast, ExprVec): - if ast.expr.op == '[]' and ast.expr.type() == Type_Vector: - item = [f for f in Transform.flattened_list if - f[0] == ast.expr.lhs and - f[1] == ast.index and - f[2] == ast.expr.rhs and - not ast.expr.rhs.is_mutable()] - if item: - return item[0][3] - + if isinstance(ast, BinOp): + if ast.is_vector_property_access(): layout = ast.lhs.layout() - flat_index = None - if layout == Layout_AoS: - flat_index = ast.expr.rhs * ast.expr.sim.dimensions + ast.index + for i in ast.vector_indexes(): + flat_index = None - elif layout == Layout_SoA: - flat_index = ast.index * ast.expr.sim.particle_capacity + ast.expr.rhs + if layout == Layout_AoS: + flat_index = ast.rhs * ast.sim.dimensions + i - else: - raise Exception("Invalid property layout!") + elif layout == Layout_SoA: + flat_index = i * ast.sim.particle_capacity + ast.rhs + + else: + raise Exception("Invalid property layout!") - new_expr = Expr(ast.expr.sim, ast.expr.lhs, flat_index, '[]', ast.expr.mem) - Transform.flattened_list.append((ast.expr.lhs, ast.index, ast.expr.rhs, new_expr)) - return new_expr + ast.map_vector_index(i, flat_index) if isinstance(ast, Property): ast.flattened = True @@ -49,13 +39,13 @@ class Transform: return ast def simplify(ast): - if isinstance(ast, Expr): + if isinstance(ast, BinOp): sim = ast.lhs.sim if ast.op in ['+', '-'] and ast.rhs == 0: return ast.lhs - if ast.op in ['+', '-'] and ast.lhs == 0: + if ast.op in ['+'] and ast.lhs == 0: return ast.rhs if ast.op in ['*', '/'] and ast.rhs == 1: @@ -70,7 +60,7 @@ class Transform: return ast def reuse_index_expressions(ast): - if isinstance(ast, Expr): + if isinstance(ast, BinOp): iter_id = None if isinstance(ast.lhs, Iter): @@ -94,13 +84,13 @@ class Transform: return ast def reuse_expr_expressions(ast): - if isinstance(ast, Expr): + if isinstance(ast, BinOp): expr_id = None - if isinstance(ast.lhs, Expr): + if isinstance(ast.lhs, BinOp): expr_id = ast.lhs.expr_id - if isinstance(ast.rhs, Expr): + if isinstance(ast.rhs, BinOp): expr_id = ast.rhs.expr_id if expr_id is not None: @@ -118,7 +108,7 @@ class Transform: return ast def reuse_array_access_expressions(ast): - if isinstance(ast, Expr): + if isinstance(ast, BinOp): acc_id = None if isinstance(ast.lhs, ArrayAccess): @@ -142,7 +132,7 @@ class Transform: return ast def move_loop_invariant_expressions(ast): - if isinstance(ast, Expr): + if isinstance(ast, BinOp): scope = ast.scope() if scope.level > 0: scope.add_expression(ast) diff --git a/ast/variables.py b/ast/variables.py index 9a075eba90d8eda9ce78be3ec3a422a2b28a902a..9604b28cdc97a4fb88ad12fbaee2829af875c4b3 100644 --- a/ast/variables.py +++ b/ast/variables.py @@ -1,5 +1,5 @@ from ast.assign import Assign -from ast.expr import Expr +from ast.expr import BinOp class Variables: @@ -36,40 +36,40 @@ class Var: return f"Var<name: {self.var_name}, type: {self.var_type}>" def __add__(self, other): - return Expr(self.sim, self, other, '+') + return BinOp(self.sim, self, other, '+') def __radd__(self, other): - return Expr(self.sim, other, self, '+') + return BinOp(self.sim, other, self, '+') def __sub__(self, other): - return Expr(self.sim, self, other, '-') + return BinOp(self.sim, self, other, '-') def __mul__(self, other): - return Expr(self.sim, self, other, '*') + return BinOp(self.sim, self, other, '*') def __rmul__(self, other): - return Expr(self.sim, other, self, '*') + return BinOp(self.sim, other, self, '*') def __truediv__(self, other): - return Expr(self.sim, self, other, '/') + return BinOp(self.sim, self, other, '/') def __rtruediv__(self, other): - return Expr(self.sim, other, self, '/') + return BinOp(self.sim, other, self, '/') def __lt__(self, other): - return Expr(self.sim, self, other, '<') + return BinOp(self.sim, self, other, '<') def __le__(self, other): - return Expr(self.sim, self, other, '<=') + return BinOp(self.sim, self, other, '<=') def __gt__(self, other): - return Expr(self.sim, self, other, '>') + return BinOp(self.sim, self, other, '>') def __ge__(self, other): - return Expr(self.sim, self, other, '>=') + return BinOp(self.sim, self, other, '>=') def inv(self): - return Expr(self.sim, 1.0, self, '/') + return BinOp(self.sim, 1.0, self, '/') def set(self, other): return self.sim.add_statement(Assign(self.sim, self, other)) @@ -104,7 +104,7 @@ class Var: def children(self): return [] - def generate(self, mem=False): + def generate(self, mem=False, index=None): return self.var_name def transform(self, fn): @@ -120,7 +120,7 @@ class VarDecl: def children(self): return [] - def generate(self, mem=False): + def generate(self, mem=False, index=None): self.sim.code_gen.generate_var_decl( self.var.name(), self.var.type(), self.var.init_value()) diff --git a/sim/cell_lists.py b/sim/cell_lists.py index cd412cde096dc8f21f7ded474dc6023ee236d2cb..fe69d06e487072309a791513db1ef58082f4024e 100644 --- a/sim/cell_lists.py +++ b/sim/cell_lists.py @@ -1,7 +1,7 @@ from ast.branches import Branch, Filter from ast.cast import Cast from ast.data_types import Type_Int -from ast.expr import Expr +from ast.expr import BinOp from ast.loops import For, ParticleFor from ast.utils import Print from functools import reduce @@ -95,7 +95,7 @@ class CellListsBuild: else flat_idx * cl.ncells[d] + cell_index[d]) cell_size = cl.cell_sizes[flat_idx] - for _ in Filter(cl.sim, Expr.and_op(flat_idx >= 0, flat_idx <= cl.ncells_all)): + for _ in Filter(cl.sim, BinOp.and_op(flat_idx >= 0, flat_idx <= cl.ncells_all)): for cond in Branch(cl.sim, cell_size >= cl.cell_capacity): if cond: resize.set(cell_size) diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index da346ddbcdc8ba738ba5800b21f71bcecd46b785..5d982910f7b4c5d4a6e32da6803a4a189b4fb732 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -31,7 +31,6 @@ class ParticleSimulation: self.particle_capacity = self.add_var('particle_capacity', Type_Int, 10000) self.nlocal = self.add_var('nlocal', Type_Int) self.nghost = self.add_var('nghost', Type_Int) - self.nparticles = self.nlocal + self.nghost self.grid = None self.cell_lists = None self.pbc = None @@ -46,44 +45,39 @@ class ParticleSimulation: self.expr_id = 0 self.iter_id = 0 self.vtk_file = None + self.nparticles = self.nlocal + self.nghost self.properties.add_capacity(self.particle_capacity) def add_real_property(self, prop_name, value=0.0, vol=False): - assert self.property(prop_name) is None, \ - f"add_real_property(): Property already defined: {prop_name}" + assert self.property(prop_name) is None, f"Property already defined: {prop_name}" return self.properties.add(prop_name, Type_Float, value, vol) def add_vector_property(self, prop_name, value=[0.0, 0.0, 0.0], vol=False, layout=Layout_AoS): - assert self.property(prop_name) is None, \ - f"add_vector_property(): Property already defined: {prop_name}" + assert self.property(prop_name) is None, f"Property already defined: {prop_name}" return self.properties.add(prop_name, Type_Vector, value, vol, layout) def property(self, prop_name): return self.properties.find(prop_name) def add_array(self, arr_name, arr_sizes, arr_type, arr_layout=Layout_AoS): - assert self.array(arr_name) is None, \ - f"add_array(): Array already defined: {arr_name}" + assert self.array(arr_name) is None, f"Array already defined: {arr_name}" return self.arrays.add(arr_name, arr_sizes, arr_type, arr_layout) def add_static_array(self, arr_name, arr_sizes, arr_type, arr_layout=Layout_AoS): - assert self.array(arr_name) is None, \ - f"add_static_array(): Array already defined: {arr_name}" + assert self.array(arr_name) is None, f"Array already defined: {arr_name}" return self.arrays.add_static(arr_name, arr_sizes, arr_type, arr_layout) def array(self, arr_name): return self.arrays.find(arr_name) def add_var(self, var_name, var_type, init_value=0): - assert self.var(var_name) is None, \ - f"add_var(): Variable already defined: {var_name}" + assert self.var(var_name) is None, f"Variable already defined: {var_name}" return self.vars.add(var_name, var_type, init_value) def add_or_reuse_var(self, var_name, var_type, init_value=0): existing_var = self.var(var_name) if existing_var is not None: - assert existing_var.type() == var_type, \ - f"add_or_reuse_var(): Cannot reuse variable {var_name}: types differ!" + assert existing_var.type() == var_type, f"Cannot reuse variable {var_name}: types differ!" return existing_var return self.vars.add(var_name, var_type, init_value) @@ -200,7 +194,7 @@ class ParticleSimulation: self.global_scope = program Block.set_block_levels(program) Transform.apply(program, Transform.flatten) - Transform.apply(program, Transform.simplify) + #Transform.apply(program, Transform.simplify) #Transform.apply(program, Transform.reuse_index_expressions) #Transform.apply(program, Transform.reuse_expr_expressions) #Transform.apply(program, Transform.reuse_array_access_expressions) diff --git a/sim/pbc.py b/sim/pbc.py index 528cc0f60283f3a69ea4f813f99779bb1371d533..f6e7f66d9d0702b4f5728e6c35c65541409743aa 100644 --- a/sim/pbc.py +++ b/sim/pbc.py @@ -89,6 +89,7 @@ class SetupPBC: npbc.set(0) for d in range(0, ndims): for i in For(sim, 0, nlocal + npbc): + last_id = nlocal + npbc # TODO: VecFilter? for _ in Filter(sim, positions[i][d] < grid.min(d) + cutneigh): for capacity_exceeded in Branch(sim, npbc >= pbc_capacity): @@ -97,11 +98,11 @@ class SetupPBC: else: pbc_map[npbc].set(i) pbc_mult[npbc][d].set(1) - positions[nlocal + npbc][d].set(positions[i][d] + grid.length(d)) + positions[last_id][d].set(positions[i][d] + grid.length(d)) for d_ in [x for x in range(0, ndims) if x != d]: pbc_mult[npbc][d_].set(0) - positions[nlocal + npbc][d_].set(positions[i][d_]) + positions[last_id][d_].set(positions[i][d_]) npbc.add(1) @@ -112,11 +113,11 @@ class SetupPBC: else: pbc_map[npbc].set(i) pbc_mult[npbc][d].set(-1) - positions[nlocal + npbc][d].set(positions[i][d] - grid.length(d)) + positions[last_id][d].set(positions[i][d] - grid.length(d)) for d_ in [x for x in range(0, ndims) if x != d]: pbc_mult[npbc][d_].set(0) - positions[nlocal + npbc][d_].set(positions[i][d_]) + positions[last_id][d_].set(positions[i][d_]) npbc.add(1) diff --git a/sim/timestep.py b/sim/timestep.py index cf6b4dfaf8ec3a33de0189a50af2ec26465a9c74..5a676486e16c5529ebdd85ea95811cd13c2bcd8c 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -1,5 +1,5 @@ from ast.block import Block -from ast.expr import Expr +from ast.expr import BinOp from ast.branches import Branch from ast.loops import For @@ -34,7 +34,7 @@ class Timestep: if exec_every > 0: self.block.add_statement( - Branch(self.sim, Expr.cmp(ts % exec_every, 0), True if stmts_else is None else False, + Branch(self.sim, BinOp.cmp(ts % exec_every, 0), True if stmts_else is None else False, Block(self.sim, stmts), None if stmts_else is None else Block(self.sim, stmts_else))) else: self.block.add_statement(stmts)