diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 2d8210a1bf59523c727b18fadc8b8b7a4c96ddad..8ac7b7230bff34ead855f519f3722168e12ffba8 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -901,14 +901,22 @@ def cut_loop(loop_node, cutting_points, with_conditional: bool = False, replace_ new_start = loop_node.start cutting_points = list(cutting_points) + [loop_node.stop] for new_end in cutting_points: - if replace_loops_with_length_one and (new_end - new_start == 1): + if new_end - new_start == 1: new_body = deepcopy(loop_node.body) - new_body.subs({loop_node.loop_counter_symbol: new_start}) + if replace_loops_with_length_one: + new_body.subs({loop_node.loop_counter_symbol: new_start}) if with_conditional: conditional_expr = sp.And(sp.Ge(new_start, loop_node.start), sp.Le(new_start, loop_node.stop)) - new_loops.append(ast.Conditional(conditional_expr, new_body)) + new_body_wrapped = ast.Block(ast.Conditional(conditional_expr, new_body)) else: - new_loops.append(new_body) + new_body_wrapped = new_body + if not replace_loops_with_length_one: + new_loop = ast.LoopOverCoordinate( + new_body_wrapped, loop_node.coordinate_to_loop_over, + new_start, new_end, loop_node.step) + else: + new_loop = new_body_wrapped + new_loops.append(new_loop) elif new_end - new_start == 0: pass else: