diff --git a/ast/arrays.py b/ast/arrays.py index 198269318437a734be400b81841632f31962aeeb..e7fd2c707debeda3be8928b038dd0c403120627a 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -1,6 +1,7 @@ from ast.assign import Assign +from ast.ast_node import ASTNode +from ast.bin_op import BinOp, ASTTerm from ast.data_types import Type_Array -from ast.expr import BinOp from ast.layouts import Layout_AoS, Layout_SoA from ast.lit import as_lit_ast from ast.memory import Realloc @@ -35,9 +36,9 @@ class Arrays: return None -class Array: +class Array(ASTNode): def __init__(self, sim, a_name, a_sizes, a_type, a_layout=Layout_AoS): - self.sim = sim + super().__init__(sim) self.arr_name = a_name self.arr_sizes = \ [as_lit_ast(sim, a_sizes)] if not isinstance(a_sizes, list) \ @@ -65,9 +66,6 @@ class Array: def layout(self): return self.arr_layout - def scope(self): - return self.sim.global_scope - def ndims(self): return self.arr_ndims @@ -77,12 +75,6 @@ class Array: def alloc_size(self): return reduce((lambda x, y: x * y), [s for s in self.arr_sizes]) - def children(self): - return [] - - def transform(self, fn): - return fn(self) - class ArrayStatic(Array): def __init__(self, sim, a_name, a_sizes, a_type, a_layout=Layout_AoS): @@ -110,7 +102,7 @@ class ArrayND(Array): return Realloc(self.sim, self, self.alloc_size()) -class ArrayAccess: +class ArrayAccess(ASTTerm): last_acc = 0 def new_id(): @@ -118,7 +110,7 @@ class ArrayAccess: return ArrayAccess.last_acc - 1 def __init__(self, sim, array, index): - self.sim = sim + super().__init__(sim) self.acc_id = ArrayAccess.new_id() self.array = array self.indexes = [as_lit_ast(sim, index)] @@ -130,15 +122,6 @@ class ArrayAccess: def __str__(self): return f"ArrayAccess<array: {self.array}, indexes: {self.indexes}>" - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - return BinOp(self.sim, other, self, '*') - def __getitem__(self, index): assert self.index is None, "Number of indexes higher than array dimension!" self.indexes.append(as_lit_ast(self.sim, index)) @@ -206,14 +189,8 @@ class ArrayAccess: return fn(self) -class ArrayDecl: +class ArrayDecl(ASTNode): def __init__(self, sim, array): - self.sim = sim + super().__init__(sim) self.array = array self.sim.add_statement(self) - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/assign.py b/ast/assign.py index 59f11e4e35c7bcfaaf4c1aa2a1d0d6f074068ae4..8ce1d7b8ac8e3027dc3f2b2e4a78cefb210e04b6 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -1,11 +1,12 @@ +from ast.ast_node import ASTNode from ast.data_types import Type_Vector from ast.lit import as_lit_ast from functools import reduce -class Assign: +class Assign(ASTNode): def __init__(self, sim, dest, src): - self.sim = sim + super().__init__(sim) self.parent_block = None self.type = dest.type() src = as_lit_ast(sim, src) @@ -14,7 +15,7 @@ class Assign: self.assignments = [] for i in range(0, sim.dimensions): - from ast.expr import BinOp + from ast.bin_op import BinOp dsrc = (src if (not isinstance(src, BinOp) or src.type() != Type_Vector) else src[i]) diff --git a/ast/ast_node.py b/ast/ast_node.py new file mode 100644 index 0000000000000000000000000000000000000000..b0641f45c16a64efe7eb70ef1279b67f19691da3 --- /dev/null +++ b/ast/ast_node.py @@ -0,0 +1,24 @@ +from ast.data_types import Type_Invalid + + +class ASTNode: + def __init__(self, sim): + self.sim = sim + + def __str__(self): + return "ASTNode<>" + + def type(self): + return Type_Invalid + + def is_mutable(self): + return False + + def scope(self): + return self.sim.global_scope + + def children(self): + return [] + + def transform(self, fn): + return fn(self) diff --git a/ast/expr.py b/ast/bin_op.py similarity index 80% rename from ast/expr.py rename to ast/bin_op.py index 93db1df3775f001ccd9315f8d5fd6e02aeaa0730..4cc8dd56a03177e322d882ccc037445ebc43bf5f 100644 --- a/ast/expr.py +++ b/ast/bin_op.py @@ -1,11 +1,13 @@ +from ast.ast_node import ASTNode from ast.assign import Assign from ast.data_types import Type_Float, Type_Bool, Type_Vector from ast.lit import as_lit_ast from ast.properties import Property -class BinOpDef: +class BinOpDef(ASTNode): def __init__(self, bin_op): + super().__init__(bin_op.sim) self.bin_op = bin_op self.bin_op.sim.add_statement(self) @@ -20,7 +22,7 @@ class BinOpDef: return fn(self) -class BinOp: +class BinOp(ASTNode): # BinOp kinds Kind_Scalar = 0 Kind_Vector = 1 @@ -38,7 +40,7 @@ class BinOp: return op.inline_rec() def __init__(self, sim, lhs, rhs, op, mem=False): - self.sim = sim + super().__init__(sim) self.bin_op_id = BinOp.new_id() self.lhs = as_lit_ast(sim, lhs) self.rhs = as_lit_ast(sim, rhs) @@ -56,51 +58,6 @@ class BinOp: def __str__(self): return f"BinOp<a: {self.lhs.id()}, b: {self.rhs.id()}, op: {self.op}>" - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __radd__(self, other): - return BinOp(self.sim, other, self, '+') - - def __sub__(self, other): - return BinOp(self.sim, self, other, '-') - - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - return BinOp(self.sim, other, self, '*') - - def __truediv__(self, other): - return BinOp(self.sim, self, other, '/') - - def __rtruediv__(self, other): - return BinOp(self.sim, other, self, '/') - - def __lt__(self, other): - return BinOp(self.sim, self, other, '<') - - def __le__(self, other): - return BinOp(self.sim, self, other, '<=') - - def __gt__(self, other): - return BinOp(self.sim, self, other, '>') - - def __ge__(self, other): - return BinOp(self.sim, self, other, '>=') - - def and_op(self, other): - return BinOp(self.sim, self, other, '&&') - - def cmp(lhs, rhs): - return BinOp(lhs.sim, lhs, rhs, '==') - - def neq(lhs, rhs): - return BinOp(lhs.sim, lhs, rhs, '!=') - - def inv(self): - return BinOp(self.sim, 1.0, self, '/') - def match(self, bin_op): return self.lhs == bin_op.lhs and \ self.rhs == bin_op.rhs and \ @@ -230,3 +187,103 @@ class BinOp: self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()} return fn(self) + def __add__(self, other): + return BinOp(self.sim, self, other, '+') + + def __radd__(self, other): + return BinOp(self.sim, other, self, '+') + + def __sub__(self, other): + return BinOp(self.sim, self, other, '-') + + def __mul__(self, other): + return BinOp(self.sim, self, other, '*') + + def __rmul__(self, other): + return BinOp(self.sim, other, self, '*') + + def __truediv__(self, other): + return BinOp(self.sim, self, other, '/') + + def __rtruediv__(self, other): + return BinOp(self.sim, other, self, '/') + + def __lt__(self, other): + return BinOp(self.sim, self, other, '<') + + def __le__(self, other): + return BinOp(self.sim, self, other, '<=') + + def __gt__(self, other): + return BinOp(self.sim, self, other, '>') + + def __ge__(self, other): + return BinOp(self.sim, self, other, '>=') + + def and_op(self, other): + return BinOp(self.sim, self, other, '&&') + + def cmp(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '==') + + def neq(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '!=') + + def inv(self): + return BinOp(self.sim, 1.0, self, '/') + + def __mod__(self, other): + return BinOp(self.sim, self, other, '%') + + +class ASTTerm(ASTNode): + def __init__(self, sim): + super().__init__(sim) + + def __add__(self, other): + return BinOp(self.sim, self, other, '+') + + def __radd__(self, other): + return BinOp(self.sim, other, self, '+') + + def __sub__(self, other): + return BinOp(self.sim, self, other, '-') + + def __mul__(self, other): + return BinOp(self.sim, self, other, '*') + + def __rmul__(self, other): + return BinOp(self.sim, other, self, '*') + + def __truediv__(self, other): + return BinOp(self.sim, self, other, '/') + + def __rtruediv__(self, other): + return BinOp(self.sim, other, self, '/') + + def __lt__(self, other): + return BinOp(self.sim, self, other, '<') + + def __le__(self, other): + return BinOp(self.sim, self, other, '<=') + + def __gt__(self, other): + return BinOp(self.sim, self, other, '>') + + def __ge__(self, other): + return BinOp(self.sim, self, other, '>=') + + def and_op(self, other): + return BinOp(self.sim, self, other, '&&') + + def cmp(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '==') + + def neq(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '!=') + + def inv(self): + return BinOp(self.sim, 1.0, self, '/') + + def __mod__(self, other): + return BinOp(self.sim, self, other, '%') diff --git a/ast/block.py b/ast/block.py index 4ed52ee9c8c38fac3f188683476dfb8fd75a54ad..45fbf74a8e11a4cc44c190e8d4f5be3c568f356a 100644 --- a/ast/block.py +++ b/ast/block.py @@ -1,9 +1,10 @@ +from ast.ast_node import ASTNode from ast.visitor import Visitor -class Block: +class Block(ASTNode): def __init__(self, sim, stmts): - self.sim = sim + super().__init__(sim) self.level = 0 self.expressions = [] diff --git a/ast/branches.py b/ast/branches.py index a3ff5f1edf5f079b102c647e6e254567a98fe69f..ab31b0b17dbe0ab049c79d63d26026d0fa6f32ff 100644 --- a/ast/branches.py +++ b/ast/branches.py @@ -1,8 +1,9 @@ +from ast.ast_node import ASTNode from ast.block import Block from ast.lit import as_lit_ast -class Branch: +class Branch(ASTNode): def __init__(self, sim, cond, one_way=False, blk_if=None, blk_else=None): self.sim = sim self.parent_block = None diff --git a/ast/cast.py b/ast/cast.py index 8036676441d3e652af15a3e6caf7b2ddb88b520e..5bf1ba93212eaa50c8ca0450c426dd2fed6c9f61 100644 --- a/ast/cast.py +++ b/ast/cast.py @@ -1,9 +1,10 @@ +from ast.ast_node import ASTNode from ast.data_types import Type_Int, Type_Float -class Cast: +class Cast(ASTNode): def __init__(self, sim, expr, cast_type): - self.sim = sim + super().__init__(sim) self.expr = expr self.cast_type = cast_type diff --git a/ast/lit.py b/ast/lit.py index b04a969b020905af961b12ed6255f5494c24f3f4..8ad0d8392c4063c8c366048c9ab14fe9925ebcf4 100644 --- a/ast/lit.py +++ b/ast/lit.py @@ -1,21 +1,18 @@ -from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool -from ast.data_types import Type_Vector +from ast.ast_node import ASTNode +from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector def is_literal(a): - return (isinstance(a, int) or - isinstance(a, float) or - isinstance(a, bool) or - isinstance(a, list)) + return isinstance(a, (int, float, bool, list)) def as_lit_ast(sim, a): return Lit(sim, a) if is_literal(a) else a -class Lit: +class Lit(ASTNode): def __init__(self, sim, value): - self.sim = sim + super().__init__(sim) self.value = value self.lit_type = Type_Invalid @@ -47,15 +44,3 @@ class Lit: def type(self): return self.lit_type - - def is_mutable(self): - return False - - def scope(self): - return self.sim.global_scope - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/loops.py b/ast/loops.py index d99c15595e223cd34dc79593d8b4b52780cd0516..902cec345ad65dc2968c8b9d7861e90f46a5c1cb 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -1,11 +1,12 @@ +from ast.ast_node import ASTNode +from ast.bin_op import BinOp, ASTTerm from ast.block import Block from ast.branches import Filter from ast.data_types import Type_Int -from ast.expr import BinOp from ast.lit import as_lit_ast -class Iter(): +class Iter(ASTTerm): last_iter = 0 def new_id(): @@ -13,7 +14,7 @@ class Iter(): return Iter.last_iter - 1 def __init__(self, sim, loop): - self.sim = sim + super().__init__(sim) self.loop = loop self.iter_id = Iter.new_id() @@ -23,26 +24,9 @@ class Iter(): def type(self): return Type_Int - def is_mutable(self): - return False - def scope(self): return self.loop.block - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __sub__(self, other): - return BinOp(self.sim, self, other, '-') - - def __mul__(self, other): - from ast.expr import BinOp - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - from ast.expr import BinOp - return BinOp(self.sim, other, self, '*') - def __eq__(self, other): if isinstance(other, Iter): return self.iter_id == other.iter_id @@ -52,23 +36,13 @@ class Iter(): def __req__(self, other): return self.__cmp__(other) - def __mod__(self, other): - from ast.expr import BinOp - return BinOp(self.sim, self, other, '%') - def __str__(self): return f"Iter<{self.iter_id}>" - def children(self): - return [] - - def transform(self, fn): - return fn(self) - -class For(): +class For(ASTNode): def __init__(self, sim, range_min, range_max, block=None): - self.sim = sim + super().__init__(sim) self.iterator = Iter(sim, self) self.min = as_lit_ast(sim, range_min) self.max = as_lit_ast(sim, range_max) @@ -108,10 +82,9 @@ class ParticleFor(For): return f"ParticleFor<>" -class While(): +class While(ASTNode): def __init__(self, sim, cond, block=None): - from ast.expr import BinOp - self.sim = sim + super().__init__(sim) self.parent_block = None self.cond = BinOp.inline(cond) self.block = Block(sim, []) if block is None else block diff --git a/ast/math.py b/ast/math.py index cc38fb84f6bfc18358259194cefeabdc1d94e65a..c872b5b1810183db3e6156a207d6538bdbb4aecc 100644 --- a/ast/math.py +++ b/ast/math.py @@ -1,9 +1,10 @@ +from ast.ast_node import ASTNode from ast.data_types import Type_Int, Type_Float -class Sqrt: +class Sqrt(ASTNode): def __init__(self, sim, expr, cast_type): - self.sim = sim + super().__init__(sim) self.expr = expr def __str__(self): diff --git a/ast/memory.py b/ast/memory.py index bb26d4c37146a25605f3a8f6e3eaa165a926edab..ef8abfed3c2f7a736861d04c4e188e457e56af16 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -1,12 +1,13 @@ -from ast.expr import BinOp +from ast.ast_node import ASTNode +from ast.bin_op import BinOp from ast.sizeof import Sizeof from functools import reduce import operator -class Malloc: +class Malloc(ASTNode): def __init__(self, sim, array, sizes, decl=False): - self.sim = sim + super().__init__(sim) self.parent_block = None self.array = array self.decl = decl @@ -23,9 +24,9 @@ class Malloc: return fn(self) -class Realloc: +class Realloc(ASTNode): def __init__(self, sim, array, size): - self.sim = sim + super().__init__(sim) self.parent_block = None self.array = array self.prim_size = Sizeof(sim, array.type()) diff --git a/ast/operators.py b/ast/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ast/properties.py b/ast/properties.py index bbb6465a5d550ca66d60246ba69088186d8ac631..44f141cca71eb234fe0c40ffb189faa45925d9e0 100644 --- a/ast/properties.py +++ b/ast/properties.py @@ -1,3 +1,4 @@ +from ast.ast_node import ASTNode from ast.layouts import Layout_AoS @@ -38,16 +39,15 @@ class Properties: return None -class Property: +class Property(ASTNode): def __init__(self, sim, name, dtype, default, volatile, layout=Layout_AoS): - self.sim = sim + super().__init__(sim) self.prop_name = name self.prop_type = dtype self.prop_layout = layout self.default_value = default self.volatile = volatile self.mutable = True - self.flattened = False def __str__(self): return f"Property<{self.prop_name}>" @@ -70,12 +70,6 @@ class Property: def scope(self): return self.sim.global_scope - def __getitem__(self, expr_ast): - from ast.expr import BinOp - return BinOp(self.sim, self, expr_ast, '[]', True) - - def children(self): - return [] - - def transform(self, fn): - return fn(self) + def __getitem__(self, expr): + from ast.bin_op import BinOp + return BinOp(self.sim, self, expr, '[]', True) diff --git a/ast/select.py b/ast/select.py index bdb154660cab9456f9573ed6bc977b19dadef41b..702bb8fb4a16e54fed93e717728e3439f35bdd21 100644 --- a/ast/select.py +++ b/ast/select.py @@ -1,10 +1,11 @@ -from ast.expr import BinOp +from ast.ast_node import ASTNode +from ast.bin_op import BinOp from ast.lit import as_lit_ast -class Select: +class Select(ASTNode): def __init__(self, sim, cond, expr_if, expr_else): - self.sim = sim + super().__init__(sim) self.cond = as_lit_ast(sim, cond) self.expr_if = BinOp.inline(as_lit_ast(sim, expr_if)) self.expr_else = BinOp.inline(as_lit_ast(sim, expr_else)) diff --git a/ast/sizeof.py b/ast/sizeof.py index e18f89ef312c99d76148881654e8fe65b2aeb231..7a551066007ab0add9324b24d750d8b9c275440e 100644 --- a/ast/sizeof.py +++ b/ast/sizeof.py @@ -1,26 +1,11 @@ +from ast.bin_op import ASTTerm from ast.data_types import Type_Int -from ast.expr import BinOp -class Sizeof: +class Sizeof(ASTTerm): def __init__(self, sim, data_type): - self.sim = sim + super().__init__(sim) self.data_type = data_type - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - def type(self): return Type_Int - - def is_mutable(self): - return False - - def scope(self): - return self.sim.global_scope - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/transform.py b/ast/transform.py index 612b0cfe44508f4a60c74ce2660fc94ce248c5aa..428d372831d164864e81ef7de940c63a729b171b 100644 --- a/ast/transform.py +++ b/ast/transform.py @@ -1,6 +1,6 @@ from ast.arrays import ArrayAccess +from ast.bin_op import BinOp from ast.data_types import Type_Int, Type_Vector -from ast.expr import BinOp from ast.layouts import Layout_AoS, Layout_SoA from ast.lit import Lit from ast.loops import Iter @@ -33,9 +33,6 @@ class Transform: ast.map_vector_index(i, flat_index) - if isinstance(ast, Property): - ast.flattened = True - return ast def simplify(ast): @@ -130,11 +127,3 @@ class Transform: Transform.reuse_expressions[acc_id].append(ast) return ast - - def move_loop_invariant_expressions(ast): - if isinstance(ast, BinOp): - scope = ast.scope() - if scope.level > 0: - scope.add_expression(ast) - - return ast diff --git a/ast/utils.py b/ast/utils.py index 922dcbe8507b730090c5bb1f49fb51d434e6f86a..2bc3e43710663a915942cd43f99809418bb6a6f3 100644 --- a/ast/utils.py +++ b/ast/utils.py @@ -1,13 +1,10 @@ -class Print: +from ast.ast_node import ASTNode + + +class Print(ASTNode): def __init__(self, sim, string): - self.sim = sim + super().__init__(sim) self.string = string def __str__(self): return f"Print<{self.string}>" - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/variables.py b/ast/variables.py index 5f1271949fe2c82e361ae4fc4c8a2668d79683d7..34f9a9d9e818619f2c157a0d65d82ce9a3a73ec8 100644 --- a/ast/variables.py +++ b/ast/variables.py @@ -1,5 +1,6 @@ +from ast.ast_node import ASTNode from ast.assign import Assign -from ast.expr import BinOp +from ast.bin_op import ASTTerm class Variables: @@ -23,9 +24,9 @@ class Variables: return None -class Var: +class Var(ASTTerm): def __init__(self, sim, var_name, var_type, init_value=0): - self.sim = sim + super().__init__(sim) self.var_name = var_name self.var_type = var_type self.var_init_value = init_value @@ -35,42 +36,6 @@ class Var: def __str__(self): return f"Var<name: {self.var_name}, type: {self.var_type}>" - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __radd__(self, other): - return BinOp(self.sim, other, self, '+') - - def __sub__(self, other): - return BinOp(self.sim, self, other, '-') - - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - return BinOp(self.sim, other, self, '*') - - def __truediv__(self, other): - return BinOp(self.sim, self, other, '/') - - def __rtruediv__(self, other): - return BinOp(self.sim, other, self, '/') - - def __lt__(self, other): - return BinOp(self.sim, self, other, '<') - - def __le__(self, other): - return BinOp(self.sim, self, other, '<=') - - def __gt__(self, other): - return BinOp(self.sim, self, other, '>') - - def __ge__(self, other): - return BinOp(self.sim, self, other, '>=') - - def inv(self): - return BinOp(self.sim, 1.0, self, '/') - def set(self, other): return self.sim.add_statement(Assign(self.sim, self, other)) @@ -98,24 +63,9 @@ class Var: def is_mutable(self): return self.mutable - def scope(self): - return self.sim.global_scope - - def children(self): - return [] - - def transform(self, fn): - return fn(self) - -class VarDecl: +class VarDecl(ASTNode): def __init__(self, sim, var): - self.sim = sim + super().__init__(sim) self.var = var self.sim.add_statement(self) - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/code_gen/cgen.py b/code_gen/cgen.py index 96e0d9d132fbab3539d133f0250157b0f753b6d2..ea87d2a6cffda99eb69d73386ad1eadc7ee89a83 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -3,7 +3,7 @@ from ast.arrays import ArrayAccess, ArrayDecl from ast.block import Block from ast.branches import Branch from ast.cast import Cast -from ast.expr import BinOp, BinOpDef +from ast.bin_op import BinOp, BinOpDef from ast.data_types import Type_Int, Type_Float, Type_Vector from ast.lit import Lit from ast.loops import For, Iter, ParticleFor, While diff --git a/graph/graphviz.py b/graph/graphviz.py index 5a17e9accf8abd1a43133643591b1ee85e2d40ba..21a7e8195ecb3e03fa92860d9f49ecbeee01528c 100644 --- a/graph/graphviz.py +++ b/graph/graphviz.py @@ -1,5 +1,5 @@ from ast.arrays import Array -from ast.expr import BinOp, BinOpDef +from ast.bin_op import BinOp, BinOpDef from ast.lit import Lit from ast.loops import Iter from ast.properties import Property @@ -7,6 +7,7 @@ from ast.variables import Var from ast.visitor import Visitor from graphviz import Digraph + class ASTGraph: def __init__(self, ast_node, filename, ref="AST", max_depth=0): self.graph = Digraph(ref, filename=filename, node_attr={'color': 'lightblue2', 'style': 'filled'}) diff --git a/sim/cell_lists.py b/sim/cell_lists.py index fe69d06e487072309a791513db1ef58082f4024e..4fe79862d356b4f4c29c97c462a82af20b72017d 100644 --- a/sim/cell_lists.py +++ b/sim/cell_lists.py @@ -1,7 +1,7 @@ +from ast.bin_op import BinOp from ast.branches import Branch, Filter from ast.cast import Cast from ast.data_types import Type_Int -from ast.expr import BinOp from ast.loops import For, ParticleFor from ast.utils import Print from functools import reduce diff --git a/sim/properties.py b/sim/properties.py index f07f9d58029d79e9c90f8b6dbaa92509bb6e7bb6..81901fb64217c142c7eb456be2c6bcfd54208e20 100644 --- a/sim/properties.py +++ b/sim/properties.py @@ -18,10 +18,7 @@ class PropertiesAlloc: if p.type() == Type_Float: sizes = [capacity] elif p.type() == Type_Vector: - if p.flattened: - sizes = [capacity * self.sim.dimensions] - else: - sizes = [capacity, self.sim.dimensions] + sizes = [capacity * self.sim.dimensions] else: raise Exception("Invalid property type!") diff --git a/sim/resize.py b/sim/resize.py index eada64289851d2dc4f3c2033342eb06949ec2878..48a1f01ddc2ad5042263929148e3fcea9a402e64 100644 --- a/sim/resize.py +++ b/sim/resize.py @@ -26,11 +26,9 @@ class Resize: if properties.is_capacity(self.capacity_var): capacity = sum(self.sim.properties.capacities) for p in properties.all(): - sizes = capacity if p.type() == Type_Vector: - if p.flattened: - sizes = capacity * self.sim.dimensions - else: - sizes = capacity * self.sim.dimensions + sizes = capacity * self.sim.dimensions + else: + sizes = capacity Realloc(self.sim, p, sizes) diff --git a/sim/timestep.py b/sim/timestep.py index 86e5c95a6816ccecd91bfee0b2adbf6021b74eb3..0d541330bce00d894ae82442f6c19f4c3649ae77 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -1,5 +1,5 @@ +from ast.bin_op import BinOp from ast.block import Block -from ast.expr import BinOp from ast.branches import Branch from ast.loops import For