Skip to content
Snippets Groups Projects
Commit 13e4eba9 authored by Nils Kohl's avatar Nils Kohl :full_moon_with_face:
Browse files

Extended LoopOverCoordinate.

The condition's comparison operator can now be set to any sp.Relational.
This allows for example to loop backwards.
parent 37650bc9
No related branches found
No related tags found
1 merge request!149WIP: Hyteg
Pipeline #15965 failed
...@@ -345,7 +345,7 @@ class LoopOverCoordinate(Node): ...@@ -345,7 +345,7 @@ class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr" LOOP_COUNTER_NAME_PREFIX = "ctr"
BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr" BlOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False): def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, relational=sp.Lt):
super(LoopOverCoordinate, self).__init__(parent=None) super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body self.body = body
body.parent = self body.parent = self
...@@ -356,10 +356,11 @@ class LoopOverCoordinate(Node): ...@@ -356,10 +356,11 @@ class LoopOverCoordinate(Node):
self.body.parent = self self.body.parent = self
self.prefix_lines = [] self.prefix_lines = []
self.is_block_loop = is_block_loop self.is_block_loop = is_block_loop
self.relational = relational
def new_loop_with_different_body(self, new_body): def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop, result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
self.step, self.is_block_loop) self.step, self.is_block_loop, self.relational)
result.prefix_lines = [l for l in self.prefix_lines] result.prefix_lines = [l for l in self.prefix_lines]
return result return result
...@@ -452,14 +453,16 @@ class LoopOverCoordinate(Node): ...@@ -452,14 +453,16 @@ class LoopOverCoordinate(Node):
return len(self.atoms(LoopOverCoordinate)) == 0 return len(self.atoms(LoopOverCoordinate)) == 0
def __str__(self): def __str__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start, return 'for({!s}={!s}; {!s}; {!s}+={!s})\n{!s}'.format(self.loop_counter_name, self.start,
self.loop_counter_name, self.stop, sp.ccode(self.relational(self.loop_counter_name,
self.loop_counter_name, self.step, self.stop)),
self.loop_counter_name, self.step,
("\t" + "\t".join(str(self.body).splitlines(True)))) ("\t" + "\t".join(str(self.body).splitlines(True))))
def __repr__(self): def __repr__(self):
return 'for({!s}={!s}; {!s}<{!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start, return 'for({!s}={!s}; {!s}; {!s}+={!s})'.format(self.loop_counter_name, self.start,
self.loop_counter_name, self.stop, sp.ccode(self.relational(self.loop_counter_name,
self.stop)),
self.loop_counter_name, self.step) self.loop_counter_name, self.step)
......
...@@ -154,7 +154,8 @@ class CBackend: ...@@ -154,7 +154,8 @@ class CBackend:
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start))
condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) condition_expression = node.relational(node.loop_counter_symbol, node.stop)
condition = self.sympy_printer.doprint(condition_expression)
update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),)
loop_str = "for (%s; %s; %s)" % (start, condition, update) loop_str = "for (%s; %s; %s)" % (start, condition, update)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment