Skip to content
Snippets Groups Projects
Commit b9d654ae authored by Markus Holzer's avatar Markus Holzer Committed by Jan Hönig
Browse files

Revision3

parent f38af3f3
No related merge requests found
...@@ -638,6 +638,7 @@ def move_constants_before_loop(ast_node): ...@@ -638,6 +638,7 @@ def move_constants_before_loop(ast_node):
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const), target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before) child_to_insert_before)
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups): def split_inner_loop(ast_node: ast.Node, symbol_groups):
......
...@@ -25,7 +25,9 @@ def test_symbol_renaming(): ...@@ -25,7 +25,9 @@ def test_symbol_renaming():
loops = block.atoms(LoopOverCoordinate) loops = block.atoms(LoopOverCoordinate)
assert len(loops) == 2 assert len(loops) == 2
assert len(block.args[2].body.args) == 1
assert len(block.args[3].body.args) == 2
for loop in loops: for loop in loops:
assert len(loop.body.args) == 1
assert len(loop.parent.args) == 4 # 2 loops + 2 subexpressions assert len(loop.parent.args) == 4 # 2 loops + 2 subexpressions
assert loop.parent.args[0].lhs.name != loop.parent.args[1].lhs.name assert loop.parent.args[0].lhs.name != loop.parent.args[1].lhs.name
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment