Skip to content
Snippets Groups Projects
Commit bd01a24d authored by Markus Holzer's avatar Markus Holzer
Browse files

Introduce DivFunc and UnevaluatedExpression

parent fe5cece2
No related branches found
No related tags found
1 merge request!309Improve FLOP counting function
Pipeline #47536 failed
......@@ -639,8 +639,10 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
for child_term, condition in t.args:
visit(child_term)
visit_children = False
elif isinstance(t, sp.Rel):
elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else:
warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
......
......@@ -15,6 +15,7 @@ from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta
from pystencils import Assignment
from pystencils.functions import DivFunc
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts)
......@@ -163,6 +164,19 @@ def test_count_operations():
assert ops['divs'] == 1
assert ops['sqrts'] == 1
expr = DivFunc(x, y)
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
expr = DivFunc(x + z, y + z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 2
assert ops['divs'] == 1
expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
ops = count_operations(expr, only_type=None)
assert ops['muls'] == 99
def test_common_denominator():
x = sympy.symbols('x')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment