diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 4fc236e539fa6e6a0ab2ef0686d22bcd38a119bb..f57448277d5794cf041fa4c32afb27614ba5b89d 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -220,8 +220,10 @@ class TypedSymbol(sp.Symbol): return obj def __new_stage2__(cls, name, dtype, *args, **kwargs): - assumptions = assumptions_from_dtype(dtype) - obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs) + assumptions_from_type = assumptions_from_dtype(dtype) + kwargs.update(assumptions_from_type) + cls._sanitize(kwargs, cls) + obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs) try: obj._dtype = create_type(dtype) except (TypeError, ValueError): @@ -602,15 +604,24 @@ class BasicType(Type): return 'ComplexDouble' elif name.startswith('int'): width = int(name[len("int"):]) - return "int%d_t" % (width,) + return f"int{width}_t" elif name.startswith('uint'): width = int(name[len("uint"):]) - return "uint%d_t" % (width,) + return f"uint{width}_t" elif name == 'bool': return 'bool' else: raise NotImplementedError(f"Can map numpy to C name for {name}") + def __new__(cls, dtype, const=False, *args, **kwargs): + obj = sp.Basic.__new__(cls) + obj.const = const + if isinstance(dtype, Type): + obj._dtype = dtype.numpy_dtype + else: + obj._dtype = np.dtype(dtype) + return obj + def __init__(self, dtype, const=False): self.const = const if isinstance(dtype, Type): @@ -621,8 +632,14 @@ class BasicType(Type): assert self._dtype.hasobject is False assert self._dtype.subdtype is None - def __getnewargs__(self): - return self.numpy_dtype, self.const + # def __getnewargs__(self): + # return self.numpy_dtype, self.const + + def __getnewargs_ex__(self): + return (self.numpy_dtype, self.const), {} + # + # def _hashable_content(self): + # return (self.i, self.label) @property def base_type(self): @@ -701,7 +718,7 @@ class VectorType(Type): def __str__(self): if self.instruction_set is None: - return "%s[%d]" % (self.base_type, self.width) + return f"{self.base_type}[{self.width}]" else: if self.base_type == create_type("int64") or self.base_type == create_type("int32"): return self.instruction_set['int'] @@ -727,6 +744,11 @@ class PointerType(Type): self.const = const self.restrict = restrict + def __new__(self, base_type, const=False, restrict=True): + self._base_type = base_type + self.const = const + self.restrict = restrict + def __getnewargs__(self): return self.base_type, self.const, self.restrict @@ -819,12 +841,7 @@ class TypedImaginaryUnit(TypedSymbol): return obj def __new_stage2__(cls, dtype, *args, **kwargs): - obj = super(TypedImaginaryUnit, cls).__xnew__(cls, - "_i", - dtype, - imaginary=True, - *args, - **kwargs) + obj = super(TypedImaginaryUnit, cls).__xnew__(cls, "_i", dtype, imaginary=True, *args, **kwargs) return obj headers = ['"cuda_complex.hpp"'] diff --git a/pystencils/kernelparameters.py b/pystencils/kernelparameters.py index 3257522e419bf921b13010215e44a51a5290ce80..2e4e8a00750f43080400a68215b639ef77b63d3f 100644 --- a/pystencils/kernelparameters.py +++ b/pystencils/kernelparameters.py @@ -28,15 +28,18 @@ class FieldStrideSymbol(TypedSymbol): obj = FieldStrideSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name, coordinate): + def __new_stage2__(cls, field_name, coordinate, **assumptions): name = f"_stride_{field_name}_{coordinate}" obj = super(FieldStrideSymbol, cls).__xnew__(cls, name, STRIDE_DTYPE, positive=True) obj.field_name = field_name obj.coordinate = coordinate return obj - def __getnewargs__(self): - return self.field_name, self.coordinate + # def __getnewargs__(self): + # return self.field_name, self.coordinate + + def __getnewargs_ex__(self): + return (self.field_name, self.coordinate), self.assumptions0 __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) @@ -52,7 +55,7 @@ class FieldShapeSymbol(TypedSymbol): obj = FieldShapeSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_names, coordinate): + def __new_stage2__(cls, field_names, coordinate, **assumptions): names = "_".join([field_name for field_name in field_names]) name = f"_size_{names}_{coordinate}" obj = super(FieldShapeSymbol, cls).__xnew__(cls, name, SHAPE_DTYPE, positive=True) @@ -60,8 +63,11 @@ class FieldShapeSymbol(TypedSymbol): obj.coordinate = coordinate return obj - def __getnewargs__(self): - return self.field_names, self.coordinate + # def __getnewargs__(self): + # return self.field_names, self.coordinate + + def __getnewargs_ex__(self): + return (self.field_names, self.coordinate), self.assumptions0 __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) @@ -76,15 +82,18 @@ class FieldPointerSymbol(TypedSymbol): obj = FieldPointerSymbol.__xnew_cached_(cls, *args, **kwds) return obj - def __new_stage2__(cls, field_name, field_dtype, const): + def __new_stage2__(cls, field_name, field_dtype, const, **assumptions): name = f"_data_{field_name}" dtype = PointerType(get_base_type(field_dtype), const=const, restrict=True) obj = super(FieldPointerSymbol, cls).__xnew__(cls, name, dtype) obj.field_name = field_name return obj - def __getnewargs__(self): - return self.field_name, self.dtype, self.dtype.const + # def __getnewargs__(self): + # return self.field_name, self.dtype, self.dtype.const + + def __getnewargs_ex__(self): + return (self.field_name, self.dtype, self.dtype.const), self.assumptions0 def _hashable_content(self): return super()._hashable_content(), self.field_name