From ecac066bdc53ea194cffc26e5eca2a86d53a861a Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de>
Date: Fri, 8 Jan 2021 20:43:56 +0100
Subject: [PATCH] Refactor code, add ASTNode and ASTTerm classes

Signed-off-by: Rafael Ravedutti Lucio Machado <rafael.r.ravedutti@fau.de>
---
 ast/arrays.py              |  39 ++--------
 ast/assign.py              |   7 +-
 ast/ast_node.py            |  24 ++++++
 ast/{expr.py => bin_op.py} | 153 +++++++++++++++++++++++++------------
 ast/block.py               |   5 +-
 ast/branches.py            |   3 +-
 ast/cast.py                |   5 +-
 ast/lit.py                 |  25 ++----
 ast/loops.py               |  43 ++---------
 ast/math.py                |   5 +-
 ast/memory.py              |  11 +--
 ast/operators.py           |   0
 ast/properties.py          |  18 ++---
 ast/select.py              |   7 +-
 ast/sizeof.py              |  21 +----
 ast/transform.py           |  13 +---
 ast/utils.py               |  13 ++--
 ast/variables.py           |  62 ++-------------
 code_gen/cgen.py           |   2 +-
 graph/graphviz.py          |   3 +-
 sim/cell_lists.py          |   2 +-
 sim/properties.py          |   5 +-
 sim/resize.py              |   8 +-
 sim/timestep.py            |   2 +-
 24 files changed, 205 insertions(+), 271 deletions(-)
 create mode 100644 ast/ast_node.py
 rename ast/{expr.py => bin_op.py} (80%)
 create mode 100644 ast/operators.py

diff --git a/ast/arrays.py b/ast/arrays.py
index 1982693..e7fd2c7 100644
--- a/ast/arrays.py
+++ b/ast/arrays.py
@@ -1,6 +1,7 @@
 from ast.assign import Assign
+from ast.ast_node import ASTNode
+from ast.bin_op import BinOp, ASTTerm
 from ast.data_types import Type_Array
-from ast.expr import BinOp
 from ast.layouts import Layout_AoS, Layout_SoA
 from ast.lit import as_lit_ast
 from ast.memory import Realloc
@@ -35,9 +36,9 @@ class Arrays:
         return None
 
 
-class Array:
+class Array(ASTNode):
     def __init__(self, sim, a_name, a_sizes, a_type, a_layout=Layout_AoS):
-        self.sim = sim
+        super().__init__(sim)
         self.arr_name = a_name
         self.arr_sizes = \
             [as_lit_ast(sim, a_sizes)] if not isinstance(a_sizes, list) \
@@ -65,9 +66,6 @@ class Array:
     def layout(self):
         return self.arr_layout
 
-    def scope(self):
-        return self.sim.global_scope
-
     def ndims(self):
         return self.arr_ndims
 
@@ -77,12 +75,6 @@ class Array:
     def alloc_size(self):
         return reduce((lambda x, y: x * y), [s for s in self.arr_sizes])
 
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
-
 
 class ArrayStatic(Array):
     def __init__(self, sim, a_name, a_sizes, a_type, a_layout=Layout_AoS):
@@ -110,7 +102,7 @@ class ArrayND(Array):
         return Realloc(self.sim, self, self.alloc_size())
 
 
-class ArrayAccess:
+class ArrayAccess(ASTTerm):
     last_acc = 0
 
     def new_id():
@@ -118,7 +110,7 @@ class ArrayAccess:
         return ArrayAccess.last_acc - 1
 
     def __init__(self, sim, array, index):
-        self.sim = sim
+        super().__init__(sim)
         self.acc_id = ArrayAccess.new_id()
         self.array = array
         self.indexes = [as_lit_ast(sim, index)]
@@ -130,15 +122,6 @@ class ArrayAccess:
     def __str__(self):
         return f"ArrayAccess<array: {self.array}, indexes: {self.indexes}>"
 
-    def __add__(self, other):
-        return BinOp(self.sim, self, other, '+')
-
-    def __mul__(self, other):
-        return BinOp(self.sim, self, other, '*')
-
-    def __rmul__(self, other):
-        return BinOp(self.sim, other, self, '*')
-
     def __getitem__(self, index):
         assert self.index is None, "Number of indexes higher than array dimension!"
         self.indexes.append(as_lit_ast(self.sim, index))
@@ -206,14 +189,8 @@ class ArrayAccess:
         return fn(self)
 
 
-class ArrayDecl:
+class ArrayDecl(ASTNode):
     def __init__(self, sim, array):
-        self.sim = sim
+        super().__init__(sim)
         self.array = array
         self.sim.add_statement(self)
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/ast/assign.py b/ast/assign.py
index 59f11e4..8ce1d7b 100644
--- a/ast/assign.py
+++ b/ast/assign.py
@@ -1,11 +1,12 @@
+from ast.ast_node import ASTNode
 from ast.data_types import Type_Vector
 from ast.lit import as_lit_ast
 from functools import reduce
 
 
-class Assign:
+class Assign(ASTNode):
     def __init__(self, sim, dest, src):
-        self.sim = sim
+        super().__init__(sim)
         self.parent_block = None
         self.type = dest.type()
         src = as_lit_ast(sim, src)
@@ -14,7 +15,7 @@ class Assign:
             self.assignments = []
 
             for i in range(0, sim.dimensions):
-                from ast.expr import BinOp
+                from ast.bin_op import BinOp
                 dsrc = (src if (not isinstance(src, BinOp) or
                                 src.type() != Type_Vector)
                         else src[i])
diff --git a/ast/ast_node.py b/ast/ast_node.py
new file mode 100644
index 0000000..b0641f4
--- /dev/null
+++ b/ast/ast_node.py
@@ -0,0 +1,24 @@
+from ast.data_types import Type_Invalid
+
+
+class ASTNode:
+    def __init__(self, sim):
+        self.sim = sim
+
+    def __str__(self):
+        return "ASTNode<>"
+
+    def type(self):
+        return Type_Invalid
+
+    def is_mutable(self):
+        return False
+
+    def scope(self):
+        return self.sim.global_scope
+
+    def children(self):
+        return []
+
+    def transform(self, fn):
+        return fn(self)
diff --git a/ast/expr.py b/ast/bin_op.py
similarity index 80%
rename from ast/expr.py
rename to ast/bin_op.py
index 93db1df..4cc8dd5 100644
--- a/ast/expr.py
+++ b/ast/bin_op.py
@@ -1,11 +1,13 @@
+from ast.ast_node import ASTNode
 from ast.assign import Assign
 from ast.data_types import Type_Float, Type_Bool, Type_Vector
 from ast.lit import as_lit_ast
 from ast.properties import Property
 
 
-class BinOpDef:
+class BinOpDef(ASTNode):
     def __init__(self, bin_op):
+        super().__init__(bin_op.sim)
         self.bin_op = bin_op
         self.bin_op.sim.add_statement(self)
 
@@ -20,7 +22,7 @@ class BinOpDef:
         return fn(self)
 
 
-class BinOp:
+class BinOp(ASTNode):
     # BinOp kinds
     Kind_Scalar = 0
     Kind_Vector = 1
@@ -38,7 +40,7 @@ class BinOp:
         return op.inline_rec()
 
     def __init__(self, sim, lhs, rhs, op, mem=False):
-        self.sim = sim
+        super().__init__(sim)
         self.bin_op_id = BinOp.new_id()
         self.lhs = as_lit_ast(sim, lhs)
         self.rhs = as_lit_ast(sim, rhs)
@@ -56,51 +58,6 @@ class BinOp:
     def __str__(self):
         return f"BinOp<a: {self.lhs.id()}, b: {self.rhs.id()}, op: {self.op}>"
 
-    def __add__(self, other):
-        return BinOp(self.sim, self, other, '+')
-
-    def __radd__(self, other):
-        return BinOp(self.sim, other, self, '+')
-
-    def __sub__(self, other):
-        return BinOp(self.sim, self, other, '-')
-
-    def __mul__(self, other):
-        return BinOp(self.sim, self, other, '*')
-
-    def __rmul__(self, other):
-        return BinOp(self.sim, other, self, '*')
-
-    def __truediv__(self, other):
-        return BinOp(self.sim, self, other, '/')
-
-    def __rtruediv__(self, other):
-        return BinOp(self.sim, other, self, '/')
-
-    def __lt__(self, other):
-        return BinOp(self.sim, self, other, '<')
-
-    def __le__(self, other):
-        return BinOp(self.sim, self, other, '<=')
-
-    def __gt__(self, other):
-        return BinOp(self.sim, self, other, '>')
-
-    def __ge__(self, other):
-        return BinOp(self.sim, self, other, '>=')
-
-    def and_op(self, other):
-        return BinOp(self.sim, self, other, '&&')
-
-    def cmp(lhs, rhs):
-        return BinOp(lhs.sim, lhs, rhs, '==')
-
-    def neq(lhs, rhs):
-        return BinOp(lhs.sim, lhs, rhs, '!=')
-
-    def inv(self):
-        return BinOp(self.sim, 1.0, self, '/')
-
     def match(self, bin_op):
         return self.lhs == bin_op.lhs and \
                self.rhs == bin_op.rhs and \
@@ -230,3 +187,103 @@ class BinOp:
         self.bin_op_vector_index_mapping = {i: e.transform(fn) for i, e in self.bin_op_vector_index_mapping.items()}
         return fn(self)
 
+    def __add__(self, other):
+        return BinOp(self.sim, self, other, '+')
+
+    def __radd__(self, other):
+        return BinOp(self.sim, other, self, '+')
+
+    def __sub__(self, other):
+        return BinOp(self.sim, self, other, '-')
+
+    def __mul__(self, other):
+        return BinOp(self.sim, self, other, '*')
+
+    def __rmul__(self, other):
+        return BinOp(self.sim, other, self, '*')
+
+    def __truediv__(self, other):
+        return BinOp(self.sim, self, other, '/')
+
+    def __rtruediv__(self, other):
+        return BinOp(self.sim, other, self, '/')
+
+    def __lt__(self, other):
+        return BinOp(self.sim, self, other, '<')
+
+    def __le__(self, other):
+        return BinOp(self.sim, self, other, '<=')
+
+    def __gt__(self, other):
+        return BinOp(self.sim, self, other, '>')
+
+    def __ge__(self, other):
+        return BinOp(self.sim, self, other, '>=')
+
+    def and_op(self, other):
+        return BinOp(self.sim, self, other, '&&')
+
+    def cmp(lhs, rhs):
+        return BinOp(lhs.sim, lhs, rhs, '==')
+
+    def neq(lhs, rhs):
+        return BinOp(lhs.sim, lhs, rhs, '!=')
+
+    def inv(self):
+        return BinOp(self.sim, 1.0, self, '/')
+
+    def __mod__(self, other):
+        return BinOp(self.sim, self, other, '%')
+
+
+class ASTTerm(ASTNode):
+    def __init__(self, sim):
+        super().__init__(sim)
+
+    def __add__(self, other):
+        return BinOp(self.sim, self, other, '+')
+
+    def __radd__(self, other):
+        return BinOp(self.sim, other, self, '+')
+
+    def __sub__(self, other):
+        return BinOp(self.sim, self, other, '-')
+
+    def __mul__(self, other):
+        return BinOp(self.sim, self, other, '*')
+
+    def __rmul__(self, other):
+        return BinOp(self.sim, other, self, '*')
+
+    def __truediv__(self, other):
+        return BinOp(self.sim, self, other, '/')
+
+    def __rtruediv__(self, other):
+        return BinOp(self.sim, other, self, '/')
+
+    def __lt__(self, other):
+        return BinOp(self.sim, self, other, '<')
+
+    def __le__(self, other):
+        return BinOp(self.sim, self, other, '<=')
+
+    def __gt__(self, other):
+        return BinOp(self.sim, self, other, '>')
+
+    def __ge__(self, other):
+        return BinOp(self.sim, self, other, '>=')
+
+    def and_op(self, other):
+        return BinOp(self.sim, self, other, '&&')
+
+    def cmp(lhs, rhs):
+        return BinOp(lhs.sim, lhs, rhs, '==')
+
+    def neq(lhs, rhs):
+        return BinOp(lhs.sim, lhs, rhs, '!=')
+
+    def inv(self):
+        return BinOp(self.sim, 1.0, self, '/')
+
+    def __mod__(self, other):
+        return BinOp(self.sim, self, other, '%')
diff --git a/ast/block.py b/ast/block.py
index 4ed52ee..45fbf74 100644
--- a/ast/block.py
+++ b/ast/block.py
@@ -1,9 +1,10 @@
+from ast.ast_node import ASTNode
 from ast.visitor import Visitor
 
 
-class Block:
+class Block(ASTNode):
     def __init__(self, sim, stmts):
-        self.sim = sim
+        super().__init__(sim)
         self.level = 0
         self.expressions = []
 
diff --git a/ast/branches.py b/ast/branches.py
index a3ff5f1..ab31b0b 100644
--- a/ast/branches.py
+++ b/ast/branches.py
@@ -1,8 +1,9 @@
+from ast.ast_node import ASTNode
 from ast.block import Block
 from ast.lit import as_lit_ast
 
 
-class Branch:
+class Branch(ASTNode):
     def __init__(self, sim, cond, one_way=False, blk_if=None, blk_else=None):
         self.sim = sim
         self.parent_block = None
diff --git a/ast/cast.py b/ast/cast.py
index 8036676..5bf1ba9 100644
--- a/ast/cast.py
+++ b/ast/cast.py
@@ -1,9 +1,10 @@
+from ast.ast_node import ASTNode
 from ast.data_types import Type_Int, Type_Float
 
 
-class Cast:
+class Cast(ASTNode):
     def __init__(self, sim, expr, cast_type):
-        self.sim = sim
+        super().__init__(sim)
         self.expr = expr
         self.cast_type = cast_type
 
diff --git a/ast/lit.py b/ast/lit.py
index b04a969..8ad0d83 100644
--- a/ast/lit.py
+++ b/ast/lit.py
@@ -1,21 +1,18 @@
-from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool
-from ast.data_types import Type_Vector
+from ast.ast_node import ASTNode
+from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector
 
 
 def is_literal(a):
-    return (isinstance(a, int) or
-            isinstance(a, float) or
-            isinstance(a, bool) or
-            isinstance(a, list))
+    return isinstance(a, (int, float, bool, list))
 
 
 def as_lit_ast(sim, a):
     return Lit(sim, a) if is_literal(a) else a
 
 
-class Lit:
+class Lit(ASTNode):
     def __init__(self, sim, value):
-        self.sim = sim
+        super().__init__(sim)
         self.value = value
         self.lit_type = Type_Invalid
 
@@ -47,15 +44,3 @@ class Lit:
 
     def type(self):
         return self.lit_type
-
-    def is_mutable(self):
-        return False
-
-    def scope(self):
-        return self.sim.global_scope
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/ast/loops.py b/ast/loops.py
index d99c155..902cec3 100644
--- a/ast/loops.py
+++ b/ast/loops.py
@@ -1,11 +1,12 @@
+from ast.ast_node import ASTNode
+from ast.bin_op import BinOp, ASTTerm
 from ast.block import Block
 from ast.branches import Filter
 from ast.data_types import Type_Int
-from ast.expr import BinOp
 from ast.lit import as_lit_ast
 
 
-class Iter():
+class Iter(ASTTerm):
     last_iter = 0
 
     def new_id():
@@ -13,7 +14,7 @@ class Iter():
         return Iter.last_iter - 1
 
     def __init__(self, sim, loop):
-        self.sim = sim
+        super().__init__(sim)
         self.loop = loop
         self.iter_id = Iter.new_id()
 
@@ -23,26 +24,9 @@ class Iter():
     def type(self):
         return Type_Int
 
-    def is_mutable(self):
-        return False
-
     def scope(self):
         return self.loop.block
 
-    def __add__(self, other):
-        return BinOp(self.sim, self, other, '+')
-
-    def __sub__(self, other):
-        return BinOp(self.sim, self, other, '-')
-
-    def __mul__(self, other):
-        from ast.expr import BinOp
-        return BinOp(self.sim, self, other, '*')
-
-    def __rmul__(self, other):
-        from ast.expr import BinOp
-        return BinOp(self.sim, other, self, '*')
-
     def __eq__(self, other):
         if isinstance(other, Iter):
             return self.iter_id == other.iter_id
@@ -52,23 +36,13 @@ class Iter():
     def __req__(self, other):
         return self.__cmp__(other)
 
-    def __mod__(self, other):
-        from ast.expr import BinOp
-        return BinOp(self.sim, self, other, '%')
-
     def __str__(self):
         return f"Iter<{self.iter_id}>"
 
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
 
-
-class For():
+class For(ASTNode):
     def __init__(self, sim, range_min, range_max, block=None):
-        self.sim = sim
+        super().__init__(sim)
         self.iterator = Iter(sim, self)
         self.min = as_lit_ast(sim, range_min)
         self.max = as_lit_ast(sim, range_max)
@@ -108,10 +82,9 @@ class ParticleFor(For):
         return f"ParticleFor<>"
 
 
-class While():
+class While(ASTNode):
     def __init__(self, sim, cond, block=None):
-        from ast.expr import BinOp
-        self.sim = sim
+        super().__init__(sim)
         self.parent_block = None
         self.cond = BinOp.inline(cond)
         self.block = Block(sim, []) if block is None else block
diff --git a/ast/math.py b/ast/math.py
index cc38fb8..c872b5b 100644
--- a/ast/math.py
+++ b/ast/math.py
@@ -1,9 +1,10 @@
+from ast.ast_node import ASTNode
 from ast.data_types import Type_Int, Type_Float
 
 
-class Sqrt:
+class Sqrt(ASTNode):
     def __init__(self, sim, expr, cast_type):
-        self.sim = sim
+        super().__init__(sim)
         self.expr = expr
 
     def __str__(self):
diff --git a/ast/memory.py b/ast/memory.py
index bb26d4c..ef8abfe 100644
--- a/ast/memory.py
+++ b/ast/memory.py
@@ -1,12 +1,13 @@
-from ast.expr import BinOp
+from ast.ast_node import ASTNode
+from ast.bin_op import BinOp
 from ast.sizeof import Sizeof
 from functools import reduce
 import operator
 
 
-class Malloc:
+class Malloc(ASTNode):
     def __init__(self, sim, array, sizes, decl=False):
-        self.sim = sim
+        super().__init__(sim)
         self.parent_block = None
         self.array = array
         self.decl = decl
@@ -23,9 +24,9 @@ class Malloc:
         return fn(self)
 
 
-class Realloc:
+class Realloc(ASTNode):
     def __init__(self, sim, array, size):
-        self.sim = sim
+        super().__init__(sim)
         self.parent_block = None
         self.array = array
         self.prim_size = Sizeof(sim, array.type())
diff --git a/ast/operators.py b/ast/operators.py
new file mode 100644
index 0000000..e69de29
diff --git a/ast/properties.py b/ast/properties.py
index bbb6465..44f141c 100644
--- a/ast/properties.py
+++ b/ast/properties.py
@@ -1,3 +1,4 @@
+from ast.ast_node import ASTNode
 from ast.layouts import Layout_AoS
 
 
@@ -38,16 +39,15 @@ class Properties:
         return None
 
 
-class Property:
+class Property(ASTNode):
     def __init__(self, sim, name, dtype, default, volatile, layout=Layout_AoS):
-        self.sim = sim
+        super().__init__(sim)
         self.prop_name = name
         self.prop_type = dtype
         self.prop_layout = layout
         self.default_value = default
         self.volatile = volatile
         self.mutable = True
-        self.flattened = False
 
     def __str__(self):
         return f"Property<{self.prop_name}>"
@@ -70,12 +70,6 @@ class Property:
     def scope(self):
         return self.sim.global_scope
 
-    def __getitem__(self, expr_ast):
-        from ast.expr import BinOp
-        return BinOp(self.sim, self, expr_ast, '[]', True)
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
+    def __getitem__(self, expr):
+        from ast.bin_op import BinOp
+        return BinOp(self.sim, self, expr, '[]', True)
diff --git a/ast/select.py b/ast/select.py
index bdb1546..702bb8f 100644
--- a/ast/select.py
+++ b/ast/select.py
@@ -1,10 +1,11 @@
-from ast.expr import BinOp
+from ast.ast_node import ASTNode
+from ast.bin_op import BinOp
 from ast.lit import as_lit_ast
 
 
-class Select:
+class Select(ASTNode):
     def __init__(self, sim, cond, expr_if, expr_else):
-        self.sim = sim
+        super().__init__(sim)
         self.cond = as_lit_ast(sim, cond)
         self.expr_if = BinOp.inline(as_lit_ast(sim, expr_if))
         self.expr_else = BinOp.inline(as_lit_ast(sim, expr_else))
diff --git a/ast/sizeof.py b/ast/sizeof.py
index e18f89e..7a55106 100644
--- a/ast/sizeof.py
+++ b/ast/sizeof.py
@@ -1,26 +1,11 @@
+from ast.bin_op import ASTTerm
 from ast.data_types import Type_Int
-from ast.expr import BinOp
 
 
-class Sizeof:
+class Sizeof(ASTTerm):
     def __init__(self, sim, data_type):
-        self.sim = sim
+        super().__init__(sim)
         self.data_type = data_type
 
-    def __mul__(self, other):
-        return BinOp(self.sim, self, other, '*')
-
     def type(self):
         return Type_Int
-
-    def is_mutable(self):
-        return False
-
-    def scope(self):
-        return self.sim.global_scope
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/ast/transform.py b/ast/transform.py
index 612b0cf..428d372 100644
--- a/ast/transform.py
+++ b/ast/transform.py
@@ -1,6 +1,6 @@
 from ast.arrays import ArrayAccess
+from ast.bin_op import BinOp
 from ast.data_types import Type_Int, Type_Vector
-from ast.expr import BinOp
 from ast.layouts import Layout_AoS, Layout_SoA
 from ast.lit import Lit
 from ast.loops import Iter
@@ -33,9 +33,6 @@ class Transform:
 
                     ast.map_vector_index(i, flat_index)
 
-        if isinstance(ast, Property):
-            ast.flattened = True
-
         return ast
 
     def simplify(ast):
@@ -130,11 +127,3 @@ class Transform:
                 Transform.reuse_expressions[acc_id].append(ast)
 
         return ast
-
-    def move_loop_invariant_expressions(ast):
-        if isinstance(ast, BinOp):
-            scope = ast.scope()
-            if scope.level > 0:
-                scope.add_expression(ast)
-
-        return ast
diff --git a/ast/utils.py b/ast/utils.py
index 922dcbe..2bc3e43 100644
--- a/ast/utils.py
+++ b/ast/utils.py
@@ -1,13 +1,10 @@
-class Print:
+from ast.ast_node import ASTNode
+
+
+class Print(ASTNode):
     def __init__(self, sim, string):
-        self.sim = sim
+        super().__init__(sim)
         self.string = string
 
     def __str__(self):
         return f"Print<{self.string}>"
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/ast/variables.py b/ast/variables.py
index 5f12719..34f9a9d 100644
--- a/ast/variables.py
+++ b/ast/variables.py
@@ -1,5 +1,6 @@
+from ast.ast_node import ASTNode
 from ast.assign import Assign
-from ast.expr import BinOp
+from ast.bin_op import ASTTerm 
 
 
 class Variables:
@@ -23,9 +24,9 @@ class Variables:
 
         return None
 
-class Var:
+class Var(ASTTerm):
     def __init__(self, sim, var_name, var_type, init_value=0):
-        self.sim = sim
+        super().__init__(sim)
         self.var_name = var_name
         self.var_type = var_type
         self.var_init_value = init_value
@@ -35,42 +36,6 @@ class Var:
     def __str__(self):
         return f"Var<name: {self.var_name}, type: {self.var_type}>"
 
-    def __add__(self, other):
-        return BinOp(self.sim, self, other, '+')
-
-    def __radd__(self, other):
-        return BinOp(self.sim, other, self, '+')
-
-    def __sub__(self, other):
-        return BinOp(self.sim, self, other, '-')
-
-    def __mul__(self, other):
-        return BinOp(self.sim, self, other, '*')
-
-    def __rmul__(self, other):
-        return BinOp(self.sim, other, self, '*')
-
-    def __truediv__(self, other):
-        return BinOp(self.sim, self, other, '/')
-
-    def __rtruediv__(self, other):
-        return BinOp(self.sim, other, self, '/')
-
-    def __lt__(self, other):
-        return BinOp(self.sim, self, other, '<')
-
-    def __le__(self, other):
-        return BinOp(self.sim, self, other, '<=')
-
-    def __gt__(self, other):
-        return BinOp(self.sim, self, other, '>')
-
-    def __ge__(self, other):
-        return BinOp(self.sim, self, other, '>=')
-
-    def inv(self):
-        return BinOp(self.sim, 1.0, self, '/')
-
     def set(self, other):
         return self.sim.add_statement(Assign(self.sim, self, other))
 
@@ -98,24 +63,9 @@ class Var:
     def is_mutable(self):
         return self.mutable
 
-    def scope(self):
-        return self.sim.global_scope
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
 
-
-class VarDecl:
+class VarDecl(ASTNode):
     def __init__(self, sim, var):
-        self.sim = sim
+        super().__init__(sim)
         self.var = var
         self.sim.add_statement(self)
-
-    def children(self):
-        return []
-
-    def transform(self, fn):
-        return fn(self)
diff --git a/code_gen/cgen.py b/code_gen/cgen.py
index 96e0d9d..ea87d2a 100644
--- a/code_gen/cgen.py
+++ b/code_gen/cgen.py
@@ -3,7 +3,7 @@ from ast.arrays import ArrayAccess, ArrayDecl
 from ast.block import Block
 from ast.branches import Branch
 from ast.cast import Cast
-from ast.expr import BinOp, BinOpDef
+from ast.bin_op import BinOp, BinOpDef
 from ast.data_types import Type_Int, Type_Float, Type_Vector
 from ast.lit import Lit
 from ast.loops import For, Iter, ParticleFor, While
diff --git a/graph/graphviz.py b/graph/graphviz.py
index 5a17e9a..21a7e81 100644
--- a/graph/graphviz.py
+++ b/graph/graphviz.py
@@ -1,5 +1,5 @@
 from ast.arrays import Array
-from ast.expr import BinOp, BinOpDef
+from ast.bin_op import BinOp, BinOpDef
 from ast.lit import Lit
 from ast.loops import Iter
 from ast.properties import Property
@@ -7,6 +7,7 @@ from ast.variables import Var
 from ast.visitor import Visitor
 from graphviz import Digraph
 
+
 class ASTGraph:
     def __init__(self, ast_node, filename, ref="AST", max_depth=0):
         self.graph = Digraph(ref, filename=filename, node_attr={'color': 'lightblue2', 'style': 'filled'})
diff --git a/sim/cell_lists.py b/sim/cell_lists.py
index fe69d06..4fe7986 100644
--- a/sim/cell_lists.py
+++ b/sim/cell_lists.py
@@ -1,7 +1,7 @@
+from ast.bin_op import BinOp
 from ast.branches import Branch, Filter
 from ast.cast import Cast
 from ast.data_types import Type_Int
-from ast.expr import BinOp
 from ast.loops import For, ParticleFor
 from ast.utils import Print
 from functools import reduce
diff --git a/sim/properties.py b/sim/properties.py
index f07f9d5..81901fb 100644
--- a/sim/properties.py
+++ b/sim/properties.py
@@ -18,10 +18,7 @@ class PropertiesAlloc:
             if p.type() == Type_Float:
                 sizes = [capacity]
             elif p.type() == Type_Vector:
-                if p.flattened:
-                    sizes = [capacity * self.sim.dimensions]
-                else:
-                    sizes = [capacity, self.sim.dimensions]
+                sizes = [capacity * self.sim.dimensions]
             else:
                 raise Exception("Invalid property type!")
 
diff --git a/sim/resize.py b/sim/resize.py
index eada642..48a1f01 100644
--- a/sim/resize.py
+++ b/sim/resize.py
@@ -26,11 +26,9 @@ class Resize:
                 if properties.is_capacity(self.capacity_var):
                     capacity = sum(self.sim.properties.capacities)
                     for p in properties.all():
-                        sizes = capacity
                         if p.type() == Type_Vector:
-                            if p.flattened:
-                                sizes = capacity * self.sim.dimensions
-                            else:
-                                sizes = capacity * self.sim.dimensions
+                            sizes = capacity * self.sim.dimensions
+                        else:
+                            sizes = capacity
 
                         Realloc(self.sim, p, sizes)
diff --git a/sim/timestep.py b/sim/timestep.py
index 86e5c95..0d54133 100644
--- a/sim/timestep.py
+++ b/sim/timestep.py
@@ -1,5 +1,5 @@
+from ast.bin_op import BinOp
 from ast.block import Block
-from ast.expr import BinOp
 from ast.branches import Branch
 from ast.loops import For
 
-- 
GitLab