diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index 4fc236e539fa6e6a0ab2ef0686d22bcd38a119bb..f57448277d5794cf041fa4c32afb27614ba5b89d 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -220,8 +220,10 @@ class TypedSymbol(sp.Symbol):
         return obj
 
     def __new_stage2__(cls, name, dtype, *args, **kwargs):
-        assumptions = assumptions_from_dtype(dtype)
-        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
+        assumptions_from_type = assumptions_from_dtype(dtype)
+        kwargs.update(assumptions_from_type)
+        cls._sanitize(kwargs, cls)
+        obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
         try:
             obj._dtype = create_type(dtype)
         except (TypeError, ValueError):
@@ -602,15 +604,24 @@ class BasicType(Type):
             return 'ComplexDouble'
         elif name.startswith('int'):
             width = int(name[len("int"):])
-            return "int%d_t" % (width,)
+            return f"int{width}_t"
         elif name.startswith('uint'):
             width = int(name[len("uint"):])
-            return "uint%d_t" % (width,)
+            return f"uint{width}_t"
         elif name == 'bool':
             return 'bool'
         else:
             raise NotImplementedError(f"Can map numpy to C name for {name}")
 
+    def __new__(cls,  dtype, const=False, *args, **kwargs):
+        obj = sp.Basic.__new__(cls)
+        obj.const = const
+        if isinstance(dtype, Type):
+            obj._dtype = dtype.numpy_dtype
+        else:
+            obj._dtype = np.dtype(dtype)
+        return obj
+
     def __init__(self, dtype, const=False):
         self.const = const
         if isinstance(dtype, Type):
@@ -621,8 +632,14 @@ class BasicType(Type):
         assert self._dtype.hasobject is False
         assert self._dtype.subdtype is None
 
-    def __getnewargs__(self):
-        return self.numpy_dtype, self.const
+    # def __getnewargs__(self):
+    #     return self.numpy_dtype, self.const
+
+    def __getnewargs_ex__(self):
+        return (self.numpy_dtype, self.const), {}
+    #
+    # def _hashable_content(self):
+    #     return (self.i, self.label)
 
     @property
     def base_type(self):
@@ -701,7 +718,7 @@ class VectorType(Type):
 
     def __str__(self):
         if self.instruction_set is None:
-            return "%s[%d]" % (self.base_type, self.width)
+            return f"{self.base_type}[{self.width}]"
         else:
             if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
                 return self.instruction_set['int']
@@ -727,6 +744,11 @@ class PointerType(Type):
         self.const = const
         self.restrict = restrict
 
+    def __new__(self, base_type, const=False, restrict=True):
+        self._base_type = base_type
+        self.const = const
+        self.restrict = restrict
+
     def __getnewargs__(self):
         return self.base_type, self.const, self.restrict
 
@@ -819,12 +841,7 @@ class TypedImaginaryUnit(TypedSymbol):
         return obj
 
     def __new_stage2__(cls, dtype, *args, **kwargs):
-        obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
-                                                      "_i",
-                                                      dtype,
-                                                      imaginary=True,
-                                                      *args,
-                                                      **kwargs)
+        obj = super(TypedImaginaryUnit, cls).__xnew__(cls, "_i", dtype, imaginary=True, *args, **kwargs)
         return obj
 
     headers = ['"cuda_complex.hpp"']
diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py
index 3257522e419bf921b13010215e44a51a5290ce80..2e4e8a00750f43080400a68215b639ef77b63d3f 100644
--- a/pystencils/kernelparameters.py
+++ b/pystencils/kernelparameters.py
@@ -28,15 +28,18 @@ class FieldStrideSymbol(TypedSymbol):
         obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_name, coordinate):
+    def __new_stage2__(cls, field_name, coordinate, **assumptions):
         name = f"_stride_{field_name}_{coordinate}"
         obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True)
         obj.field_name = field_name
         obj.coordinate = coordinate
         return obj
 
-    def __getnewargs__(self):
-        return self.field_name, self.coordinate
+    # def __getnewargs__(self):
+    #     return self.field_name, self.coordinate
+
+    def __getnewargs_ex__(self):
+        return (self.field_name, self.coordinate), self.assumptions0
 
     __xnew__ = staticmethod(__new_stage2__)
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@@ -52,7 +55,7 @@ class FieldShapeSymbol(TypedSymbol):
         obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_names, coordinate):
+    def __new_stage2__(cls, field_names, coordinate, **assumptions):
         names = "_".join([field_name for field_name in field_names])
         name = f"_size_{names}_{coordinate}"
         obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True)
@@ -60,8 +63,11 @@ class FieldShapeSymbol(TypedSymbol):
         obj.coordinate = coordinate
         return obj
 
-    def __getnewargs__(self):
-        return self.field_names, self.coordinate
+    # def __getnewargs__(self):
+    #     return self.field_names, self.coordinate
+
+    def __getnewargs_ex__(self):
+        return (self.field_names, self.coordinate), self.assumptions0
 
     __xnew__ = staticmethod(__new_stage2__)
     __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@@ -76,15 +82,18 @@ class FieldPointerSymbol(TypedSymbol):
         obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds)
         return obj
 
-    def __new_stage2__(cls, field_name, field_dtype, const):
+    def __new_stage2__(cls, field_name, field_dtype, const, **assumptions):
         name = f"_data_{field_name}"
         dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True)
         obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype)
         obj.field_name = field_name
         return obj
 
-    def __getnewargs__(self):
-        return self.field_name, self.dtype, self.dtype.const
+    # def __getnewargs__(self):
+    #     return self.field_name, self.dtype, self.dtype.const
+
+    def __getnewargs_ex__(self):
+        return (self.field_name, self.dtype, self.dtype.const), self.assumptions0
 
     def _hashable_content(self):
         return super()._hashable_content(), self.field_name