From 232254d1a4f26b15fa8561e89653d1b4e3ab8dd3 Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Fri, 6 Aug 2021 17:22:34 +0200 Subject: [PATCH] Fix results with runtime version Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- code_gen/cgen.py | 10 ++++++++-- ir/cast.py | 4 ++-- ir/math.py | 23 +++++++++++++++++++++-- sim/cell_lists.py | 13 +++++-------- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/code_gen/cgen.py b/code_gen/cgen.py index d29be3a..3a00f31 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -9,7 +9,7 @@ from ir.functions import Call from ir.layouts import Layout_AoS, Layout_SoA, Layout_Invalid from ir.lit import Lit from ir.loops import For, Iter, ParticleFor, While -from ir.math import Sqrt +from ir.math import Ceil, Sqrt from ir.memory import Malloc, Realloc from ir.properties import Property, PropertyList, RegisterProperty from ir.select import Select @@ -40,9 +40,10 @@ class CGen: def generate_program(self, ast_node): self.print.start() + self.print("#include <math.h>") + self.print("#include <stdbool.h>") self.print("#include <stdio.h>") self.print("#include <stdlib.h>") - self.print("#include <stdbool.h>") self.print("//---") self.print("#include \"runtime/pairs.hpp\"") self.print("#include \"runtime/read_from_file.hpp\"") @@ -247,6 +248,11 @@ class CGen: expr = self.generate_expression(ast_node.expr) return f"({tkw})({expr})" + if isinstance(ast_node, Ceil): + assert mem is False, "Ceil call is not lvalue!" + expr = self.generate_expression(ast_node.expr) + return f"ceil({expr})" + if isinstance(ast_node, Iter): assert mem is False, "Iterator is not lvalue!" return f"i{ast_node.id()}" diff --git a/ir/cast.py b/ir/cast.py index b85a596..7207b10 100644 --- a/ir/cast.py +++ b/ir/cast.py @@ -1,8 +1,8 @@ -from ir.ast_node import ASTNode +from ir.bin_op import ASTTerm from ir.data_types import Type_Int, Type_Float -class Cast(ASTNode): +class Cast(ASTTerm): def __init__(self, sim, expr, cast_type): super().__init__(sim) self.expr = expr diff --git a/ir/math.py b/ir/math.py index 034f4fa..93d6269 100644 --- a/ir/math.py +++ b/ir/math.py @@ -1,8 +1,8 @@ -from ir.ast_node import ASTNode +from ir.bin_op import ASTTerm from ir.data_types import Type_Int, Type_Float -class Sqrt(ASTNode): +class Sqrt(ASTTerm): def __init__(self, sim, expr, cast_type): super().__init__(sim) self.expr = expr @@ -18,3 +18,22 @@ class Sqrt(ASTNode): def children(self): return [self.expr] + + +class Ceil(ASTTerm): + def __init__(self, sim, expr): + assert expr.type() == Type_Float, "Expression must be of floating-point type!" + super().__init__(sim) + self.expr = expr + + def __str__(self): + return f"Ceil<expr: {self.expr}>" + + def type(self): + return Type_Int + + def scope(self): + return self.expr.scope() + + def children(self): + return [self.expr] diff --git a/sim/cell_lists.py b/sim/cell_lists.py index 76d7224..40ec518 100644 --- a/sim/cell_lists.py +++ b/sim/cell_lists.py @@ -2,6 +2,7 @@ from ir.bin_op import BinOp from ir.branches import Branch, Filter from ir.cast import Cast from ir.data_types import Type_Int +from ir.math import Ceil from ir.loops import For, ParticleFor from ir.utils import Print from functools import reduce @@ -13,13 +14,9 @@ class CellLists: def __init__(self, sim, grid, spacing, cutoff_radius): self.sim = sim self.grid = grid - self.spacing = spacing + self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())] self.cutoff_radius = cutoff_radius - - self.nneighbor_cells = [ - math.ceil(cutoff_radius / (spacing if not isinstance(spacing, list) else spacing[d])) for d in range(sim.ndims()) - ] - + self.nneighbor_cells = [math.ceil(cutoff_radius / self.spacing[d]) for d in range(sim.ndims())] self.nstencil = self.sim.add_var('nstencil', Type_Int) self.nstencil_max = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(sim.ndims())]) self.ncells = self.sim.add_var('ncells', Type_Int, 1) @@ -46,7 +43,7 @@ class CellListsStencilBuild: cl.sim.add_statement(Print(cl.sim, "CellListsStencilBuild")) for d in range(cl.sim.ndims()): - cl.dim_ncells[d].set(Cast.int(cl.sim, (grid.max(d) - grid.min(d)) / cl.spacing)) + cl.dim_ncells[d].set(Ceil(cl.sim, (grid.max(d) - grid.min(d)) / cl.spacing[d]) + 2) nall *= cl.dim_ncells[d] cl.ncells.set(nall) @@ -84,7 +81,7 @@ class CellListsBuild: for i in ParticleFor(cl.sim, local_only=False): cell_index = [ - Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / cl.spacing) + Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / cl.spacing[d]) for d in range(0, cl.sim.ndims())] flat_idx = None -- GitLab