diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 654b92b82f8c5e225f26abd5a5c14ce0b113a43e..b1f7899ed8f44aeba1eb10da64792fefcbc70e84 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -609,6 +609,20 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return result def _print_Add(self, expr, order=None): + def visit(summands): + if len(summands) == 2: + sign = summands[0].sign * summands[1].sign + func = self.instruction_set['-'] if sign == -1 else self.instruction_set['+'] + return func.format(summands[0].term, summands[1].term) + else: + elements = len(summands) // 2 + if len(summands[:elements]) < 2: + func = self.instruction_set['-'] if summands[0].sign == -1 else self.instruction_set['+'] + return func.format(summands[0].term, visit(summands[elements:])) + else: + func = self.instruction_set['+'] + return func.format(visit(summands[:elements]), visit(summands[elements:])) + result = self._scalarFallback('_print_Add', expr) if result: return result @@ -628,10 +642,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): summands.insert(0, self.SummandInfo(1, "0")) assert len(summands) >= 2 - processed = summands[0].term - for summand in summands[1:]: - func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] - processed = func.format(processed, summand.term) + if len(summands) < 10: + processed = summands[0].term + for summand in summands[1:]: + func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] + processed = func.format(processed, summand.term) + else: + processed = visit(summands) return processed def _print_Pow(self, expr):