Skip to content
Snippets Groups Projects
Commit 232254d1 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Fix results with runtime version

parent ae45fc9b
Branches
Tags
No related merge requests found
......@@ -9,7 +9,7 @@ from ir.functions import Call
from ir.layouts import Layout_AoS, Layout_SoA, Layout_Invalid
from ir.lit import Lit
from ir.loops import For, Iter, ParticleFor, While
from ir.math import Sqrt
from ir.math import Ceil, Sqrt
from ir.memory import Malloc, Realloc
from ir.properties import Property, PropertyList, RegisterProperty
from ir.select import Select
......@@ -40,9 +40,10 @@ class CGen:
def generate_program(self, ast_node):
self.print.start()
self.print("#include <math.h>")
self.print("#include <stdbool.h>")
self.print("#include <stdio.h>")
self.print("#include <stdlib.h>")
self.print("#include <stdbool.h>")
self.print("//---")
self.print("#include \"runtime/pairs.hpp\"")
self.print("#include \"runtime/read_from_file.hpp\"")
......@@ -247,6 +248,11 @@ class CGen:
expr = self.generate_expression(ast_node.expr)
return f"({tkw})({expr})"
if isinstance(ast_node, Ceil):
assert mem is False, "Ceil call is not lvalue!"
expr = self.generate_expression(ast_node.expr)
return f"ceil({expr})"
if isinstance(ast_node, Iter):
assert mem is False, "Iterator is not lvalue!"
return f"i{ast_node.id()}"
......
from ir.ast_node import ASTNode
from ir.bin_op import ASTTerm
from ir.data_types import Type_Int, Type_Float
class Cast(ASTNode):
class Cast(ASTTerm):
def __init__(self, sim, expr, cast_type):
super().__init__(sim)
self.expr = expr
......
from ir.ast_node import ASTNode
from ir.bin_op import ASTTerm
from ir.data_types import Type_Int, Type_Float
class Sqrt(ASTNode):
class Sqrt(ASTTerm):
def __init__(self, sim, expr, cast_type):
super().__init__(sim)
self.expr = expr
......@@ -18,3 +18,22 @@ class Sqrt(ASTNode):
def children(self):
return [self.expr]
class Ceil(ASTTerm):
def __init__(self, sim, expr):
assert expr.type() == Type_Float, "Expression must be of floating-point type!"
super().__init__(sim)
self.expr = expr
def __str__(self):
return f"Ceil<expr: {self.expr}>"
def type(self):
return Type_Int
def scope(self):
return self.expr.scope()
def children(self):
return [self.expr]
......@@ -2,6 +2,7 @@ from ir.bin_op import BinOp
from ir.branches import Branch, Filter
from ir.cast import Cast
from ir.data_types import Type_Int
from ir.math import Ceil
from ir.loops import For, ParticleFor
from ir.utils import Print
from functools import reduce
......@@ -13,13 +14,9 @@ class CellLists:
def __init__(self, sim, grid, spacing, cutoff_radius):
self.sim = sim
self.grid = grid
self.spacing = spacing
self.spacing = spacing if isinstance(spacing, list) else [spacing for d in range(sim.ndims())]
self.cutoff_radius = cutoff_radius
self.nneighbor_cells = [
math.ceil(cutoff_radius / (spacing if not isinstance(spacing, list) else spacing[d])) for d in range(sim.ndims())
]
self.nneighbor_cells = [math.ceil(cutoff_radius / self.spacing[d]) for d in range(sim.ndims())]
self.nstencil = self.sim.add_var('nstencil', Type_Int)
self.nstencil_max = reduce((lambda x, y: x * y), [self.nneighbor_cells[d] * 2 + 1 for d in range(sim.ndims())])
self.ncells = self.sim.add_var('ncells', Type_Int, 1)
......@@ -46,7 +43,7 @@ class CellListsStencilBuild:
cl.sim.add_statement(Print(cl.sim, "CellListsStencilBuild"))
for d in range(cl.sim.ndims()):
cl.dim_ncells[d].set(Cast.int(cl.sim, (grid.max(d) - grid.min(d)) / cl.spacing))
cl.dim_ncells[d].set(Ceil(cl.sim, (grid.max(d) - grid.min(d)) / cl.spacing[d]) + 2)
nall *= cl.dim_ncells[d]
cl.ncells.set(nall)
......@@ -84,7 +81,7 @@ class CellListsBuild:
for i in ParticleFor(cl.sim, local_only=False):
cell_index = [
Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / cl.spacing)
Cast.int(cl.sim, (positions[i][d] - grid.min(d)) / cl.spacing[d])
for d in range(0, cl.sim.ndims())]
flat_idx = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment