diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 0ee8d0e43d8ca7eca932dcae30a044e09df4ccdc..5c8259699247d3c20f893c933c71bf37058010ed 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -151,8 +151,8 @@ class CustomCodeNode(Node): def undefined_symbols(self): return self._symbols_read - self._symbols_defined - def __eq___(self, other): - return self._code == other._code + def __eq__(self, other): + return type(self) == type(other) and self._code == other._code def __hash__(self): return hash(self._code) diff --git a/pystencils/rng.py b/pystencils/rng.py index c75c3f9727720d2d313adee3cda3eead520334c7..6e9bc95480cf83654ce4b4b0b7d783fbb0c6718b 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -61,6 +61,15 @@ class RNGBase(CustomCodeNode): return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \ self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")" + def _hashable_content(self): + return (self._name, *self.result_symbols, *self.args) + + def __eq__(self, other): + return type(self) == type(other) and self._hashable_content() == other._hashable_content() + + def __hash__(self): + return hash(self._hashable_content()) + class PhiloxTwoDoubles(RNGBase): _name = "philox_double2"