From 232254d1a4f26b15fa8561e89653d1b4e3ab8dd3 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 6 Aug 2021 17:22:34 +0200
Subject: [PATCH] Fix results with runtime version

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 code_gen/cgen.py  | 10 ++++++++--
 ir/cast.py        |  4 ++--
 ir/math.py        | 23 +++++++++++++++++++++--
 sim/cell_lists.py | 13 +++++--------
 4 files changed, 36 insertions(+), 14 deletions(-)

diff --git a/code_gen/cgen.py b/code_gen/cgen.py
index d29be3a..3a00f31 100644
--- a/code_gen/cgen.py
+++ b/code_gen/cgen.py
@@ -9,7 +9,7 @@ from ir.functions import Call
 from ir.layouts import Layout_AoS, Layout_SoA, Layout_Invalid
 from ir.lit import Lit
 from ir.loops import For, Iter, ParticleFor, While
-from ir.math import Sqrt
+from ir.math import Ceil, Sqrt
 from ir.memory import Malloc, Realloc
 from ir.properties import Property, PropertyList, RegisterProperty
 from ir.select import Select
@@ -40,9 +40,10 @@ class CGen:
 
     def generate_program(self, ast_node):
         self.print.start()
+        self.print("#include <math.h>")
+        self.print("#include <stdbool.h>")
         self.print("#include <stdio.h>")
         self.print("#include <stdlib.h>")
-        self.print("#include <stdbool.h>")
         self.print("//---")
         self.print("#include \"runtime/pairs.hpp\"")
         self.print("#include \"runtime/read_from_file.hpp\"")
@@ -247,6 +248,11 @@ class CGen:
             expr = self.generate_expression(ast_node.expr)
             return f"({tkw})({expr})"
 
+        if isinstance(ast_node, Ceil):
+            assert mem is False, "Ceil call is not lvalue!"
+            expr = self.generate_expression(ast_node.expr)
+            return f"ceil({expr})"
+
         if isinstance(ast_node, Iter):
             assert mem is False, "Iterator is not lvalue!"
             return f"i{ast_node.id()}"
diff --git a/ir/cast.py b/ir/cast.py
index b85a596..7207b10 100644
--- a/ir/cast.py
+++ b/ir/cast.py
@@ -1,8 +1,8 @@
-from ir.ast_node import ASTNode
+from ir.bin_op import ASTTerm
 from ir.data_types import Type_Int, Type_Float
 
 
-class Cast(ASTNode):
+class Cast(ASTTerm):
     def __init__(self, sim, expr, cast_type):
         super().__init__(sim)
         self.expr = expr
diff --git a/ir/math.py b/ir/math.py
index 034f4fa..93d6269 100644
--- a/ir/math.py
+++ b/ir/math.py
@@ -1,8 +1,8 @@
-from ir.ast_node import ASTNode
+from ir.bin_op import ASTTerm
 from ir.data_types import Type_Int, Type_Float
 
 
-class Sqrt(ASTNode):
+class Sqrt(ASTTerm):
     def __init__(self, sim, expr, cast_type):
         super().__init__(sim)
         self.expr = expr
@@ -18,3 +18,22 @@ class Sqrt(ASTNode):
 
     def children(self):
         return [self.expr]
+
+
+class Ceil(ASTTerm):
+    def __init__(self, sim, expr):
+        assert expr.type() == Type_Float, "Expression must be of floating-point type!"
+        super().__init__(sim)
+        self.expr = expr
+
+    def __str__(self):
+        return f"Ceil<expr: {self.expr}>"
+
+    def type(self):
+        return Type_Int
+
+    def scope(self):
+        return self.expr.scope()
+
+    def children(self):
+        return [self.expr]
diff --git a/sim/cell_lists.py b/sim/cell_lists.py
index 76d7224..40ec518 100644
--- a/sim/cell_lists.py
+++ b/sim/cell_lists.py
@@ -2,6 +2,7 @@ from ir.bin_op import BinOp
 from ir.branches import Branch, Filter
 from ir.cast import Cast
 from ir.data_types import Type_Int
+from ir.math import Ceil
 from ir.loops import For, ParticleFor
 from ir.utils import Print
 from functools import reduce
@@ -13,13 +14,9 @@ class CellLists:
     def __init__(self, sim, grid, spacing, cutoff_radius):
         self.sim = sim
         self.grid = grid
-        self.spacing = spacing
+        self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())]
         self.cutoff_radius = cutoff_radius
-
-        self.nneighbor_cells = [
-            math.ceil(cutoff_radius / (spacing if not isinstance(spacing, list) else spacing[d])) for d in range(sim.ndims())
-        ]
-
+        self.nneighbor_cells = [math.ceil(cutoff_radius / self.spacing[d]) for d in range(sim.ndims())]
         self.nstencil = self.sim.add_var('nstencil', Type_Int)
         self.nstencil_max = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(sim.ndims())])
         self.ncells = self.sim.add_var('ncells', Type_Int, 1)
@@ -46,7 +43,7 @@ class CellListsStencilBuild:
         cl.sim.add_statement(Print(cl.sim, "CellListsStencilBuild"))
 
         for d in range(cl.sim.ndims()):
-            cl.dim_ncells[d].set(Cast.int(cl.sim, (grid.max(d) - grid.min(d)) / cl.spacing))
+            cl.dim_ncells[d].set(Ceil(cl.sim, (grid.max(d) - grid.min(d)) / cl.spacing[d]) + 2)
             nall *= cl.dim_ncells[d]
 
         cl.ncells.set(nall)
@@ -84,7 +81,7 @@ class CellListsBuild:
 
             for i in ParticleFor(cl.sim, local_only=False):
                 cell_index = [
-                    Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / cl.spacing)
+                    Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / cl.spacing[d])
                     for d in range(0, cl.sim.ndims())]
 
                 flat_idx = None
-- 
GitLab