From 29c418ae9827059b27a93f99bc5729ae7ac07782 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Sun, 28 Mar 2021 04:19:04 +0200
Subject: [PATCH] Solve several issues for new syntax

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 code_gen/cgen.py            |  32 +++---
 graph/graphviz.py           |  14 +--
 {ast => ir}/__init__.py     |   0
 {ast => ir}/arrays.py       |  16 +--
 {ast => ir}/assign.py       |   8 +-
 {ast => ir}/ast_node.py     |   2 +-
 {ast => ir}/bin_op.py       |  10 +-
 {ast => ir}/block.py        |   2 +-
 {ast => ir}/branches.py     |   6 +-
 {ast => ir}/cast.py         |   4 +-
 {ast => ir}/data_types.py   |   0
 {ast => ir}/layouts.py      |   0
 {ast => ir}/lit.py          |   4 +-
 {ast => ir}/loops.py        |  12 +--
 {ast => ir}/math.py         |   4 +-
 {ast => ir}/memory.py       |   6 +-
 {ast => ir}/mutator.py      |   0
 {ast => ir}/operators.py    |   0
 {ast => ir}/properties.py   |   6 +-
 {ast => ir}/select.py       |   6 +-
 {ast => ir}/sizeof.py       |   4 +-
 {ast => ir}/transform.py    |  14 +--
 {ast => ir}/utils.py        |   2 +-
 {ast => ir}/variables.py    |   6 +-
 {ast => ir}/visitor.py      |   0
 new_syntax.py               | 192 ++++++++++++++++++++++++++++++++++++
 new_syntax/test.py          |  36 -------
 particle.py                 |   2 +-
 sim/arrays.py               |   4 +-
 sim/cell_lists.py           |  12 +--
 sim/kernel_wrapper.py       |   2 +-
 sim/lattice.py              |   4 +-
 sim/particle_simulation.py  |  16 +--
 sim/pbc.py                  |  10 +-
 sim/properties.py           |   8 +-
 sim/read_from_file.py       |   4 +-
 sim/resize.py               |  10 +-
 sim/setup_wrapper.py        |   2 +-
 sim/timestep.py             |   8 +-
 sim/variables.py            |   2 +-
 sim/vtk.py                  |   4 +-
 transformations/LICM.py     |   8 +-
 transformations/flatten.py  |   4 +-
 transformations/simplify.py |   6 +-
 44 files changed, 324 insertions(+), 168 deletions(-)
 rename {ast => ir}/__init__.py (100%)
 rename {ast => ir}/arrays.py (94%)
 rename {ast => ir}/assign.py (85%)
 rename {ast => ir}/ast_node.py (89%)
 rename {ast => ir}/bin_op.py (97%)
 rename {ast => ir}/block.py (98%)
 rename {ast => ir}/branches.py (93%)
 rename {ast => ir}/cast.py (87%)
 rename {ast => ir}/data_types.py (100%)
 rename {ast => ir}/layouts.py (100%)
 rename {ast => ir}/lit.py (89%)
 rename {ast => ir}/loops.py (93%)
 rename {ast => ir}/math.py (81%)
 rename {ast => ir}/memory.py (89%)
 rename {ast => ir}/mutator.py (100%)
 rename {ast => ir}/operators.py (100%)
 rename {ast => ir}/properties.py (93%)
 rename {ast => ir}/select.py (79%)
 rename {ast => ir}/sizeof.py (72%)
 rename {ast => ir}/transform.py (93%)
 rename {ast => ir}/utils.py (84%)
 rename {ast => ir}/variables.py (94%)
 rename {ast => ir}/visitor.py (100%)
 create mode 100644 new_syntax.py
 delete mode 100644 new_syntax/test.py

diff --git a/code_gen/cgen.py b/code_gen/cgen.py
index 8fccf24..6235723 100644
--- a/code_gen/cgen.py
+++ b/code_gen/cgen.py
@@ -1,19 +1,19 @@
-from ast.assign import Assign
-from ast.arrays import ArrayAccess, ArrayDecl
-from ast.block import Block
-from ast.branches import Branch
-from ast.cast import Cast
-from ast.bin_op import BinOp, BinOpDef
-from ast.data_types import Type_Int, Type_Float, Type_Vector
-from ast.lit import Lit
-from ast.loops import For, Iter, ParticleFor, While
-from ast.math import Sqrt
-from ast.memory import Malloc, Realloc
-from ast.properties import Property
-from ast.select import Select
-from ast.sizeof import Sizeof
-from ast.utils import Print
-from ast.variables import Var, VarDecl
+from ir.assign import Assign
+from ir.arrays import ArrayAccess, ArrayDecl
+from ir.block import Block
+from ir.branches import Branch
+from ir.cast import Cast
+from ir.bin_op import BinOp, BinOpDef
+from ir.data_types import Type_Int, Type_Float, Type_Vector
+from ir.lit import Lit
+from ir.loops import For, Iter, ParticleFor, While
+from ir.math import Sqrt
+from ir.memory import Malloc, Realloc
+from ir.properties import Property
+from ir.select import Select
+from ir.sizeof import Sizeof
+from ir.utils import Print
+from ir.variables import Var, VarDecl
 from sim.timestep import Timestep
 from sim.vtk import VTKWrite
 from code_gen.printer import Printer
diff --git a/graph/graphviz.py b/graph/graphviz.py
index 7af2c1e..30c0df9 100644
--- a/graph/graphviz.py
+++ b/graph/graphviz.py
@@ -1,10 +1,10 @@
-from ast.arrays import Array
-from ast.bin_op import BinOp, BinOpDef
-from ast.lit import Lit
-from ast.loops import Iter
-from ast.properties import Property
-from ast.variables import Var
-from ast.visitor import Visitor
+from ir.arrays import Array
+from ir.bin_op import BinOp, BinOpDef
+from ir.lit import Lit
+from ir.loops import Iter
+from ir.properties import Property
+from ir.variables import Var
+from ir.visitor import Visitor
 from graphviz import Digraph
 
 
diff --git a/ast/__init__.py b/ir/__init__.py
similarity index 100%
rename from ast/__init__.py
rename to ir/__init__.py
diff --git a/ast/arrays.py b/ir/arrays.py
similarity index 94%
rename from ast/arrays.py
rename to ir/arrays.py
index 1ea9014..27ffac1 100644
--- a/ast/arrays.py
+++ b/ir/arrays.py
@@ -1,11 +1,11 @@
-from ast.assign import Assign
-from ast.ast_node import ASTNode
-from ast.bin_op import BinOp, ASTTerm
-from ast.data_types import Type_Array
-from ast.layouts import Layout_AoS, Layout_SoA
-from ast.lit import as_lit_ast
-from ast.memory import Realloc
-from ast.variables import Var
+from ir.assign import Assign
+from ir.ast_node import ASTNode
+from ir.bin_op import BinOp, ASTTerm
+from ir.data_types import Type_Array
+from ir.layouts import Layout_AoS, Layout_SoA
+from ir.lit import as_lit_ast
+from ir.memory import Realloc
+from ir.variables import Var
 from functools import reduce
 
 
diff --git a/ast/assign.py b/ir/assign.py
similarity index 85%
rename from ast/assign.py
rename to ir/assign.py
index 08f6a72..388e79c 100644
--- a/ast/assign.py
+++ b/ir/assign.py
@@ -1,6 +1,6 @@
-from ast.ast_node import ASTNode
-from ast.data_types import Type_Vector
-from ast.lit import as_lit_ast
+from ir.ast_node import ASTNode
+from ir.data_types import Type_Vector
+from ir.lit import as_lit_ast
 from functools import reduce
 
 
@@ -14,7 +14,7 @@ class Assign(ASTNode):
             self.assignments = []
 
             for i in range(0, sim.dimensions):
-                from ast.bin_op import BinOp
+                from ir.bin_op import BinOp
                 dim_src = src if not isinstance(src, BinOp) or src.type() != Type_Vector else src[i]
                 self.assignments.append((dest[i], dim_src))
         else:
diff --git a/ast/ast_node.py b/ir/ast_node.py
similarity index 89%
rename from ast/ast_node.py
rename to ir/ast_node.py
index aa1c220..cbfc0f4 100644
--- a/ast/ast_node.py
+++ b/ir/ast_node.py
@@ -1,4 +1,4 @@
-from ast.data_types import Type_Invalid
+from ir.data_types import Type_Invalid
 
 
 class ASTNode:
diff --git a/ast/bin_op.py b/ir/bin_op.py
similarity index 97%
rename from ast/bin_op.py
rename to ir/bin_op.py
index 993c3ba..237a8cc 100644
--- a/ast/bin_op.py
+++ b/ir/bin_op.py
@@ -1,8 +1,8 @@
-from ast.ast_node import ASTNode
-from ast.assign import Assign
-from ast.data_types import Type_Float, Type_Bool, Type_Vector
-from ast.lit import as_lit_ast
-from ast.properties import Property
+from ir.ast_node import ASTNode
+from ir.assign import Assign
+from ir.data_types import Type_Float, Type_Bool, Type_Vector
+from ir.lit import as_lit_ast
+from ir.properties import Property
 
 
 class BinOpDef(ASTNode):
diff --git a/ast/block.py b/ir/block.py
similarity index 98%
rename from ast/block.py
rename to ir/block.py
index 58a9092..89ffcd8 100644
--- a/ast/block.py
+++ b/ir/block.py
@@ -1,4 +1,4 @@
-from ast.ast_node import ASTNode
+from ir.ast_node import ASTNode
 
 
 class Block(ASTNode):
diff --git a/ast/branches.py b/ir/branches.py
similarity index 93%
rename from ast/branches.py
rename to ir/branches.py
index 209dd12..224e671 100644
--- a/ast/branches.py
+++ b/ir/branches.py
@@ -1,6 +1,6 @@
-from ast.ast_node import ASTNode
-from ast.block import Block
-from ast.lit import as_lit_ast
+from ir.ast_node import ASTNode
+from ir.block import Block
+from ir.lit import as_lit_ast
 
 
 class Branch(ASTNode):
diff --git a/ast/cast.py b/ir/cast.py
similarity index 87%
rename from ast/cast.py
rename to ir/cast.py
index 39a8538..b85a596 100644
--- a/ast/cast.py
+++ b/ir/cast.py
@@ -1,5 +1,5 @@
-from ast.ast_node import ASTNode
-from ast.data_types import Type_Int, Type_Float
+from ir.ast_node import ASTNode
+from ir.data_types import Type_Int, Type_Float
 
 
 class Cast(ASTNode):
diff --git a/ast/data_types.py b/ir/data_types.py
similarity index 100%
rename from ast/data_types.py
rename to ir/data_types.py
diff --git a/ast/layouts.py b/ir/layouts.py
similarity index 100%
rename from ast/layouts.py
rename to ir/layouts.py
diff --git a/ast/lit.py b/ir/lit.py
similarity index 89%
rename from ast/lit.py
rename to ir/lit.py
index 8ad0d83..89a55fe 100644
--- a/ast/lit.py
+++ b/ir/lit.py
@@ -1,5 +1,5 @@
-from ast.ast_node import ASTNode
-from ast.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector
+from ir.ast_node import ASTNode
+from ir.data_types import Type_Invalid, Type_Int, Type_Float, Type_Bool, Type_Vector
 
 
 def is_literal(a):
diff --git a/ast/loops.py b/ir/loops.py
similarity index 93%
rename from ast/loops.py
rename to ir/loops.py
index 82d85ad..a1af262 100644
--- a/ast/loops.py
+++ b/ir/loops.py
@@ -1,9 +1,9 @@
-from ast.ast_node import ASTNode
-from ast.bin_op import BinOp, ASTTerm
-from ast.block import Block
-from ast.branches import Filter
-from ast.data_types import Type_Int
-from ast.lit import as_lit_ast
+from ir.ast_node import ASTNode
+from ir.bin_op import BinOp, ASTTerm
+from ir.block import Block
+from ir.branches import Filter
+from ir.data_types import Type_Int
+from ir.lit import as_lit_ast
 
 
 class Iter(ASTTerm):
diff --git a/ast/math.py b/ir/math.py
similarity index 81%
rename from ast/math.py
rename to ir/math.py
index 4e610cb..034f4fa 100644
--- a/ast/math.py
+++ b/ir/math.py
@@ -1,5 +1,5 @@
-from ast.ast_node import ASTNode
-from ast.data_types import Type_Int, Type_Float
+from ir.ast_node import ASTNode
+from ir.data_types import Type_Int, Type_Float
 
 
 class Sqrt(ASTNode):
diff --git a/ast/memory.py b/ir/memory.py
similarity index 89%
rename from ast/memory.py
rename to ir/memory.py
index e853a0d..21b0ace 100644
--- a/ast/memory.py
+++ b/ir/memory.py
@@ -1,6 +1,6 @@
-from ast.ast_node import ASTNode
-from ast.bin_op import BinOp
-from ast.sizeof import Sizeof
+from ir.ast_node import ASTNode
+from ir.bin_op import BinOp
+from ir.sizeof import Sizeof
 from functools import reduce
 import operator
 
diff --git a/ast/mutator.py b/ir/mutator.py
similarity index 100%
rename from ast/mutator.py
rename to ir/mutator.py
diff --git a/ast/operators.py b/ir/operators.py
similarity index 100%
rename from ast/operators.py
rename to ir/operators.py
diff --git a/ast/properties.py b/ir/properties.py
similarity index 93%
rename from ast/properties.py
rename to ir/properties.py
index fee0151..3adc4e2 100644
--- a/ast/properties.py
+++ b/ir/properties.py
@@ -1,5 +1,5 @@
-from ast.ast_node import ASTNode
-from ast.layouts import Layout_AoS
+from ir.ast_node import ASTNode
+from ir.layouts import Layout_AoS
 
 
 class Properties:
@@ -67,5 +67,5 @@ class Property(ASTNode):
         return self.sim.global_scope
 
     def __getitem__(self, expr):
-        from ast.bin_op import BinOp
+        from ir.bin_op import BinOp
         return BinOp(self.sim, self, expr, '[]', True)
diff --git a/ast/select.py b/ir/select.py
similarity index 79%
rename from ast/select.py
rename to ir/select.py
index 1fe9d08..0079d2f 100644
--- a/ast/select.py
+++ b/ir/select.py
@@ -1,6 +1,6 @@
-from ast.ast_node import ASTNode
-from ast.bin_op import BinOp
-from ast.lit import as_lit_ast
+from ir.ast_node import ASTNode
+from ir.bin_op import BinOp
+from ir.lit import as_lit_ast
 
 
 class Select(ASTNode):
diff --git a/ast/sizeof.py b/ir/sizeof.py
similarity index 72%
rename from ast/sizeof.py
rename to ir/sizeof.py
index 7a55106..9d042bb 100644
--- a/ast/sizeof.py
+++ b/ir/sizeof.py
@@ -1,5 +1,5 @@
-from ast.bin_op import ASTTerm
-from ast.data_types import Type_Int
+from ir.bin_op import ASTTerm
+from ir.data_types import Type_Int
 
 
 class Sizeof(ASTTerm):
diff --git a/ast/transform.py b/ir/transform.py
similarity index 93%
rename from ast/transform.py
rename to ir/transform.py
index 428d372..a9c5eb0 100644
--- a/ast/transform.py
+++ b/ir/transform.py
@@ -1,10 +1,10 @@
-from ast.arrays import ArrayAccess
-from ast.bin_op import BinOp
-from ast.data_types import Type_Int, Type_Vector
-from ast.layouts import Layout_AoS, Layout_SoA
-from ast.lit import Lit
-from ast.loops import Iter
-from ast.properties import Property
+from ir.arrays import ArrayAccess
+from ir.bin_op import BinOp
+from ir.data_types import Type_Int, Type_Vector
+from ir.layouts import Layout_AoS, Layout_SoA
+from ir.lit import Lit
+from ir.loops import Iter
+from ir.properties import Property
 
 
 class Transform:
diff --git a/ast/utils.py b/ir/utils.py
similarity index 84%
rename from ast/utils.py
rename to ir/utils.py
index 2bc3e43..a4446f1 100644
--- a/ast/utils.py
+++ b/ir/utils.py
@@ -1,4 +1,4 @@
-from ast.ast_node import ASTNode
+from ir.ast_node import ASTNode
 
 
 class Print(ASTNode):
diff --git a/ast/variables.py b/ir/variables.py
similarity index 94%
rename from ast/variables.py
rename to ir/variables.py
index aed0821..b4b40da 100644
--- a/ast/variables.py
+++ b/ir/variables.py
@@ -1,6 +1,6 @@
-from ast.ast_node import ASTNode
-from ast.assign import Assign
-from ast.bin_op import ASTTerm 
+from ir.ast_node import ASTNode
+from ir.assign import Assign
+from ir.bin_op import ASTTerm 
 
 
 class Variables:
diff --git a/ast/visitor.py b/ir/visitor.py
similarity index 100%
rename from ast/visitor.py
rename to ir/visitor.py
diff --git a/new_syntax.py b/new_syntax.py
new file mode 100644
index 0000000..7f9003f
--- /dev/null
+++ b/new_syntax.py
@@ -0,0 +1,192 @@
+import part_prot as pt
+from ir.assign import Assign
+from ir.bin_op import BinOp
+import ast
+import inspect
+
+
+def delta(i, j):
+    return position[i] - position[j]
+
+
+def rsq(i, j):
+    dp = delta(i, j)
+    return dp.x() * dp.x() + dp.y() * dp.y() + dp.z() * dp.z()
+
+
+def lj(i, j):
+    sr2 = 1.0 / rsq
+    sr6 = sr2 * sr2 * sr2 * sigma6
+    force[i] += delta * 48.0 * sr6 * (sr6 - 0.5) * sr2 * epsilon
+
+
+def euler(i):
+    velocity[i] += dt * force[i] / mass[i]
+    position[i] += dt * velocity[i]
+
+
+class UndefinedSymbol():
+    def __init__(self, symbol_id):
+        self.symbol_id = symbol_id
+
+
+class FetchParticleFuncInfo(ast.NodeVisitor):
+    def __init__(self):
+        self._params = []
+
+    def visit_arg(self, node):
+        self._params.append(node.arg)
+
+    def nparams(self):
+        return len(self._params)
+
+    def params(self):
+        return self._params
+
+
+class BuildParticleIR(ast.NodeVisitor):
+    def get_op(op):
+        if isinstance(op, ast.Add):
+            return '+'
+
+        if isinstance(op, ast.Sub):
+            return '-'
+
+        if isinstance(op, ast.Mult):
+            return '*'
+
+        if isinstance(op, ast.Div):
+            return '/'
+
+        raise Exception("Invalid operator: {}".format(ast.dump(op)))
+
+    def parse_function_and_get_return_value(func, args):
+        return None
+
+    def __init__(self, sim, ctx_symbols={}, ctx_calls=[]):
+        self.sim = sim
+        self.ctx_symbols = ctx_symbols
+        self.ctx_calls = ctx_calls
+
+    def add_symbols(self, symbols):
+        self.ctx_symbols.update(symbols)
+
+    def visit_Assign(self, node):
+        assert len(node.targets) == 1, "Only one target is allowed on assignments!"
+        lhs = self.visit(node.targets[0])
+        rhs = self.visit(node.value)
+
+        if isinstance(lhs, UndefinedSymbol):
+            self.add_symbols({lhs.symbol_id: rhs})
+        else:
+            print(lhs)
+            lhs = rhs
+
+    def visit_AugAssign(self, node):
+        lhs = self.visit(node.target)
+        rhs = self.visit(node.value)
+
+        if isinstance(lhs, UndefinedSymbol):
+            self.add_symbols({lhs.symbol_id: rhs})
+        else:
+            print(lhs)
+            lhs += rhs
+
+    def visit_BinOp(self, node):
+        print(ast.dump(node))
+        lhs = self.visit(node.left)
+        assert not isinstance(lhs, UndefinedSymbol), f"Undefined lhs used in BinOp: {lhs.symbol_id}"
+        rhs = self.visit(node.right)
+        assert not isinstance(rhs, UndefinedSymbol), f"Undefined rhs used in BinOp: {rhs.symbol_id}"
+        return BinOp(self.sim, lhs, rhs, BuildParticleIR.get_op(node.op))
+
+    def visit_Call(self, node):
+        func = self.visit(node.func)
+        args = [self.visit(a) for a in node.args]
+
+        for c in self.ctx_calls:
+            if c['func'] == func and len(c['args']) == len(args) and all([c['args'][a] == args[a] for a in range(0, len(args))]):
+                return c['value']
+
+        value = BuildParticleIR.parse_function_and_get_return_value(func, args)
+        self.ctx_calls.append({'func': func, 'args': args, 'value': value})
+        return value
+
+    def visit_Index(self, node):
+        return self.visit(node.value)
+
+    def visit_Name(self, node):
+        as_sym = self.ctx_symbols[node.id] if node.id in self.ctx_symbols else None
+        if as_sym is not None:
+            return as_sym
+
+        as_array = self.sim.array(node.id)
+        if as_array is not None:
+            return as_array
+
+        as_prop = self.sim.property(node.id)
+        if as_prop is not None:
+            return as_prop
+
+        as_var = self.sim.var(node.id)
+        if as_var is not None:
+            return as_var
+
+        return UndefinedSymbol(node.id)
+
+    def visit_Num(self, node):
+        return node.n
+
+    def visit_Subscript(self, node):
+        print(ast.dump(node))
+        return self.visit(node.value)[self.visit(node.slice)]
+
+
+def add_kernel(sim, func, cutoff_radius=None, position=None, symbols={}):
+    src = inspect.getsource(func)
+    tree = ast.parse(src, mode='exec')
+    print(ast.dump(ast.parse(src, mode='exec')))
+
+    # Fetch function info
+    info = FetchParticleFuncInfo()
+    info.visit(tree)
+    params = info.params()
+    nparams = info.nparams()
+
+    # Start building IR
+    ir = BuildParticleIR(sim, symbols)
+
+    if nparams == 1:
+        for i in sim.particles():
+            ir.add_symbols({params[0]: i})
+            ir.visit(tree)
+
+    elif nparams == 2:
+        for i, j, delta, rsq in psim.particle_pairs(cutoff_radius, sim.property(position)):
+            ir.add_symbols({params[0]: i, params[1]: j, 'delta': delta, 'rsq': rsq})
+            ir.visit(tree)
+    else:
+        raise Exception(f"Invalid number of parameters: {nparams}")
+
+
+dt = 0.005
+cutoff_radius = 2.5
+skin = 0.3
+sigma = 1.0
+epsilon = 1.0
+sigma6 = sigma ** 6
+
+psim = pt.simulation("lj_ns")
+psim.add_real_property('mass', 1.0)
+psim.add_vector_property('position')
+psim.add_vector_property('velocity')
+psim.add_vector_property('force', vol=True)
+psim.from_file("data/minimd_setup_4x4x4.input", ['mass', 'position', 'velocity'])
+psim.create_cell_lists(2.8, 2.8)
+psim.periodic(2.8)
+psim.vtk_output("output/test")
+
+add_kernel(psim, lj, cutoff_radius, 'position', {'sigma6': sigma6, 'epsilon': epsilon})
+add_kernel(psim, euler, symbols={'dt': dt})
+
+psim.generate()
diff --git a/new_syntax/test.py b/new_syntax/test.py
deleted file mode 100644
index 48e25a3..0000000
--- a/new_syntax/test.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import ast
-import inspect
-
-
-def lj(i, j):
-    sr2 = 1.0 / rsq(i, j)
-    sr6 = sr2 * sr2 * sr2 * sigma6
-    force[i] += delta(i, j) * 48.0 * sr6 * (sr6 - 0.5) * sr2 * epsilon
-
-
-def euler(i):
-    velocity[i] += dt * force[i] / mass[i]
-    position[i] += dt * velocity[i]
-
-
-class BuildParticleAST(ast.NodeVisitor):
-    def __init__(self, sim):
-        self.sim = sim
-        self.block = Block([])
-        self.temp_values = {}
-
-    def visit_Assign(self, node):
-        print(node.targets[0].id)
-        print(node.value)
-
-    def visit_AugAssign(self, node):
-        print(node.targets[0].id)
-        print(node.value)
-
-
-lj_src = inspect.getsource(lj)
-#print(ast.dump(ast.parse(lj_src, mode='eval'), indent=4))
-tree = ast.parse(lj_src, mode='exec')
-build = BuildParticleAST()
-build.visit(tree)
-#print(ast.dump(ast.parse(lj_src, mode='exec')))
diff --git a/particle.py b/particle.py
index e3bd789..2dc621c 100644
--- a/particle.py
+++ b/particle.py
@@ -1,5 +1,5 @@
 import part_prot as pt
-from ast.layouts import Layout_SoA
+from ir.layouts import Layout_SoA
 
 dt = 0.005
 cutoff_radius = 2.5
diff --git a/sim/arrays.py b/sim/arrays.py
index 0d01954..7406019 100644
--- a/sim/arrays.py
+++ b/sim/arrays.py
@@ -1,5 +1,5 @@
-from ast.memory import Malloc
-from ast.arrays import ArrayDecl
+from ir.memory import Malloc
+from ir.arrays import ArrayDecl
 
 
 class ArraysDecl:
diff --git a/sim/cell_lists.py b/sim/cell_lists.py
index 4fe7986..923c4de 100644
--- a/sim/cell_lists.py
+++ b/sim/cell_lists.py
@@ -1,9 +1,9 @@
-from ast.bin_op import BinOp
-from ast.branches import Branch, Filter
-from ast.cast import Cast
-from ast.data_types import Type_Int
-from ast.loops import For, ParticleFor
-from ast.utils import Print
+from ir.bin_op import BinOp
+from ir.branches import Branch, Filter
+from ir.cast import Cast
+from ir.data_types import Type_Int
+from ir.loops import For, ParticleFor
+from ir.utils import Print
 from functools import reduce
 from sim.resize import Resize
 import math
diff --git a/sim/kernel_wrapper.py b/sim/kernel_wrapper.py
index d8146fb..7ac32e3 100644
--- a/sim/kernel_wrapper.py
+++ b/sim/kernel_wrapper.py
@@ -1,4 +1,4 @@
-from ast.block import Block
+from ir.block import Block
 
 
 class KernelWrapper():
diff --git a/sim/lattice.py b/sim/lattice.py
index 4ea51d1..7067141 100644
--- a/sim/lattice.py
+++ b/sim/lattice.py
@@ -1,5 +1,5 @@
-from ast.data_types import Type_Vector
-from ast.loops import For
+from ir.data_types import Type_Vector
+from ir.loops import For
 
 
 class ParticleLattice():
diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py
index b21b5e3..3dfa57c 100644
--- a/sim/particle_simulation.py
+++ b/sim/particle_simulation.py
@@ -1,11 +1,11 @@
-from ast.arrays import Arrays
-from ast.block import Block
-from ast.branches import Filter
-from ast.data_types import Type_Int, Type_Float, Type_Vector
-from ast.layouts import Layout_AoS
-from ast.loops import ParticleFor, NeighborFor
-from ast.properties import Properties
-from ast.variables import Variables
+from ir.arrays import Arrays
+from ir.block import Block
+from ir.branches import Filter
+from ir.data_types import Type_Int, Type_Float, Type_Vector
+from ir.layouts import Layout_AoS
+from ir.loops import ParticleFor, NeighborFor
+from ir.properties import Properties
+from ir.variables import Variables
 from graph.graphviz import ASTGraph
 from sim.arrays import ArraysDecl
 from sim.cell_lists import CellLists, CellListsBuild, CellListsStencilBuild
diff --git a/sim/pbc.py b/sim/pbc.py
index f6e7f66..e996941 100644
--- a/sim/pbc.py
+++ b/sim/pbc.py
@@ -1,8 +1,8 @@
-from ast.branches import Branch, Filter
-from ast.data_types import Type_Int
-from ast.loops import For, ParticleFor
-from ast.utils import Print
-from ast.select import Select
+from ir.branches import Branch, Filter
+from ir.data_types import Type_Int
+from ir.loops import For, ParticleFor
+from ir.utils import Print
+from ir.select import Select
 from sim.resize import Resize
 
 
diff --git a/sim/properties.py b/sim/properties.py
index 81901fb..b9a87e9 100644
--- a/sim/properties.py
+++ b/sim/properties.py
@@ -1,7 +1,7 @@
-from ast.data_types import Type_Float, Type_Vector
-from ast.loops import ParticleFor
-from ast.memory import Malloc, Realloc
-from ast.utils import Print
+from ir.data_types import Type_Float, Type_Vector
+from ir.loops import ParticleFor
+from ir.memory import Malloc, Realloc
+from ir.utils import Print
 
 
 class PropertiesAlloc:
diff --git a/sim/read_from_file.py b/sim/read_from_file.py
index 42608a1..546966c 100644
--- a/sim/read_from_file.py
+++ b/sim/read_from_file.py
@@ -1,5 +1,5 @@
-from ast.data_types import Type_Int, Type_Float, Type_Vector
-from ast.loops import For
+from ir.data_types import Type_Int, Type_Float, Type_Vector
+from ir.loops import For
 from sim.grid import Grid
 
 class ReadFromFile():
diff --git a/sim/resize.py b/sim/resize.py
index 48a1f01..afeb9cd 100644
--- a/sim/resize.py
+++ b/sim/resize.py
@@ -1,8 +1,8 @@
-from ast.branches import Filter
-from ast.data_types import Type_Int, Type_Float, Type_Vector
-from ast.loops import While
-from ast.memory import Realloc
-from ast.utils import Print
+from ir.branches import Filter
+from ir.data_types import Type_Int, Type_Float, Type_Vector
+from ir.loops import While
+from ir.memory import Realloc
+from ir.utils import Print
 
 class Resize:
     def __init__(self, sim, capacity_var, grow_fn=None):
diff --git a/sim/setup_wrapper.py b/sim/setup_wrapper.py
index bce0281..bc0d535 100644
--- a/sim/setup_wrapper.py
+++ b/sim/setup_wrapper.py
@@ -1,4 +1,4 @@
-from ast.block import Block
+from ir.block import Block
 
 
 class SetupWrapper():
diff --git a/sim/timestep.py b/sim/timestep.py
index 74aa241..f15ee7c 100644
--- a/sim/timestep.py
+++ b/sim/timestep.py
@@ -1,7 +1,7 @@
-from ast.bin_op import BinOp
-from ast.block import Block
-from ast.branches import Branch
-from ast.loops import For
+from ir.bin_op import BinOp
+from ir.block import Block
+from ir.branches import Branch
+from ir.loops import For
 
 
 class Timestep:
diff --git a/sim/variables.py b/sim/variables.py
index 2f616b1..4d37e32 100644
--- a/sim/variables.py
+++ b/sim/variables.py
@@ -1,4 +1,4 @@
-from ast.variables import VarDecl
+from ir.variables import VarDecl
 
 
 class VariablesDecl:
diff --git a/sim/vtk.py b/sim/vtk.py
index 284533a..c6125c4 100644
--- a/sim/vtk.py
+++ b/sim/vtk.py
@@ -1,5 +1,5 @@
-from ast.lit import as_lit_ast
-from ast.ast_node import ASTNode
+from ir.lit import as_lit_ast
+from ir.ast_node import ASTNode
 
 
 class VTKWrite(ASTNode):
diff --git a/transformations/LICM.py b/transformations/LICM.py
index ffa3327..74f9969 100644
--- a/transformations/LICM.py
+++ b/transformations/LICM.py
@@ -1,7 +1,7 @@
-from ast.bin_op import BinOp
-from ast.loops import For, While
-from ast.mutator import Mutator
-from ast.visitor import Visitor
+from ir.bin_op import BinOp
+from ir.loops import For, While
+from ir.mutator import Mutator
+from ir.visitor import Visitor
 
 
 class SetBlockVariants(Mutator):
diff --git a/transformations/flatten.py b/transformations/flatten.py
index b414079..d673c19 100644
--- a/transformations/flatten.py
+++ b/transformations/flatten.py
@@ -1,5 +1,5 @@
-from ast.layouts import Layout_AoS, Layout_SoA
-from ast.mutator import Mutator
+from ir.layouts import Layout_AoS, Layout_SoA
+from ir.mutator import Mutator
 
 
 class FlattenPropertyAccesses(Mutator):
diff --git a/transformations/simplify.py b/transformations/simplify.py
index d59dd59..83d6cac 100644
--- a/transformations/simplify.py
+++ b/transformations/simplify.py
@@ -1,6 +1,6 @@
-from ast.data_types import Type_Int
-from ast.lit import Lit
-from ast.mutator import Mutator
+from ir.data_types import Type_Int
+from ir.lit import Lit
+from ir.mutator import Mutator
 
 
 class SimplifyExpressions(Mutator):
-- 
GitLab