From ce3fb3e86673be620ccd17f5470815594e75ae42 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Wed, 23 Oct 2024 09:42:23 +0200
Subject: [PATCH] add another late constant folding pass

---
 src/pystencils/kernelcreation.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py
index 57d972389..7d9ac7aa4 100644
--- a/src/pystencils/kernelcreation.py
+++ b/src/pystencils/kernelcreation.py
@@ -132,7 +132,7 @@ def create_kernel(
                 f"Code generation for target {target} not implemented"
             )
 
-    #   Simplifying transformations
+    #   Fold and extract constants
     elim_constants = EliminateConstants(ctx, extract_constant_exprs=True)
     kernel_ast = cast(PsBlock, elim_constants(kernel_ast))
 
@@ -151,10 +151,16 @@ def create_kernel(
     select_functions = SelectFunctions(platform)
     kernel_ast = cast(PsBlock, select_functions(kernel_ast))
 
-    #   Lowering introduces new symbols, which have to be canonicalized
+    #   Late canonicalization and constant elimination passes
+    #    * Since lowering introduces new index calculations and indexing symbols into the AST,
+    #    * these need to be handled here
+    
     canonicalize = CanonicalizeSymbols(ctx, True)
     kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
 
+    late_fold_constants = EliminateConstants(ctx, extract_constant_exprs=False)
+    kernel_ast = cast(PsBlock, late_fold_constants(kernel_ast))
+
     if config.target.is_cpu():
         return create_cpu_kernel_function(
             ctx,
-- 
GitLab