Skip to content
Snippets Groups Projects

Fix Sympy pipeline

Closed Markus Holzer requested to merge holzer/pystencils:FixSympy into master
Compare and
2 files
+ 51
22
Preferences
Compare changes
Files
2
+ 33
13
@@ -220,8 +220,10 @@ class TypedSymbol(sp.Symbol):
return obj
def __new_stage2__(cls, name, dtype, *args, **kwargs):
assumptions = assumptions_from_dtype(dtype)
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
assumptions_from_type = assumptions_from_dtype(dtype)
kwargs.update(assumptions_from_type)
cls._sanitize(kwargs, cls)
obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **kwargs)
try:
obj._dtype = create_type(dtype)
except (TypeError, ValueError):
@@ -242,6 +244,9 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return ((self.name, self.dtype), self.assumptions0)
@property
def canonical(self):
return self
@@ -599,15 +604,24 @@ class BasicType(Type):
return 'ComplexDouble'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
return f"int{width}_t"
elif name.startswith('uint'):
width = int(name[len("uint"):])
return "uint%d_t" % (width,)
return f"uint{width}_t"
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can map numpy to C name for {name}")
def __new__(cls, dtype, const=False, *args, **kwargs):
obj = sp.Basic.__new__(cls)
obj.const = const
if isinstance(dtype, Type):
obj._dtype = dtype.numpy_dtype
else:
obj._dtype = np.dtype(dtype)
return obj
def __init__(self, dtype, const=False):
self.const = const
if isinstance(dtype, Type):
@@ -618,8 +632,14 @@ class BasicType(Type):
assert self._dtype.hasobject is False
assert self._dtype.subdtype is None
def __getnewargs__(self):
return self.numpy_dtype, self.const
# def __getnewargs__(self):
# return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
#
# def _hashable_content(self):
# return (self.i, self.label)
@property
def base_type(self):
@@ -698,7 +718,7 @@ class VectorType(Type):
def __str__(self):
if self.instruction_set is None:
return "%s[%d]" % (self.base_type, self.width)
return f"{self.base_type}[{self.width}]"
else:
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
@@ -724,6 +744,11 @@ class PointerType(Type):
self.const = const
self.restrict = restrict
def __new__(self, base_type, const=False, restrict=True):
self._base_type = base_type
self.const = const
self.restrict = restrict
def __getnewargs__(self):
return self.base_type, self.const, self.restrict
@@ -816,12 +841,7 @@ class TypedImaginaryUnit(TypedSymbol):
return obj
def __new_stage2__(cls, dtype, *args, **kwargs):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True,
*args,
**kwargs)
obj = super(TypedImaginaryUnit, cls).__xnew__(cls, "_i", dtype, imaginary=True, *args, **kwargs)
return obj
headers = ['"cuda_complex.hpp"']