From 72c3c7cb382ccfa678b3383e8a51f14dd344fec2 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Tue, 9 Jan 2024 10:18:02 +0100
Subject: [PATCH] Fix type equality checks

---
 pystencils/astnodes.py                      | 2 +-
 pystencils/backends/cbackend.py             | 2 +-
 pystencils/boundaries/boundaryconditions.py | 2 +-
 pystencils/cpu/vectorization.py             | 2 +-
 pystencils/rng.py                           | 2 +-
 pystencils/sympyextensions.py               | 6 +++---
 6 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 3d1940e69..f399287ed 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -673,7 +673,7 @@ class SympyAssignment(Node):
         return hash((self.lhs, self.rhs))
 
     def __eq__(self, other):
-        return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
+        return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
 
 
 class ResolvedFieldAccess(sp.Indexed):
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 5c8259699..7dbf84d37 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -152,7 +152,7 @@ class CustomCodeNode(Node):
         return self._symbols_read - self._symbols_defined
 
     def __eq__(self, other):
-        return type(self) == type(other) and self._code == other._code
+        return type(self) is type(other) and self._code == other._code
 
     def __hash__(self):
         return hash(self._code)
diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py
index 65243177d..5fd8480b6 100644
--- a/pystencils/boundaries/boundaryconditions.py
+++ b/pystencils/boundaries/boundaryconditions.py
@@ -76,7 +76,7 @@ class Neumann(Boundary):
         return hash("Neumann")
 
     def __eq__(self, other):
-        return type(other) == Neumann
+        return type(other) is Neumann
 
 
 class Dirichlet(Boundary):
diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 0e2b77da6..872f0b3c4 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -295,7 +295,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
         elif isinstance(expr, CastFunc):
             cast_type = expr.args[1]
             arg = visit_expr(expr.args[0], default_type, force_vectorize)
-            assert cast_type in [BasicType('float32'), BasicType('float64')],\
+            assert cast_type in [BasicType('float32'), BasicType('float64')], \
                 f'Vectorization cannot vectorize type {cast_type}'
             return expr.func(arg, VectorType(cast_type, instruction_set['width']))
         elif expr.func is sp.Abs and 'abs' not in instruction_set:
diff --git a/pystencils/rng.py b/pystencils/rng.py
index 6e9bc9548..84155b00c 100644
--- a/pystencils/rng.py
+++ b/pystencils/rng.py
@@ -65,7 +65,7 @@ class RNGBase(CustomCodeNode):
         return (self._name, *self.result_symbols, *self.args)
 
     def __eq__(self, other):
-        return type(self) == type(other) and self._hashable_content() == other._hashable_content()
+        return type(self) is type(other) and self._hashable_content() == other._hashable_content()
 
     def __hash__(self):
         return hash(self._hashable_content())
diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py
index 40be43eaa..680b58670 100644
--- a/pystencils/sympyextensions.py
+++ b/pystencils/sympyextensions.py
@@ -356,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
         factor_count = 0
         if type(product) is Mul:
             for factor in product.args:
-                if type(factor) == Pow:
+                if type(factor) is Pow:
                     if factor.args[0] in symbols:
                         factor_count += factor.args[1]
                 if factor in symbols:
@@ -366,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
                 factor_count += product.args[1]
         return factor_count
 
-    if type(expr) == Mul or type(expr) == Pow:
+    if type(expr) is Mul or type(expr) is Pow:
         if velocity_factors_in_product(expr) <= order:
             return expr
         else:
             return Zero()
 
-    if type(expr) != Add:
+    if type(expr) is not Add:
         return expr
 
     for sum_term in expr.args:
-- 
GitLab