diff --git a/pystencils/math_optimizations.py b/pystencils/math_optimizations.py index 44fda53e8b1aeb34dcfdd276c0a1f48da94cd6ef..0b64b3e7f02a9e2a6a026431a7678286395a2b80 100644 --- a/pystencils/math_optimizations.py +++ b/pystencils/math_optimizations.py @@ -9,6 +9,7 @@ import itertools from pystencils import Assignment from pystencils.astnodes import SympyAssignment +from pystencils.integer_functions import IntegerFunctionTwoArgsMixIn try: from sympy.codegen.rewriting import optims_c99, optimize @@ -24,6 +25,9 @@ try: optims_pystencils_cpu = [evaluate_constant_terms] + list(optims_c99) optims_pystencils_gpu = [evaluate_constant_terms] + list(optims_c99) except ImportError: + from warnings import warn + warn("Could not import ReplaceOptim, optims_c99, optimize from sympy.codegen.rewriting." + "Please update your sympy installation!") optims_c99 = [] optims_pystencils_cpu = [] optims_pystencils_gpu = [] @@ -34,7 +38,8 @@ def optimize_assignments(assignments, optimizations): if HAS_REWRITING: assignments = [Assignment(a.lhs, optimize(a.rhs, optimizations)) - if hasattr(a, 'lhs') else a for a in assignments] + if hasattr(a, 'lhs') and not a.rhs.atoms(IntegerFunctionTwoArgsMixIn) + else a for a in assignments] assignments_nodes = [a.atoms(SympyAssignment) for a in assignments] for a in itertools.chain.from_iterable(assignments_nodes): a.optimize(optimizations)