From c3a2d7341665adb11fbdb18cdfdd884734f6f204 Mon Sep 17 00:00:00 2001
From: Daniel Bauer <daniel.j.bauer@fau.de>
Date: Tue, 30 Apr 2024 17:18:27 +0200
Subject: [PATCH] reenable quadloop

---
 .../function_space_impls.py                   |  18 +--
 hog/operator_generation/kernel_types.py       |   2 +-
 hog/operator_generation/operators.py          |  99 ++++++++-------
 hog/operator_generation/optimizer.py          |   3 -
 hog/quadrature/quad_loop.py                   | 116 +++++++++++-------
 pyproject.toml                                |   2 +-
 requirements.txt                              |   2 +-
 7 files changed, 143 insertions(+), 99 deletions(-)

diff --git a/hog/operator_generation/function_space_impls.py b/hog/operator_generation/function_space_impls.py
index bd73d80..75fdd9f 100644
--- a/hog/operator_generation/function_space_impls.py
+++ b/hog/operator_generation/function_space_impls.py
@@ -27,11 +27,10 @@ from pystencils.backend.ast.expressions import (
     PsExpression,
     PsLiteral,
 )
-from pystencils.backend.ast.structural import PsComment, PsDeclaration
+from pystencils.backend.ast.structural import PsDeclaration
 from pystencils.backend.functions import CFunction
 from pystencils.kernelcreation import FreezeExpressions, KernelCreationContext, Typifier
-from pystencils.types import PsType
-from pystencils.types.basic_types import PsCustomType
+from pystencils.types import PsCustomType, PsType
 from pystencils.types.quick import UInt
 
 from hog.element_geometry import ElementGeometry
@@ -539,12 +538,17 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl):
         symbols = [sp.Symbol(f"{name}_{i}_{i}") for i in range(n_dofs)]
         diag_vars = [PsExpression.make(ctx.get_symbol(s.name)) for s in symbols]
 
-        trafo_matrix_var = PsExpression.make(
+        # Create a symbol with a unique name to ensure that the name is not
+        # changed later by a CanonicalizeSymbols pass. This allows us to use
+        # this symbol's name below.
+        trafo_matrix_symb = ctx.duplicate_symbol(
             ctx.get_symbol(
                 name,
                 PsCustomType(f"Eigen::DiagonalMatrix< real_t, {n_dofs} >", const=True),
             )
         )
+        trafo_matrix_var = PsExpression.make(trafo_matrix_symb)
+
         level_var = PsExpression.make(ctx.get_symbol("level", UInt(64)))
         macro_var = PsExpression.make(
             ctx.get_symbol(f"{macro}", PsCustomType(f"{Macro}&", const=True))
@@ -577,9 +581,9 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl):
                 var,
                 PsCast(
                     ctx.default_dtype,
-                    CFunction(f"{name}.diagonal()", (UInt(64),), real_t)(
-                        PsExpression.make(PsConstant(i, UInt(64)))
-                    ),
+                    CFunction(
+                        f"{trafo_matrix_symb.name}.diagonal()", (UInt(64),), real_t
+                    )(PsExpression.make(PsConstant(i, UInt(64)))),
                 ),
             )
             for i, var in enumerate(diag_vars)
diff --git a/hog/operator_generation/kernel_types.py b/hog/operator_generation/kernel_types.py
index 57e2db3..6b4fd2f 100644
--- a/hog/operator_generation/kernel_types.py
+++ b/hog/operator_generation/kernel_types.py
@@ -38,7 +38,7 @@ from pystencils.backend.ast.structural import (
 from pystencils.backend.functions import CFunction
 from pystencils.field import Field
 from pystencils.kernelcreation import FreezeExpressions, KernelCreationContext, Typifier
-from pystencils.types.basic_types import PsCustomType
+from pystencils.types import PsCustomType
 from pystencils.types.quick import SInt
 
 from hog.cpp_printing import (
diff --git a/hog/operator_generation/operators.py b/hog/operator_generation/operators.py
index 89f3b18..54621aa 100644
--- a/hog/operator_generation/operators.py
+++ b/hog/operator_generation/operators.py
@@ -16,6 +16,7 @@
 
 from dataclasses import dataclass
 from enum import auto, Enum
+from functools import reduce
 import logging
 import os
 from textwrap import indent
@@ -30,17 +31,14 @@ from sympy.codegen.ast import Assignment
 from pystencils import AssignmentCollection, Target, TypedSymbol
 from pystencils.backend import KernelFunction
 from pystencils.backend.ast import PsAstNode
-from pystencils.backend.ast.structural import (
-    PsAssignment,
-    PsBlock,
-    PsComment,
-    PsLoop,
-)
+from pystencils.backend.ast.analysis import UndefinedSymbolsCollector
+from pystencils.backend.ast.structural import PsBlock
 from pystencils.backend.kernelcreation import (
     FreezeExpressions,
     KernelCreationContext,
     Typifier,
 )
+from pystencils.backend.transformations import CanonicalizeSymbols
 from pystencils.sympyextensions import fast_subs
 
 from hog.cpp_printing import (
@@ -67,7 +65,7 @@ from hog.operator_generation.function_space_impls import FunctionSpaceImpl
 from hog.operator_generation.pystencils_extensions import create_generic_fields
 
 from hog.forms import Form
-from hog.ast import Operations, count_operations
+from hog.ast import Operations
 from hog.blending import GeometryMap
 import hog.code_generation
 import hog.cse
@@ -754,14 +752,14 @@ class HyTeGElementwiseOperator:
             )
 
         if integration_info.quad_loop:
-            accessed_mat_entries = mat.atoms(TypedSymbol)
-            accessed_mat_entries &= set().union(
-                *[ass.undefined_symbols for ass in kernel_op_assignments]
+            accessed_symbols = reduce(
+                set.union, (ass.rhs.free_symbols for ass in kernel_op_assignments)
             )
+            accessed_mat_entries = mat.atoms(sp.Symbol) & accessed_symbols
 
             with TimedLogger("constructing quadrature loops"):
                 quad_loop = integration_info.quad_loop.construct_quad_loop(
-                    accessed_mat_entries, self._optimizer.cse_impl()
+                    ctx, accessed_mat_entries, self._optimizer.cse_impl()
                 )
         else:
             quad_loop = []
@@ -784,8 +782,10 @@ class HyTeGElementwiseOperator:
         )
 
         jacobi_assignments = hog.code_generation.jacobi_matrix_assignments(
-            mat,
-            integration_info.tables + quad_loop,
+            mat
+            if integration_info.quad_loop is None
+            else integration_info.quad_loop.mat_integrand,
+            integration_info.tables,
             geometry,
             self.symbolizer,
             affine_points=coord_symbols_for_jac_affine,
@@ -851,7 +851,12 @@ class HyTeGElementwiseOperator:
             # coefficient that "lives" on a FEM function).
             dof_symbols_set: Set[DoFSymbol] = {
                 a
-                for ass in kernel_op_assignments + quad_loop
+                for ass in kernel_op_assignments
+                + (
+                    []
+                    if integration_info.quad_loop is None
+                    else [integration_info.quad_loop.mat_integrand]
+                )
                 for a in ass.atoms(DoFSymbol)
             }
             dof_symbols = sorted(dof_symbols_set, key=lambda ds: ds.name)
@@ -909,46 +914,47 @@ class HyTeGElementwiseOperator:
                     for component in range(geometry.dimensions)
                 ]
 
-            body = (
-                loop_counter_custom_code_nodes
-                + coords_assignments
-                + load_vecs
-                + quad_loop
-                + kernel_op_assignments
-            )
-
             if not self._optimizer[Opts.QUADLOOPS]:
                 # Only now we replace the quadrature points and weights - if there are any.
-                # We also setup sympy assignments in body
                 with TimedLogger(
                     "replacing quadrature points and weigths", logging.DEBUG
                 ):
                     if not quadrature.is_exact() and not quadrature.inline_values:
                         subs_dict = dict(quadrature.points() + quadrature.weights())
-                        for i, node in enumerate(body):
-                            body[i] = fast_subs(node, subs_dict)
+                        for i, node in enumerate(kernel_op_assignments):
+                            kernel_op_assignments[i] = fast_subs(node, subs_dict)
 
             # count operations
-            # TODO: count in post operation
+            # TODO: count in backend ast after optimizations
             ops = Operations()
-            for stmt in body:
-                if isinstance(stmt, PsLoop):
-                    for stmt2 in stmt.body.args:
-                        count_operations(
-                            stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start
-                        )
-                elif isinstance(stmt, sp.codegen.Assignment) or isinstance(
-                    stmt, PsAssignment
-                ):
-                    count_operations(stmt.rhs, ops)
-                elif isinstance(stmt, PsComment):
-                    pass
-                else:
-                    ops.unknown_ops += 1
+            # for stmt in body:
+            #     if isinstance(stmt, PsLoop):
+            #         for stmt2 in stmt.body.statements:
+            #             count_operations(
+            #                 stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start
+            #             )
+            #     elif isinstance(stmt, sp.codegen.Assignment) or isinstance(
+            #         stmt, PsAssignment
+            #     ):
+            #         count_operations(stmt.rhs, ops)
+            #     elif isinstance(stmt, PsComment):
+            #         pass
+            #     else:
+            #         ops.unknown_ops += 1
 
             freeze = FreezeExpressions(ctx)
             typify = Typifier(ctx)
-            loop_bodies[element_type] = typify(freeze(AssignmentCollection(body)))
+            loop_bodies[element_type] = typify(
+                freeze(
+                    AssignmentCollection(
+                        loop_counter_custom_code_nodes + coords_assignments + load_vecs
+                    )
+                )
+            )
+            loop_bodies[element_type].statements += [stmt.clone() for stmt in quad_loop]
+            loop_bodies[element_type].statements += typify(
+                freeze(AssignmentCollection(kernel_op_assignments))
+            ).statements
             loop_bodies[element_type].statements += kernel_op_post_ast
 
             with TimedLogger(
@@ -1008,10 +1014,10 @@ class HyTeGElementwiseOperator:
         # Add quadrature points and weights array declarations, but only those
         # which are actually needed.
         if integration_info.quad_loop:
-            q_decls = integration_info.quad_loop.point_weight_decls()
-            undefined = block.undefined_symbols
+            q_decls = integration_info.quad_loop.point_weight_decls(ctx)
+            undefined = UndefinedSymbolsCollector()(block)
             block.statements = [
-                q_decl for q_decl in q_decls if q_decl.lhs in undefined
+                q_decl for q_decl in q_decls if q_decl.lhs.symbol in undefined
             ] + block.statements
 
         return (block, ops.to_table())
@@ -1054,6 +1060,11 @@ class HyTeGElementwiseOperator:
                         ctx, dim, integration_info, loop_strategy, kernel_type
                     )
 
+                # TODO: remove unused symbols automatically
+                with TimedLogger("canonicalizing symbols", logging.DEBUG):
+                    canonicalize = CanonicalizeSymbols(ctx)
+                    kernel = canonicalize(kernel)
+
                 # optimizer applies optimizations
                 with TimedLogger(
                     f"Optimizing kernel: {kernel_type.name} in {dim}D", logging.INFO
diff --git a/hog/operator_generation/optimizer.py b/hog/operator_generation/optimizer.py
index 3419dee..dec5934 100644
--- a/hog/operator_generation/optimizer.py
+++ b/hog/operator_generation/optimizer.py
@@ -40,7 +40,6 @@ from hog.logger import TimedLogger
 from hog.operator_generation.loop_strategies import (
     LoopStrategy,
     CUBES,
-    SAWTOOTH,
     FUSEDROWS,
 )
 from hog.operator_generation.pystencils_extensions import (
@@ -189,8 +188,6 @@ class Optimizer:
                 simplify_conditionals(loop, loop_counter_simplification=True)
 
         if self[Opts.MOVECONSTANTS]:
-            # TODO: Symbols must be canonicalized before hoisting invariants.
-
             with TimedLogger("moving constants out of loop", logging.DEBUG):
                 hoist_invariants = HoistLoopInvariantDeclarations(ctx)
                 kernel.statements = hoist_invariants(kernel).statements
diff --git a/hog/quadrature/quad_loop.py b/hog/quadrature/quad_loop.py
index bd75f1b..4e59169 100644
--- a/hog/quadrature/quad_loop.py
+++ b/hog/quadrature/quad_loop.py
@@ -18,14 +18,22 @@ import logging
 import sympy as sp
 from typing import Iterable, List, Optional
 
-import pystencils as ps
+from pystencils import AssignmentCollection, TypedSymbol
 from pystencils.backend.ast import PsAstNode
+from pystencils.backend.ast.expressions import PsArrayInitList, PsConstant, PsExpression
+from pystencils.backend.ast.structural import PsDeclaration
+from pystencils.backend.kernelcreation import (
+    AstFactory,
+    FreezeExpressions,
+    KernelCreationContext,
+    Typifier,
+)
+from pystencils.types import PsArrayType
 
 from .quadrature import Quadrature
 import hog
 from hog.cse import CseImplementation
 from hog.logger import TimedLogger
-from hog.operator_generation.pystencils_extensions import create_field_access
 from hog.operator_generation.types import HOGType
 from hog.symbolizer import Symbolizer
 from hog.sympy_extensions import fast_subs
@@ -34,7 +42,7 @@ from hog.sympy_extensions import fast_subs
 class QuadLoop:
     """Implements a quadrature scheme by an explicit loop over quadrature points."""
 
-    # q_ctr = ps.TypedSymbol("q", BasicType(int))
+    q_ctr = sp.Symbol("q")
 
     def __init__(
         self,
@@ -61,11 +69,7 @@ class QuadLoop:
                 if self.symmetric and row > col:
                     continue
 
-                q_acc = ps.TypedSymbol(
-                    f"q_acc_{row}_{col}",
-                    BasicType(self.type_descriptor.pystencils_type),
-                    const=False,
-                )
+                q_acc = sp.Symbol(f"q_acc_{row}_{col}")
                 self.mat[row, col] = q_acc
 
                 if self.symmetric:
@@ -73,7 +77,8 @@ class QuadLoop:
 
     def construct_quad_loop(
         self,
-        accessed_mat_entries: Iterable[ps.TypedSymbol],
+        ctx: KernelCreationContext,
+        accessed_mat_entries: Iterable[sp.Symbol],
         cse: Optional[CseImplementation] = None,
     ) -> List[PsAstNode]:
         ref_symbols = self.symbolizer.ref_coords_as_list(
@@ -93,29 +98,35 @@ class QuadLoop:
                     continue
 
                 coord_subs_dict = {
-                    symbol: create_field_access(
-                        self.p_array_names[dim],
-                        self.type_descriptor.pystencils_type,
+                    symbol: sp.Indexed(
+                        TypedSymbol(
+                            self.p_array_names[dim],
+                            PsArrayType(
+                                ctx.default_dtype, len(self.quadrature.weights())
+                            ),
+                        ),
                         self.q_ctr,
                     )
                     for dim, symbol in enumerate(ref_symbols)
                 }
 
-                weight = create_field_access(
-                    self.w_array_name, self.type_descriptor.pystencils_type, self.q_ctr
+                weight = sp.Indexed(
+                    TypedSymbol(
+                        self.w_array_name,
+                        PsArrayType(ctx.default_dtype, len(self.quadrature.weights())),
+                    ),
+                    self.q_ctr,
                 )
                 integrated = weight * fast_subs(
                     self.mat_integrand[row, col], coord_subs_dict
                 )
 
-                accumulator_declarations.append(
-                    ast.SympyAssignment(q_acc, 0.0, is_const=False)
-                )
+                accumulator_declarations.append(sp.codegen.Assignment(q_acc, 0.0))
                 quadrature_assignments.append(
-                    ast.SympyAssignment(tmp_symbol, integrated)
+                    sp.codegen.Assignment(tmp_symbol, integrated)
                 )
                 accumulator_updates.append(
-                    ast.SympyAssignment(q_acc, q_acc + tmp_symbol, is_const=False)
+                    sp.codegen.ast.AddAugmentedAssignment(q_acc, tmp_symbol)
                 )
 
         # common subexpression elimination
@@ -125,40 +136,61 @@ class QuadLoop:
                     quadrature_assignments,
                     cse,
                     "tmp_qloop",
-                    return_type=ast.SympyAssignment,
+                    return_type=sp.codegen.Assignment,
                 )
 
-        return accumulator_declarations + [
-            ast.ForLoop(
-                ast.Block(quadrature_assignments + accumulator_updates),
-                self.q_ctr,
-                0,
-                len(self.quadrature.weights()),
-            )
-        ]
+        ast_factory = AstFactory(ctx)
+        freeze = FreezeExpressions(ctx)
+        typify = Typifier(ctx)
+
+        ctx.get_symbol(self.q_ctr.name, ctx.index_dtype)
 
-    def point_weight_decls(self) -> List[PsAstNode]:
+        block = typify(freeze(AssignmentCollection(accumulator_declarations)))
+        loop_body = typify(
+            freeze(AssignmentCollection(quadrature_assignments + accumulator_updates))
+        )
+        loop = ast_factory.loop(
+            self.q_ctr.name, slice(len(self.quadrature.weights())), loop_body
+        )
+        block.statements.append(loop)
+
+        return block.statements
+
+    def point_weight_decls(self, ctx: KernelCreationContext) -> List[PsAstNode]:
         """Returns statements that declare the quadrature rules' points and weights as c arrays."""
+        typify = Typifier(ctx)
+
         quad_decls = []
         quad_decls.append(
-            ast.ArrayDeclaration(
-                ast.FieldPointerSymbol(
-                    self.w_array_name,
-                    BasicType(self.type_descriptor.pystencils_type),
-                    False,
+            PsDeclaration(
+                PsExpression.make(
+                    ctx.get_symbol(
+                        self.w_array_name,
+                        PsArrayType(ctx.default_dtype, len(self.quadrature.weights())),
+                    )
+                ),
+                PsArrayInitList(
+                    PsExpression.make(PsConstant(w))
+                    for _, w in self.quadrature.weights()
                 ),
-                *(sp.Float(w) for _, w in self.quadrature.weights()),
             )
         )
         for dim in range(0, self.quadrature.geometry.dimensions):
             quad_decls.append(
-                ast.ArrayDeclaration(
-                    ast.FieldPointerSymbol(
-                        self.p_array_names[dim],
-                        BasicType(self.type_descriptor.pystencils_type),
-                        False,
+                PsDeclaration(
+                    PsExpression.make(
+                        ctx.get_symbol(
+                            self.p_array_names[dim],
+                            PsArrayType(
+                                ctx.default_dtype, len(self.quadrature.weights())
+                            ),
+                        )
+                    ),
+                    PsArrayInitList(
+                        PsExpression.make(PsConstant(point[dim]))
+                        for point in self.quadrature._points
                     ),
-                    *(sp.Float(point[dim]) for point in self.quadrature._points),
                 )
             )
-        return quad_decls
+
+        return [typify(d) for d in quad_decls]
diff --git a/pyproject.toml b/pyproject.toml
index 232abf5..e47b4a4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ dependencies = [
   "numpy==1.24.3",
   "quadpy-gpl==0.16.10",
   "poly-cse-py",
-  "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1",
+  "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c",
   "pytest==7.3.1",
   "sympy==1.11.1",
   "tabulate==0.9.0",
diff --git a/requirements.txt b/requirements.txt
index cdeb93f..ab5a2f9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@
 islpy
 numpy==1.24.3
 poly-cse-py
-pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1
+pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c
 pytest==7.3.1
 sympy==1.11.1
 tabulate==0.9.0
-- 
GitLab