diff --git a/hog/operator_generation/function_space_impls.py b/hog/operator_generation/function_space_impls.py index 75fdd9f154f770bbc5adde432dc3bcb386a4b71c..a8e263789a64d776903a015afce6548b0f5fe73d 100644 --- a/hog/operator_generation/function_space_impls.py +++ b/hog/operator_generation/function_space_impls.py @@ -23,9 +23,9 @@ from pystencils.backend.ast import PsAstNode from pystencils.backend.ast.expressions import ( PsArrayInitList, PsCast, - PsConstant, PsExpression, PsLiteral, + PsLookup, ) from pystencils.backend.ast.structural import PsDeclaration from pystencils.backend.functions import CFunction @@ -538,16 +538,12 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl): symbols = [sp.Symbol(f"{name}_{i}_{i}") for i in range(n_dofs)] diag_vars = [PsExpression.make(ctx.get_symbol(s.name)) for s in symbols] - # Create a symbol with a unique name to ensure that the name is not - # changed later by a CanonicalizeSymbols pass. This allows us to use - # this symbol's name below. - trafo_matrix_symb = ctx.duplicate_symbol( + trafo_matrix_var = PsExpression.make( ctx.get_symbol( name, PsCustomType(f"Eigen::DiagonalMatrix< real_t, {n_dofs} >", const=True), ) ) - trafo_matrix_var = PsExpression.make(trafo_matrix_symb) level_var = PsExpression.make(ctx.get_symbol("level", UInt(64))) macro_var = PsExpression.make( @@ -580,10 +576,7 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl): PsDeclaration( var, PsCast( - ctx.default_dtype, - CFunction( - f"{trafo_matrix_symb.name}.diagonal()", (UInt(64),), real_t - )(PsExpression.make(PsConstant(i, UInt(64)))), + ctx.default_dtype, PsLookup(trafo_matrix_var, f"diagonal()({i})") ), ) for i, var in enumerate(diag_vars) diff --git a/hog/operator_generation/loop_strategies.py b/hog/operator_generation/loop_strategies.py index 37c0eaf89eeb45af9ed2390a025a3c52dacd3f15..1b1fff86ab2299ae994fd7d60d5641cdf61522ee 100644 --- a/hog/operator_generation/loop_strategies.py +++ b/hog/operator_generation/loop_strategies.py @@ -30,15 +30,14 @@ from abc import ABC, abstractmethod import sympy as sp from typing import Dict, Type, Union -from pystencils.backend.ast.expressions import PsSymbolExpr -from pystencils.backend.ast.structural import PsBlock, PsComment -from pystencils.backend.kernelcreation import KernelCreationContext -from pystencils.sympyextensions import fast_subs +from pystencils.backend.ast.expressions import PsConstant, PsExpression +from pystencils.backend.ast.structural import PsBlock, PsComment, PsConditional +from pystencils.backend.kernelcreation import AstFactory, KernelCreationContext +from pystencils.types.quick import Bool from hog.exception import HOGException from hog.operator_generation.pystencils_extensions import ( loop_over_simplex, - get_innermost_loop, create_micro_element_loops, fuse_loops_over_simplex, ) @@ -126,68 +125,58 @@ class CUBES(LoopStrategy): def __init__(self): super(CUBES, self).__init__() - def create_loop(self, dim, element_index, micro_edges_per_macro_edge): - """We now build all the conditional blocks for all element types. They are filled later.""" - self.conditional_blocks = {} + def create_loop( + self, + symbolizer: Symbolizer, + ctx: KernelCreationContext, + dim: int, + micro_edges_per_macro_edge: int, + loop_bodies: Dict[Union[FaceType, CellType], PsBlock], + pre_loop_stmts: Dict[Union[FaceType, CellType], PsBlock] = {}, + ) -> PsBlock: + """Create a single spatial loop nest, containing kernels for all element types.""" + + ast_factory = AstFactory(ctx) + body = PsBlock([]) + element_index = symbolizer.loop_counters(dim) + for element_type in all_element_types(dim): + block = PsBlock([PsComment(str(element_type))]) + if element_type in pre_loop_stmts: + block.statements += pre_loop_stmts[element_type].statements + if (dim, element_type) in [ (2, FaceType.GRAY), (3, CellType.WHITE_UP), ]: - cb = Conditional(sp.Eq(0, 0), Block([])) + cond = PsExpression.make(PsConstant(True, Bool())) elif (dim, element_type) in [ (3, CellType.WHITE_DOWN), ]: - cb = Conditional( + cond = ast_factory.parse_sympy( sp.Lt( element_index[0], micro_edges_per_macro_edge - 2 - sum(element_index[1:]), - ), - Block([]), + ) ) else: - cb = Conditional( + cond = ast_factory.parse_sympy( sp.Lt( element_index[0], micro_edges_per_macro_edge - 1 - sum(element_index[1:]), - ), - Block([]), + ) ) - self.conditional_blocks[element_type] = cb - - # For the "cubes" loop strategy we only need one loop. - # The different element types are handled later via conditionals. - return loop_over_simplex(dim, micro_edges_per_macro_edge) - - def add_body_to_loop(self, loop, body, element_type): - """Adds all conditionals to the innermost loop.""" - conditional_body = Block(body) - self.conditional_blocks[element_type].true_block.append(conditional_body) - conditional_body.parent = self.conditional_blocks[element_type].true_block - - body = Block([cb for cb in self.conditional_blocks.values()]) - innermost_loop = get_innermost_loop(loop) - innermost_loop[0].body = body - body.parent = innermost_loop[0] - def add_preloop_for_loop(self, loops, preloop_stmts, element_type): - """add given list of statements directly in front of the loop corresponding to element_type.""" - if not isinstance(loops, list): - loops = [loops] - preloop_stmts_lhs_subs = { - stmt.lhs: get_element_replacement(stmt.lhs, element_type) - for stmt in preloop_stmts - } + conditional = PsConditional(cond, loop_bodies[element_type]) + block.statements.append(conditional) + body.statements.append(block) - self.conditional_blocks[element_type] = fast_subs( - self.conditional_blocks[element_type], preloop_stmts_lhs_subs + # For the "cubes" loop strategy we only need one loop. + # The different element types are handled via conditionals. + loop = loop_over_simplex( + symbolizer, ctx, dim, micro_edges_per_macro_edge, body=body ) - - new_preloop_stmts = [ - stmt.fast_subs(preloop_stmts_lhs_subs) for stmt in preloop_stmts - ] - - return new_preloop_stmts + loops + return PsBlock([loop]) def __str__(self): return "CUBES" @@ -235,7 +224,7 @@ class SAWTOOTH(LoopStrategy): block = PsBlock( PsBlock( [PsComment(str(element_type))] - + pre_loop_stmts.get(element_type, PsBlock).statements + + pre_loop_stmts.get(element_type, PsBlock([])).statements + [element_loops[element_type]] ) for element_type, l in element_loops.items() diff --git a/hog/operator_generation/operators.py b/hog/operator_generation/operators.py index 023aa4d31d91cb51b210b01961b5bf419de0f3d5..3aee78524451b686f3f692e4216670bc7fedf077 100644 --- a/hog/operator_generation/operators.py +++ b/hog/operator_generation/operators.py @@ -30,7 +30,6 @@ from sympy.codegen.ast import Assignment from pystencils import AssignmentCollection, Target, TypedSymbol from pystencils.backend import KernelFunction -from pystencils.backend.ast import PsAstNode from pystencils.backend.ast.analysis import UndefinedSymbolsCollector from pystencils.backend.ast.structural import PsBlock from pystencils.backend.kernelcreation import ( @@ -502,6 +501,8 @@ class HyTeGElementwiseOperator: ) ) + operator_cpp_file.add(CppComment(comment=f"using std::max;", where="impl")) + operator_cpp_file.add(CppComment(comment=f"namespace hyteg {{")) operator_cpp_file.add(CppComment(comment=f"namespace operatorgeneration {{")) diff --git a/hog/operator_generation/optimizer.py b/hog/operator_generation/optimizer.py index dec59349eddc83b335871b84cf80f11127a77eb1..b85a13795b9d9818b1758bb2e6c3b870156453d8 100644 --- a/hog/operator_generation/optimizer.py +++ b/hog/operator_generation/optimizer.py @@ -22,9 +22,14 @@ from typing import Dict, Iterable, List, Set from pystencils import TypedSymbol from pystencils.backend import KernelFunction -from pystencils.backend.ast.structural import PsAssignment, PsBlock +from pystencils.backend.ast.expressions import PsCall, PsConstant, PsExpression +from pystencils.backend.ast.structural import PsAssignment, PsBlock, PsLoop +from pystencils.backend.functions import MathFunctions, PsMathFunction from pystencils.backend.kernelcreation import KernelCreationContext -from pystencils.backend.transformations import HoistLoopInvariantDeclarations +from pystencils.backend.transformations import ( + HoistLoopInvariantDeclarations, + ReshapeLoops, +) # from pystencils.cpu.vectorization import vectorize # from pystencils.transformations import ( @@ -42,9 +47,6 @@ from hog.operator_generation.loop_strategies import ( CUBES, FUSEDROWS, ) -from hog.operator_generation.pystencils_extensions import ( - get_innermost_loop, -) class Opts(enum.Enum): @@ -161,31 +163,51 @@ class Optimizer: # conditionals can be safely evaluated to True or False. with TimedLogger("cutting loops", logging.DEBUG): - loops = [ - loop - for loop in kernel_function.body.args - if isinstance(loop, LoopOverCoordinate) - ] + reshape_loops = ReshapeLoops(ctx) + + loops = [loop for loop in kernel.statements if isinstance(loop, PsLoop)] assert len(loops) == 1, f"Expecting a single loop here, not {loops}" loop = loops[0] + assert ( + len(loop.body.statements) == 1 + ), f"Expecting a single statement here" - innermost_loop = get_innermost_loop(loop, return_all_inner=True)[0] if dim == 2: - new_loops = cut_loop(innermost_loop, [innermost_loop.stop - 1]) - loop.body = new_loops - new_loops.parent = loop - + x_loop = loop.body.statements[0] + new_loops = reshape_loops.cut_loop( + x_loop, + [ + x_loop.stop + - PsExpression.make(PsConstant(1, ctx.index_dtype)) + ], + ) + loop.body = PsBlock(new_loops) elif dim == 3: - new_loops = cut_loop( - innermost_loop, - [innermost_loop.stop - 2, innermost_loop.stop - 1], - with_conditional=True, + y_loop = loop.body.statements[0] + assert ( + len(y_loop.body.statements) == 1 + ), f"Expecting a single statement here" + x_loop = y_loop.body.statements[0] + new_loops = reshape_loops.cut_loop( + x_loop, + [ + PsCall( + PsMathFunction(MathFunctions.Max), + ( + x_loop.start, + x_loop.stop + - PsExpression.make(PsConstant(2, ctx.index_dtype)), + ), + ), + x_loop.stop + - PsExpression.make(PsConstant(1, ctx.index_dtype)), + ], ) - innermost_loop.parent.body = new_loops - new_loops.parent = innermost_loop.parent + y_loop.body = PsBlock(new_loops) - with TimedLogger("simplifying conditionals", logging.DEBUG): - simplify_conditionals(loop, loop_counter_simplification=True) + # TODO + # with TimedLogger("simplifying conditionals", logging.DEBUG): + # simplify_conditionals(loop, loop_counter_simplification=True) if self[Opts.MOVECONSTANTS]: with TimedLogger("moving constants out of loop", logging.DEBUG): diff --git a/hog/operator_generation/pystencils_extensions.py b/hog/operator_generation/pystencils_extensions.py index a5c55dd2cac562a4d3c7f8ae47bb8c55110c25a6..25b99afbe07ebf4193c7596b1f6d194f73e2cd13 100644 --- a/hog/operator_generation/pystencils_extensions.py +++ b/hog/operator_generation/pystencils_extensions.py @@ -18,7 +18,6 @@ import sympy as sp from typing import Dict, List, Tuple, Union from pystencils import Field, FieldType -from pystencils.backend.ast.expressions import PsConstant, PsConstantExpr from pystencils.backend.ast.structural import PsAstNode, PsBlock, PsLoop from pystencils.backend.kernelcreation import AstFactory, KernelCreationContext from pystencils.types import UserTypeSpec @@ -34,6 +33,7 @@ def loop_over_simplex( dim: int, width: int, cut_innermost: int = 0, + body: PsBlock = PsBlock([]), ) -> PsLoop: """ Arranges loop (ast-)nodes implementing an iteration over a structured simplex of arbitrary dimension. @@ -75,9 +75,7 @@ def loop_over_simplex( ] loop_bounds[0] = slice(loop_bounds[0].stop - cut_innermost) - return ast_factory.loop_nest( - loop_counter_names[::-1], loop_bounds[::-1], PsBlock([]) - ) + return ast_factory.loop_nest(loop_counter_names[::-1], loop_bounds[::-1], body) def create_micro_element_loops( @@ -85,6 +83,7 @@ def create_micro_element_loops( ctx: KernelCreationContext, dim: int, micro_edges_per_macro_edge: int, + # TODO take (optional) dict of loop bodys ) -> Dict[Union[FaceType, CellType], PsLoop]: element_loops: Dict[Union[FaceType, CellType], PsLoop] = {} if dim == 2: @@ -163,6 +162,7 @@ def fuse_loops_over_simplex( return (fused_loops[max_dim], [loop.body for loop in dim_to_fuse_loops]) +# TODO remove dead code def get_innermost_loop( ast_node: PsAstNode, shift_to_outer: int = 0, return_all_inner: bool = False ) -> List[PsLoop]: diff --git a/pyproject.toml b/pyproject.toml index e47b4a4d3fc27b666305d8085820e59a6ac3ddff..61441314f815474746e85cb28f21ebd8c08d0757 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dependencies = [ "numpy==1.24.3", "quadpy-gpl==0.16.10", "poly-cse-py", - "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c", + "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@90239d2cc0d5860a9fe9a0c04ff60f38450c6315", "pytest==7.3.1", "sympy==1.11.1", "tabulate==0.9.0", diff --git a/requirements.txt b/requirements.txt index ab5a2f91880a9f39d17bc29c954f21c865f60b3f..1987eb6d7dd51a43d350114fb02792a473264a1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ islpy numpy==1.24.3 poly-cse-py -pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c +pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@90239d2cc0d5860a9fe9a0c04ff60f38450c6315 pytest==7.3.1 sympy==1.11.1 tabulate==0.9.0