diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 560b94143f8b35182d2be4d02033a8ade652839b..a80944d2d79f0a8aced712c2de9e8e303001b5f5 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -54,16 +54,14 @@ class TypeAdder: if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): return self.process_assignment(obj) elif isinstance(obj, ast.Conditional): + condition, condition_type = self.figure_out_type(obj.condition_expr) + assert condition_type == BasicType('bool') + true_block = self.visit(obj.true_block) false_block = None if obj.false_block is None else self.visit( obj.false_block) - result = ast.Conditional(self.process_expression( - obj.condition_expr, type_constants=False), - true_block=self.visit(obj.true_block), - false_block=false_block) - return result + return ast.Conditional(condition, true_block=true_block, false_block=false_block) elif isinstance(obj, ast.Block): - result = ast.Block([self.visit(e) for e in obj.args]) - return result + return ast.Block([self.visit(e) for e in obj.args]) elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): return obj else: