diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index e658108e08148a78dd1fd8e834d78c3edcce7c9b..6a04f4f95a390a33cfb59a62f67c04b9c8dd54c6 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -34,6 +34,10 @@ class PsExpression(PsAstNode, ABC):
 
     The type annotations are used by various transformation passes to make decisions, e.g. in
     function materialization and intrinsic selection.
+
+    .. attention::
+        The ``structurally_equal`` check currently does not take expression data types into
+        account. This may change in the future.
     """
 
     def __init__(self, dtype: PsType | None = None) -> None:
@@ -94,8 +98,26 @@ class PsExpression(PsAstNode, ABC):
         else:
             raise ValueError(f"Cannot make expression out of {obj}")
 
+    def clone(self):
+        """Clone this expression.
+        
+        .. note::
+            Subclasses of `PsExpression` should not override this method,
+            but implement `_clone_expr` instead.
+            That implementation shall call `clone` on any of its subexpressions,
+            but does not need to fix the `dtype` property.
+            The `dtype` is correctly applied by `PsExpression.clone` internally.
+        """
+        cloned = self._clone_expr()
+        cloned._dtype = self.dtype
+        return cloned
+
     @abstractmethod
-    def clone(self) -> PsExpression:
+    def _clone_expr(self) -> PsExpression:
+        """Implementation of expression cloning.
+        
+        :meta public:
+        """
         pass
 
 
@@ -121,7 +143,7 @@ class PsSymbolExpr(PsLeafMixIn, PsLvalue, PsExpression):
     def symbol(self, symbol: PsSymbol):
         self._symbol = symbol
 
-    def clone(self) -> PsSymbolExpr:
+    def _clone_expr(self) -> PsSymbolExpr:
         return PsSymbolExpr(self._symbol)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -149,7 +171,7 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
     def constant(self, c: PsConstant):
         self._constant = c
 
-    def clone(self) -> PsConstantExpr:
+    def _clone_expr(self) -> PsConstantExpr:
         return PsConstantExpr(self._constant)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -177,7 +199,7 @@ class PsLiteralExpr(PsLeafMixIn, PsExpression):
     def literal(self, lit: PsLiteral):
         self._literal = lit
 
-    def clone(self) -> PsLiteralExpr:
+    def _clone_expr(self) -> PsLiteralExpr:
         return PsLiteralExpr(self._literal)
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -240,7 +262,7 @@ class PsBufferAcc(PsLvalue, PsExpression):
         else:
             self._index[idx - 1] = failing_cast(PsExpression, c)
 
-    def clone(self) -> PsBufferAcc:
+    def _clone_expr(self) -> PsBufferAcc:
         return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index])
 
     def __repr__(self) -> str:
@@ -277,7 +299,7 @@ class PsSubscript(PsLvalue, PsExpression):
     def index(self, idx: Sequence[PsExpression]):
         self._index = list(idx)
 
-    def clone(self) -> PsSubscript:
+    def _clone_expr(self) -> PsSubscript:
         return PsSubscript(self._arr.clone(), [i.clone() for i in self._index])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -322,7 +344,7 @@ class PsMemAcc(PsLvalue, PsExpression):
     def offset(self, expr: PsExpression):
         self._offset = expr
 
-    def clone(self) -> PsMemAcc:
+    def _clone_expr(self) -> PsMemAcc:
         return PsMemAcc(self._ptr.clone(), self._offset.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -374,7 +396,7 @@ class PsVectorMemAcc(PsMemAcc):
     def get_vector_type(self) -> PsVectorType:
         return cast(PsVectorType, self._dtype)
 
-    def clone(self) -> PsVectorMemAcc:
+    def _clone_expr(self) -> PsVectorMemAcc:
         return PsVectorMemAcc(
             self._ptr.clone(),
             self._offset.clone(),
@@ -419,7 +441,7 @@ class PsLookup(PsExpression, PsLvalue):
     def member_name(self, name: str):
         self._name = name
 
-    def clone(self) -> PsLookup:
+    def _clone_expr(self) -> PsLookup:
         return PsLookup(self._aggregate.clone(), self._member_name)
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -469,7 +491,7 @@ class PsCall(PsExpression):
 
         self._args = list(exprs)
 
-    def clone(self) -> PsCall:
+    def _clone_expr(self) -> PsCall:
         return PsCall(self._function, [arg.clone() for arg in self._args])
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -513,7 +535,7 @@ class PsTernary(PsExpression):
     def case_else(self) -> PsExpression:
         return self._else
 
-    def clone(self) -> PsExpression:
+    def _clone_expr(self) -> PsExpression:
         return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone())
 
     def get_children(self) -> tuple[PsExpression, ...]:
@@ -563,7 +585,7 @@ class PsUnOp(PsExpression):
     def operand(self, expr: PsExpression):
         self._operand = expr
 
-    def clone(self) -> PsUnOp:
+    def _clone_expr(self) -> PsUnOp:
         return type(self)(self._operand.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -617,7 +639,7 @@ class PsCast(PsUnOp):
     def target_type(self, dtype: PsType):
         self._target_type = dtype
 
-    def clone(self) -> PsUnOp:
+    def _clone_expr(self) -> PsUnOp:
         return PsCast(self._target_type, self._operand.clone())
 
     def structurally_equal(self, other: PsAstNode) -> bool:
@@ -653,7 +675,7 @@ class PsBinOp(PsExpression):
     def operand2(self, expr: PsExpression):
         self._op2 = expr
 
-    def clone(self) -> PsBinOp:
+    def _clone_expr(self) -> PsBinOp:
         return type(self)(self._op1.clone(), self._op2.clone())
 
     def get_children(self) -> tuple[PsAstNode, ...]:
@@ -838,7 +860,7 @@ class PsArrayInitList(PsExpression):
     def set_child(self, idx: int, c: PsAstNode):
         self._items.flat[idx] = failing_cast(PsExpression, c)
 
-    def clone(self) -> PsExpression:
+    def _clone_expr(self) -> PsExpression:
         return PsArrayInitList(
             np.array([expr.clone() for expr in self.children]).reshape(  # type: ignore
                 self._items.shape
diff --git a/src/pystencils/backend/ast/structural.py b/src/pystencils/backend/ast/structural.py
index 57244c03b6413d3fd8c7b521618cb9021b0e2037..e2f202f6594b8591903531ba795be251b6b544d6 100644
--- a/src/pystencils/backend/ast/structural.py
+++ b/src/pystencils/backend/ast/structural.py
@@ -320,7 +320,7 @@ class PsPragma(PsLeafMixIn, PsEmptyLeafMixIn, PsAstNode):
     Example usage: ``PsPragma("omp parallel for")`` translates to ``#pragma omp parallel for``.
 
     Args:
-        text: The pragma's text, without the ``#pragma ``.
+        text: The pragmas text, without the ``#pragma``.
     """
 
     __match_args__ = ("text",)
diff --git a/src/pystencils/backend/extensions/cpp.py b/src/pystencils/backend/extensions/cpp.py
index 1055b79e9ab197d62c4307b70ac5b2a71c13f139..025f4a3fb61d51d7fd9c485b597a671ae2cfc231 100644
--- a/src/pystencils/backend/extensions/cpp.py
+++ b/src/pystencils/backend/extensions/cpp.py
@@ -25,7 +25,7 @@ class CppMethodCall(PsForeignExpression):
 
         return super().structurally_equal(other) and self._method == other._method
 
-    def clone(self) -> CppMethodCall:
+    def _clone_expr(self) -> CppMethodCall:
         return CppMethodCall(
             cast(PsExpression, self.children[0]),
             self._method,
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index 975ea5a6041fbd3e7c408fd9885a79b7b23fcefe..debcc3cf61d4b419f6387dfaca8e1eac655406f4 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -23,6 +23,7 @@ from ..ast.structural import (
     PsExpression,
     PsAssignment,
     PsDeclaration,
+    PsStatement,
     PsEmptyLeafMixIn,
 )
 from ..ast.expressions import (
@@ -301,6 +302,12 @@ class Typifier:
                 for s in statements:
                     self.visit(s)
 
+            case PsStatement(expr):
+                tc = TypeContext()
+                self.visit_expr(expr, tc)
+                if tc.target_type is None:
+                    tc.apply_dtype(self._ctx.default_dtype)
+
             case PsDeclaration(lhs, rhs) if isinstance(rhs, PsArrayInitList):
                 #   Special treatment for array declarations
                 assert isinstance(lhs, PsSymbolExpr)
diff --git a/src/pystencils/types/parsing.py b/src/pystencils/types/parsing.py
index 5771eaca84413708c68c4f7941e07cbd63403e9e..8e7d27f58265c08461cba6b05373848112a6fee7 100644
--- a/src/pystencils/types/parsing.py
+++ b/src/pystencils/types/parsing.py
@@ -8,6 +8,7 @@ from .types import (
     PsUnsignedIntegerType,
     PsSignedIntegerType,
     PsIeeeFloatType,
+    PsBoolType,
 )
 
 UserTypeSpec = str | type | np.dtype | PsType
@@ -143,6 +144,9 @@ def parse_type_string(s: str) -> PsType:
 
 def parse_type_name(typename: str, const: bool):
     match typename:
+        case "bool":
+            return PsBoolType(const=const)
+        
         case "int" | "int64" | "int64_t":
             return PsSignedIntegerType(64, const=const)
         case "int32" | "int32_t":
diff --git a/tests/nbackend/test_ast.py b/tests/nbackend/test_ast.py
index cf7fd3f31b13f0fbbac3b350f769e6993ab44d9d..2408b8d867038a0f2fd5c4d8a5f22bc82312c701 100644
--- a/tests/nbackend/test_ast.py
+++ b/tests/nbackend/test_ast.py
@@ -1,6 +1,7 @@
 import pytest
 
-from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory
+from pystencils import create_type
+from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, Typifier
 from pystencils.backend.memory import PsSymbol, BufferBasePtr
 from pystencils.backend.constants import PsConstant
 from pystencils.backend.ast.expressions import (
@@ -15,6 +16,7 @@ from pystencils.backend.ast.expressions import (
 from pystencils.backend.ast.structural import (
     PsStatement,
     PsAssignment,
+    PsDeclaration,
     PsBlock,
     PsConditional,
     PsComment,
@@ -25,15 +27,25 @@ from pystencils.types.quick import Fp, Ptr
 
 
 def test_cloning():
-    x, y, z = [PsExpression.make(PsSymbol(name)) for name in "xyz"]
+    ctx = KernelCreationContext()
+    typify = Typifier(ctx)
+
+    x, y, z, m = [PsExpression.make(ctx.get_symbol(name)) for name in "xyzm"]
+    q = PsExpression.make(ctx.get_symbol("q", create_type("bool")))
+    a, b, c = [PsExpression.make(ctx.get_symbol(name, ctx.index_dtype)) for name in "abc"]
     c1 = PsExpression.make(PsConstant(3.0))
     c2 = PsExpression.make(PsConstant(-1.0))
-    one = PsExpression.make(PsConstant(1))
+    one_f = PsExpression.make(PsConstant(1.0))
+    one_i = PsExpression.make(PsConstant(1))
 
     def check(orig, clone):
         assert not (orig is clone)
         assert type(orig) is type(clone)
         assert orig.structurally_equal(clone)
+        
+        if isinstance(orig, PsExpression):
+            #   Regression: Expression data types used to not be cloned
+            assert orig.dtype == clone.dtype
 
         for c1, c2 in zip(orig.children, clone.children, strict=True):
             check(c1, c2)
@@ -49,18 +61,21 @@ def test_cloning():
         PsAssignment(y, x / c1),
         PsBlock([PsAssignment(x, c1 * y), PsAssignment(z, c2 + c1 * z)]),
         PsConditional(
-            y, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
+            q, PsBlock([PsStatement(x + y)]), PsBlock([PsComment("hello world")])
+        ),
+        PsDeclaration(
+            m,
+            PsArrayInitList([
+                [x, y, one_f + x],
+                [one_f, c2, z]
+            ])
         ),
-        PsArrayInitList([
-            [x, y, one + x],
-            [one, c2, z]
-        ]),
         PsPragma("omp parallel for"),
         PsLoop(
-            x,
-            y,
-            z,
-            one,
+            a,
+            b,
+            c,
+            one_i,
             PsBlock(
                 [
                     PsComment("Loop body"),
@@ -68,13 +83,14 @@ def test_cloning():
                     PsAssignment(x, y),
                     PsPragma("#pragma clang loop vectorize(enable)"),
                     PsStatement(
-                        PsMemAcc(PsCast(Ptr(Fp(32)), z), one)
-                        + PsSubscript(z, (one + one + one, y + one))
+                        PsMemAcc(PsCast(Ptr(Fp(32)), z), one_i)
+                        + PsCast(Fp(32), PsSubscript(m, (one_i + one_i + one_i, b + one_i)))
                     ),
                 ]
             ),
         ),
     ]:
+        ast = typify(ast)
         ast_clone = ast.clone()
         check(ast, ast_clone)