Skip to content
Snippets Groups Projects
Commit 54b91e2f authored by Jean-Noël Grad's avatar Jean-Noël Grad
Browse files

Enable vectorization of casted sp.Pow expressions

parent 245a6ee1
No related merge requests found
Pipeline #41788 passed with stages
in 17 minutes and 53 seconds
......@@ -651,14 +651,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# known = self.known_functions[arg.__class__.__name__.lower()]
# code = self._print(arg)
# return code.replace(known, f"{known}f")
elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'):
raise NotImplementedError('Vectorizer cannot print casted aka. not double pow')
# known = ['sqrt', 'cbrt', 'pow']
# code = self._print(arg)
# for k in known:
# if k in code:
# return code.replace(k, f'{k}f')
# raise ValueError(f"{code} doesn't give {known=} function back.")
elif isinstance(arg, sp.Pow):
return self._print_Pow(arg)
else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
......
......@@ -10,6 +10,7 @@ from pystencils.backends.simd_instruction_sets import get_supported_instruction_
from pystencils.cpu.vectorization import vectorize
from pystencils.enums import Target
from pystencils.transformations import replace_inner_stride_with_one
from pystencils.typing import CastFunc, BasicType
supported_instruction_sets = get_supported_instruction_sets()
if supported_instruction_sets:
......@@ -247,6 +248,7 @@ def test_vectorised_pow(instruction_set=instruction_set):
as4 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], 4))
as5 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -4))
as6 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -1))
as7 = ps.Assignment(g[0, 0], CastFunc(as2.rhs, BasicType('double')))
ast = ps.create_kernel(as1)
vectorize(ast, instruction_set=instruction_set)
......@@ -273,6 +275,12 @@ def test_vectorised_pow(instruction_set=instruction_set):
vectorize(ast, instruction_set=instruction_set)
ast.compile()
ast2 = ps.create_kernel(as2)
vectorize(ast2, instruction_set=instruction_set)
ast7 = ps.create_kernel(as7)
vectorize(ast7, instruction_set=instruction_set)
np.testing.assert_equal(ps.get_code_str(ast2), ps.get_code_str(ast7))
def test_issue40(*_):
"""https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/40"""
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment