Skip to content
Snippets Groups Projects
Commit 3805b45e authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Add variables declaration

parent ff03a25a
Branches
Tags
No related merge requests found
...@@ -8,20 +8,24 @@ class Variables: ...@@ -8,20 +8,24 @@ class Variables:
self.vars = [] self.vars = []
self.nvars = 0 self.nvars = 0
def add(self, v_name, v_type): def add(self, v_name, v_type, v_value=0):
v = Var(self.sim, v_name, v_type) v = Var(self.sim, v_name, v_type)
self.vars.append(v) self.vars.append(v)
return v return v
def all(self):
return self.vars
def find(self, v_name): def find(self, v_name):
return [v for v in self.vars if v.name() == v_name][0] return [v for v in self.vars if v.name() == v_name][0]
class Var: class Var:
def __init__(self, sim, var_name, var_type): def __init__(self, sim, var_name, var_type, init_value=0):
self.sim = sim self.sim = sim
self.var_name = var_name self.var_name = var_name
self.var_type = var_type self.var_type = var_type
self.var_init_value = init_value
def __str__(self): def __str__(self):
return f"Var<name: {self.var_name}, type: {self.var_type}>" return f"Var<name: {self.var_name}, type: {self.var_type}>"
...@@ -68,6 +72,9 @@ class Var: ...@@ -68,6 +72,9 @@ class Var:
def type(self): def type(self):
return self.var_type return self.var_type
def init_value(self):
return self.var_init_value
def scope(self): def scope(self):
return self.sim.global_scope return self.sim.global_scope
...@@ -79,3 +86,20 @@ class Var: ...@@ -79,3 +86,20 @@ class Var:
def transform(self, fn): def transform(self, fn):
return fn(self) return fn(self)
class VarDecl:
def __init__(self, sim, var):
self.sim = sim
self.var = var
self.sim.add_statement(self)
def children(self):
return []
def generate(self, mem=False):
self.sim.code_gen.generate_var_decl(
self.var.name(), self.var.type(), self.var.init_value())
def transform(self, fn):
return fn(self)
...@@ -111,3 +111,8 @@ class CGen: ...@@ -111,3 +111,8 @@ class CGen:
def generate_inline_expr(lhs, rhs, op): def generate_inline_expr(lhs, rhs, op):
return f"{lhs} {op} {rhs}" return f"{lhs} {op} {rhs}"
def generate_var_decl(v_name, v_type, v_value):
t = ('double' if v_type == Type_Float
else 'int' if v_type == Type_Int else 'bool')
printer.print(f"{t} {v_name} = {v_value};")
...@@ -2,7 +2,6 @@ from ast.branches import Branch, Filter ...@@ -2,7 +2,6 @@ from ast.branches import Branch, Filter
from ast.cast import Cast from ast.cast import Cast
from ast.data_types import Type_Int from ast.data_types import Type_Int
from ast.expr import Expr from ast.expr import Expr
from ast.layouts import Layout_SoA
from ast.loops import For, ParticleFor from ast.loops import For, ParticleFor
from functools import reduce from functools import reduce
from sim.resize import Resize from sim.resize import Resize
......
...@@ -13,6 +13,7 @@ from sim.lattice import ParticleLattice ...@@ -13,6 +13,7 @@ from sim.lattice import ParticleLattice
from sim.properties import PropertiesDecl, PropertiesResetVolatile from sim.properties import PropertiesDecl, PropertiesResetVolatile
from sim.setup_wrapper import SetupWrapper from sim.setup_wrapper import SetupWrapper
from sim.timestep import Timestep from sim.timestep import Timestep
from sim.variables import VariablesDecl
class ParticleSimulation: class ParticleSimulation:
...@@ -53,8 +54,8 @@ class ParticleSimulation: ...@@ -53,8 +54,8 @@ class ParticleSimulation:
def array(self, arr_name): def array(self, arr_name):
return self.arrays.find(arr_name) return self.arrays.find(arr_name)
def add_var(self, var_name, var_type): def add_var(self, var_name, var_type, init_value=0):
return self.vars.add(var_name, var_type) return self.vars.add(var_name, var_type, init_value)
def var(self, var_name): def var(self, var_name):
return self.vars.find(var_name) return self.vars.find(var_name)
...@@ -119,6 +120,7 @@ class ParticleSimulation: ...@@ -119,6 +120,7 @@ class ParticleSimulation:
def generate(self): def generate(self):
program = Block.from_list(self, [ program = Block.from_list(self, [
VariablesDecl(self).lower(),
PropertiesDecl(self).lower(), PropertiesDecl(self).lower(),
CellListsStencilBuild(self, self.cell_lists).lower(), CellListsStencilBuild(self, self.cell_lists).lower(),
self.setups.lower(), self.setups.lower(),
......
from ast.variables import VarDecl
class VariablesDecl:
def __init__(self, sim):
self.sim = sim
def lower(self):
self.sim.clear_block()
for v in self.sim.vars.all():
VarDecl(self.sim, v)
return self.sim.block
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment