From c3a2d7341665adb11fbdb18cdfdd884734f6f204 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Tue, 30 Apr 2024 17:18:27 +0200 Subject: [PATCH] reenable quadloop --- .../function_space_impls.py | 18 +-- hog/operator_generation/kernel_types.py | 2 +- hog/operator_generation/operators.py | 99 ++++++++------- hog/operator_generation/optimizer.py | 3 - hog/quadrature/quad_loop.py | 116 +++++++++++------- pyproject.toml | 2 +- requirements.txt | 2 +- 7 files changed, 143 insertions(+), 99 deletions(-) diff --git a/hog/operator_generation/function_space_impls.py b/hog/operator_generation/function_space_impls.py index bd73d80..75fdd9f 100644 --- a/hog/operator_generation/function_space_impls.py +++ b/hog/operator_generation/function_space_impls.py @@ -27,11 +27,10 @@ from pystencils.backend.ast.expressions import ( PsExpression, PsLiteral, ) -from pystencils.backend.ast.structural import PsComment, PsDeclaration +from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.functions import CFunction from pystencils.kernelcreation import FreezeExpressions, KernelCreationContext, Typifier -from pystencils.types import PsType -from pystencils.types.basic_types import PsCustomType +from pystencils.types import PsCustomType, PsType from pystencils.types.quick import UInt from hog.element_geometry import ElementGeometry @@ -539,12 +538,17 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl): symbols = [sp.Symbol(f"{name}_{i}_{i}") for i in range(n_dofs)] diag_vars = [PsExpression.make(ctx.get_symbol(s.name)) for s in symbols] - trafo_matrix_var = PsExpression.make( + # Create a symbol with a unique name to ensure that the name is not + # changed later by a CanonicalizeSymbols pass. This allows us to use + # this symbol's name below. + trafo_matrix_symb = ctx.duplicate_symbol( ctx.get_symbol( name, PsCustomType(f"Eigen::DiagonalMatrix< real_t, {n_dofs} >", const=True), ) ) + trafo_matrix_var = PsExpression.make(trafo_matrix_symb) + level_var = PsExpression.make(ctx.get_symbol("level", UInt(64))) macro_var = PsExpression.make( ctx.get_symbol(f"{macro}", PsCustomType(f"{Macro}&", const=True)) @@ -577,9 +581,9 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl): var, PsCast( ctx.default_dtype, - CFunction(f"{name}.diagonal()", (UInt(64),), real_t)( - PsExpression.make(PsConstant(i, UInt(64))) - ), + CFunction( + f"{trafo_matrix_symb.name}.diagonal()", (UInt(64),), real_t + )(PsExpression.make(PsConstant(i, UInt(64)))), ), ) for i, var in enumerate(diag_vars) diff --git a/hog/operator_generation/kernel_types.py b/hog/operator_generation/kernel_types.py index 57e2db3..6b4fd2f 100644 --- a/hog/operator_generation/kernel_types.py +++ b/hog/operator_generation/kernel_types.py @@ -38,7 +38,7 @@ from pystencils.backend.ast.structural import ( from pystencils.backend.functions import CFunction from pystencils.field import Field from pystencils.kernelcreation import FreezeExpressions, KernelCreationContext, Typifier -from pystencils.types.basic_types import PsCustomType +from pystencils.types import PsCustomType from pystencils.types.quick import SInt from hog.cpp_printing import ( diff --git a/hog/operator_generation/operators.py b/hog/operator_generation/operators.py index 89f3b18..54621aa 100644 --- a/hog/operator_generation/operators.py +++ b/hog/operator_generation/operators.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from enum import auto, Enum +from functools import reduce import logging import os from textwrap import indent @@ -30,17 +31,14 @@ from sympy.codegen.ast import Assignment from pystencils import AssignmentCollection, Target, TypedSymbol from pystencils.backend import KernelFunction from pystencils.backend.ast import PsAstNode -from pystencils.backend.ast.structural import ( - PsAssignment, - PsBlock, - PsComment, - PsLoop, -) +from pystencils.backend.ast.analysis import UndefinedSymbolsCollector +from pystencils.backend.ast.structural import PsBlock from pystencils.backend.kernelcreation import ( FreezeExpressions, KernelCreationContext, Typifier, ) +from pystencils.backend.transformations import CanonicalizeSymbols from pystencils.sympyextensions import fast_subs from hog.cpp_printing import ( @@ -67,7 +65,7 @@ from hog.operator_generation.function_space_impls import FunctionSpaceImpl from hog.operator_generation.pystencils_extensions import create_generic_fields from hog.forms import Form -from hog.ast import Operations, count_operations +from hog.ast import Operations from hog.blending import GeometryMap import hog.code_generation import hog.cse @@ -754,14 +752,14 @@ class HyTeGElementwiseOperator: ) if integration_info.quad_loop: - accessed_mat_entries = mat.atoms(TypedSymbol) - accessed_mat_entries &= set().union( - *[ass.undefined_symbols for ass in kernel_op_assignments] + accessed_symbols = reduce( + set.union, (ass.rhs.free_symbols for ass in kernel_op_assignments) ) + accessed_mat_entries = mat.atoms(sp.Symbol) & accessed_symbols with TimedLogger("constructing quadrature loops"): quad_loop = integration_info.quad_loop.construct_quad_loop( - accessed_mat_entries, self._optimizer.cse_impl() + ctx, accessed_mat_entries, self._optimizer.cse_impl() ) else: quad_loop = [] @@ -784,8 +782,10 @@ class HyTeGElementwiseOperator: ) jacobi_assignments = hog.code_generation.jacobi_matrix_assignments( - mat, - integration_info.tables + quad_loop, + mat + if integration_info.quad_loop is None + else integration_info.quad_loop.mat_integrand, + integration_info.tables, geometry, self.symbolizer, affine_points=coord_symbols_for_jac_affine, @@ -851,7 +851,12 @@ class HyTeGElementwiseOperator: # coefficient that "lives" on a FEM function). dof_symbols_set: Set[DoFSymbol] = { a - for ass in kernel_op_assignments + quad_loop + for ass in kernel_op_assignments + + ( + [] + if integration_info.quad_loop is None + else [integration_info.quad_loop.mat_integrand] + ) for a in ass.atoms(DoFSymbol) } dof_symbols = sorted(dof_symbols_set, key=lambda ds: ds.name) @@ -909,46 +914,47 @@ class HyTeGElementwiseOperator: for component in range(geometry.dimensions) ] - body = ( - loop_counter_custom_code_nodes - + coords_assignments - + load_vecs - + quad_loop - + kernel_op_assignments - ) - if not self._optimizer[Opts.QUADLOOPS]: # Only now we replace the quadrature points and weights - if there are any. - # We also setup sympy assignments in body with TimedLogger( "replacing quadrature points and weigths", logging.DEBUG ): if not quadrature.is_exact() and not quadrature.inline_values: subs_dict = dict(quadrature.points() + quadrature.weights()) - for i, node in enumerate(body): - body[i] = fast_subs(node, subs_dict) + for i, node in enumerate(kernel_op_assignments): + kernel_op_assignments[i] = fast_subs(node, subs_dict) # count operations - # TODO: count in post operation + # TODO: count in backend ast after optimizations ops = Operations() - for stmt in body: - if isinstance(stmt, PsLoop): - for stmt2 in stmt.body.args: - count_operations( - stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start - ) - elif isinstance(stmt, sp.codegen.Assignment) or isinstance( - stmt, PsAssignment - ): - count_operations(stmt.rhs, ops) - elif isinstance(stmt, PsComment): - pass - else: - ops.unknown_ops += 1 + # for stmt in body: + # if isinstance(stmt, PsLoop): + # for stmt2 in stmt.body.statements: + # count_operations( + # stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start + # ) + # elif isinstance(stmt, sp.codegen.Assignment) or isinstance( + # stmt, PsAssignment + # ): + # count_operations(stmt.rhs, ops) + # elif isinstance(stmt, PsComment): + # pass + # else: + # ops.unknown_ops += 1 freeze = FreezeExpressions(ctx) typify = Typifier(ctx) - loop_bodies[element_type] = typify(freeze(AssignmentCollection(body))) + loop_bodies[element_type] = typify( + freeze( + AssignmentCollection( + loop_counter_custom_code_nodes + coords_assignments + load_vecs + ) + ) + ) + loop_bodies[element_type].statements += [stmt.clone() for stmt in quad_loop] + loop_bodies[element_type].statements += typify( + freeze(AssignmentCollection(kernel_op_assignments)) + ).statements loop_bodies[element_type].statements += kernel_op_post_ast with TimedLogger( @@ -1008,10 +1014,10 @@ class HyTeGElementwiseOperator: # Add quadrature points and weights array declarations, but only those # which are actually needed. if integration_info.quad_loop: - q_decls = integration_info.quad_loop.point_weight_decls() - undefined = block.undefined_symbols + q_decls = integration_info.quad_loop.point_weight_decls(ctx) + undefined = UndefinedSymbolsCollector()(block) block.statements = [ - q_decl for q_decl in q_decls if q_decl.lhs in undefined + q_decl for q_decl in q_decls if q_decl.lhs.symbol in undefined ] + block.statements return (block, ops.to_table()) @@ -1054,6 +1060,11 @@ class HyTeGElementwiseOperator: ctx, dim, integration_info, loop_strategy, kernel_type ) + # TODO: remove unused symbols automatically + with TimedLogger("canonicalizing symbols", logging.DEBUG): + canonicalize = CanonicalizeSymbols(ctx) + kernel = canonicalize(kernel) + # optimizer applies optimizations with TimedLogger( f"Optimizing kernel: {kernel_type.name} in {dim}D", logging.INFO diff --git a/hog/operator_generation/optimizer.py b/hog/operator_generation/optimizer.py index 3419dee..dec5934 100644 --- a/hog/operator_generation/optimizer.py +++ b/hog/operator_generation/optimizer.py @@ -40,7 +40,6 @@ from hog.logger import TimedLogger from hog.operator_generation.loop_strategies import ( LoopStrategy, CUBES, - SAWTOOTH, FUSEDROWS, ) from hog.operator_generation.pystencils_extensions import ( @@ -189,8 +188,6 @@ class Optimizer: simplify_conditionals(loop, loop_counter_simplification=True) if self[Opts.MOVECONSTANTS]: - # TODO: Symbols must be canonicalized before hoisting invariants. - with TimedLogger("moving constants out of loop", logging.DEBUG): hoist_invariants = HoistLoopInvariantDeclarations(ctx) kernel.statements = hoist_invariants(kernel).statements diff --git a/hog/quadrature/quad_loop.py b/hog/quadrature/quad_loop.py index bd75f1b..4e59169 100644 --- a/hog/quadrature/quad_loop.py +++ b/hog/quadrature/quad_loop.py @@ -18,14 +18,22 @@ import logging import sympy as sp from typing import Iterable, List, Optional -import pystencils as ps +from pystencils import AssignmentCollection, TypedSymbol from pystencils.backend.ast import PsAstNode +from pystencils.backend.ast.expressions import PsArrayInitList, PsConstant, PsExpression +from pystencils.backend.ast.structural import PsDeclaration +from pystencils.backend.kernelcreation import ( + AstFactory, + FreezeExpressions, + KernelCreationContext, + Typifier, +) +from pystencils.types import PsArrayType from .quadrature import Quadrature import hog from hog.cse import CseImplementation from hog.logger import TimedLogger -from hog.operator_generation.pystencils_extensions import create_field_access from hog.operator_generation.types import HOGType from hog.symbolizer import Symbolizer from hog.sympy_extensions import fast_subs @@ -34,7 +42,7 @@ from hog.sympy_extensions import fast_subs class QuadLoop: """Implements a quadrature scheme by an explicit loop over quadrature points.""" - # q_ctr = ps.TypedSymbol("q", BasicType(int)) + q_ctr = sp.Symbol("q") def __init__( self, @@ -61,11 +69,7 @@ class QuadLoop: if self.symmetric and row > col: continue - q_acc = ps.TypedSymbol( - f"q_acc_{row}_{col}", - BasicType(self.type_descriptor.pystencils_type), - const=False, - ) + q_acc = sp.Symbol(f"q_acc_{row}_{col}") self.mat[row, col] = q_acc if self.symmetric: @@ -73,7 +77,8 @@ class QuadLoop: def construct_quad_loop( self, - accessed_mat_entries: Iterable[ps.TypedSymbol], + ctx: KernelCreationContext, + accessed_mat_entries: Iterable[sp.Symbol], cse: Optional[CseImplementation] = None, ) -> List[PsAstNode]: ref_symbols = self.symbolizer.ref_coords_as_list( @@ -93,29 +98,35 @@ class QuadLoop: continue coord_subs_dict = { - symbol: create_field_access( - self.p_array_names[dim], - self.type_descriptor.pystencils_type, + symbol: sp.Indexed( + TypedSymbol( + self.p_array_names[dim], + PsArrayType( + ctx.default_dtype, len(self.quadrature.weights()) + ), + ), self.q_ctr, ) for dim, symbol in enumerate(ref_symbols) } - weight = create_field_access( - self.w_array_name, self.type_descriptor.pystencils_type, self.q_ctr + weight = sp.Indexed( + TypedSymbol( + self.w_array_name, + PsArrayType(ctx.default_dtype, len(self.quadrature.weights())), + ), + self.q_ctr, ) integrated = weight * fast_subs( self.mat_integrand[row, col], coord_subs_dict ) - accumulator_declarations.append( - ast.SympyAssignment(q_acc, 0.0, is_const=False) - ) + accumulator_declarations.append(sp.codegen.Assignment(q_acc, 0.0)) quadrature_assignments.append( - ast.SympyAssignment(tmp_symbol, integrated) + sp.codegen.Assignment(tmp_symbol, integrated) ) accumulator_updates.append( - ast.SympyAssignment(q_acc, q_acc + tmp_symbol, is_const=False) + sp.codegen.ast.AddAugmentedAssignment(q_acc, tmp_symbol) ) # common subexpression elimination @@ -125,40 +136,61 @@ class QuadLoop: quadrature_assignments, cse, "tmp_qloop", - return_type=ast.SympyAssignment, + return_type=sp.codegen.Assignment, ) - return accumulator_declarations + [ - ast.ForLoop( - ast.Block(quadrature_assignments + accumulator_updates), - self.q_ctr, - 0, - len(self.quadrature.weights()), - ) - ] + ast_factory = AstFactory(ctx) + freeze = FreezeExpressions(ctx) + typify = Typifier(ctx) + + ctx.get_symbol(self.q_ctr.name, ctx.index_dtype) - def point_weight_decls(self) -> List[PsAstNode]: + block = typify(freeze(AssignmentCollection(accumulator_declarations))) + loop_body = typify( + freeze(AssignmentCollection(quadrature_assignments + accumulator_updates)) + ) + loop = ast_factory.loop( + self.q_ctr.name, slice(len(self.quadrature.weights())), loop_body + ) + block.statements.append(loop) + + return block.statements + + def point_weight_decls(self, ctx: KernelCreationContext) -> List[PsAstNode]: """Returns statements that declare the quadrature rules' points and weights as c arrays.""" + typify = Typifier(ctx) + quad_decls = [] quad_decls.append( - ast.ArrayDeclaration( - ast.FieldPointerSymbol( - self.w_array_name, - BasicType(self.type_descriptor.pystencils_type), - False, + PsDeclaration( + PsExpression.make( + ctx.get_symbol( + self.w_array_name, + PsArrayType(ctx.default_dtype, len(self.quadrature.weights())), + ) + ), + PsArrayInitList( + PsExpression.make(PsConstant(w)) + for _, w in self.quadrature.weights() ), - *(sp.Float(w) for _, w in self.quadrature.weights()), ) ) for dim in range(0, self.quadrature.geometry.dimensions): quad_decls.append( - ast.ArrayDeclaration( - ast.FieldPointerSymbol( - self.p_array_names[dim], - BasicType(self.type_descriptor.pystencils_type), - False, + PsDeclaration( + PsExpression.make( + ctx.get_symbol( + self.p_array_names[dim], + PsArrayType( + ctx.default_dtype, len(self.quadrature.weights()) + ), + ) + ), + PsArrayInitList( + PsExpression.make(PsConstant(point[dim])) + for point in self.quadrature._points ), - *(sp.Float(point[dim]) for point in self.quadrature._points), ) ) - return quad_decls + + return [typify(d) for d in quad_decls] diff --git a/pyproject.toml b/pyproject.toml index 232abf5..e47b4a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dependencies = [ "numpy==1.24.3", "quadpy-gpl==0.16.10", "poly-cse-py", - "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1", + "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c", "pytest==7.3.1", "sympy==1.11.1", "tabulate==0.9.0", diff --git a/requirements.txt b/requirements.txt index cdeb93f..ab5a2f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ islpy numpy==1.24.3 poly-cse-py -pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1 +pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c pytest==7.3.1 sympy==1.11.1 tabulate==0.9.0 -- GitLab