Skip to content
Snippets Groups Projects
Commit ad9cba7c authored by Markus Holzer's avatar Markus Holzer Committed by Christoph Alt
Browse files

Improve FLOP counting function

parent dee98bb8
Branches
Tags
No related merge requests found
......@@ -639,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
for child_term, condition in t.args:
visit(child_term)
visit_children = False
elif isinstance(t, sp.Rel):
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else:
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
......
......@@ -15,6 +15,7 @@ from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta
from pystencils import Assignment
from pystencils.functions import DivFunc
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts)
......@@ -163,6 +164,30 @@ def test_count_operations():
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = DivFunc(x, y)
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
expr = DivFunc(x + z, y + z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 2
assert ops['divs'] == 1
expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
ops = count_operations(expr, only_type=None)
assert ops['muls'] == 99
expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
assert ops['muls'] == 99
expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)))
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['divs'] == 1
assert ops['muls'] == 99
def test_common_denominator():
x = sympy.symbols('x')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment