diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py
index da40c510ef91c7ca7fee0e6a0259b3eef50f0ab8..223da701a4d5c133715eb30f99366c44b13f16b2 100644
--- a/pystencils/typing/utilities.py
+++ b/pystencils/typing/utilities.py
@@ -187,18 +187,15 @@ def get_type_of_expression(expr,
 
 # Fix for sympy versions from 1.9
 sympy_version = sp.__version__.split('.')
-if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
+sympy_version_int = int(sympy_version[0]) * 100 + int(sympy_version[1])
+if sympy_version_int >= 109:
     # __setstate__ would bypass the contructor, so we remove it
-    sp.Number.__getstate__ = sp.Basic.__getstate__
-    del sp.Basic.__getstate__
-
-    class FunctorWithStoredKwargs:
-        def __init__(self, func, **kwargs):
-            self.func = func
-            self.kwargs = kwargs
-
-        def __call__(self, *args):
-            return self.func(*args, **self.kwargs)
+    if sympy_version_int >= 111:
+        del sp.Basic.__setstate__
+        del sp.Symbol.__setstate__
+    else:
+        sp.Number.__getstate__ = sp.Basic.__getstate__
+        del sp.Basic.__getstate__
 
     # __reduce_ex__ would strip kwargs, so we override it
     def basic_reduce_ex(self, protocol):
@@ -210,9 +207,7 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
             state = self.__getstate__()
         else:
             state = None
-        return FunctorWithStoredKwargs(type(self), **kwargs), args, state
-
-    sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
+        return partial(type(self), **kwargs), args, state
     sp.Basic.__reduce_ex__ = basic_reduce_ex