Skip to content
Snippets Groups Projects
Commit 13e4eba9 authored by Nils Kohl's avatar Nils Kohl :full_moon_with_face:
Browse files

Extended LoopOverCoordinate.

The condition's comparison operator can now be set to any sp.Relational.
This allows for example to loop backwards.
parent 37650bc9
No related merge requests found
...@@ -345,7 +345,7 @@ class LoopOverCoordinate(Node): ...@@ -345,7 +345,7 @@ class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr" LOOP_COUNTER_NAME_PREFIX = "ctr"
BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, relational=sp.Lt):
super(LoopOverCoordinate, self).__init__(parent=None) super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body self.body = body
body.parent = self body.parent = self
...@@ -356,10 +356,11 @@ class LoopOverCoordinate(Node): ...@@ -356,10 +356,11 @@ class LoopOverCoordinate(Node):
self.body.parent = self self.body.parent = self
self.prefix_lines = [] self.prefix_lines = []
self.is_block_loop = is_block_loop self.is_block_loop = is_block_loop
self.relational = relational
def new_loop_with_different_body(self, new_body): def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
self.step, self.is_block_loop) self.step, self.is_block_loop, self.relational)
result.prefix_lines = [l for l in self.prefix_lines] result.prefix_lines = [l for l in self.prefix_lines]
return result return result
...@@ -452,14 +453,16 @@ class LoopOverCoordinate(Node): ...@@ -452,14 +453,16 @@ class LoopOverCoordinate(Node):
return len(self.atoms(LoopOverCoordinate)) == 0 return len(self.atoms(LoopOverCoordinate)) == 0
def __str__(self): def __str__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, return 'for({!s}={!s}; {!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
self.loop_counter_name, self.stop, sp.ccode(self.relational(self.loop_counter_name,
self.loop_counter_name, self.step, self.stop)),
self.loop_counter_name, self.step,
("\t" + "\t".join(str(self.body).splitlines(True)))) ("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self): def __repr__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, return 'for({!s}={!s}; {!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
self.loop_counter_name, self.stop, sp.ccode(self.relational(self.loop_counter_name,
self.stop)),
self.loop_counter_name, self.step) self.loop_counter_name, self.step)
......
...@@ -154,7 +154,8 @@ class CBackend: ...@@ -154,7 +154,8 @@ class CBackend:
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) condition_expression = node.relational(node.loop_counter_symbol, node.stop)
condition = self.sympy_printer.doprint(condition_expression)
update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
loop_str = "for (%s; %s; %s)" % (start, condition, update) loop_str = "for (%s; %s; %s)" % (start, condition, update)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment