From e8fd8409cb7c887e2b39a003b0ee355964534abb Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafael.r.ravedutti@fau.de>
Date: Mon, 20 May 2024 20:13:42 +0200
Subject: [PATCH] Fix bug with analysis and transformations in modular version

Signed-off-by: Rafael Ravedutti <rafael.r.ravedutti@fau.de>
---
 CMakeLists.txt                        |  1 +
 examples/md.py                        |  7 ++++---
 src/pairs/analysis/__init__.py        |  9 ++++++---
 src/pairs/sim/simulation.py           | 12 +++++-------
 src/pairs/transformations/__init__.py | 19 ++++++++++++-------
 5 files changed, 28 insertions(+), 20 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index bd6c307..3fce218 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -48,6 +48,7 @@ set(RUNTIME_COMMON_FILES
     runtime/stats.cpp
     runtime/thermo.cpp
     runtime/timing.cpp
+    runtime/vtk.cpp
     runtime/domain/block_forest.cpp
     runtime/domain/regular_6d_stencil.cpp)
 
diff --git a/examples/md.py b/examples/md.py
index 828321c..d06d664 100644
--- a/examples/md.py
+++ b/examples/md.py
@@ -36,6 +36,7 @@ rho = 0.8442
 temp = 1.44
 
 psim = pairs.simulation("md", [pairs.point_mass()], timesteps=200, double_prec=True, debug=True)
+#psim = pairs.simulation("md", [pairs.point_mass()], timesteps=200, double_prec=True, debug=True, generate_whole_program=True)
 
 if target == 'gpu':
     psim.target(pairs.target_gpu())
@@ -51,14 +52,14 @@ psim.add_feature_property('type', 'epsilon', pairs.real(), [sigma for i in range
 psim.add_feature_property('type', 'sigma6', pairs.real(), [epsilon for i in range(ntypes * ntypes)])
 
 psim.copper_fcc_lattice(nx, ny, nz, rho, temp, ntypes)
-psim.set_domain_partitioner(pairs.block_forest())
-#psim.set_domain_partitioner(pairs.regular_domain_partitioner())
+#psim.set_domain_partitioner(pairs.block_forest())
+psim.set_domain_partitioner(pairs.regular_domain_partitioner())
 psim.compute_thermo(100)
 
 psim.reneighbor_every(20)
 #psim.compute_half()
 psim.build_neighbor_lists(cutoff_radius + skin)
-#psim.vtk_output(f"output/md_{target}")
+psim.vtk_output(f"output/md_{target}")
 
 psim.compute(initial_integrate, symbols={'dt': dt}, pre_step=True, skip_first=True)
 psim.compute(lennard_jones, cutoff_radius)
diff --git a/src/pairs/analysis/__init__.py b/src/pairs/analysis/__init__.py
index 99e7a80..ba2204c 100644
--- a/src/pairs/analysis/__init__.py
+++ b/src/pairs/analysis/__init__.py
@@ -9,13 +9,16 @@ class Analysis:
     """Compiler analysis performed on P4IRS"""
 
     def __init__(self, ast):
-        self._ast = ast
+        self._ast_list = ast if isinstance(ast, list) else [ast]
 
     def apply(self, analysis):
         print(f"Performing analysis: {type(analysis).__name__}... ", end="")
         start = time.time()
-        analysis.set_ast(self._ast)
-        analysis.visit()
+
+        for ast in self._ast_list:
+            analysis.set_ast(ast)
+            analysis.visit()
+
         elapsed = time.time() - start
         print(f"{elapsed:.2f}s elapsed.")
 
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 5992b4c..b0581e3 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -4,7 +4,7 @@ from pairs.ir.branches import Filter
 from pairs.ir.features import Features, FeatureProperties
 from pairs.ir.kernel import Kernel
 from pairs.ir.layouts import Layouts
-from pairs.ir.module import Module
+from pairs.ir.module import Module, ModuleCall
 from pairs.ir.properties import Properties, ContactProperties
 from pairs.ir.symbols import Symbol
 from pairs.ir.types import Types
@@ -164,7 +164,7 @@ class Simulation:
             self.module_list.append(module)
 
     def modules(self):
-        """List simulation modudles, with main always in the last position"""
+        """List simulation modules, with main always in the last position"""
 
         sorted_mods = []
         main_mod = None
@@ -536,12 +536,10 @@ class Simulation:
                 ]))
 
             initialize_module = Module(self, name='initialize', block=all_setups)
-            initialize_transformations = Transformations(initialize_module, self._target)
-            initialize_transformations.apply_all()
-
             do_timestep_module = Module(self, name='do_timestep', block=timestep.as_block())
-            do_timestep_transformations = Transformations(do_timestep_module, self._target)
-            do_timestep_transformations.apply_all()
+
+            transformations = Transformations([initialize_module, do_timestep_module], self._target)
+            transformations.apply_all()
 
             # Generate library
             self.code_gen.generate_library(initialize_module, do_timestep_module)
diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py
index 7d5cab5..e788f35 100644
--- a/src/pairs/transformations/__init__.py
+++ b/src/pairs/transformations/__init__.py
@@ -10,24 +10,29 @@ from pairs.transformations.modules import DereferenceWriteVariables, AddResizeLo
 
 
 class Transformations:
-    def __init__(self, ast, target):
-        self._ast = ast
+    def __init__(self, ast_list, target):
+        self._ast_list = ast_list if isinstance(ast_list, list) else [ast_list]
         self._target = target
         self._module_resizes = None
 
     def apply(self, transformation, data=None):
         print(f"Applying transformation: {type(transformation).__name__}... ", end="")
         start = time.time()
-        transformation.set_ast(self._ast)
-        if data is not None:
-            transformation.set_data(data)
 
-        self._ast = transformation.mutate()
+        new_ast_list = []
+        for ast in self._ast_list:
+            transformation.set_ast(ast)
+            if data is not None:
+                transformation.set_data(data)
+
+            new_ast_list.append(transformation.mutate())
+
+        self._ast_list = new_ast_list
         elapsed = time.time() - start
         print(f"{elapsed:.2f}s elapsed.")
 
     def analysis(self):
-        return Analysis(self._ast)
+        return Analysis(self._ast_list)
 
     def lower(self, lower_finals=False):
         nlowered = 1
-- 
GitLab