From a7432bdd313ae9fd692de4dbb1c2d56e49453a64 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Tue, 16 Nov 2021 02:01:04 +0100
Subject: [PATCH] Introduce modules as first step for GPU kernels

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 src/pairs/code_gen/cgen.py                    | 77 ++++++++++++++++---
 src/pairs/code_gen/printer.py                 |  2 +-
 src/pairs/ir/arrays.py                        | 12 ---
 src/pairs/ir/ast_node.py                      |  3 -
 src/pairs/ir/bin_op.py                        |  9 ---
 src/pairs/ir/block.py                         |  5 +-
 src/pairs/ir/branches.py                      | 12 +--
 src/pairs/ir/cast.py                          |  3 -
 src/pairs/ir/functions.py                     |  1 +
 src/pairs/ir/loops.py                         | 11 +--
 src/pairs/ir/math.py                          |  6 --
 src/pairs/ir/module.py                        | 72 +++++++++++++++++
 src/pairs/ir/mutator.py                       |  4 +
 src/pairs/ir/properties.py                    |  6 --
 src/pairs/ir/variables.py                     | 16 ++++
 src/pairs/sim/interaction.py                  |  4 +-
 src/pairs/sim/simulation.py                   | 23 ++++--
 .../fetch_modules_references.py               | 62 +++++++++++++++
 .../replace_modules_by_calls.py               | 15 ++++
 19 files changed, 272 insertions(+), 71 deletions(-)
 create mode 100644 src/pairs/ir/module.py
 create mode 100644 src/pairs/transformations/fetch_modules_references.py
 create mode 100644 src/pairs/transformations/replace_modules_by_calls.py

diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index d3501b0..bc3d1c2 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -12,11 +12,12 @@ from pairs.ir.lit import Lit
 from pairs.ir.loops import For, Iter, ParticleFor, While
 from pairs.ir.math import Ceil, Sqrt
 from pairs.ir.memory import Malloc, Realloc
+from pairs.ir.module import Module_Call
 from pairs.ir.properties import Property, PropertyAccess, PropertyList, RegisterProperty, UpdateProperty
 from pairs.ir.select import Select
 from pairs.ir.sizeof import Sizeof
 from pairs.ir.utils import Print
-from pairs.ir.variables import Var, VarDecl
+from pairs.ir.variables import Var, VarDecl, Deref
 from pairs.sim.timestep import Timestep
 from pairs.code_gen.printer import Printer
 
@@ -51,12 +52,46 @@ class CGen:
         self.print("")
         self.print("using namespace pairs;")
         self.print("")
-        self.print("int main() {")
-        self.print("    PairsSim *ps = new PairsSim();")
-        self.generate_statement(ast_node)
-        self.print("}")
+        for module in self.sim.modules():
+            self.generate_module(module)
         self.print.end()
 
+    def generate_module(self, module):
+        if module.name == 'main':
+            self.print("int main() {")
+            self.print("    PairsSim *ps = new PairsSim();")
+            self.generate_statement(module.block)
+            self.print("    return 0;")
+            self.print("}")
+
+        else:
+            module_params = ""
+            for var in module.read_only_variables():
+                type_kw = CGen.type2keyword(var.type())
+                decl = f"{type_kw} {var.name()}"
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            for var in module.write_variables():
+                type_kw = CGen.type2keyword(var.type())
+                decl = f"{type_kw} *{var.name()}"
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            for array in module.arrays():
+                type_kw = CGen.type2keyword(array.type())
+                decl = f"{type_kw} *{array.name()}"
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            for prop in module.properties():
+                type_kw = CGen.type2keyword(prop.type())
+                decl = f"{type_kw} *{prop.name()}"
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            self.print(f"void {module.name}({module_params}) {{")
+            self.print.add_indent(4)
+            self.generate_statement(module.block)
+            self.print.add_indent(-4)
+            self.print("}")
+
     def generate_statement(self, ast_node, bypass_checking=False):
         if isinstance(ast_node, ArrayDecl):
             tkw = CGen.type2keyword(ast_node.array.type())
@@ -70,10 +105,10 @@ class CGen:
                 self.print(f"{dest} = {src};")
 
         if isinstance(ast_node, Block):
-            self.print.add_ind(4)
+            self.print.add_indent(4)
             for stmt in ast_node.statements():
                 self.generate_statement(stmt)
-            self.print.add_ind(-4)
+            self.print.add_indent(-4)
 
         # TODO: Why there are Decls for other types?
         if isinstance(ast_node, Decl):
@@ -133,9 +168,9 @@ class CGen:
             self.print(f"pairs::copy_to_device({ast_node.prop.name()})")
 
         if isinstance(ast_node, KernelBlock):
-            self.print.add_ind(-4)
+            self.print.add_indent(-4)
             self.generate_statement(ast_node.block)
-            self.print.add_ind(4) # Workaround for fixing indentation of kernels
+            self.print.add_indent(4) # Workaround for fixing indentation of kernels
 
         if isinstance(ast_node, For):
             iterator = self.generate_expression(ast_node.iterator)
@@ -166,6 +201,26 @@ class CGen:
             else:
                 self.print(f"{array_name} = ({tkw} *) malloc({size});")
 
+        if isinstance(ast_node, Module_Call):
+            module_params = ""
+            for var in module.read_only_variables():
+                decl = var.name()
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            for var in module.write_variables():
+                decl = f"&{var.name()}"
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            for array in module.arrays():
+                decl = array.name()
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            for prop in module.properties():
+                decl = prop.name()
+                module_params += decl if len(module_params) <= 0 else f", {decl}"
+
+            self.print(f"{module.name}({module_params});")
+
         if isinstance(ast_node, Print):
             self.print(f"fprintf(stdout, \"{ast_node.string}\\n\");")
             self.print(f"fflush(stdout);")
@@ -269,6 +324,10 @@ class CGen:
             expr = self.generate_expression(ast_node.expr)
             return f"ceil({expr})"
 
+        if isinstance(ast_node, Deref):
+            var = self.generate_expression(ast_node.var)
+            return f"*{var}"
+
         if isinstance(ast_node, Iter):
             assert mem is False, "Iterator is not lvalue!"
             return f"i{ast_node.id()}"
diff --git a/src/pairs/code_gen/printer.py b/src/pairs/code_gen/printer.py
index 9529946..4d73b76 100644
--- a/src/pairs/code_gen/printer.py
+++ b/src/pairs/code_gen/printer.py
@@ -4,7 +4,7 @@ class Printer:
         self.stream = None
         self.indent = 0
 
-    def add_ind(self, offset):
+    def add_indent(self, offset):
         self.indent += offset
 
     def start(self):
diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index 49aeb88..a8721b2 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -160,18 +160,6 @@ class ArrayAccess(ASTTerm):
         return self.array.type()
         # return self.array.type() if self.index is None else Type_Array
 
-    def scope(self):
-        if self.index is None:
-            scope = None
-            for i in self.indexes:
-                index_scp = i.scope()
-                if scope is None or index_scp > scope:
-                    scope = index_scp
-
-            return scope
-
-        return self.index.scope()
-
     def children(self):
         if self.index is not None:
             return [self.array, self.index]
diff --git a/src/pairs/ir/ast_node.py b/src/pairs/ir/ast_node.py
index 94e75fa..71ce1c1 100644
--- a/src/pairs/ir/ast_node.py
+++ b/src/pairs/ir/ast_node.py
@@ -12,8 +12,5 @@ class ASTNode:
     def type(self):
         return Type_Invalid
 
-    def scope(self):
-        return self.sim.global_scope
-
     def children(self):
         return []
diff --git a/src/pairs/ir/bin_op.py b/src/pairs/ir/bin_op.py
index 95f48dc..8990ad9 100644
--- a/src/pairs/ir/bin_op.py
+++ b/src/pairs/ir/bin_op.py
@@ -42,7 +42,6 @@ class BinOp(VectorExpression):
         self.inlined = False
         self.generated = False
         self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op)
-        self.bin_op_scope = None
         self.terminals = set()
         self.decl = Decl(sim, self)
 
@@ -132,14 +131,6 @@ class BinOp(VectorExpression):
     def add_terminal(self, terminal):
         self.terminals.add(terminal)
 
-    def scope(self):
-        if self.bin_op_scope is None:
-            lhs_scp = self.lhs.scope()
-            rhs_scp = self.rhs.scope()
-            self.bin_op_scope = lhs_scp if lhs_scp > rhs_scp else rhs_scp
-
-        return self.bin_op_scope
-
     def children(self):
         return [self.lhs, self.rhs] + list(super().children())
 
diff --git a/src/pairs/ir/block.py b/src/pairs/ir/block.py
index b9c04b8..4789b38 100644
--- a/src/pairs/ir/block.py
+++ b/src/pairs/ir/block.py
@@ -1,4 +1,5 @@
 from pairs.ir.ast_node import ASTNode
+from pairs.ir.module import Module
 
 
 def pairs_block(func):
@@ -16,7 +17,9 @@ def pairs_device_block(func):
         sim = args[0].sim # self.sim
         sim.clear_block()
         func(*args, **kwargs)
-        return KernelBlock(sim, sim.block)
+        module = Module(sim, block=KernelBlock(sim, sim.block))
+        sim.add_module(module)
+        return module
 
     return inner
 
diff --git a/src/pairs/ir/branches.py b/src/pairs/ir/branches.py
index 75bbfe7..f6c3295 100644
--- a/src/pairs/ir/branches.py
+++ b/src/pairs/ir/branches.py
@@ -17,14 +17,14 @@ class Branch(ASTNode):
     def __iter__(self):
         self.sim.add_statement(self)
         self.switch = True
-        self.sim.enter_scope(self)
+        self.sim.enter(self)
         yield self.switch
-        self.sim.leave_scope()
+        self.sim.leave()
 
         self.switch = False
-        self.sim.enter_scope(self)
+        self.sim.enter(self)
         yield self.switch
-        self.sim.leave_scope()
+        self.sim.leave()
 
     def add_statement(self, stmt):
         if self.switch:
@@ -43,9 +43,9 @@ class Filter(Branch):
 
     def __iter__(self):
         self.sim.add_statement(self)
-        self.sim.enter_scope(self)
+        self.sim.enter(self)
         yield
-        self.sim.leave_scope()
+        self.sim.leave()
 
     def add_statement(self, stmt):
         self.block_if.add_statement(stmt)
diff --git a/src/pairs/ir/cast.py b/src/pairs/ir/cast.py
index 061eb47..fe6674a 100644
--- a/src/pairs/ir/cast.py
+++ b/src/pairs/ir/cast.py
@@ -20,8 +20,5 @@ class Cast(ASTTerm):
     def type(self):
         return self.cast_type
 
-    def scope(self):
-        return self.expr.scope()
-
     def children(self):
         return [self.expr]
diff --git a/src/pairs/ir/functions.py b/src/pairs/ir/functions.py
index 8b01b90..0d30f49 100644
--- a/src/pairs/ir/functions.py
+++ b/src/pairs/ir/functions.py
@@ -22,6 +22,7 @@ class Call(ASTTerm):
     def children(self):
         return self.params
 
+
 class Call_Int(Call):
     def __init__(self, sim, func_name, parameters):
         super().__init__(sim, func_name, parameters, Type_Int)
diff --git a/src/pairs/ir/loops.py b/src/pairs/ir/loops.py
index 9f18739..099c5c4 100644
--- a/src/pairs/ir/loops.py
+++ b/src/pairs/ir/loops.py
@@ -27,9 +27,6 @@ class Iter(ASTTerm):
     def type(self):
         return Type_Int
 
-    def scope(self):
-        return self.loop.block
-
     def __eq__(self, other):
         if isinstance(other, Iter):
             return self.iter_id == other.iter_id
@@ -59,9 +56,9 @@ class For(ASTNode):
 
     def __iter__(self):
         self.sim.add_statement(self)
-        self.sim.enter_scope(self)
+        self.sim.enter(self)
         yield self.iterator
-        self.sim.leave_scope()
+        self.sim.leave()
 
     def add_statement(self, stmt):
         self.block.add_statement(stmt)
@@ -90,9 +87,9 @@ class While(ASTNode):
 
     def __iter__(self):
         self.sim.add_statement(self)
-        self.sim.enter_scope(self)
+        self.sim.enter(self)
         yield
-        self.sim.leave_scope()
+        self.sim.leave()
 
     def add_statement(self, stmt):
         self.block.add_statement(stmt)
diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py
index 9ef762f..09e4171 100644
--- a/src/pairs/ir/math.py
+++ b/src/pairs/ir/math.py
@@ -13,9 +13,6 @@ class Sqrt(ASTTerm):
     def type(self):
         return self.expr.type()
 
-    def scope(self):
-        return self.expr.scope()
-
     def children(self):
         return [self.expr]
 
@@ -32,8 +29,5 @@ class Ceil(ASTTerm):
     def type(self):
         return Type_Int
 
-    def scope(self):
-        return self.expr.scope()
-
     def children(self):
         return [self.expr]
diff --git a/src/pairs/ir/module.py b/src/pairs/ir/module.py
new file mode 100644
index 0000000..900e565
--- /dev/null
+++ b/src/pairs/ir/module.py
@@ -0,0 +1,72 @@
+from pairs.ir.arrays import Array
+from pairs.ir.ast_node import ASTNode
+from pairs.ir.properties import Property
+from pairs.ir.variables import Var
+
+
+class Module(ASTNode):
+    last_module = 0
+
+    def __init__(self, sim, name=None, block=None):
+        super().__init__(sim)
+        self._name = name if name is not None else "module_" + str(Module.last_module)
+        self._variables = {}
+        self._arrays = set()
+        self._properties = set()
+        self._block = block
+        sim.add_module(self)
+        Module.last_module += 1
+
+    @property
+    def name(self):
+        return self._name
+
+    @property
+    def block(self):
+        return self._block
+
+    def variables(self):
+        return self._variables
+
+    def read_only_variables(self):
+        return [v for v in self._variables if not self._variables[v]]
+
+    def write_variables(self):
+        return [v for v in self._variables if self._variables[v]]
+
+    def arrays(self):
+        return self._arrays
+
+    def properties(self):
+        return self._properties
+
+    def add_array(self, array, write=False):
+        array_list = array if isinstance(array, list) else [array]
+        for a in array_list:
+            assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!"
+            self._arrays.add(a)
+
+    def add_variable(self, variable, write=False):
+        variable_list = variable if isinstance(variable, list) else [variable]
+        for v in variable_list:
+            assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!"
+            if v not in self._variables:
+                self._variables[v] = write
+            else:
+                self._variables[v] = self._variables[v] or write
+
+    def add_property(self, prop, write=False):
+        prop_list = prop if isinstance(prop, list) else [prop]
+        for p in prop_list:
+            assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!"
+            self._properties.add(p)
+
+    def children(self):
+        return [self._block]
+
+
+class Module_Call(ASTNode):
+    def __init__(self, sim, module):
+        assert isinstance(module, Module), "Module_Call(): given parameter is not of type Module!"
+        super().__init__(sim)
+        self.module = module
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index cb394cf..09c9f73 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -84,6 +84,10 @@ class Mutator:
         ast_node.size = self.mutate(ast_node.size)
         return ast_node
 
+    def mutate_Module(self, ast_node):
+        ast_node._block = self.mutate(ast_node._block)
+        return ast_node
+
     def mutate_Realloc(self, ast_node):
         ast_node.array = self.mutate(ast_node.array)
         ast_node.size = self.mutate(ast_node.size)
diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py
index c17c78a..d23a154 100644
--- a/src/pairs/ir/properties.py
+++ b/src/pairs/ir/properties.py
@@ -84,9 +84,6 @@ class Property(ASTNode):
     def sizes(self):
         return [self.sim.particle_capacity] if self.prop_type != Type_Vector else [self.sim.ndims(), self.sim.particle_capacity]
 
-    def scope(self):
-        return self.sim.global_scope
-
     def __getitem__(self, expr):
         return PropertyAccess(self.sim, self, expr)
 
@@ -144,9 +141,6 @@ class PropertyAccess(ASTTerm, VectorExpression):
     def add_terminal(self, terminal):
         self.terminals.add(terminal)
 
-    def scope(self):
-        return self.index.scope()
-
     def children(self):
         return [self.prop, self.index] + list(super().children())
 
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index 102dc2d..2990e91 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -67,3 +67,19 @@ class VarDecl(ASTNode):
         super().__init__(sim)
         self.var = var
         self.sim.add_statement(self)
+
+
+class Deref(ASTTerm):
+    def __init__(self, sim, var):
+        super().__init__(sim)
+        self._var = var
+
+    def __str__(self):
+        return f"Deref<var: {self.var.name()}>"
+
+    @property
+    def var(self):
+        return self._var
+
+    def type(self):
+        return self._var.type()
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index 9658ee7..074f416 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -57,9 +57,9 @@ class ParticleInteraction(Lowerable):
 
     def __iter__(self):
         self.sim.add_statement(self)
-        self.sim.enter_scope(self)
+        self.sim.enter(self)
         yield self.i, self.j
-        self.sim.leave_scope()
+        self.sim.leave()
 
     @pairs_block
     def lower(self):
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index db46e6a..70f11b0 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -3,6 +3,7 @@ from pairs.ir.block import Block, KernelBlock
 from pairs.ir.branches import Filter
 from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector
 from pairs.ir.layouts import Layout_AoS
+from pairs.ir.module import Module
 from pairs.ir.properties import Properties
 from pairs.ir.symbols import Symbol
 from pairs.ir.variables import Variables
@@ -20,12 +21,14 @@ from pairs.sim.timestep import Timestep
 from pairs.sim.variables import VariablesDecl
 from pairs.sim.vtk import VTKWrite
 from pairs.transformations.add_device_copies import add_device_copies
+from pairs.transformations.fetch_modules_references import fetch_modules_references
 from pairs.transformations.prioritize_scalar_ops import prioritize_scalar_ops
 from pairs.transformations.set_used_bin_ops import set_used_bin_ops
 from pairs.transformations.simplify import simplify_expressions
 from pairs.transformations.LICM import move_loop_invariant_code
 from pairs.transformations.lower import lower_everything
 from pairs.transformations.merge_adjacent_blocks import merge_adjacent_blocks
+from pairs.transformations.replace_modules_by_calls import replace_modules_by_calls
 from pairs.transformations.replace_symbols import replace_symbols
 
 
@@ -33,7 +36,6 @@ class Simulation:
     def __init__(self, code_gen, dims=3, timesteps=100):
         self.code_gen = code_gen
         self.code_gen.assign_simulation(self)
-        self.global_scope = None
         self.position_prop = None
         self.properties = Properties(self)
         self.vars = Variables(self)
@@ -52,6 +54,7 @@ class Simulation:
         self.block = Block(self, [])
         self.setups = Block(self, [])
         self.kernels = Block(self, [])
+        self.module_list = []
         self.dims = dims
         self.ntimesteps = timesteps
         self.expr_id = 0
@@ -61,6 +64,13 @@ class Simulation:
         self.nparticles = self.nlocal + self.nghost
         self.properties.add_capacity(self.particle_capacity)
 
+    def add_module(self, module):
+        assert isinstance(module, Module), "add_module(): Given parameter is not of type Module!"
+        self.module_list.append(module)
+
+    def modules(self):
+        return self.module_list
+
     def ndims(self):
         return self.dims
 
@@ -157,7 +167,7 @@ class Simulation:
         self.block = Block(self, [])
 
     def build_kernel_block_with_statements(self):
-        self.kernels.add_statement(KernelBlock(self, self.block))
+        self.kernels.add_statement(Module(self, block=KernelBlock(self, self.block)))
 
     def add_statement(self, stmt):
         if not self.scope:
@@ -175,10 +185,10 @@ class Simulation:
         for _ in range(0, self.nested_count):
             self.scope.pop()
 
-    def enter_scope(self, scope):
+    def enter(self, scope):
         self.scope.append(scope)
 
-    def leave_scope(self):
+    def leave(self):
         if not self.nest:
             self.scope.pop()
         else:
@@ -212,8 +222,7 @@ class Simulation:
             PropertiesAlloc(self),
         ])
 
-        program = Block.merge_blocks(decls, body)
-        self.global_scope = program
+        program = Module(self, name='main', block=Block.merge_blocks(decls, body))
 
         # Transformations
         lower_everything(program)
@@ -222,8 +231,10 @@ class Simulation:
         prioritize_scalar_ops(program)
         simplify_expressions(program)
         move_loop_invariant_code(program)
+        fetch_modules_references(program)
         set_used_bin_ops(program)
         add_device_copies(program)
+        replace_modules_by_calls(program)
 
         # For this part on, all bin ops are generated without usage verification
         self.check_decl_usage = False
diff --git a/src/pairs/transformations/fetch_modules_references.py b/src/pairs/transformations/fetch_modules_references.py
new file mode 100644
index 0000000..534ddfb
--- /dev/null
+++ b/src/pairs/transformations/fetch_modules_references.py
@@ -0,0 +1,62 @@
+from pairs.ir.module import Module
+from pairs.ir.mutator import Mutator
+from pairs.ir.variables import Deref
+from pairs.ir.visitor import Visitor
+
+
+class FetchModulesReferences(Visitor):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.module_stack = []
+        self.writing = False
+
+    def visit_Assign(self, ast_node):
+        self.writing = True
+        for c in ast_node.destinations():
+            self.visit(c)
+
+        self.writing = False
+        for c in ast_node.sources():
+            self.visit(c)
+
+    def visit_Module(self, ast_node):
+        self.module_stack.append(ast_node)
+        self.visit_children(ast_node)
+        self.module_stack.pop()
+
+    def visit_Array(self, ast_node):
+        for m in self.module_stack:
+            m.add_array(ast_node)
+
+    def visit_Property(self, ast_node):
+        for m in self.module_stack:
+            m.add_property(ast_node)
+
+    def visit_Var(self, ast_node):
+        for m in self.module_stack:
+            m.add_variable(ast_node, self.writing)
+
+
+class AddDereferencesToWriteVariables(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+        self.module_stack = []
+
+    def mutate_Module(self, ast_node):
+        self.module_stack.append(ast_node)
+        ast_node._block = self.mutate(ast_node._block)
+        self.module_stack.pop()
+        return ast_node
+
+    def mutate_Var(self, ast_node):
+        if ast_node in self.module_stack[-1].write_variables():
+            return Deref(ast_node.sim, ast_node)
+
+        return ast_node
+
+
+def fetch_modules_references(ast):
+    fetch_refs = FetchModulesReferences(ast)
+    fetch_refs.visit()
+    add_derefs_to_write_vars = AddDereferencesToWriteVariables(ast)
+    add_derefs_to_write_vars.mutate()
diff --git a/src/pairs/transformations/replace_modules_by_calls.py b/src/pairs/transformations/replace_modules_by_calls.py
new file mode 100644
index 0000000..4b2e6c7
--- /dev/null
+++ b/src/pairs/transformations/replace_modules_by_calls.py
@@ -0,0 +1,15 @@
+from pairs.ir.module import Module_Call
+from pairs.ir.mutator import Mutator
+
+
+class ReplaceModulesByCalls(Mutator):
+    def __init__(self, ast):
+        super().__init__(ast)
+
+    def mutate_Module(self, ast_node):
+        return Module_Call(ast_node.sim, ast_node)
+
+
+def replace_modules_by_calls(ast):
+    replace = ReplaceModulesByCalls(ast)
+    replace.mutate()
-- 
GitLab