Skip to content
Snippets Groups Projects

Print small integer powers as divisions/multiplications

Merged Daniel Bauer requested to merge terraneo/pystencils:bauerd/print-pow into master
1 file
+ 51
0
Compare changes
  • Side-by-side
  • Inline
import pytest
import re
import sympy as sp
import pystencils
@@ -51,3 +52,53 @@ def test_print_subtraction(dtype, target):
code = pystencils.get_code_str(ast)
assert "-1.0" not in code
def test_print_small_integer_pow():
printer = pystencils.backends.cbackend.CBackend()
x = sp.Symbol("x")
y = sp.Symbol("y")
n = pystencils.TypedSymbol("n", "int")
t = pystencils.TypedSymbol("t", "float32")
s = pystencils.TypedSymbol("s", "float32")
equs = [
pystencils.astnodes.SympyAssignment(y, 1/x),
pystencils.astnodes.SympyAssignment(y, x*x),
pystencils.astnodes.SympyAssignment(y, 1/(x*x)),
pystencils.astnodes.SympyAssignment(y, x**8),
pystencils.astnodes.SympyAssignment(y, x**(-8)),
pystencils.astnodes.SympyAssignment(y, x**9),
pystencils.astnodes.SympyAssignment(y, x**(-9)),
pystencils.astnodes.SympyAssignment(y, x**n),
pystencils.astnodes.SympyAssignment(y, sp.Pow(4, 4, evaluate=False)),
pystencils.astnodes.SympyAssignment(y, x**0.25),
pystencils.astnodes.SympyAssignment(y, x**y),
pystencils.astnodes.SympyAssignment(y, pystencils.typing.cast_functions.CastFunc(1/x, "float32")),
pystencils.astnodes.SympyAssignment(y, pystencils.typing.cast_functions.CastFunc(x*x, "float32")),
pystencils.astnodes.SympyAssignment(y, (t+s)**(-8)),
pystencils.astnodes.SympyAssignment(y, (t+s)**(-9)),
]
typed = pystencils.typing.transformations.add_types(equs, pystencils.CreateKernelConfig())
regexes = [
r"1\.0\s*/\s*\(?\s*x\s*\)?",
r"x\s*\*\s*x",
r"1\.0\s*/\s*\(\s*x\s*\*x\s*\)",
r"x(\s*\*\s*x){7}",
r"1\.0\s*/\s*\(\s*x(\s*\*\s*x){7}\s*\)",
r"pow\(\s*x\s*,\s*9(\.0)?\s*\)",
r"pow\(\s*x\s*,\s*-9(\.0)?\s*\)",
r"pow\(\s*x\s*,\s*\(?\s*\(\s*double\s*\)\s*\(\s*n\s*\)\s*\)?\s*\)",
r"\(\s*int[a-zA-Z0-9_]*\s*\)\s*\(+\s*4(\s*\*\s*4){3}\s*\)+",
r"pow\(\s*x\s*,\s*0\.25\s*\)",
r"pow\(\s*x\s*,\s*y\s*\)",
r"\(\s*float\s*\)[ ()]*1\.0\s*/\s*\(?\s*x\s*\)?",
r"\(\s*float\s*\)[ ()]*x\s*\*\s*x",
r"\(\s*float\s*\)\s*\(\s*1\.0f\s*/\s*\(\s*\(\s*s\s*\+\s*t\s*\)(\s*\*\s*\(\s*s\s*\+\s*t\s*\)){7}\s*\)",
r"powf\(\s*s\s*\+\s*t\s*,\s*-9\.0f\s*\)",
]
for r, e in zip(regexes, typed):
assert re.search(r, printer(e))
Loading