Skip to content
Snippets Groups Projects
Commit 0845d667 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fix symbol canonicalization to not duplicate when marking as updated

parent 3ee5d9b6
No related branches found
No related tags found
1 merge request!380Fix symbol canonicalization to not duplicate when marking as updated
Pipeline #65464 passed
......@@ -33,7 +33,7 @@ class CanonContext:
return replacement
def mark_as_updated(self, symb: PsSymbol):
self.updated_symbols.add(self.deduplicate(symb))
self.updated_symbols.add(symb)
def is_live(self, symb: PsSymbol) -> bool:
return symb in self.live_symbols_map
......
......@@ -86,3 +86,31 @@ def test_do_not_constify():
assert ctx.find_symbol("x").dtype.const
assert not ctx.find_symbol("z").dtype.const
def test_loop_counters():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
canonicalize = CanonicalizeSymbols(ctx)
f = Field.create_generic("f", 2, index_shape=(1,))
g = Field.create_generic("g", 2, index_shape=(1,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:, :], archetype_field=f)
ctx.set_iteration_space(ispace)
asm = Assignment(f.center(0), 2 * g.center(0))
body = PsBlock([factory.parse_sympy(asm)])
loops = factory.loops_from_ispace(ispace, body)
loops_copy = loops.clone()
ast = PsBlock([loops, loops_copy])
ast = canonicalize(ast)
assert loops_copy.counter.symbol.name == "ctr_0"
assert not loops_copy.counter.symbol.get_dtype().const
assert loops.counter.symbol.name == "ctr_0__0"
assert not loops.counter.symbol.get_dtype().const
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment