From f9a17154f174d34849e850c829e4a0f5a08a6bb4 Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Thu, 27 May 2021 14:38:37 +0200 Subject: [PATCH] never attempt to vectorize the tail loop --- pystencils/backends/cbackend.py | 10 +++++++++- pystencils/cpu/vectorization.py | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 94f21c076..271bf2322 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -192,7 +192,9 @@ class CBackend: def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.vector_sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.scalar_sympy_printer = CustomSympyPrinter() + self.sympy_printer = self.vector_sympy_printer else: self.sympy_printer = CustomSympyPrinter() else: @@ -259,6 +261,12 @@ class CBackend: prefix = "\n".join(node.prefix_lines) if prefix: prefix += "\n" + if self._vector_instruction_set and hasattr(node, 'instruction_set') and node.instruction_set is None: + # the tail loop must not be vectorized + self.sympy_printer = self.scalar_sympy_printer + code = f"{prefix}{loop_str}\n{self._print(node.body)}" + self.sympy_printer = self.vector_sympy_printer + return code return f"{prefix}{loop_str}\n{self._print(node.body)}" def _print_SympyAssignment(self, node): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index b9fa2819e..e4a27ad34 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -173,6 +173,8 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a cutting_point = modulo_floor(loop_range, vector_width) + loop_node.start loop_nodes = [l for l in cut_loop(loop_node, [cutting_point]).args if isinstance(l, ast.LoopOverCoordinate)] assert len(loop_nodes) in (0, 1, 2) # 2 for main and tail loop, 1 if loop range divisible by vector width + if len(loop_nodes) == 2: + loop_nodes[1].instruction_set = None if len(loop_nodes) == 0: continue loop_node = loop_nodes[0] @@ -322,6 +324,9 @@ def insert_vector_casts(ast_node): return expr def visit_node(node, substitution_dict): + if hasattr(node, 'instruction_set') and node.instruction_set is None: + # the tail loop must not be vectorized + return substitution_dict = substitution_dict.copy() for arg in node.args: if isinstance(arg, ast.SympyAssignment): -- GitLab