Skip to content
Snippets Groups Projects
Commit 2cb231b2 authored by Martin Bauer's avatar Martin Bauer
Browse files

FLOPs counting now also counts sqrts, invsqrts and their fast approximations

parent 0f298c63
Branches
Tags
No related merge requests found
...@@ -6,7 +6,6 @@ from collections import defaultdict, Counter ...@@ -6,7 +6,6 @@ from collections import defaultdict, Counter
import sympy as sp import sympy as sp
from sympy.functions import Abs from sympy.functions import Abs
from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple from typing import Optional, Union, List, TypeVar, Iterable, Sequence, Callable, Dict, Tuple
from pystencils.data_types import get_type_of_expression, get_base_type, cast_func from pystencils.data_types import get_type_of_expression, get_base_type, cast_func
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
...@@ -233,7 +232,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -233,7 +232,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients)) intersection = set(subexpression_coefficient_dict.keys()).intersection(set(expr_coefficients))
if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match): if len(intersection) >= max(normalized_replacement_match, normalized_current_expr_match):
# find common factor # find common factor
factors = defaultdict(lambda: 0) factors = defaultdict(int)
skips = 0 skips = 0
for common_symbol in subexpression_coefficient_dict.keys(): for common_symbol in subexpression_coefficient_dict.keys():
if common_symbol not in expr_coefficients: if common_symbol not in expr_coefficients:
...@@ -428,7 +427,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -428,7 +427,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
Returns: Returns:
dict with 'adds', 'muls' and 'divs' keys dict with 'adds', 'muls' and 'divs' keys
""" """
result = {'adds': 0, 'muls': 0, 'divs': 0} from pystencils.fast_approximation import fast_sqrt, fast_inv_sqrt, fast_division
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence): if isinstance(term, Sequence):
for element in term: for element in term:
...@@ -480,6 +482,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -480,6 +482,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(t, cast_func): elif isinstance(t, cast_func):
visit_children = False visit_children = False
visit(t.args[0]) visit(t.args[0])
elif t.func is fast_sqrt:
result['fast_sqrts'] += 1
elif t.func is fast_inv_sqrt:
result['fast_inv_sqrts'] += 1
elif t.func is fast_division:
result['fast_div'] += 1
elif t.func is sp.Pow: elif t.func is sp.Pow:
if check_type(t.args[0]): if check_type(t.args[0]):
visit_children = False visit_children = False
...@@ -490,6 +498,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -490,6 +498,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result['muls'] -= 1 result['muls'] -= 1
result['divs'] += 1 result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1 result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node")
else: else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, " warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate") "counting will be inaccurate")
...@@ -513,14 +525,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -513,14 +525,13 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
def count_operations_in_ast(ast) -> Dict[str, int]: def count_operations_in_ast(ast) -> Dict[str, int]:
"""Counts number of operations in an abstract syntax tree, see also :func:`count_operations`""" """Counts number of operations in an abstract syntax tree, see also :func:`count_operations`"""
from pystencils.astnodes import SympyAssignment from pystencils.astnodes import SympyAssignment
result = {'adds': 0, 'muls': 0, 'divs': 0} result = defaultdict(int)
def visit(node): def visit(node):
if isinstance(node, SympyAssignment): if isinstance(node, SympyAssignment):
r = count_operations(node.rhs) r = count_operations(node.rhs)
result['adds'] += r['adds'] for k, v in r.items():
result['muls'] += r['muls'] result[k] += v
result['divs'] += r['divs']
else: else:
for arg in node.args: for arg in node.args:
visit(arg) visit(arg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment