diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py index 657f60d2f16f14a20f81ebfc77414eb31ba0236a..6f62e1c74f1d99fdf5198174f3cf0b5624ce2876 100644 --- a/src/pystencils/backends/cbackend.py +++ b/src/pystencils/backends/cbackend.py @@ -634,7 +634,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return None def _print_Abs(self, expr): - if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess): + if isinstance(get_type_of_expression(expr), (VectorType, VectorMemoryAccess)): return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs) return super()._print_Abs(expr) diff --git a/tests/test_vectorization_specific.py b/tests/test_vectorization_specific.py index 19c6e0033c1b73a967d18cc36fbb93438c7359f5..55606808bfeb49e336ac94c417791c97c8fc47d8 100644 --- a/tests/test_vectorization_specific.py +++ b/tests/test_vectorization_specific.py @@ -39,7 +39,7 @@ def test_vectorisation_varying_arch(instruction_set): @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) -def test_vectorized_abs(instruction_set, dtype): +def test_vectorized_abs_field(instruction_set, dtype): """Some instructions sets have abs, some don't. Furthermore, the special treatment of unary minus makes this data type-sensitive too. """ @@ -58,6 +58,24 @@ def test_vectorized_abs(instruction_set, dtype): np.testing.assert_equal(np.sum(dst[1:-1, 1:-1]), 2 ** 2 * 2 ** 3) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_vectorized_abs_scalar(instruction_set): + """Some instructions sets have abs, some don't. + Furthermore, the special treatment of unary minus makes this data type-sensitive too. + """ + arr = np.zeros((2 ** 2 + 2, 2 ** 3 + 2), dtype="float64") + + f = ps.fields(f=arr) + update_rule = [ps.Assignment(f.center(), sp.Abs(sp.Symbol("a")))] + + config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + + func = ast.compile() + func(f=arr, a=-1) + np.testing.assert_equal(np.sum(arr[1:-1, 1:-1]), 2 ** 2 * 2 ** 3) + + @pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.parametrize('instruction_set', supported_instruction_sets) @pytest.mark.parametrize('nontemporal', [False, True])