From 0a252c24d88e6ef441b8731f64b538b7b1a8aa6a Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 3 Apr 2024 16:50:49 +0200
Subject: [PATCH] move CanonicalizeSymbols pass to optimize_cpu

---
 src/pystencils/backend/kernelcreation/cpu_optimization.py | 5 ++++-
 src/pystencils/kernelcreation.py                          | 4 ----
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/src/pystencils/backend/kernelcreation/cpu_optimization.py b/src/pystencils/backend/kernelcreation/cpu_optimization.py
index 21285a7fc..b0156c7e8 100644
--- a/src/pystencils/backend/kernelcreation/cpu_optimization.py
+++ b/src/pystencils/backend/kernelcreation/cpu_optimization.py
@@ -3,7 +3,7 @@ from typing import cast
 
 from .context import KernelCreationContext
 from ..platforms import GenericCpu
-from ..transformations import HoistLoopInvariantDeclarations
+from ..transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
 from ..ast.structural import PsBlock
 
 from ...config import CpuOptimConfig
@@ -17,6 +17,9 @@ def optimize_cpu(
 ) -> PsBlock:
     """Carry out CPU-specific optimizations according to the given configuration."""
 
+    canonicalize = CanonicalizeSymbols(ctx, True)
+    kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
+
     hoist_invariants = HoistLoopInvariantDeclarations(ctx)
     kernel_ast = cast(PsBlock, hoist_invariants(kernel_ast))
 
diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 4fa7f98cf..0f6941cf5 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -26,7 +26,6 @@ from .backend.kernelcreation.iteration_space import (
 
 from .backend.ast.analysis import collect_required_headers, collect_undefined_symbols
 from .backend.transformations import (
-    CanonicalizeSymbols,
     EliminateConstants,
     EraseAnonymousStructTypes,
     SelectFunctions,
@@ -99,9 +98,6 @@ def create_kernel(
     kernel_ast = platform.materialize_iteration_space(kernel_body, ispace)
 
     #   Simplifying transformations
-    canonicalize = CanonicalizeSymbols(ctx, True)
-    kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
-
     elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
     kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
 
-- 
GitLab