Skip to content
Snippets Groups Projects
Commit 0845d667 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix symbol canonicalization to not duplicate when marking as updated

parent 3ee5d9b6
Branches
No related tags found
No related merge requests found
......@@ -33,7 +33,7 @@ class CanonContext:
return replacement
def mark_as_updated(self, symb: PsSymbol):
self.updated_symbols.add(self.deduplicate(symb))
self.updated_symbols.add(symb)
def is_live(self, symb: PsSymbol) -> bool:
return symb in self.live_symbols_map
......
......@@ -86,3 +86,31 @@ def test_do_not_constify():
assert ctx.find_symbol("x").dtype.const
assert not ctx.find_symbol("z").dtype.const
def test_loop_counters():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
f = Field.create_generic("f", 2, index_shape=(1,))
g = Field.create_generic("g", 2, index_shape=(1,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f)
ctx.set_iteration_space(ispace)
asm = Assignment(f.center(0), 2 * g.center(0))
body = PsBlock([factory.parse_sympy(asm)])
loops = factory.loops_from_ispace(ispace, body)
loops_copy = loops.clone()
ast = PsBlock([loops, loops_copy])
ast = canonicalize(ast)
assert loops_copy.counter.symbol.name == "ctr_0"
assert not loops_copy.counter.symbol.get_dtype().const
assert loops.counter.symbol.name == "ctr_0__0"
assert not loops.counter.symbol.get_dtype().const
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment