diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index d3f0b03313c58f4308eda9eb07bf6d7eb6835c88..259821afaf7289113de7816e1569b237ea384cb4 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -10,9 +10,17 @@ from ...types import ( PsIntegerType, PsArrayType, PsSubscriptableType, + PsBoolType, deconstify, ) -from ..ast.structural import PsAstNode, PsBlock, PsLoop, PsExpression, PsAssignment +from ..ast.structural import ( + PsAstNode, + PsBlock, + PsLoop, + PsConditional, + PsExpression, + PsAssignment, +) from ..ast.expressions import ( PsSymbolExpr, PsConstantExpr, @@ -162,6 +170,15 @@ class Typifier: assert tc.target_type is not None self.visit_expr(rhs, tc) + case PsConditional(cond, branch_true, branch_false): + cond_tc = TypeContext(PsBoolType(const=True)) + self.visit_expr(cond, cond_tc) + + self.visit(branch_true) + + if branch_false is not None: + self.visit(branch_false) + case PsLoop(ctr, start, stop, step, body): if ctr.symbol.dtype is None: ctr.symbol.apply_dtype(self._ctx.index_dtype)