diff --git a/hog/operator_generation/operators.py b/hog/operator_generation/operators.py index ebaba4bc6d69b8826d99a1d27a55952811308d9f..8b5f1875d6dda4e3dd4af967248109b4faf54a52 100644 --- a/hog/operator_generation/operators.py +++ b/hog/operator_generation/operators.py @@ -24,14 +24,16 @@ from typing import Dict, List, Optional, Set, Tuple, Union import numpy as np import sympy as sp +import tabulate import pystencils as ps from sympy.codegen.ast import Assignment from pystencils import AssignmentCollection, Target, TypedSymbol from pystencils.backend import KernelFunction -from pystencils.backend.ast.analysis import UndefinedSymbolsCollector -from pystencils.backend.ast.structural import PsBlock +from pystencils.backend.ast.analysis import OperationCounter, UndefinedSymbolsCollector +from pystencils.backend.ast.iteration import dfs_preorder +from pystencils.backend.ast.structural import PsBlock, PsLoop from pystencils.backend.kernelcreation import ( FreezeExpressions, KernelCreationContext, @@ -672,7 +674,7 @@ class HyTeGElementwiseOperator: integration_info: IntegrationInfo, loop_strategy: LoopStrategy, kernel_type: KernelType, - ) -> Tuple[PsBlock, str]: + ) -> PsBlock: """Generate an AST that represents the passed kernel type. The AST roughly looks like: @@ -700,8 +702,7 @@ class HyTeGElementwiseOperator: :param kernel_type: Specifies the kernel to execute - this could be e.g., a matrix-vector multiplication. - :returns: A tuple of the AST (of type PsBlock) and a string stating the - number of operations performed per element. + :returns: The AST (of type PsBlock). """ geometry = integration_info.geometry @@ -927,24 +928,6 @@ class HyTeGElementwiseOperator: for i, node in enumerate(kernel_op_assignments): kernel_op_assignments[i] = fast_subs(node, subs_dict) - # count operations - # TODO: count in backend ast after optimizations - ops = Operations() - # for stmt in body: - # if isinstance(stmt, PsLoop): - # for stmt2 in stmt.body.statements: - # count_operations( - # stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start - # ) - # elif isinstance(stmt, sp.codegen.Assignment) or isinstance( - # stmt, PsAssignment - # ): - # count_operations(stmt.rhs, ops) - # elif isinstance(stmt, PsComment): - # pass - # else: - # ops.unknown_ops += 1 - freeze = FreezeExpressions(ctx) typify = Typifier(ctx) loop_bodies[element_type] = typify( @@ -954,7 +937,9 @@ class HyTeGElementwiseOperator: ) ) ) - loop_bodies[element_type].statements += [stmt.clone() for stmt in quad_loop] + loop_bodies[element_type].statements += [ + typify(stmt.clone()) for stmt in quad_loop + ] loop_bodies[element_type].statements += typify( freeze(AssignmentCollection(kernel_op_assignments)) ).statements @@ -1017,7 +1002,7 @@ class HyTeGElementwiseOperator: q_decl for q_decl in q_decls if q_decl.lhs.symbol in undefined ] + block.statements - return (block, ops.to_table()) + return block def generate_kernels(self, loop_strategy: LoopStrategy) -> None: """ @@ -1050,10 +1035,7 @@ class HyTeGElementwiseOperator: with TimedLogger( f"Generating kernel: {kernel_type.name} in {dim}D", logging.INFO ): - ( - kernel, - kernel_op_count, - ) = self._generate_kernel( + kernel = self._generate_kernel( ctx, dim, integration_info, loop_strategy, kernel_type ) @@ -1075,6 +1057,18 @@ class HyTeGElementwiseOperator: ctx, kernel, dim, loop_strategy ) + first_x_loop = next( + dfs_preorder( + kernel, + lambda node: isinstance(node, PsLoop) + and node.counter.symbol.name + == self.symbolizer.loop_counters(1)[0].name, + ) + ) + op_counts = OperationCounter()(first_x_loop.body) + d = vars(op_counts) + kernel_op_count = tabulate.tabulate([d.values()], headers=d.keys()) + kernel_function = ps.kernelcreation.create_kernel_function( ctx, kernel, diff --git a/pyproject.toml b/pyproject.toml index 820b8ede46df780c3b22acd9909b3b23550320c7..4e60f33060d69b7178a83b600600567ee824d3fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dependencies = [ "numpy==1.24.3", "quadpy-gpl==0.16.10", "poly-cse-py", - "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d6621ef950a0f695ea32e8e715f2e64c468248bc", + "pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@b6f6afd88981ea41032538d0ae2c04e1be1d9352", "pytest==7.3.1", "sympy==1.11.1", "tabulate==0.9.0", diff --git a/requirements.txt b/requirements.txt index 21f1dad66f44a4ec345b10396daf16abf23ffd7b..4d10e521fef28b7ac76a2635c043965f221ec852 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ islpy numpy==1.24.3 poly-cse-py -pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d6621ef950a0f695ea32e8e715f2e64c468248bc +pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@b6f6afd88981ea41032538d0ae2c04e1be1d9352 pytest==7.3.1 sympy==1.11.1 tabulate==0.9.0