diff --git a/ast/arrays.py b/ast/arrays.py index c8019352fa1b7d6e5f1be5f3027373dcd91f461a..56f2f03a1b1081fc0f594e87bc689d5c7011d617 100644 --- a/ast/arrays.py +++ b/ast/arrays.py @@ -80,9 +80,6 @@ class Array: def children(self): return [] - def generate(self, mem=False, index=None): - return self.arr_name - def transform(self, fn): return fn(self) @@ -110,7 +107,7 @@ class ArrayND(Array): f"type: {self.arr_type}>") def realloc(self): - return Realloc(self.sim, self, self.arr_type, self.alloc_size()) + return Realloc(self.sim, self, self.alloc_size()) class ArrayAccess: @@ -174,6 +171,9 @@ class ArrayAccess: def add(self, other): return self.sim.add_statement(Assign(self.sim, self, self + other)) + def id(self): + return self.acc_id + def type(self): return self.array.type() # return self.array.type() if self.index is None else Type_Array @@ -196,15 +196,6 @@ class ArrayAccess: def children(self): return [self.array] + self.indexes - def generate(self, mem=False, index=None): - agen = self.array.generate() - igen = self.index.generate() - if mem is False and self.generated is False: - self.sim.code_gen.generate_array_access(self.acc_id, self.array.type(), agen, igen) - self.generated = True - - return self.sim.code_gen.generate_array_access_ref(self.acc_id, agen, igen, mem) - def transform(self, fn): self.array = self.array.transform(fn) self.indexes = [i.transform(fn) for i in self.indexes] @@ -220,8 +211,5 @@ class ArrayDecl: def children(self): return [] - def generate(self, mem=False, index=None): - self.sim.code_gen.generate_array_decl(self.array.name(), self.array.type(), BinOp.inline(self.array.alloc_size()).generate()) - def transform(self, fn): return fn(self) diff --git a/ast/assign.py b/ast/assign.py index 98b78b16c03da1a8884d3ed4f10c4da97609ba0a..59f11e4e35c7bcfaaf4c1aa2a1d0d6f074068ae4 100644 --- a/ast/assign.py +++ b/ast/assign.py @@ -8,7 +8,6 @@ class Assign: self.sim = sim self.parent_block = None self.type = dest.type() - self.generated = False src = as_lit_ast(sim, src) if dest.type() == Type_Vector: @@ -32,15 +31,6 @@ class Assign: [self.assignments[i][0], self.assignments[i][1]] for i in range(0, len(self.assignments))]) - def generate(self): - if self.generated is False: - for dest, src in self.assignments: - d = dest.generate(True) - s = src.generate() - self.sim.code_gen.generate_assignment(d, s) - - self.generated = True - def transform(self, fn): self.assignments = [( self.assignments[i][0].transform(fn), diff --git a/ast/block.py b/ast/block.py index de9e32bd9967ce26244aa4f3dc2f448ad5a85726..4ed52ee9c8c38fac3f188683476dfb8fd75a54ad 100644 --- a/ast/block.py +++ b/ast/block.py @@ -50,17 +50,6 @@ class Block: def children(self): return self.stmts - def generate(self): - self.sim.code_gen.generate_block_preamble() - - for expr in self.expressions: - expr.generate() - - for stmt in self.stmts: - stmt.generate() - - self.sim.code_gen.generate_block_epilogue() - def transform(self, fn): for i in range(0, len(self.stmts)): self.stmts[i] = self.stmts[i].transform(fn) diff --git a/ast/branches.py b/ast/branches.py index aa9dcd5180fece375f7141b43fc29d58160bc8ab..a3ff5f1edf5f079b102c647e6e254567a98fe69f 100644 --- a/ast/branches.py +++ b/ast/branches.py @@ -36,17 +36,6 @@ class Branch: return [self.cond, self.block_if] + \ ([] if self.block_else is None else [self.block_else]) - def generate(self): - cond_gen = self.cond.generate() - self.sim.code_gen.generate_if(cond_gen) - self.block_if.generate() - - if self.block_else is not None: - self.sim.code_gen.generate_else() - self.block_else.generate() - - self.sim.code_gen.generate_endif() - def transform(self, fn): self.cond = self.cond.transform(fn) self.block_if = self.block_if.transform(fn) diff --git a/ast/cast.py b/ast/cast.py index aa101806675efb33e99c8daffd3fbcd31a0ce858..8036676441d3e652af15a3e6caf7b2ddb88b520e 100644 --- a/ast/cast.py +++ b/ast/cast.py @@ -28,9 +28,6 @@ class Cast: def children(self): return [self.expr] - def generate(self, mem=False, index=None): - return self.sim.code_gen.generate_cast(self.cast_type, self.expr.generate()) - def transform(self, fn): self.expr = self.expr.transform(fn) return fn(self) diff --git a/ast/expr.py b/ast/expr.py index 723759f9c2bd65fbfcf09b43f9b0ac642d39a2e2..371ca0e909305fdd7200783cd4bec36ab0905380 100644 --- a/ast/expr.py +++ b/ast/expr.py @@ -15,29 +15,6 @@ class BinOpDef: def children(self): return [self.bin_op] - def generate(self, mem=False): - bin_op = self.bin_op - - if not isinstance(bin_op, BinOp): - return None - - if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False: - if bin_op.kind() == BinOp.Kind_Scalar: - lhs = bin_op.lhs.generate(bin_op.mem) - rhs = bin_op.rhs.generate() - bin_op.sim.code_gen.generate_expr(bin_op.id(), bin_op.type(), lhs, rhs, bin_op.op) - - elif bin_op.kind() == BinOp.Kind_Vector: - for i in bin_op.vector_indexes(): - lhs = bin_op.lhs.generate(bin_op.mem, index=i) - rhs = bin_op.rhs.generate(index=i) - bin_op.sim.code_gen.generate_vec_expr(bin_op.id(), i, lhs, rhs, bin_op.operator(), bin_op.mem) - - else: - raise Exception("Invalid BinOp kind!") - - bin_op.generated = True - def transform(self, fn): self.bin_op = self.bin_op.transform(fn) return fn(self) @@ -224,6 +201,9 @@ class BinOp: def type(self): return self.bin_op_type + def definition(self): + return self.bin_op_def + def operator(self): return self.op @@ -244,32 +224,6 @@ class BinOp: def children(self): return [self.lhs, self.rhs] - def generate(self, mem=False, index=None): - if isinstance(self.lhs, BinOp) and self.lhs.kind() == BinOp.Kind_Vector and self.op == '[]': - return self.lhs.generate(self.mem, self.rhs.generate()) - - lhs = self.lhs.generate(mem, index) - rhs = self.rhs.generate(index=index) - - if self.op == '[]': - idx = self.mapped_vector_index(index).generate() if self.is_vector_property_access() else rhs - return self.sim.code_gen.generate_expr_access(lhs, idx, self.mem) - - if self.inlined is True: - assert self.bin_op_type != Type_Vector, "Vector operations cannot be inlined!" - return self.sim.code_gen.generate_inline_expr(lhs, rhs, self.op) - - # Some expressions can be defined on-the-fly during transformations, hence they do not have - # a definition statement, so we generate them right before usage - if not self.generated: - self.bin_op_def.generate() - - if self.kind() == BinOp.Kind_Vector: - assert index is not None, "Index must be set for vector reference!" - return self.sim.code_gen.generate_vec_expr_ref(self.id(), index, self.mem) - - return self.sim.code_gen.generate_expr_ref(self.id()) - def transform(self, fn): self.lhs = self.lhs.transform(fn) self.rhs = self.rhs.transform(fn) diff --git a/ast/lit.py b/ast/lit.py index 58b25a7c2a0dbdbab12bf8f00eab394f81f45218..b04a969b020905af961b12ed6255f5494c24f3f4 100644 --- a/ast/lit.py +++ b/ast/lit.py @@ -57,9 +57,5 @@ class Lit: def children(self): return [] - def generate(self, mem=False, index=None): - assert mem is False, "Literal is not lvalue!" - return self.value - def transform(self, fn): return fn(self) diff --git a/ast/loops.py b/ast/loops.py index bd81620473b8ea31648c32f5de0ec04906261633..96ccca0c04fa175741979bb3e02a226569e20ef5 100644 --- a/ast/loops.py +++ b/ast/loops.py @@ -17,6 +17,9 @@ class Iter(): self.loop = loop self.iter_id = Iter.new_id() + def id(self): + return self.iter_id + def type(self): return Type_Int @@ -62,10 +65,6 @@ class Iter(): def children(self): return [] - def generate(self, mem=False, index=None): - assert mem is False, "Iterator is not lvalue!" - return f"i{self.iter_id}" - def transform(self, fn): return fn(self) @@ -97,14 +96,6 @@ class For(): def children(self): return [self.iterator, self.block] - def generate(self): - it_id = self.iterator.generate() - rmin = self.min.generate() - rmax = self.max.generate() - self.sim.code_gen.generate_for_preamble(it_id, rmin, rmax) - self.block.generate() - self.sim.code_gen.generate_for_epilogue() - def transform(self, fn): self.iterator = self.iterator.transform(fn) self.block = self.block.transform(fn) @@ -119,12 +110,6 @@ class ParticleFor(For): def __str__(self): return f"ParticleFor<>" - def generate(self): - upper_range = self.sim.nlocal if self.local_only else self.sim.nlocal + self.sim.pbc.npbc - self.sim.code_gen.generate_for_preamble(self.iterator.generate(), 0, upper_range.generate()) - self.block.generate() - self.sim.code_gen.generate_for_epilogue() - class While(): def __init__(self, sim, cond, block=None): @@ -149,11 +134,6 @@ class While(): def children(self): return [self.cond, self.block] - def generate(self): - self.sim.code_gen.generate_while_preamble(self.cond.generate()) - self.block.generate() - self.sim.code_gen.generate_while_epilogue() - def transform(self, fn): self.cond = self.cond.transform(fn) self.block = self.block.transform(fn) diff --git a/ast/math.py b/ast/math.py index 656543c97c4b200eef0568f2014a43087f33495a..cc38fb84f6bfc18358259194cefeabdc1d94e65a 100644 --- a/ast/math.py +++ b/ast/math.py @@ -21,9 +21,6 @@ class Sqrt: def children(self): return [self.expr] - def generate(self, mem=False, index=None): - return self.sim.code_gen.generate_sqrt(self.expr.generate()) - def transform(self, fn): self.expr = self.expr.transform(fn) return fn(self) diff --git a/ast/memory.py b/ast/memory.py index 205bd445033f83276ecff3c87fa459424be1aaa9..bb26d4c37146a25605f3a8f6e3eaa165a926edab 100644 --- a/ast/memory.py +++ b/ast/memory.py @@ -5,22 +5,18 @@ import operator class Malloc: - def __init__(self, sim, array, a_type, sizes, decl=False): + def __init__(self, sim, array, sizes, decl=False): self.sim = sim self.parent_block = None self.array = array - self.array_type = a_type self.decl = decl - self.prim_size = Sizeof(sim, a_type) + self.prim_size = Sizeof(sim, array.type()) self.size = BinOp.inline(self.prim_size * (reduce(operator.mul, sizes) if isinstance(sizes, list) else sizes)) self.sim.add_statement(self) def children(self): return [self.array, self.size] - def generate(self, mem=False, index=None): - self.sim.code_gen.generate_malloc(self.array.generate(), self.array_type, self.size.generate(), self.decl) - def transform(self, fn): self.array = self.array.transform(fn) self.size = self.size.transform(fn) @@ -28,21 +24,17 @@ class Malloc: class Realloc: - def __init__(self, sim, array, a_type, size): + def __init__(self, sim, array, size): self.sim = sim self.parent_block = None self.array = array - self.array_type = a_type - self.prim_size = Sizeof(sim, a_type) + self.prim_size = Sizeof(sim, array.type()) self.size = BinOp.inline(self.prim_size * size) self.sim.add_statement(self) def children(self): return [self.array, self.size] - def generate(self, mem=False, index=None): - self.sim.code_gen.generate_realloc(self.array.generate(), self.array_type, self.size.generate()) - def transform(self, fn): self.array = self.array.transform(fn) self.size = self.size.transform(fn) diff --git a/ast/properties.py b/ast/properties.py index 0c6b2a44bb600e1b8ba22d3345a3234c21bd5710..bbb6465a5d550ca66d60246ba69088186d8ac631 100644 --- a/ast/properties.py +++ b/ast/properties.py @@ -77,8 +77,5 @@ class Property: def children(self): return [] - def generate(self, mem=False, index=None): - return self.prop_name - def transform(self, fn): return fn(self) diff --git a/ast/select.py b/ast/select.py index 8ba489c307459b8eb5a7696f875d34597d967934..bdb154660cab9456f9573ed6bc977b19dadef41b 100644 --- a/ast/select.py +++ b/ast/select.py @@ -12,9 +12,6 @@ class Select: def children(self): return [self.cond, self.expr_if, self.expr_else] - def generate(self): - return self.sim.code_gen.generate_select(self.cond.generate(), self.expr_if.generate(), self.expr_else.generate()) - def transform(self, fn): self.cond = self.cond.transform(fn) self.expr_if = self.expr_if.transform(fn) diff --git a/ast/sizeof.py b/ast/sizeof.py index ee2598771d10a9a8efb00f8ba2f4b1e33441240f..e18f89ef312c99d76148881654e8fe65b2aeb231 100644 --- a/ast/sizeof.py +++ b/ast/sizeof.py @@ -22,8 +22,5 @@ class Sizeof: def children(self): return [] - def generate(self, mem=False, index=None): - return self.sim.code_gen.generate_sizeof(self.data_type) - def transform(self, fn): return fn(self) diff --git a/ast/utils.py b/ast/utils.py index 55d630b88023c1e8ac9f6f4bd98b1be9c30c4b06..922dcbe8507b730090c5bb1f49fb51d434e6f86a 100644 --- a/ast/utils.py +++ b/ast/utils.py @@ -9,8 +9,5 @@ class Print: def children(self): return [] - def generate(self): - self.sim.code_gen.generate_print(self.string) - def transform(self, fn): return fn(self) diff --git a/ast/variables.py b/ast/variables.py index 9604b28cdc97a4fb88ad12fbaee2829af875c4b3..5f1271949fe2c82e361ae4fc4c8a2668d79683d7 100644 --- a/ast/variables.py +++ b/ast/variables.py @@ -104,9 +104,6 @@ class Var: def children(self): return [] - def generate(self, mem=False, index=None): - return self.var_name - def transform(self, fn): return fn(self) @@ -120,9 +117,5 @@ class VarDecl: def children(self): return [] - def generate(self, mem=False, index=None): - self.sim.code_gen.generate_var_decl( - self.var.name(), self.var.type(), self.var.init_value()) - def transform(self, fn): return fn(self) diff --git a/code_gen/cgen.py b/code_gen/cgen.py index 6482e91c49f9c4dd3c49de192dd607fe53d8d05b..96e0d9d132fbab3539d133f0250157b0f753b6d2 100644 --- a/code_gen/cgen.py +++ b/code_gen/cgen.py @@ -1,4 +1,21 @@ +from ast.assign import Assign +from ast.arrays import ArrayAccess, ArrayDecl +from ast.block import Block +from ast.branches import Branch +from ast.cast import Cast +from ast.expr import BinOp, BinOpDef from ast.data_types import Type_Int, Type_Float, Type_Vector +from ast.lit import Lit +from ast.loops import For, Iter, ParticleFor, While +from ast.math import Sqrt +from ast.memory import Malloc, Realloc +from ast.properties import Property +from ast.select import Select +from ast.sizeof import Sizeof +from ast.utils import Print +from ast.variables import Var, VarDecl +from sim.timestep import Timestep +from sim.vtk import VTKWrite from code_gen.printer import printer @@ -10,170 +27,261 @@ class CGen: else 'bool' ) - def generate_program_preamble(): + def generate_program(sim, ast_node): printer.print("#include <stdio.h>") printer.print("#include <stdlib.h>") printer.print("#include <stdbool.h>") printer.print("") printer.print("int main() {") - - def generate_program_epilogue(): - printer.print("}") - - def generate_block_preamble(): - printer.add_ind(4) - - def generate_block_epilogue(): - printer.add_ind(-4) - - def generate_cast(ctype, expr): - tkw = CGen.type2keyword(ctype) - return f"({tkw})({expr})" - - def generate_if(cond): - printer.print(f"if({cond}) {{") - - def generate_else(): - printer.print("} else {") - - def generate_endif(): - printer.print("}") - - def generate_assignment(dest, src): - printer.print(f"{dest} = {src};") - - def generate_array_decl(array, a_type, size): - tkw = CGen.type2keyword(a_type) - printer.print(f"{tkw} {array}[{size}];") - - def generate_array_access_ref(acc_id, array, index, mem=False): - if mem: - return f"{array}[{index}]" - - return f"a{acc_id}" - - def generate_array_access(acc_id, acc_type, array, index): - ref = CGen.generate_array_access_ref(acc_id, array, index) - tkw = CGen.type2keyword(acc_type) - acc = f"const {tkw} {ref} = {array}[{index}];" - printer.print(acc) - - def generate_malloc(array, a_type, size, decl): - tkw = CGen.type2keyword(a_type) - if decl: - printer.print(f"{tkw} *{array} = ({tkw} *) malloc({size});") - else: - printer.print(f"{array} = ({tkw} *) malloc({size});") - - def generate_realloc(array, a_type, size): - tkw = CGen.type2keyword(a_type) - printer.print(f"{array} = ({tkw} *) realloc({array}, {size});") - - def generate_sizeof(data_type): - tkw = CGen.type2keyword(data_type) - return f"sizeof({tkw})" - - def generate_for_preamble(iter_id, rmin, rmax): - printer.print(f"for(int {iter_id} = {rmin}; {iter_id} < {rmax}; {iter_id}++) {{") - - def generate_for_epilogue(): + CGen.generate_statement(sim, ast_node) printer.print("}") - def generate_while_preamble(cond): - printer.print(f"while({cond}) {{") - - def generate_while_epilogue(): - printer.print("}") - - def generate_expr_ref(expr_id): - return f"e{expr_id}" - - def generate_expr(expr_id, expr_type, lhs, rhs, op): - ref = CGen.generate_expr_ref(expr_id) - tkw = CGen.type2keyword(expr_type) - printer.print(f"const {tkw} {ref} = {lhs} {op} {rhs};") - - def generate_expr_access(lhs, rhs, mem): - return f"{lhs}[{rhs}]" if mem else f"{lhs}_{rhs}" + def generate_statement(sim, ast_node): + if isinstance(ast_node, ArrayDecl): + tkw = CGen.type2keyword(ast_node.array.type()) + size = CGen.generate_expression(sim, BinOp.inline(ast_node.array.alloc_size())) + printer.print(f"{tkw} {ast_node.array.name()}[{size}];") - def generate_vec_expr_ref(expr_id, index, mem): - return f"e{expr_id}[{index}]" if mem else f"e{expr_id}_{index}" + if isinstance(ast_node, Assign): + for assign_dest, assign_src in ast_node.assignments: + dest = CGen.generate_expression(sim, assign_dest, mem=True) + src = CGen.generate_expression(sim, assign_src) + printer.print(f"{dest} = {src};") + + if isinstance(ast_node, Block): + printer.add_ind(4) + + for stmt in ast_node.statements(): + CGen.generate_statement(sim, stmt) + + printer.add_ind(-4) + + if isinstance(ast_node, BinOpDef): + bin_op = ast_node.bin_op + + if not isinstance(bin_op, BinOp): + return None + + if bin_op.inlined is False and bin_op.operator() != '[]' and bin_op.generated is False: + if bin_op.kind() == BinOp.Kind_Scalar: + lhs = CGen.generate_expression(sim, bin_op.lhs, bin_op.mem) + rhs = CGen.generate_expression(sim, bin_op.rhs) + tkw = CGen.type2keyword(bin_op.type()) + printer.print(f"const {tkw} e{bin_op.id()} = {lhs} {bin_op.operator()} {rhs};") + + elif bin_op.kind() == BinOp.Kind_Vector: + for i in bin_op.vector_indexes(): + lhs = CGen.generate_expression(sim, bin_op.lhs, bin_op.mem, index=i) + rhs = CGen.generate_expression(sim, bin_op.rhs, index=i) + printer.print(f"const double e{bin_op.id()}_{i} = {lhs} {bin_op.operator()} {rhs};") + + else: + raise Exception("Invalid BinOp kind!") - def generate_vec_expr(expr_id, index, lhs, rhs, op, mem): - ref = CGen.generate_vec_expr_ref(expr_id, index, mem) - printer.print(f"const double {ref} = {lhs} {op} {rhs};") + bin_op.generated = True + + if isinstance(ast_node, Branch): + cond = CGen.generate_expression(sim, ast_node.cond) + printer.print(f"if({cond}) {{") + CGen.generate_statement(sim, ast_node.block_if) + + if ast_node.block_else is not None: + printer.print("} else {") + CGen.generate_statement(sim, ast_node.block_else) + + printer.print("}") + + if isinstance(ast_node, For): + iterator = CGen.generate_expression(sim, ast_node.iterator) + lower_range = None + upper_range = None + + if isinstance(ast_node, ParticleFor): + n = sim.nlocal if ast_node.local_only else sim.nlocal + sim.pbc.npbc + lower_range = 0 + upper_range = CGen.generate_expression(sim, n) + + else: + lower_range = CGen.generate_expression(sim, ast_node.min) + upper_range = CGen.generate_expression(sim, ast_node.max) + + printer.print(f"for(int {iterator} = {lower_range}; {iterator} < {upper_range}; {iterator}++) {{") + CGen.generate_statement(sim, ast_node.block) + printer.print("}") - def generate_inline_expr(lhs, rhs, op): - return f"({lhs} {op} {rhs})" - def generate_var_decl(v_name, v_type, v_value): - tkw = CGen.type2keyword(v_type) - printer.print(f"{tkw} {v_name} = {v_value};") - - def generate_sqrt(expr): - return f"sqrt({expr})" - - def generate_select(cond, expr_if, expr_else): - return f"({cond}) ? ({expr_if}) : ({expr_else})" - - def generate_print(string): - printer.print(f"fprintf(stdout, \"{string}\\n\");") - printer.print(f"fflush(stdout);") - - def generate_vtk_writing(id, filename, start, n, timestep): + if isinstance(ast_node, Malloc): + tkw = CGen.type2keyword(ast_node.array.type()) + size = CGen.generate_expression(sim, ast_node.size) + array_name = ast_node.array.name() + + if ast_node.decl: + printer.print(f"{tkw} *{array_name} = ({tkw} *) malloc({size});") + else: + printer.print(f"{array_name} = ({tkw} *) malloc({size});") + + if isinstance(ast_node, Print): + printer.print(f"fprintf(stdout, \"{ast_node.string}\\n\");") + printer.print(f"fflush(stdout);") + + if isinstance(ast_node, Realloc): + tkw = CGen.type2keyword(ast_node.array.type()) + size = CGen.generate_expression(sim, ast_node.size) + array_name = ast_node.array.name() + printer.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});") + + if isinstance(ast_node, Timestep): + CGen.generate_statement(sim, ast_node.block) + + if isinstance(ast_node, VarDecl): + tkw = CGen.type2keyword(ast_node.var.type()) + printer.print(f"{tkw} {ast_node.var.name()} = {ast_node.var.init_value()};") + + if isinstance(ast_node, VTKWrite): + nlocal = CGen.generate_expression(sim, sim.nlocal) + npbc = CGen.generate_expression(sim, sim.pbc.npbc) + nall = CGen.generate_expression(sim, sim.nlocal + sim.pbc.npbc) + timestep = CGen.generate_expression(sim, ast_node.timestep) + CGen.generate_vtk_writing(ast_node.vtk_id * 2, f"{ast_node.filename}_local", 0, nlocal, nlocal, timestep) + CGen.generate_vtk_writing(ast_node.vtk_id * 2 + 1, f"{ast_node.filename}_pbc", nlocal, nall, npbc, timestep) + + if isinstance(ast_node, While): + cond = CGen.generate_expression(sim, ast_node.cond) + printer.print(f"while({cond}) {{") + CGen.generate_statement(sim, ast_node.block) + printer.print("}") + + def generate_expression(sim, ast_node, mem=False, index=None): + if isinstance(ast_node, ArrayAccess): + index = CGen.generate_expression(sim, ast_node.index) + array_name = ast_node.array.name() + + if mem: + return f"{array_name}[{index}]" + + acc_ref = f"a{ast_node.id()}" + if ast_node.generated is False: + tkw = CGen.type2keyword(ast_node.type()) + printer.print(f"const {tkw} {acc_ref} = {array_name}[{index}];") + ast_node.generated = True + + return acc_ref + + if isinstance(ast_node, BinOp): + if isinstance(ast_node.lhs, BinOp) and ast_node.lhs.kind() == BinOp.Kind_Vector and ast_node.operator() == '[]': + return CGen.generate_expression(sim, ast_node.lhs, ast_node.mem, CGen.generate_expression(sim, ast_node.rhs)) + + lhs = CGen.generate_expression(sim, ast_node.lhs, mem, index) + rhs = CGen.generate_expression(sim, ast_node.rhs, index=index) + + if ast_node.operator() == '[]': + idx = CGen.generate_expression(sim, ast_node.mapped_vector_index(index)) if ast_node.is_vector_property_access() else rhs + return f"{lhs}[{idx}]" if ast_node.mem else f"{lhs}_{idx}" + + if ast_node.inlined is True: + assert ast_node.type() != Type_Vector, "Vector operations cannot be inlined!" + return f"({lhs} {ast_node.operator()} {rhs})" + + # Some expressions can be defined on-the-fly during transformations, hence they do not have + # a definition statement in the tree, so we generate them right before usage + if not ast_node.generated: + CGen.generate_statement(sim, ast_node.definition()) + + if ast_node.kind() == BinOp.Kind_Vector: + assert index is not None, "Index must be set for vector reference!" + return f"e{ast_node.id()}[{index}]" if ast_node.mem else f"e{ast_node.id()}_{index}" + + return f"e{ast_node.id()}" + + if isinstance(ast_node, Cast): + tkw = CGen.type2keyword(ast_node.cast_type) + expr = CGen.generate_expression(sim, ast_node.expr) + return f"({tkw})({expr})" + + if isinstance(ast_node, Iter): + assert mem is False, "Iterator is not lvalue!" + return f"i{ast_node.id()}" + + if isinstance(ast_node, Lit): + assert mem is False, "Literal is not lvalue!" + return ast_node.value + + if isinstance(ast_node, Property): + return ast_node.name() + + if isinstance(ast_node, Sizeof): + assert mem is False, "Sizeof expression is not lvalue!" + tkw = CGen.type2keyword(ast_node.data_type) + return f"sizeof({tkw})" + + if isinstance(ast_node, Sqrt): + assert mem is False, "Square root call is not lvalue!" + expr = CGen.generate_expression(sim, ast_node.expr) + return f"sqrt({expr})" + + if isinstance(ast_node, Select): + assert mem is False, "Select expression is not lvalue!" + cond = CGen.generate_expression(sim, ast_node.cond) + expr_if = CGen.generate_expression(sim, ast_node.expr_if) + expr_else = CGen.generate_expression(sim, ast_node.expr_else) + return f"({cond}) ? ({expr_if}) : ({expr_else})" + + if isinstance(ast_node, Var): + return ast_node.name() + + def generate_vtk_writing(id, filename, start, end, n, timestep): # TODO: Do this in a more elegant way, without hard coded stuff header = "# vtk DataFile Version 2.0\n" \ "Particle data\n" \ "ASCII\n" \ "DATASET UNSTRUCTURED_GRID\n" - end = start + n filename_var = f"filename{id}" filehandle_var = f"vtk{id}" printer.print(f"char {filename_var}[128];") - printer.print(f"snprintf({filename_var}, sizeof {filename_var}, \"{filename}_%d.vtk\", {timestep.generate()});") + printer.print(f"snprintf({filename_var}, sizeof {filename_var}, \"{filename}_%d.vtk\", {timestep});") printer.print(f"FILE *{filehandle_var} = fopen({filename_var}, \"w\");") for line in header.split('\n'): if len(line) > 0: printer.print(f"fwrite(\"{line}\\n\", 1, {len(line) + 1}, {filehandle_var});") # Write positions - printer.print(f"fprintf({filehandle_var}, \"POINTS %d double\\n\", {n.generate()});") - CGen.generate_for_preamble("i", start.generate(), end.generate()) + printer.print(f"fprintf({filehandle_var}, \"POINTS %d double\\n\", {n});") + printer.print(f"for(int i = {start}; i < {end}; i++) {{") printer.add_ind(4) printer.print(f"fprintf({filehandle_var}, \"%.4f %.4f %.4f\\n\", position[i * 3], position[i * 3 + 1], position[i * 3 + 2]);") printer.add_ind(-4) - CGen.generate_for_epilogue() + printer.print("}") printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") # Write cells - printer.print(f"fprintf({filehandle_var}, \"CELLS %d %d\\n\", {n.generate()}, {n.generate()} * 2);") - CGen.generate_for_preamble("i", start.generate(), end.generate()) + printer.print(f"fprintf({filehandle_var}, \"CELLS %d %d\\n\", {n}, {n} * 2);") + printer.print(f"for(int i = {start}; i < {end}; i++) {{") printer.add_ind(4) - printer.print(f"fprintf({filehandle_var}, \"1 %d\\n\", i - {start.generate()});") + printer.print(f"fprintf({filehandle_var}, \"1 %d\\n\", i - {start});") printer.add_ind(-4) - CGen.generate_for_epilogue() + printer.print("}") printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") # Write cell types - printer.print(f"fprintf({filehandle_var}, \"CELL_TYPES %d\\n\", {n.generate()});") - CGen.generate_for_preamble("i", start.generate(), end.generate()) + printer.print(f"fprintf({filehandle_var}, \"CELL_TYPES %d\\n\", {n});") + printer.print(f"for(int i = {start}; i < {end}; i++) {{") printer.add_ind(4) printer.print(f"fwrite(\"1\\n\", 1, 2, {filehandle_var});") printer.add_ind(-4) - CGen.generate_for_epilogue() + printer.print("}") printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") # Write masses - printer.print(f"fprintf({filehandle_var}, \"POINT_DATA %d\\n\", {n.generate()});") + printer.print(f"fprintf({filehandle_var}, \"POINT_DATA %d\\n\", {n});") printer.print(f"fprintf({filehandle_var}, \"SCALARS mass double\\n\");") printer.print(f"fprintf({filehandle_var}, \"LOOKUP_TABLE default\\n\");") - CGen.generate_for_preamble("i", start.generate(), end.generate()) + printer.print(f"for(int i = {start}; i < {end}; i++) {{") printer.add_ind(4) #printer.print(f"fprintf({filehandle_var}, \"%4.f\\n\", mass[i]);") printer.print(f"fprintf({filehandle_var}, \"1.0\\n\");") printer.add_ind(-4) - CGen.generate_for_epilogue() + printer.print("}") printer.print(f"fwrite(\"\\n\\n\", 1, 2, {filehandle_var});") printer.print(f"fclose({filehandle_var});") diff --git a/sim/arrays.py b/sim/arrays.py index 000ad3f9045913cb913889b7c4c4bd36f09160b3..0d019545305606cad2c318c6bd626fe8d79c411b 100644 --- a/sim/arrays.py +++ b/sim/arrays.py @@ -12,6 +12,6 @@ class ArraysDecl: if a.is_static(): ArrayDecl(self.sim, a) else: - Malloc(self.sim, a, a.type(), a.alloc_size(), True) + Malloc(self.sim, a, a.alloc_size(), True) return self.sim.block diff --git a/sim/particle_simulation.py b/sim/particle_simulation.py index 5d982910f7b4c5d4a6e32da6803a4a189b4fb732..ddabe68d080cfcc07b88b6eb4046e12f11465953 100644 --- a/sim/particle_simulation.py +++ b/sim/particle_simulation.py @@ -200,6 +200,5 @@ class ParticleSimulation: #Transform.apply(program, Transform.reuse_array_access_expressions) #Transform.apply(program, Transform.move_loop_invariant_expressions) - self.code_gen.generate_program_preamble() - program.generate() - self.code_gen.generate_program_epilogue() + self.code_gen.generate_program(self, program) + diff --git a/sim/properties.py b/sim/properties.py index e9b679119a2c64b681db119ee92fb50e705eb34c..f07f9d58029d79e9c90f8b6dbaa92509bb6e7bb6 100644 --- a/sim/properties.py +++ b/sim/properties.py @@ -26,9 +26,9 @@ class PropertiesAlloc: raise Exception("Invalid property type!") if self.realloc: - Realloc(self.sim, p, p.type(), sizes) + Realloc(self.sim, p, sizes) else: - Malloc(self.sim, p, p.type(), sizes, True) + Malloc(self.sim, p, sizes, True) return self.sim.block diff --git a/sim/resize.py b/sim/resize.py index 7a2d93b2500695eb6e22435af8f93b2539834cc3..eada64289851d2dc4f3c2033342eb06949ec2878 100644 --- a/sim/resize.py +++ b/sim/resize.py @@ -33,4 +33,4 @@ class Resize: else: sizes = capacity * self.sim.dimensions - Realloc(self.sim, p, p.type(), sizes) + Realloc(self.sim, p, sizes) diff --git a/sim/timestep.py b/sim/timestep.py index 5a676486e16c5529ebdd85ea95811cd13c2bcd8c..86e5c95a6816ccecd91bfee0b2adbf6021b74eb3 100644 --- a/sim/timestep.py +++ b/sim/timestep.py @@ -42,8 +42,5 @@ class Timestep: def as_block(self): return Block(self.sim, [self.timestep_loop]) - def generate(self): - self.block.generate() - def transform(self, fn): self.block = self.block.transform(fn) diff --git a/sim/vtk.py b/sim/vtk.py index 53e4df733c4ebf9d7b1bc2f360185877724b6ca4..c1620f85e0bc79238e28ca8b849741ea73540399 100644 --- a/sim/vtk.py +++ b/sim/vtk.py @@ -5,22 +5,13 @@ class VTKWrite: def __init__(self, sim, filename, timestep): self.sim = sim + self.vtk_id = VTKWrite.vtk_id self.filename = filename self.timestep = as_lit_ast(sim, timestep) + VTKWrite.vtk_id += 1 def children(self): return [] - def generate(self): - self.sim.code_gen.generate_vtk_writing( - VTKWrite.vtk_id * 2, f"{self.filename}_local", - as_lit_ast(self.sim, 0), self.sim.nlocal, self.timestep) - - self.sim.code_gen.generate_vtk_writing( - VTKWrite.vtk_id * 2 + 1, f"{self.filename}_pbc", - self.sim.nlocal, self.sim.pbc.npbc, self.timestep) - - VTKWrite.vtk_id += 1 - def transform(self, fn): return fn(self)