diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index da40c510ef91c7ca7fee0e6a0259b3eef50f0ab8..223da701a4d5c133715eb30f99366c44b13f16b2 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -187,18 +187,15 @@ def get_type_of_expression(expr, # Fix for sympy versions from 1.9 sympy_version = sp.__version__.split('.') -if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: +sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1]) +if sympy_version_int >= 109: # __setstate__ would bypass the contructor, so we remove it - sp.Number.__getstate__ = sp.Basic.__getstate__ - del sp.Basic.__getstate__ - - class FunctorWithStoredKwargs: - def __init__(self, func, **kwargs): - self.func = func - self.kwargs = kwargs - - def __call__(self, *args): - return self.func(*args, **self.kwargs) + if sympy_version_int >= 111: + del sp.Basic.__setstate__ + del sp.Symbol.__setstate__ + else: + sp.Number.__getstate__ = sp.Basic.__getstate__ + del sp.Basic.__getstate__ # __reduce_ex__ would strip kwargs, so we override it def basic_reduce_ex(self, protocol): @@ -210,9 +207,7 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: state = self.__getstate__() else: state = None - return FunctorWithStoredKwargs(type(self), **kwargs), args, state - - sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__ + return partial(type(self), **kwargs), args, state sp.Basic.__reduce_ex__ = basic_reduce_ex