diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py
index 90146e596123eedc73eea5f813964f7593b6f43f..133f504b140ee36dd7db7c7c01d8bb710fa9c8c5 100644
--- a/src/pystencilssfg/composer/basic_composer.py
+++ b/src/pystencilssfg/composer/basic_composer.py
@@ -7,13 +7,7 @@ from functools import reduce
 
 from pystencils import Field
 from pystencils.backend import KernelFunction
-from pystencils.types import (
-    create_type,
-    UserTypeSpec,
-    PsCustomType,
-    PsPointerType,
-    PsType,
-)
+from pystencils.types import create_type, UserTypeSpec
 
 from ..context import SfgContext
 from .custom import CustomGenerator
@@ -325,30 +319,6 @@ class SfgBasicComposer(SfgIComposer):
         """Use inside a function body to require the inclusion of headers."""
         return SfgRequireIncludes((HeaderFile.parse(incl) for incl in incls))
 
-    def cpptype(
-        self,
-        typename: UserTypeSpec,
-        ptr: bool = False,
-        ref: bool = False,
-        const: bool = False,
-    ) -> PsType:
-        if ptr and ref:
-            raise SfgException("Create either a pointer, or a ref type, not both!")
-
-        ref_qual = "&" if ref else ""
-        try:
-            base_type = create_type(typename)
-        except ValueError:
-            if not isinstance(typename, str):
-                raise ValueError(f"Could not parse type: {typename}")
-
-            base_type = PsCustomType(typename + ref_qual, const=const)
-
-        if ptr:
-            return PsPointerType(base_type)
-        else:
-            return base_type
-
     def var(self, name: str, dtype: UserTypeSpec) -> AugExpr:
         """Create a variable with given name and data type."""
         return AugExpr(create_type(dtype)).var(name)
diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py
index 1f4c4865987c1f10ec133f95d11e4bccc1ef8b76..489823b9ce619be88e3220ce4b941cf49c62b298 100644
--- a/src/pystencilssfg/composer/class_composer.py
+++ b/src/pystencilssfg/composer/class_composer.py
@@ -1,7 +1,7 @@
 from __future__ import annotations
 from typing import Sequence
 
-from pystencils.types import PsCustomType, UserTypeSpec
+from pystencils.types import PsCustomType, UserTypeSpec, create_type
 
 from ..lang import (
     _VarLike,
@@ -177,7 +177,7 @@ class SfgClassComposer(SfgComposerMixIn):
             return SfgMethod(
                 name,
                 tree,
-                return_type=self._composer.cpptype(returns),
+                return_type=create_type(returns),
                 inline=inline,
                 const=const,
             )
diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py
index eb05ff6efa8eb8166c6d5f81e8e4596c6c797704..3dc4a8f74177c4de16158b1e6f314be25cca9c18 100644
--- a/src/pystencilssfg/emission/printers.py
+++ b/src/pystencilssfg/emission/printers.py
@@ -166,7 +166,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter):
 
     @visit.case(SfgMethod)
     def sfg_method(self, method: SfgMethod):
-        code = f"{method.return_type} {method.name} ({self.param_list(method)})"
+        code = f"{method.return_type.c_string()} {method.name} ({self.param_list(method)})"
         code += "const" if method.const else ""
         if method.inline:
             code += (
diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py
index 349e030c6c9057472b9f513258c8a979bba16b8a..88dbc9be2e215b1fdce5833ef18eac6eab336d74 100644
--- a/src/pystencilssfg/extensions/sycl.py
+++ b/src/pystencilssfg/extensions/sycl.py
@@ -57,9 +57,7 @@ class SyclRange(AugExpr):
     _template = cpptype("sycl::range< {dims} >", "<sycl/sycl.hpp>")
 
     def __init__(self, dims: int, const: bool = False, ref: bool = False):
-        dtype = self._template(dims=dims, const=const)
-        if ref:
-            dtype = Ref(dtype)
+        dtype = self._template(dims=dims, const=const, ref=ref)
         super().__init__(dtype)
 
 
diff --git a/src/pystencilssfg/lang/cpp/std_mdspan.py b/src/pystencilssfg/lang/cpp/std_mdspan.py
index 5f80ad6f15d71ecc12d6f5b55b8671e1dd1124d2..68308850f9dd7d5c3486265fb6c9b634d97c6fb7 100644
--- a/src/pystencilssfg/lang/cpp/std_mdspan.py
+++ b/src/pystencilssfg/lang/cpp/std_mdspan.py
@@ -11,7 +11,7 @@ from pystencils.types import (
 
 from pystencilssfg.lang.expressions import AugExpr
 
-from ...lang import SrcField, IFieldExtraction, cpptype, Ref, HeaderFile, ExprLike
+from ...lang import SrcField, IFieldExtraction, cpptype, HeaderFile, ExprLike
 
 
 class StdMdspan(SrcField):
@@ -111,11 +111,8 @@ class StdMdspan(SrcField):
             layout_policy = f"{self._namespace}::{layout_policy}"
 
         dtype = self._template(
-            T=T, extents=extents_str, layout_policy=layout_policy, const=const
+            T=T, extents=extents_str, layout_policy=layout_policy, const=const, ref=ref
         )
-
-        if ref:
-            dtype = Ref(dtype)
         super().__init__(dtype)
 
         self._extents_type = extents_str
diff --git a/src/pystencilssfg/lang/cpp/std_span.py b/src/pystencilssfg/lang/cpp/std_span.py
index 861a4c4bb1ea81b0cbaaef4cb683316274ab2edd..f161f4874f627fa8943f4e24c2a1082780259572 100644
--- a/src/pystencilssfg/lang/cpp/std_span.py
+++ b/src/pystencilssfg/lang/cpp/std_span.py
@@ -1,7 +1,7 @@
 from pystencils.field import Field
 from pystencils.types import UserTypeSpec, create_type, PsType
 
-from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype, Ref
+from ...lang import SrcField, IFieldExtraction, AugExpr, cpptype
 
 
 class StdSpan(SrcField):
@@ -9,10 +9,7 @@ class StdSpan(SrcField):
 
     def __init__(self, T: UserTypeSpec, ref=False, const=False):
         T = create_type(T)
-        dtype = self._template(T=T, const=const)
-        if ref:
-            dtype = Ref(dtype)
-
+        dtype = self._template(T=T, const=const, ref=ref)
         self._element_type = T
         super().__init__(dtype)
 
diff --git a/src/pystencilssfg/lang/cpp/std_tuple.py b/src/pystencilssfg/lang/cpp/std_tuple.py
index bbf2ba33b8f1a19081593501885d0dc935fc3055..58a3530b9e98e2c39e205fd7dac9845b4ff35bda 100644
--- a/src/pystencilssfg/lang/cpp/std_tuple.py
+++ b/src/pystencilssfg/lang/cpp/std_tuple.py
@@ -2,7 +2,7 @@ from typing import Sequence
 
 from pystencils.types import UserTypeSpec, create_type
 
-from ...lang import SrcVector, AugExpr, cpptype, Ref
+from ...lang import SrcVector, AugExpr, cpptype
 
 
 class StdTuple(SrcVector):
@@ -18,10 +18,7 @@ class StdTuple(SrcVector):
         self._length = len(element_types)
         elt_type_strings = tuple(t.c_string() for t in self._element_types)
 
-        dtype = self._template(ts=", ".join(elt_type_strings), const=const)
-        if ref:
-            dtype = Ref(dtype)
-
+        dtype = self._template(ts=", ".join(elt_type_strings), const=const, ref=ref)
         super().__init__(dtype)
 
     def extract_component(self, coordinate: int) -> AugExpr:
diff --git a/src/pystencilssfg/lang/cpp/std_vector.py b/src/pystencilssfg/lang/cpp/std_vector.py
index 5696b32dd55c6092db0ff298fdcbd79fa7df69f5..7e9291eab670a1f4f45996b60ea8b8b3e8f49ff4 100644
--- a/src/pystencilssfg/lang/cpp/std_vector.py
+++ b/src/pystencilssfg/lang/cpp/std_vector.py
@@ -1,7 +1,7 @@
 from pystencils.field import Field
 from pystencils.types import UserTypeSpec, create_type, PsType
 
-from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype, Ref
+from ...lang import SrcField, SrcVector, AugExpr, IFieldExtraction, cpptype
 
 
 class StdVector(SrcVector, SrcField):
@@ -15,9 +15,7 @@ class StdVector(SrcVector, SrcField):
         const: bool = False,
     ):
         T = create_type(T)
-        dtype = self._template(T=T, const=const)
-        if ref:
-            dtype = Ref(dtype)
+        dtype = self._template(T=T, const=const, ref=ref)
         super().__init__(dtype)
 
         self._element_type = T
diff --git a/src/pystencilssfg/lang/cpp/sycl_accessor.py b/src/pystencilssfg/lang/cpp/sycl_accessor.py
index f01c53d24750dc3b4f0350134e9bccd6f8ea4c26..4bcad56cd4ef109faa66757075eeefb6a5b416d3 100644
--- a/src/pystencilssfg/lang/cpp/sycl_accessor.py
+++ b/src/pystencilssfg/lang/cpp/sycl_accessor.py
@@ -3,7 +3,7 @@ from ...lang import SrcField, IFieldExtraction
 from pystencils import Field
 from pystencils.types import UserTypeSpec, create_type
 
-from ...lang import AugExpr, cpptype, Ref
+from ...lang import AugExpr, cpptype
 
 
 class SyclAccessor(SrcField):
@@ -29,9 +29,7 @@ class SyclAccessor(SrcField):
         T = create_type(T)
         if dimensions > 3:
             raise ValueError("sycl accessors can only have dims 1, 2 or 3")
-        dtype = self._template(T=T, dims=dimensions, const=const)
-        if ref:
-            dtype = Ref(dtype)
+        dtype = self._template(T=T, dims=dimensions, const=const, ref=ref)
 
         super().__init__(dtype)
 
diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py
index 53064bd13a201b0716545bf1e84677baa8df3be9..03818c6b481a5436deddb1a0266597e927e982eb 100644
--- a/src/pystencilssfg/lang/expressions.py
+++ b/src/pystencilssfg/lang/expressions.py
@@ -147,7 +147,7 @@ class VarExpr(DependentExpression):
         incls: Iterable[HeaderFile]
         match base_type:
             case CppType():
-                incls = base_type.includes
+                incls = base_type.class_includes
             case _:
                 incls = (
                     HeaderFile.parse(header) for header in var.dtype.required_headers
@@ -408,7 +408,11 @@ def includes(expr: ExprLike) -> set[HeaderFile]:
 
     match expr:
         case SfgVar(_, dtype):
-            return set(HeaderFile.parse(h) for h in dtype.required_headers)
+            match dtype:
+                case CppType():
+                    return set(dtype.includes)
+                case _:
+                    return set(HeaderFile.parse(h) for h in dtype.required_headers)
         case TypedSymbol():
             return includes(asvar(expr))
         case str():
diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py
index e437ae6ea895a6e0db92c7e339aefbbf13e4967d..5f064d54fc9dcfd2f2178021e26a8c3446c50138 100644
--- a/src/pystencilssfg/lang/types.py
+++ b/src/pystencilssfg/lang/types.py
@@ -2,6 +2,7 @@ from __future__ import annotations
 from typing import Any, Iterable, Sequence, Mapping
 from abc import ABC
 from dataclasses import dataclass
+from itertools import chain
 
 import string
 
@@ -38,7 +39,9 @@ class _TemplateArgFormatter(string.Formatter):
     def check_unused_args(
         self, used_args: set[int | str], args: Sequence, kwargs: Mapping[str, Any]
     ) -> None:
-        max_args_len: int = max((k for k in used_args if isinstance(k, int)), default=-1) + 1
+        max_args_len: int = (
+            max((k for k in used_args if isinstance(k, int)), default=-1) + 1
+        )
         if len(args) > max_args_len:
             raise ValueError(
                 f"Too many positional arguments: Expected {max_args_len}, but got {len(args)}"
@@ -56,21 +59,9 @@ class _TemplateArgs:
 
 
 class CppType(PsCustomType, ABC):
-    includes: frozenset[HeaderFile]
+    class_includes: frozenset[HeaderFile]
     template_string: str
 
-    def __new__(  # type: ignore
-        cls,
-        *args,
-        ref: bool = False,
-        **kwargs,
-    ) -> CppType | Ref:
-        if ref:
-            obj = cls(*args, **kwargs)
-            return Ref(obj)
-        else:
-            return super().__new__(cls)
-
     def __init__(self, *template_args, const: bool = False, **template_kwargs):
         #   Support for cloning CppTypes
         if template_args and isinstance(template_args[0], _TemplateArgs):
@@ -89,17 +80,34 @@ class CppType(PsCustomType, ABC):
         name = formatter.format(self.template_string, *pargs, **kwargs)
 
         self._targs = targs
+        self._includes = self.class_includes
+
+        for arg in chain(pargs, kwargs.values()):
+            match arg:
+                case CppType():
+                    self._includes |= arg.includes
+                case PsType():
+                    self._includes |= {
+                        HeaderFile.parse(h) for h in arg.required_headers
+                    }
+
         super().__init__(name, const=const)
 
     def __args__(self) -> tuple[Any, ...]:
         return (self._targs,)
 
+    @property
+    def includes(self) -> frozenset[HeaderFile]:
+        return self._includes
+
     @property
     def required_headers(self) -> set[str]:
-        return set(str(h) for h in self.includes)
+        return set(str(h) for h in self.class_includes)
 
 
-def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()):
+def cpptype(
+    typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile] = ()
+) -> type[CppType]:
     headers: list[str | HeaderFile]
 
     if isinstance(include, (str, HeaderFile)):
@@ -111,9 +119,17 @@ def cpptype(typestr: str, include: str | HeaderFile | Iterable[str | HeaderFile]
 
     class TypeClass(CppType):
         template_string = typestr
-        includes = frozenset(HeaderFile.parse(h) for h in headers)
+        class_includes = frozenset(HeaderFile.parse(h) for h in headers)
+
+    class TypeClassFactory(TypeClass):
+        def __new__(cls, *args, ref: bool = False, **kwargs):
+            obj = TypeClass(*args, **kwargs)
+            if ref:
+                return Ref(obj)
+            else:
+                return obj
 
-    return TypeClass
+    return TypeClassFactory
 
 
 class Ref(PsType):
diff --git a/tests/generator_scripts/source/Conditionals.py b/tests/generator_scripts/source/Conditionals.py
index d1088a96925bf52ab3634cf439538ce1de2b58a7..9016b73744f78fef504bb4f09b6742a630a6b12d 100644
--- a/tests/generator_scripts/source/Conditionals.py
+++ b/tests/generator_scripts/source/Conditionals.py
@@ -1,4 +1,5 @@
 from pystencilssfg import SourceFileGenerator
+from pystencils.types import PsCustomType
 
 with SourceFileGenerator() as sfg:
     sfg.namespace("gen")
@@ -6,7 +7,7 @@ with SourceFileGenerator() as sfg:
     sfg.include("<iostream>")
     sfg.code(r"enum class Noodles { RIGATONI, RAMEN, SPAETZLE, SPAGHETTI };")
 
-    noodle = sfg.var("noodle", sfg.cpptype("Noodles"))
+    noodle = sfg.var("noodle", PsCustomType("Noodles"))
 
     sfg.function("printOpinion")(
         sfg.switch(noodle)
diff --git a/tests/generator_scripts/source/SimpleClasses.py b/tests/generator_scripts/source/SimpleClasses.py
index 64093f5744c61918bf505518e2c39dffbb525fad..454f1a26f8103f7c8b330f1a5b70b1b79f96ebbc 100644
--- a/tests/generator_scripts/source/SimpleClasses.py
+++ b/tests/generator_scripts/source/SimpleClasses.py
@@ -16,7 +16,7 @@ with SourceFileGenerator() as sfg:
             .init(y_)(y)
             .init(z_)(z),
 
-            sfg.method("getX", returns="const int64_t &", const=True, inline=True)(
+            sfg.method("getX", returns="const int64_t", const=True, inline=True)(
                 "return this->x_;"
             )
         ),
diff --git a/tests/lang/test_types.py b/tests/lang/test_types.py
index 35713ea7b4f4bb0797d8f9d414342d02d249dd5e..8bd0903b92c79f67cfcc8707585157174b80db08 100644
--- a/tests/lang/test_types.py
+++ b/tests/lang/test_types.py
@@ -11,7 +11,7 @@ def test_cpptypes():
     vec_type = tclass(create_type("float32"), "std::allocator< float >")
     assert str(vec_type).strip() == "std::vector< float, std::allocator< float > >"
     assert (
-        tclass.includes
+        tclass.class_includes
         == vec_type.includes
         == {HeaderFile("vector", system_header=True)}
     )
@@ -65,3 +65,15 @@ def test_cpptype_ref():
     vec_type = tclass(T=create_type("uint32"), ref=True)
     assert isinstance(vec_type, Ref)
     assert strip_ptr_ref(vec_type) == tclass(T=create_type("uint32"))
+
+
+def test_cpptype_inherits_headers():
+    optional_tclass = cpptype("std::optional< {T} >", "<optional>")
+    vec_tclass = cpptype("std::vector< {T} >", "<vector>")
+
+    vec_type = vec_tclass(T=optional_tclass(T="int"))
+    assert (
+        vec_type.includes
+        == optional_tclass.class_includes | vec_tclass.class_includes
+        == {HeaderFile.parse("<optional>"), HeaderFile.parse("<vector>")}
+    )