diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index cc1de06c031b2c1a7a554f95a06f4f52d7a0ff3b..dab3d50c65da7520983e214ad1b342f4933a2019 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -443,10 +443,22 @@ class CustomSympyPrinter(CCodePrinter): def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" - if isinstance(expr.exp, sp.Integer) and (-8 < expr.exp < 8): - raise ValueError(f"This expression: {expr} contains a pow function that should be simplified already with " - f"a sequence of multiplications") - return super(CustomSympyPrinter, self)._print_Pow(expr) + # Ideally the printer has as little logic as possible. Therefore, + # powers should be rewritten as `DivFunc`s / unevaluated `Mul`s before + # printing. `NodeCollection` offers a convenience function to do just + # that. However, `cut_loops` rewrites unevaluated multiplications as + # `Pow`s again. Neither `deepcopy` nor `func(*args)` are suited to + # rebuild unevaluated expressions. Therefore, as long as we stick with + # SymPy, this is the only way to avoid printing `pow`s. + exp = expr.exp.expr if isinstance(expr.exp, CastFunc) else expr.exp + one_type = expr.base.dtype if hasattr(expr.base, "dtype") else get_type_of_expression(expr.base) + + if exp.is_integer and exp.is_number and (0 < exp <= 8): + return f"({self._print(sp.Mul(*[expr.base] * exp, evaluate=False))})" + elif exp.is_integer and exp.is_number and (-8 <= exp < 0): + return f"{self._typed_number(1, one_type)} / ({self._print(sp.Mul(*([expr.base] * -exp), evaluate=False))})" + else: + return super(CustomSympyPrinter, self)._print_Pow(expr) # TODO don't print ones in sp.Mul @@ -490,6 +502,7 @@ class CustomSympyPrinter(CCodePrinter): assert len(expr.args) == 1, "address_of must only have one argument" return f"&({self._print(expr.args[0])})" elif isinstance(expr, CastFunc): + cast = "(({data_type})({code}))" arg, data_type = expr.args if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)): return self._typed_number(arg, data_type) @@ -504,9 +517,12 @@ class CustomSympyPrinter(CCodePrinter): for k in known: if k in code: return code.replace(k, f'{k}f') + # Powers of small integers are printed as divisions/multiplications. + if '/' in code or '*' in code: + return cast.format(data_type=data_type, code=code) raise ValueError(f"{code} doesn't give {known=} function back.") else: - return f"(({data_type})({self._print(arg)}))" + return cast.format(data_type=data_type, code=self._print(arg)) elif isinstance(expr, fast_division): raise ValueError("fast_division is only supported for Taget.GPU") elif isinstance(expr, fast_sqrt): diff --git a/pystencils_tests/test_printing.py b/pystencils_tests/test_printing.py index 6490f59e43c0c42eb64077b3c27dc99d250f04eb..cb6423c0fc7a99b804fc01037a556739d0c09612 100644 --- a/pystencils_tests/test_printing.py +++ b/pystencils_tests/test_printing.py @@ -1,4 +1,5 @@ import pytest +import re import sympy as sp import pystencils @@ -51,3 +52,53 @@ def test_print_subtraction(dtype, target): code = pystencils.get_code_str(ast) assert "-1.0" not in code + + +def test_print_small_integer_pow(): + printer = pystencils.backends.cbackend.CBackend() + + x = sp.Symbol("x") + y = sp.Symbol("y") + n = pystencils.TypedSymbol("n", "int") + t = pystencils.TypedSymbol("t", "float32") + s = pystencils.TypedSymbol("s", "float32") + + equs = [ + pystencils.astnodes.SympyAssignment(y, 1/x), + pystencils.astnodes.SympyAssignment(y, x*x), + pystencils.astnodes.SympyAssignment(y, 1/(x*x)), + pystencils.astnodes.SympyAssignment(y, x**8), + pystencils.astnodes.SympyAssignment(y, x**(-8)), + pystencils.astnodes.SympyAssignment(y, x**9), + pystencils.astnodes.SympyAssignment(y, x**(-9)), + pystencils.astnodes.SympyAssignment(y, x**n), + pystencils.astnodes.SympyAssignment(y, sp.Pow(4, 4, evaluate=False)), + pystencils.astnodes.SympyAssignment(y, x**0.25), + pystencils.astnodes.SympyAssignment(y, x**y), + pystencils.astnodes.SympyAssignment(y, pystencils.typing.cast_functions.CastFunc(1/x, "float32")), + pystencils.astnodes.SympyAssignment(y, pystencils.typing.cast_functions.CastFunc(x*x, "float32")), + pystencils.astnodes.SympyAssignment(y, (t+s)**(-8)), + pystencils.astnodes.SympyAssignment(y, (t+s)**(-9)), + ] + typed = pystencils.typing.transformations.add_types(equs, pystencils.CreateKernelConfig()) + + regexes = [ + r"1\.0\s*/\s*\(?\s*x\s*\)?", + r"x\s*\*\s*x", + r"1\.0\s*/\s*\(\s*x\s*\*x\s*\)", + r"x(\s*\*\s*x){7}", + r"1\.0\s*/\s*\(\s*x(\s*\*\s*x){7}\s*\)", + r"pow\(\s*x\s*,\s*9(\.0)?\s*\)", + r"pow\(\s*x\s*,\s*-9(\.0)?\s*\)", + r"pow\(\s*x\s*,\s*\(?\s*\(\s*double\s*\)\s*\(\s*n\s*\)\s*\)?\s*\)", + r"\(\s*int[a-zA-Z0-9_]*\s*\)\s*\(+\s*4(\s*\*\s*4){3}\s*\)+", + r"pow\(\s*x\s*,\s*0\.25\s*\)", + r"pow\(\s*x\s*,\s*y\s*\)", + r"\(\s*float\s*\)[ ()]*1\.0\s*/\s*\(?\s*x\s*\)?", + r"\(\s*float\s*\)[ ()]*x\s*\*\s*x", + r"\(\s*float\s*\)\s*\(\s*1\.0f\s*/\s*\(\s*\(\s*s\s*\+\s*t\s*\)(\s*\*\s*\(\s*s\s*\+\s*t\s*\)){7}\s*\)", + r"powf\(\s*s\s*\+\s*t\s*,\s*-9\.0f\s*\)", + ] + + for r, e in zip(regexes, typed): + assert re.search(r, printer(e))