Skip to content
Snippets Groups Projects

count_operations: fix to not count integer expressions for addresses/constants as real operations

Merged Dominik Ernst requested to merge count_ops into master
@@ -10,7 +10,8 @@ from sympy.functions import Abs
@@ -10,7 +10,8 @@ from sympy.functions import Abs
from sympy.core.numbers import Zero
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment
from pystencils.assignment import Assignment
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression, PointerType
 
from pystencils.kernelparameters import FieldPointerSymbol
T = TypeVar('T')
T = TypeVar('T')
@@ -445,7 +446,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
@@ -445,7 +446,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence):
if isinstance(term, Sequence):
for element in term:
for element in term:
r = count_operations(element, only_type)
r = count_operations(element, only_type)
@@ -455,16 +455,19 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
@@ -455,16 +455,19 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment):
elif isinstance(term, Assignment):
term = term.rhs
term = term.rhs
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e):
def check_type(e):
if only_type is None:
if only_type is None:
return True
return True
 
if isinstance(e, FieldPointerSymbol) and only_type == "real":
 
return only_type == "int"
 
try:
try:
base_type = get_base_type(get_type_of_expression(e))
base_type = get_type_of_expression(e)
except ValueError:
except ValueError:
return False
return False
 
if isinstance(base_type, PointerType):
 
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True
return True
if only_type == 'real' and (base_type.is_float()):
if only_type == 'real' and (base_type.is_float()):
@@ -515,6 +518,9 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
@@ -515,6 +518,9 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result['muls'] += (-int(t.exp)) - 1
result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1
result['sqrts'] += 1
 
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
 
result["sqrts"] += 1
 
result["divs"] += 1
else:
else:
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else:
else:
Loading