diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 727539373856950a9312966f3fe27859ae17a8f4..4f1d46d76bd86d78a86c4d7eacd9393b08601019 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -345,7 +345,7 @@ class LoopOverCoordinate(Node): LOOP_COUNTER_NAME_PREFIX = "ctr" BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" - def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): + def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, relational=sp.Lt): super(LoopOverCoordinate, self).__init__(parent=None) self.body = body body.parent = self @@ -356,10 +356,11 @@ class LoopOverCoordinate(Node): self.body.parent = self self.prefix_lines = [] self.is_block_loop = is_block_loop + self.relational = relational def new_loop_with_different_body(self, new_body): result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, - self.step, self.is_block_loop) + self.step, self.is_block_loop, self.relational) result.prefix_lines = [l for l in self.prefix_lines] return result @@ -452,14 +453,16 @@ class LoopOverCoordinate(Node): return len(self.atoms(LoopOverCoordinate)) == 0 def __str__(self): - return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, - self.loop_counter_name, self.stop, - self.loop_counter_name, self.step, + return 'for({!s}={!s}; {!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, + sp.ccode(self.relational(self.loop_counter_name, + self.stop)), + self.loop_counter_name, self.step, ("\t" + "\t".join(str(self.body).splitlines(True)))) def __repr__(self): - return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, - self.loop_counter_name, self.stop, + return 'for({!s}={!s}; {!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, + sp.ccode(self.relational(self.loop_counter_name, + self.stop)), self.loop_counter_name, self.step) diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index cd71d0ffb3ccc0e7c54bea2a7a08477fb019f451..f9bf0b1291997dc3617403db72fb6b85691da17a 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -154,7 +154,8 @@ class CBackend: def _print_LoopOverCoordinate(self, node): counter_symbol = node.loop_counter_name start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) - condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) + condition_expression = node.relational(node.loop_counter_symbol, node.stop) + condition = self.sympy_printer.doprint(condition_expression) update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) loop_str = "for (%s; %s; %s)" % (start, condition, update)