From 54b91e2fa6c385841d7e91f739a8af8128d27783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-No=C3=ABl=20Grad?= <jgrad@icp.uni-stuttgart.de> Date: Thu, 28 Jul 2022 16:38:18 +0200 Subject: [PATCH] Enable vectorization of casted sp.Pow expressions --- pystencils/backends/cbackend.py | 10 ++-------- pystencils_tests/test_vectorization.py | 8 ++++++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index de9cb0d31..42e1cbfd8 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -651,14 +651,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): # known = self.known_functions[arg.__class__.__name__.lower()] # code = self._print(arg) # return code.replace(known, f"{known}f") - elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'): - raise NotImplementedError('Vectorizer cannot print casted aka. not double pow') - # known = ['sqrt', 'cbrt', 'pow'] - # code = self._print(arg) - # for k in known: - # if k in code: - # return code.replace(k, f'{k}f') - # raise ValueError(f"{code} doesn't give {known=} function back.") + elif isinstance(arg, sp.Pow): + return self._print_Pow(arg) else: raise NotImplementedError('Vectorizer cannot cast between different datatypes') # to_type = self.instruction_set['suffix'][data_type.base_type.c_name] diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index f526341ec..21df29b9a 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -10,6 +10,7 @@ from pystencils.backends.simd_instruction_sets import get_supported_instruction_ from pystencils.cpu.vectorization import vectorize from pystencils.enums import Target from pystencils.transformations import replace_inner_stride_with_one +from pystencils.typing import CastFunc, BasicType supported_instruction_sets = get_supported_instruction_sets() if supported_instruction_sets: @@ -247,6 +248,7 @@ def test_vectorised_pow(instruction_set=instruction_set): as4 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], 4)) as5 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -4)) as6 = ps.Assignment(g[0, 0], sp.Pow(f[0, 0], -1)) + as7 = ps.Assignment(g[0, 0], CastFunc(as2.rhs, BasicType('double'))) ast = ps.create_kernel(as1) vectorize(ast, instruction_set=instruction_set) @@ -273,6 +275,12 @@ def test_vectorised_pow(instruction_set=instruction_set): vectorize(ast, instruction_set=instruction_set) ast.compile() + ast2 = ps.create_kernel(as2) + vectorize(ast2, instruction_set=instruction_set) + ast7 = ps.create_kernel(as7) + vectorize(ast7, instruction_set=instruction_set) + np.testing.assert_equal(ps.get_code_str(ast2), ps.get_code_str(ast7)) + def test_issue40(*_): """https://i10git.cs.fau.de/pycodegen/pystencils/-/issues/40""" -- GitLab