Skip to content
Snippets Groups Projects

Add AST pass for counting operations

Merged Daniel Bauer requested to merge hyteg/pystencils:bauerd/count-ops into backend-rework
Compare and
3 files
+ 292
3
Preferences
Compare changes
Files
3
from dataclasses import dataclass
from typing import cast
from functools import reduce
import operator
from .structural import (
PsAssignment,
@@ -12,11 +14,28 @@ from .structural import (
PsLoop,
PsStatement,
)
from .expressions import PsSymbolExpr, PsConstantExpr
from .expressions import (
PsAdd,
PsArrayAccess,
PsCall,
PsConstantExpr,
PsDiv,
PsIntDiv,
PsLiteralExpr,
PsMul,
PsNeg,
PsRem,
PsSub,
PsSymbolExpr,
PsTernary,
)
from ..symbols import PsSymbol
from ..exceptions import PsInternalCompilerError
from ...types import PsNumericType
from ...types.exception import PsTypeError
class UndefinedSymbolsCollector:
"""Collect undefined symbols.
@@ -120,3 +139,212 @@ def collect_required_headers(node: PsAstNode) -> set[str]:
return reduce(
set.union, (collect_required_headers(c) for c in node.children), set()
)
@dataclass
class OperationCounts:
float_adds: int = 0
float_muls: int = 0
float_divs: int = 0
int_adds: int = 0
int_muls: int = 0
int_divs: int = 0
calls: int = 0
branches: int = 0
loops_with_dynamic_bounds: int = 0
def __add__(self, other):
if not isinstance(other, OperationCounts):
return NotImplemented
return OperationCounts(
float_adds=self.float_adds + other.float_adds,
float_muls=self.float_muls + other.float_muls,
float_divs=self.float_divs + other.float_divs,
int_adds=self.int_adds + other.int_adds,
int_muls=self.int_muls + other.int_muls,
int_divs=self.int_divs + other.int_divs,
calls=self.calls + other.calls,
branches=self.branches + other.branches,
loops_with_dynamic_bounds=self.loops_with_dynamic_bounds
+ other.loops_with_dynamic_bounds,
)
def __rmul__(self, other):
if not isinstance(other, int):
return NotImplemented
return OperationCounts(
float_adds=other * self.float_adds,
float_muls=other * self.float_muls,
float_divs=other * self.float_divs,
int_adds=other * self.int_adds,
int_muls=other * self.int_muls,
int_divs=other * self.int_divs,
calls=other * self.calls,
branches=other * self.branches,
loops_with_dynamic_bounds=other * self.loops_with_dynamic_bounds,
)
class OperationCounter:
"""Counts the number of operations in an AST.
Assumes that the AST is typed. It is recommended that constant folding is
applied prior to this pass.
The counted operations are:
- Additions, multiplications and divisions of floating and integer type.
The counts of either type are reported separately and operations on
other types are ignored.
- Function calls.
- Branches.
Includes `PsConditional` and `PsTernary`. The operations in all branches
are summed up (i.e. the result is an overestimation).
- Loops with an unknown number of iterations.
The operations in the loop header and body are counted exactly once,
i.e. it is assumed that there is one loop iteration.
If the start, stop and step of the loop are `PsConstantExpr`, then any
operation within the body is multiplied by the number of iterations.
"""
def __call__(self, node: PsAstNode) -> OperationCounts:
"""Counts the number of operations in the given AST."""
return self.visit(node)
def visit(self, node: PsAstNode) -> OperationCounts:
match node:
case PsExpression():
return self.visit_expr(node)
case PsStatement(expr):
return self.visit_expr(expr)
case PsAssignment(lhs, rhs):
return self.visit_expr(lhs) + self.visit_expr(rhs)
case PsBlock(statements):
return reduce(
operator.add, (self.visit(s) for s in statements), OperationCounts()
)
case PsLoop(_, start, stop, step, body):
if (
isinstance(start, PsConstantExpr)
and isinstance(stop, PsConstantExpr)
and isinstance(step, PsConstantExpr)
):
val_start = start.constant.value
val_stop = stop.constant.value
val_step = step.constant.value
if (val_stop - val_start) % val_step == 0:
iteration_count = max(0, int((val_stop - val_start) / val_step))
else:
iteration_count = max(
0, int((val_stop - val_start) / val_step) + 1
)
return self.visit_expr(start) + iteration_count * (
OperationCounts(int_adds=1) # loop counter increment
+ self.visit_expr(stop)
+ self.visit_expr(step)
+ self.visit(body)
)
else:
return (
OperationCounts(loops_with_dynamic_bounds=1)
+ self.visit_expr(start)
+ self.visit_expr(stop)
+ self.visit_expr(step)
+ self.visit(body)
)
case PsConditional(cond, branch_true, branch_false):
op_counts = (
OperationCounts(branches=1)
+ self.visit(cond)
+ self.visit(branch_true)
)
if branch_false is not None:
op_counts += self.visit(branch_false)
return op_counts
case PsEmptyLeafMixIn():
return OperationCounts()
case unknown:
raise PsInternalCompilerError(f"Can't count operations in {unknown}")
def visit_expr(self, expr: PsExpression) -> OperationCounts:
match expr:
case PsSymbolExpr(_) | PsConstantExpr(_) | PsLiteralExpr(_):
return OperationCounts()
case PsArrayAccess(_, index):
return self.visit_expr(index)
case PsCall(_, args):
return OperationCounts(calls=1) + reduce(
operator.add, (self.visit(a) for a in args), OperationCounts()
)
case PsTernary(cond, then, els):
return (
OperationCounts(branches=1)
+ self.visit_expr(cond)
+ self.visit_expr(then)
+ self.visit_expr(els)
)
case PsNeg(arg):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_muls += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_muls += 1
return op_counts
case PsAdd(arg1, arg2) | PsSub(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_adds += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_adds += 1
return op_counts
case PsMul(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_muls += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_muls += 1
return op_counts
case PsDiv(arg1, arg2) | PsIntDiv(arg1, arg2) | PsRem(arg1, arg2):
if expr.dtype is None:
raise PsTypeError(f"Untyped arithmetic expression: {expr}")
op_counts = self.visit_expr(arg1) + self.visit_expr(arg2)
if isinstance(expr.dtype, PsNumericType) and expr.dtype.is_float():
op_counts.float_divs += 1
elif isinstance(expr.dtype, PsNumericType) and expr.dtype.is_int():
op_counts.int_divs += 1
return op_counts
case _:
return reduce(
operator.add,
(self.visit_expr(cast(PsExpression, c)) for c in expr.children),
OperationCounts(),
)