From 72c3c7cb382ccfa678b3383e8a51f14dd344fec2 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 9 Jan 2024 10:18:02 +0100 Subject: [PATCH] Fix type equality checks --- pystencils/astnodes.py | 2 +- pystencils/backends/cbackend.py | 2 +- pystencils/boundaries/boundaryconditions.py | 2 +- pystencils/cpu/vectorization.py | 2 +- pystencils/rng.py | 2 +- pystencils/sympyextensions.py | 6 +++--- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 3d1940e69..f399287ed 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -673,7 +673,7 @@ class SympyAssignment(Node): return hash((self.lhs, self.rhs)) def __eq__(self, other): - return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) + return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs) class ResolvedFieldAccess(sp.Indexed): diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 5c8259699..7dbf84d37 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -152,7 +152,7 @@ class CustomCodeNode(Node): return self._symbols_read - self._symbols_defined def __eq__(self, other): - return type(self) == type(other) and self._code == other._code + return type(self) is type(other) and self._code == other._code def __hash__(self): return hash(self._code) diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py index 65243177d..5fd8480b6 100644 --- a/pystencils/boundaries/boundaryconditions.py +++ b/pystencils/boundaries/boundaryconditions.py @@ -76,7 +76,7 @@ class Neumann(Boundary): return hash("Neumann") def __eq__(self, other): - return type(other) == Neumann + return type(other) is Neumann class Dirichlet(Boundary): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 0e2b77da6..872f0b3c4 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -295,7 +295,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): elif isinstance(expr, CastFunc): cast_type = expr.args[1] arg = visit_expr(expr.args[0], default_type, force_vectorize) - assert cast_type in [BasicType('float32'), BasicType('float64')],\ + assert cast_type in [BasicType('float32'), BasicType('float64')], \ f'Vectorization cannot vectorize type {cast_type}' return expr.func(arg, VectorType(cast_type, instruction_set['width'])) elif expr.func is sp.Abs and 'abs' not in instruction_set: diff --git a/pystencils/rng.py b/pystencils/rng.py index 6e9bc9548..84155b00c 100644 --- a/pystencils/rng.py +++ b/pystencils/rng.py @@ -65,7 +65,7 @@ class RNGBase(CustomCodeNode): return (self._name, *self.result_symbols, *self.args) def __eq__(self, other): - return type(self) == type(other) and self._hashable_content() == other._hashable_content() + return type(self) is type(other) and self._hashable_content() == other._hashable_content() def __hash__(self): return hash(self._hashable_content()) diff --git a/pystencils/sympyextensions.py b/pystencils/sympyextensions.py index 40be43eaa..680b58670 100644 --- a/pystencils/sympyextensions.py +++ b/pystencils/sympyextensions.py @@ -356,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order factor_count = 0 if type(product) is Mul: for factor in product.args: - if type(factor) == Pow: + if type(factor) is Pow: if factor.args[0] in symbols: factor_count += factor.args[1] if factor in symbols: @@ -366,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order factor_count += product.args[1] return factor_count - if type(expr) == Mul or type(expr) == Pow: + if type(expr) is Mul or type(expr) is Pow: if velocity_factors_in_product(expr) <= order: return expr else: return Zero() - if type(expr) != Add: + if type(expr) is not Add: return expr for sum_term in expr.args: -- GitLab