From ecac066bdc53ea194cffc26e5eca2a86d53a861a Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de> Date: Fri, 8 Jan 2021 20:43:56 +0100 Subject: [PATCH] Refactor code, add ASTNode and ASTTerm classes Signed-off-by: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de> --- ast/arrays.py | 39 ++-------- ast/assign.py | 7 +- ast/ast_node.py | 24 ++++++ ast/{expr.py => bin_op.py} | 153 +++++++++++++++++++++++++------------ ast/block.py | 5 +- ast/branches.py | 3 +- ast/cast.py | 5 +- ast/lit.py | 25 ++---- ast/loops.py | 43 ++--------- ast/math.py | 5 +- ast/memory.py | 11 +-- ast/operators.py | 0 ast/properties.py | 18 ++--- ast/select.py | 7 +- ast/sizeof.py | 21 +---- ast/transform.py | 13 +--- ast/utils.py | 13 ++-- ast/variables.py | 62 ++------------- code_gen/cgen.py | 2 +- graph/graphviz.py | 3 +- sim/cell_lists.py | 2 +- sim/properties.py | 5 +- sim/resize.py | 8 +- sim/timestep.py | 2 +- 24 files changed, 205 insertions(+), 271 deletions(-) create mode 100644 ast/ast_node.py rename ast/{expr.py => bin_op.py} (80%) create mode 100644 ast/operators.py diff --git a/ast/arrays.py b/ast/arrays.py index 1982693..e7fd2c7 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -1,6 +1,7 @@ from ast.assign import Assign +from ast.ast_node import ASTNode +from ast.bin_op import BinOp, ASTTerm from ast.data_types import Type_Array -from ast.expr import BinOp from ast.layouts import Layout_AoS, Layout_SoA from ast.lit import as_lit_ast from ast.memory import Realloc @@ -35,9 +36,9 @@ class Arrays: return None -class Array: +class Array(ASTNode): def __init__(self, sim, a_name, a_sizes, a_type, a_layout=Layout_AoS): - self.sim = sim + super().__init__(sim) self.arr_name = a_name self.arr_sizes = \ [as_lit_ast(sim, a_sizes)] if not isinstance(a_sizes, list) \ @@ -65,9 +66,6 @@ class Array: def layout(self): return self.arr_layout - def scope(self): - return self.sim.global_scope - def ndims(self): return self.arr_ndims @@ -77,12 +75,6 @@ class Array: def alloc_size(self): return reduce((lambda x, y: x * y), [s for s in self.arr_sizes]) - def children(self): - return [] - - def transform(self, fn): - return fn(self) - class ArrayStatic(Array): def __init__(self, sim, a_name, a_sizes, a_type, a_layout=Layout_AoS): @@ -110,7 +102,7 @@ class ArrayND(Array): return Realloc(self.sim, self, self.alloc_size()) -class ArrayAccess: +class ArrayAccess(ASTTerm): last_acc = 0 def new_id(): @@ -118,7 +110,7 @@ class ArrayAccess: return ArrayAccess.last_acc - 1 def __init__(self, sim, array, index): - self.sim = sim + super().__init__(sim) self.acc_id = ArrayAccess.new_id() self.array = array self.indexes = [as_lit_ast(sim, index)] @@ -130,15 +122,6 @@ class ArrayAccess: def __str__(self): return f"ArrayAccess<array: {self.array}, indexes: {self.indexes}>" - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - return BinOp(self.sim, other, self, '*') - def __getitem__(self, index): assert self.index is None, "Number of indexes higher than array dimension!" self.indexes.append(as_lit_ast(self.sim, index)) @@ -206,14 +189,8 @@ class ArrayAccess: return fn(self) -class ArrayDecl: +class ArrayDecl(ASTNode): def __init__(self, sim, array): - self.sim = sim + super().__init__(sim) self.array = array self.sim.add_statement(self) - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/assign.py b/ast/assign.py index 59f11e4..8ce1d7b 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -1,11 +1,12 @@ +from ast.ast_node import ASTNode from ast.data_types import Type_Vector from ast.lit import as_lit_ast from functools import reduce -class Assign: +class Assign(ASTNode): def __init__(self, sim, dest, src): - self.sim = sim + super().__init__(sim) self.parent_block = None self.type = dest.type() src = as_lit_ast(sim, src) @@ -14,7 +15,7 @@ class Assign: self.assignments = [] for i in range(0, sim.dimensions): - from ast.expr import BinOp + from ast.bin_op import BinOp dsrc = (src if (not isinstance(src, BinOp) or src.type() != Type_Vector) else src[i]) diff --git a/ast/ast_node.py b/ast/ast_node.py new file mode 100644 index 0000000..b0641f4 --- /dev/null +++ b/ast/ast_node.py @@ -0,0 +1,24 @@ +from ast.data_types import Type_Invalid + + +class ASTNode: + def __init__(self, sim): + self.sim = sim + + def __str__(self): + return "ASTNode<>" + + def type(self): + return Type_Invalid + + def is_mutable(self): + return False + + def scope(self): + return self.sim.global_scope + + def children(self): + return [] + + def transform(self, fn): + return fn(self) diff --git a/ast/expr.py b/ast/bin_op.py similarity index 80% rename from ast/expr.py rename to ast/bin_op.py index 93db1df..4cc8dd5 100644 --- a/ast/expr.py +++ b/ast/bin_op.py @@ -1,11 +1,13 @@ +from ast.ast_node import ASTNode from ast.assign import Assign from ast.data_types import Type_Float, Type_Bool, Type_Vector from ast.lit import as_lit_ast from ast.properties import Property -class BinOpDef: +class BinOpDef(ASTNode): def __init__(self, bin_op): + super().__init__(bin_op.sim) self.bin_op = bin_op self.bin_op.sim.add_statement(self) @@ -20,7 +22,7 @@ class BinOpDef: return fn(self) -class BinOp: +class BinOp(ASTNode): # BinOp kinds Kind_Scalar = 0 Kind_Vector = 1 @@ -38,7 +40,7 @@ class BinOp: return op.inline_rec() def __init__(self, sim, lhs, rhs, op, mem=False): - self.sim = sim + super().__init__(sim) self.bin_op_id = BinOp.new_id() self.lhs = as_lit_ast(sim, lhs) self.rhs = as_lit_ast(sim, rhs) @@ -56,51 +58,6 @@ class BinOp: def __str__(self): return f"BinOp<a: {self.lhs.id()}, b: {self.rhs.id()}, op: {self.op}>" - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __radd__(self, other): - return BinOp(self.sim, other, self, '+') - - def __sub__(self, other): - return BinOp(self.sim, self, other, '-') - - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - return BinOp(self.sim, other, self, '*') - - def __truediv__(self, other): - return BinOp(self.sim, self, other, '/') - - def __rtruediv__(self, other): - return BinOp(self.sim, other, self, '/') - - def __lt__(self, other): - return BinOp(self.sim, self, other, '<') - - def __le__(self, other): - return BinOp(self.sim, self, other, '<=') - - def __gt__(self, other): - return BinOp(self.sim, self, other, '>') - - def __ge__(self, other): - return BinOp(self.sim, self, other, '>=') - - def and_op(self, other): - return BinOp(self.sim, self, other, '&&') - - def cmp(lhs, rhs): - return BinOp(lhs.sim, lhs, rhs, '==') - - def neq(lhs, rhs): - return BinOp(lhs.sim, lhs, rhs, '!=') - - def inv(self): - return BinOp(self.sim, 1.0, self, '/') - def match(self, bin_op): return self.lhs == bin_op.lhs and \ self.rhs == bin_op.rhs and \ @@ -230,3 +187,103 @@ class BinOp: self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()} return fn(self) + def __add__(self, other): + return BinOp(self.sim, self, other, '+') + + def __radd__(self, other): + return BinOp(self.sim, other, self, '+') + + def __sub__(self, other): + return BinOp(self.sim, self, other, '-') + + def __mul__(self, other): + return BinOp(self.sim, self, other, '*') + + def __rmul__(self, other): + return BinOp(self.sim, other, self, '*') + + def __truediv__(self, other): + return BinOp(self.sim, self, other, '/') + + def __rtruediv__(self, other): + return BinOp(self.sim, other, self, '/') + + def __lt__(self, other): + return BinOp(self.sim, self, other, '<') + + def __le__(self, other): + return BinOp(self.sim, self, other, '<=') + + def __gt__(self, other): + return BinOp(self.sim, self, other, '>') + + def __ge__(self, other): + return BinOp(self.sim, self, other, '>=') + + def and_op(self, other): + return BinOp(self.sim, self, other, '&&') + + def cmp(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '==') + + def neq(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '!=') + + def inv(self): + return BinOp(self.sim, 1.0, self, '/') + + def __mod__(self, other): + return BinOp(self.sim, self, other, '%') + + +class ASTTerm(ASTNode): + def __init__(self, sim): + super().__init__(sim) + + def __add__(self, other): + return BinOp(self.sim, self, other, '+') + + def __radd__(self, other): + return BinOp(self.sim, other, self, '+') + + def __sub__(self, other): + return BinOp(self.sim, self, other, '-') + + def __mul__(self, other): + return BinOp(self.sim, self, other, '*') + + def __rmul__(self, other): + return BinOp(self.sim, other, self, '*') + + def __truediv__(self, other): + return BinOp(self.sim, self, other, '/') + + def __rtruediv__(self, other): + return BinOp(self.sim, other, self, '/') + + def __lt__(self, other): + return BinOp(self.sim, self, other, '<') + + def __le__(self, other): + return BinOp(self.sim, self, other, '<=') + + def __gt__(self, other): + return BinOp(self.sim, self, other, '>') + + def __ge__(self, other): + return BinOp(self.sim, self, other, '>=') + + def and_op(self, other): + return BinOp(self.sim, self, other, '&&') + + def cmp(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '==') + + def neq(lhs, rhs): + return BinOp(lhs.sim, lhs, rhs, '!=') + + def inv(self): + return BinOp(self.sim, 1.0, self, '/') + + def __mod__(self, other): + return BinOp(self.sim, self, other, '%') diff --git a/ast/block.py b/ast/block.py index 4ed52ee..45fbf74 100644 --- a/ast/block.py +++ b/ast/block.py @@ -1,9 +1,10 @@ +from ast.ast_node import ASTNode from ast.visitor import Visitor -class Block: +class Block(ASTNode): def __init__(self, sim, stmts): - self.sim = sim + super().__init__(sim) self.level = 0 self.expressions = [] diff --git a/ast/branches.py b/ast/branches.py index a3ff5f1..ab31b0b 100644 --- a/ast/branches.py +++ b/ast/branches.py @@ -1,8 +1,9 @@ +from ast.ast_node import ASTNode from ast.block import Block from ast.lit import as_lit_ast -class Branch: +class Branch(ASTNode): def __init__(self, sim, cond, one_way=False, blk_if=None, blk_else=None): self.sim = sim self.parent_block = None diff --git a/ast/cast.py b/ast/cast.py index 8036676..5bf1ba9 100644 --- a/ast/cast.py +++ b/ast/cast.py @@ -1,9 +1,10 @@ +from ast.ast_node import ASTNode from ast.data_types import Type_Int, Type_Float -class Cast: +class Cast(ASTNode): def __init__(self, sim, expr, cast_type): - self.sim = sim + super().__init__(sim) self.expr = expr self.cast_type = cast_type diff --git a/ast/lit.py b/ast/lit.py index b04a969..8ad0d83 100644 --- a/ast/lit.py +++ b/ast/lit.py @@ -1,21 +1,18 @@ -from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool -from ast.data_types import Type_Vector +from ast.ast_node import ASTNode +from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector def is_literal(a): - return (isinstance(a, int) or - isinstance(a, float) or - isinstance(a, bool) or - isinstance(a, list)) + return isinstance(a, (int, float, bool, list)) def as_lit_ast(sim, a): return Lit(sim, a) if is_literal(a) else a -class Lit: +class Lit(ASTNode): def __init__(self, sim, value): - self.sim = sim + super().__init__(sim) self.value = value self.lit_type = Type_Invalid @@ -47,15 +44,3 @@ class Lit: def type(self): return self.lit_type - - def is_mutable(self): - return False - - def scope(self): - return self.sim.global_scope - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/loops.py b/ast/loops.py index d99c155..902cec3 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -1,11 +1,12 @@ +from ast.ast_node import ASTNode +from ast.bin_op import BinOp, ASTTerm from ast.block import Block from ast.branches import Filter from ast.data_types import Type_Int -from ast.expr import BinOp from ast.lit import as_lit_ast -class Iter(): +class Iter(ASTTerm): last_iter = 0 def new_id(): @@ -13,7 +14,7 @@ class Iter(): return Iter.last_iter - 1 def __init__(self, sim, loop): - self.sim = sim + super().__init__(sim) self.loop = loop self.iter_id = Iter.new_id() @@ -23,26 +24,9 @@ class Iter(): def type(self): return Type_Int - def is_mutable(self): - return False - def scope(self): return self.loop.block - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __sub__(self, other): - return BinOp(self.sim, self, other, '-') - - def __mul__(self, other): - from ast.expr import BinOp - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - from ast.expr import BinOp - return BinOp(self.sim, other, self, '*') - def __eq__(self, other): if isinstance(other, Iter): return self.iter_id == other.iter_id @@ -52,23 +36,13 @@ class Iter(): def __req__(self, other): return self.__cmp__(other) - def __mod__(self, other): - from ast.expr import BinOp - return BinOp(self.sim, self, other, '%') - def __str__(self): return f"Iter<{self.iter_id}>" - def children(self): - return [] - - def transform(self, fn): - return fn(self) - -class For(): +class For(ASTNode): def __init__(self, sim, range_min, range_max, block=None): - self.sim = sim + super().__init__(sim) self.iterator = Iter(sim, self) self.min = as_lit_ast(sim, range_min) self.max = as_lit_ast(sim, range_max) @@ -108,10 +82,9 @@ class ParticleFor(For): return f"ParticleFor<>" -class While(): +class While(ASTNode): def __init__(self, sim, cond, block=None): - from ast.expr import BinOp - self.sim = sim + super().__init__(sim) self.parent_block = None self.cond = BinOp.inline(cond) self.block = Block(sim, []) if block is None else block diff --git a/ast/math.py b/ast/math.py index cc38fb8..c872b5b 100644 --- a/ast/math.py +++ b/ast/math.py @@ -1,9 +1,10 @@ +from ast.ast_node import ASTNode from ast.data_types import Type_Int, Type_Float -class Sqrt: +class Sqrt(ASTNode): def __init__(self, sim, expr, cast_type): - self.sim = sim + super().__init__(sim) self.expr = expr def __str__(self): diff --git a/ast/memory.py b/ast/memory.py index bb26d4c..ef8abfe 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -1,12 +1,13 @@ -from ast.expr import BinOp +from ast.ast_node import ASTNode +from ast.bin_op import BinOp from ast.sizeof import Sizeof from functools import reduce import operator -class Malloc: +class Malloc(ASTNode): def __init__(self, sim, array, sizes, decl=False): - self.sim = sim + super().__init__(sim) self.parent_block = None self.array = array self.decl = decl @@ -23,9 +24,9 @@ class Malloc: return fn(self) -class Realloc: +class Realloc(ASTNode): def __init__(self, sim, array, size): - self.sim = sim + super().__init__(sim) self.parent_block = None self.array = array self.prim_size = Sizeof(sim, array.type()) diff --git a/ast/operators.py b/ast/operators.py new file mode 100644 index 0000000..e69de29 diff --git a/ast/properties.py b/ast/properties.py index bbb6465..44f141c 100644 --- a/ast/properties.py +++ b/ast/properties.py @@ -1,3 +1,4 @@ +from ast.ast_node import ASTNode from ast.layouts import Layout_AoS @@ -38,16 +39,15 @@ class Properties: return None -class Property: +class Property(ASTNode): def __init__(self, sim, name, dtype, default, volatile, layout=Layout_AoS): - self.sim = sim + super().__init__(sim) self.prop_name = name self.prop_type = dtype self.prop_layout = layout self.default_value = default self.volatile = volatile self.mutable = True - self.flattened = False def __str__(self): return f"Property<{self.prop_name}>" @@ -70,12 +70,6 @@ class Property: def scope(self): return self.sim.global_scope - def __getitem__(self, expr_ast): - from ast.expr import BinOp - return BinOp(self.sim, self, expr_ast, '[]', True) - - def children(self): - return [] - - def transform(self, fn): - return fn(self) + def __getitem__(self, expr): + from ast.bin_op import BinOp + return BinOp(self.sim, self, expr, '[]', True) diff --git a/ast/select.py b/ast/select.py index bdb1546..702bb8f 100644 --- a/ast/select.py +++ b/ast/select.py @@ -1,10 +1,11 @@ -from ast.expr import BinOp +from ast.ast_node import ASTNode +from ast.bin_op import BinOp from ast.lit import as_lit_ast -class Select: +class Select(ASTNode): def __init__(self, sim, cond, expr_if, expr_else): - self.sim = sim + super().__init__(sim) self.cond = as_lit_ast(sim, cond) self.expr_if = BinOp.inline(as_lit_ast(sim, expr_if)) self.expr_else = BinOp.inline(as_lit_ast(sim, expr_else)) diff --git a/ast/sizeof.py b/ast/sizeof.py index e18f89e..7a55106 100644 --- a/ast/sizeof.py +++ b/ast/sizeof.py @@ -1,26 +1,11 @@ +from ast.bin_op import ASTTerm from ast.data_types import Type_Int -from ast.expr import BinOp -class Sizeof: +class Sizeof(ASTTerm): def __init__(self, sim, data_type): - self.sim = sim + super().__init__(sim) self.data_type = data_type - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - def type(self): return Type_Int - - def is_mutable(self): - return False - - def scope(self): - return self.sim.global_scope - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/transform.py b/ast/transform.py index 612b0cf..428d372 100644 --- a/ast/transform.py +++ b/ast/transform.py @@ -1,6 +1,6 @@ from ast.arrays import ArrayAccess +from ast.bin_op import BinOp from ast.data_types import Type_Int, Type_Vector -from ast.expr import BinOp from ast.layouts import Layout_AoS, Layout_SoA from ast.lit import Lit from ast.loops import Iter @@ -33,9 +33,6 @@ class Transform: ast.map_vector_index(i, flat_index) - if isinstance(ast, Property): - ast.flattened = True - return ast def simplify(ast): @@ -130,11 +127,3 @@ class Transform: Transform.reuse_expressions[acc_id].append(ast) return ast - - def move_loop_invariant_expressions(ast): - if isinstance(ast, BinOp): - scope = ast.scope() - if scope.level > 0: - scope.add_expression(ast) - - return ast diff --git a/ast/utils.py b/ast/utils.py index 922dcbe..2bc3e43 100644 --- a/ast/utils.py +++ b/ast/utils.py @@ -1,13 +1,10 @@ -class Print: +from ast.ast_node import ASTNode + + +class Print(ASTNode): def __init__(self, sim, string): - self.sim = sim + super().__init__(sim) self.string = string def __str__(self): return f"Print<{self.string}>" - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/ast/variables.py b/ast/variables.py index 5f12719..34f9a9d 100644 --- a/ast/variables.py +++ b/ast/variables.py @@ -1,5 +1,6 @@ +from ast.ast_node import ASTNode from ast.assign import Assign -from ast.expr import BinOp +from ast.bin_op import ASTTerm class Variables: @@ -23,9 +24,9 @@ class Variables: return None -class Var: +class Var(ASTTerm): def __init__(self, sim, var_name, var_type, init_value=0): - self.sim = sim + super().__init__(sim) self.var_name = var_name self.var_type = var_type self.var_init_value = init_value @@ -35,42 +36,6 @@ class Var: def __str__(self): return f"Var<name: {self.var_name}, type: {self.var_type}>" - def __add__(self, other): - return BinOp(self.sim, self, other, '+') - - def __radd__(self, other): - return BinOp(self.sim, other, self, '+') - - def __sub__(self, other): - return BinOp(self.sim, self, other, '-') - - def __mul__(self, other): - return BinOp(self.sim, self, other, '*') - - def __rmul__(self, other): - return BinOp(self.sim, other, self, '*') - - def __truediv__(self, other): - return BinOp(self.sim, self, other, '/') - - def __rtruediv__(self, other): - return BinOp(self.sim, other, self, '/') - - def __lt__(self, other): - return BinOp(self.sim, self, other, '<') - - def __le__(self, other): - return BinOp(self.sim, self, other, '<=') - - def __gt__(self, other): - return BinOp(self.sim, self, other, '>') - - def __ge__(self, other): - return BinOp(self.sim, self, other, '>=') - - def inv(self): - return BinOp(self.sim, 1.0, self, '/') - def set(self, other): return self.sim.add_statement(Assign(self.sim, self, other)) @@ -98,24 +63,9 @@ class Var: def is_mutable(self): return self.mutable - def scope(self): - return self.sim.global_scope - - def children(self): - return [] - - def transform(self, fn): - return fn(self) - -class VarDecl: +class VarDecl(ASTNode): def __init__(self, sim, var): - self.sim = sim + super().__init__(sim) self.var = var self.sim.add_statement(self) - - def children(self): - return [] - - def transform(self, fn): - return fn(self) diff --git a/code_gen/cgen.py b/code_gen/cgen.py index 96e0d9d..ea87d2a 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -3,7 +3,7 @@ from ast.arrays import ArrayAccess, ArrayDecl from ast.block import Block from ast.branches import Branch from ast.cast import Cast -from ast.expr import BinOp, BinOpDef +from ast.bin_op import BinOp, BinOpDef from ast.data_types import Type_Int, Type_Float, Type_Vector from ast.lit import Lit from ast.loops import For, Iter, ParticleFor, While diff --git a/graph/graphviz.py b/graph/graphviz.py index 5a17e9a..21a7e81 100644 --- a/graph/graphviz.py +++ b/graph/graphviz.py @@ -1,5 +1,5 @@ from ast.arrays import Array -from ast.expr import BinOp, BinOpDef +from ast.bin_op import BinOp, BinOpDef from ast.lit import Lit from ast.loops import Iter from ast.properties import Property @@ -7,6 +7,7 @@ from ast.variables import Var from ast.visitor import Visitor from graphviz import Digraph + class ASTGraph: def __init__(self, ast_node, filename, ref="AST", max_depth=0): self.graph = Digraph(ref, filename=filename, node_attr={'color': 'lightblue2', 'style': 'filled'}) diff --git a/sim/cell_lists.py b/sim/cell_lists.py index fe69d06..4fe7986 100644 --- a/sim/cell_lists.py +++ b/sim/cell_lists.py @@ -1,7 +1,7 @@ +from ast.bin_op import BinOp from ast.branches import Branch, Filter from ast.cast import Cast from ast.data_types import Type_Int -from ast.expr import BinOp from ast.loops import For, ParticleFor from ast.utils import Print from functools import reduce diff --git a/sim/properties.py b/sim/properties.py index f07f9d5..81901fb 100644 --- a/sim/properties.py +++ b/sim/properties.py @@ -18,10 +18,7 @@ class PropertiesAlloc: if p.type() == Type_Float: sizes = [capacity] elif p.type() == Type_Vector: - if p.flattened: - sizes = [capacity * self.sim.dimensions] - else: - sizes = [capacity, self.sim.dimensions] + sizes = [capacity * self.sim.dimensions] else: raise Exception("Invalid property type!") diff --git a/sim/resize.py b/sim/resize.py index eada642..48a1f01 100644 --- a/sim/resize.py +++ b/sim/resize.py @@ -26,11 +26,9 @@ class Resize: if properties.is_capacity(self.capacity_var): capacity = sum(self.sim.properties.capacities) for p in properties.all(): - sizes = capacity if p.type() == Type_Vector: - if p.flattened: - sizes = capacity * self.sim.dimensions - else: - sizes = capacity * self.sim.dimensions + sizes = capacity * self.sim.dimensions + else: + sizes = capacity Realloc(self.sim, p, sizes) diff --git a/sim/timestep.py b/sim/timestep.py index 86e5c95..0d54133 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -1,5 +1,5 @@ +from ast.bin_op import BinOp from ast.block import Block -from ast.expr import BinOp from ast.branches import Branch from ast.loops import For -- GitLab