diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index de9cb0d31b3be9903ac758991a0fca345f17a4a7..42e1cbfd82340eb69b2d206539c1c61800bfcd60 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -651,14 +651,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # known = self.known_functions[arg.__class__.__name__.lower()] # code = self._print(arg) # return code.replace(known, f"{known}f") - elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'): - raise NotImplementedError('Vectorizer cannot print casted aka. not double pow') - # known = ['sqrt', 'cbrt', 'pow'] - # code = self._print(arg) - # for k in known: - # if k in code: - # return code.replace(k, f'{k}f') - # raise ValueError(f"{code} doesn't give {known=} function back.") + elif isinstance(arg, sp.Pow): + return self._print_Pow(arg) else: raise NotImplementedError('Vectorizer cannot cast between different datatypes') # to_type = self.instruction_set['suffix'][data_type.base_type.c_name] diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index f526341ec587b508f6c28f8dd2596125366f8ecc..21df29b9a838e1700b5908d16397de196b0c83e6 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -10,6 +10,7 @@ from pystencils.backends.simd_instruction_sets import get_supported_instruction_ from pystencils.cpu.vectorization import vectorize from pystencils.enums import Target from pystencils.transformations import replace_inner_stride_with_one +from pystencils.typing import CastFunc, BasicType supported_instruction_sets = get_supported_instruction_sets() if supported_instruction_sets: @@ -247,6 +248,7 @@ def test_vectorised_pow(instruction_set=instruction_set): as4 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], 4)) as5 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -4)) as6 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -1)) + as7 = ps.Assignment(g[0, 0], CastFunc(as2.rhs, BasicType('double'))) ast = ps.create_kernel(as1) vectorize(ast, instruction_set=instruction_set) @@ -273,6 +275,12 @@ def test_vectorised_pow(instruction_set=instruction_set): vectorize(ast, instruction_set=instruction_set) ast.compile() + ast2 = ps.create_kernel(as2) + vectorize(ast2, instruction_set=instruction_set) + ast7 = ps.create_kernel(as7) + vectorize(ast7, instruction_set=instruction_set) + np.testing.assert_equal(ps.get_code_str(ast2), ps.get_code_str(ast7)) + def test_issue40(*_): """https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/40"""