From a8c5cc85ea6b5ed3923beae766a968f7c409900c Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Wed, 2 Feb 2022 20:05:03 +0100
Subject: [PATCH] Fix linter

---
 pystencils/cpu/vectorization.py          | 13 +++++++------
 pystencils/gpucuda/kernelcreation.py     |  2 +-
 pystencils/simp/assignment_collection.py | 12 +++++-------
 pystencils/transformations.py            |  2 +-
 pystencils/typing/__init__.py            |  6 ------
 pystencils/typing/cast_functions.py      |  7 +++----
 pystencils/typing/leaf_typing.py         |  2 +-
 pystencils/typing/transformations.py     |  1 -
 pystencils/typing/types.py               |  1 -
 pystencils/typing/utilities.py           | 11 ++---------
 10 files changed, 20 insertions(+), 37 deletions(-)

diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py
index 4d609a11..ac25639b 100644
--- a/pystencils/cpu/vectorization.py
+++ b/pystencils/cpu/vectorization.py
@@ -7,8 +7,8 @@ from sympy.logic.boolalg import BooleanFunction, BooleanAtom
 
 import pystencils.astnodes as ast
 from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
-from pystencils.typing import ( BasicType, PointerType, TypedSymbol, VectorType, CastFunc, collate_types,
-                                get_type_of_expression, VectorMemoryAccess)
+from pystencils.typing import (BasicType, PointerType, TypedSymbol, VectorType, CastFunc, collate_types,
+                               get_type_of_expression, VectorMemoryAccess)
 from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 from pystencils.functions import DivFunc
 from pystencils.field import Field
@@ -203,9 +203,10 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, assume_aligned, nontem
         loop_node.step = vector_width
         loop_node.subs(substitutions)
         vector_int_width = ast_node.instruction_set['intwidth']
-        vector_loop_counter = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \
-                              + CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
-                                         VectorType(loop_counter_symbol.dtype, vector_int_width))
+        arg_1 = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width))
+        arg_2 = CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)),
+                         VectorType(loop_counter_symbol.dtype, vector_int_width))
+        vector_loop_counter = arg_1 + arg_2
 
         fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter},
                   skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, VectorMemoryAccess))
@@ -333,7 +334,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
                 assignment = arg
                 # If there is a remainder loop we do not vectorise it, thus lhs will indicate this
                 # if isinstance(assignment.lhs, ast.ResolvedFieldAccess):
-                    # continue
+                # continue
                 subs_expr = fast_subs(assignment.rhs, substitution_dict,
                                       skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
                 assignment.rhs = visit_expr(subs_expr, default_type)
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index 21721bb7..a50953b6 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -10,7 +10,7 @@ from pystencils.field import Field, FieldType
 from pystencils.enums import Target, Backend
 from pystencils.gpucuda.cudajit import make_python_function
 from pystencils.node_collection import NodeCollection
-from pystencils.gpucuda.indexing import BlockIndexing, indexing_creator_from_params
+from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.transformations import (
     get_base_buffer_index, get_common_shape, parse_base_pointer_info,
diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py
index 69dcf956..b3324e42 100644
--- a/pystencils/simp/assignment_collection.py
+++ b/pystencils/simp/assignment_collection.py
@@ -136,8 +136,7 @@ class AssignmentCollection:
         bound_symbols_set = bound_symbols_set.union(*[
             assignment.symbols_defined for assignment in self.all_assignments
             if isinstance(assignment, pystencils.astnodes.Node)
-        ]
-                                                    )
+        ])
 
         return bound_symbols_set
 
@@ -159,11 +158,9 @@ class AssignmentCollection:
     @property
     def defined_symbols(self) -> Set[sp.Symbol]:
         """All symbols which occur as left-hand-sides of one of the main equations"""
-        return (set(
-            [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]
-        ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance(
-            assignment, pystencils.astnodes.Node)]
-                ))
+        lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
+        return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
+                                if isinstance(assignment, pystencils.astnodes.Node)]))
 
     @property
     def operation_count(self):
@@ -365,6 +362,7 @@ class AssignmentCollection:
 
         new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
         return self.copy(new_assignment, kept_subexpressions)
+
     # ----------------------------------------- Display and Printing   -------------------------------------------------
 
     def _repr_html_(self):
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index 7f864f9a..2e885904 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -11,7 +11,7 @@ import pystencils.astnodes as ast
 from pystencils.assignment import Assignment
 from pystencils.typing import (
     PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
-from pystencils.field import Field, Field, FieldType
+from pystencils.field import Field, FieldType
 from pystencils.typing import FieldPointerSymbol
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.slicing import normalize_slice
diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py
index 2221b812..e69de29b 100644
--- a/pystencils/typing/__init__.py
+++ b/pystencils/typing/__init__.py
@@ -1,6 +0,0 @@
-
-
-from pystencils.typing.types import *
-from pystencils.typing.typed_sympy import *
-from pystencils.typing.cast_functions import *
-from pystencils.typing.utilities import *
diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py
index 8200e969..76686c21 100644
--- a/pystencils/typing/cast_functions.py
+++ b/pystencils/typing/cast_functions.py
@@ -2,7 +2,7 @@ import numpy as np
 import sympy as sp
 from sympy.logic.boolalg import Boolean
 
-from pystencils.typing.types import AbstractType, BasicType, create_type
+from pystencils.typing.types import AbstractType, BasicType
 from pystencils.typing.typed_sympy import TypedSymbol
 
 
@@ -93,9 +93,8 @@ class CastFunc(sp.Function):
         See :func:`.TypedSymbol.is_integer`
         """
         if hasattr(self.dtype, 'numpy_dtype'):
-            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
-                   np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
-                   super().is_real
+            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype,
+                                                                                      np.floating) or super().is_real
         else:
             return super().is_real
 
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index 6ccd864e..c6282489 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -185,7 +185,7 @@ class TypeAdder:
             collated_type = collate_types([t for _, t in args_types])
             new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
             return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
-        #elif isinstance(expr, sp.Mul):
+        # elif isinstance(expr, sp.Mul):
         #    raise NotImplementedError('sp.Mul')
         #    # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
         #    # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
diff --git a/pystencils/typing/transformations.py b/pystencils/typing/transformations.py
index f5ddcfa4..74ecf19f 100644
--- a/pystencils/typing/transformations.py
+++ b/pystencils/typing/transformations.py
@@ -2,7 +2,6 @@ from typing import List
 
 from pystencils.config import CreateKernelConfig
 from pystencils.typing.leaf_typing import TypeAdder
-from pystencils.typing import BasicType
 from sympy.codegen import Assignment
 
 
diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py
index 2f45ff4a..dbe28449 100644
--- a/pystencils/typing/types.py
+++ b/pystencils/typing/types.py
@@ -293,4 +293,3 @@ def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractTy
             return BasicType(numpy_dtype, const=False)
         else:
             return StructType(numpy_dtype, const=False)
-
diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py
index 15d0beed..6a43c798 100644
--- a/pystencils/typing/utilities.py
+++ b/pystencils/typing/utilities.py
@@ -34,8 +34,6 @@ def get_base_type(data_type):
     return data_type
 
 
-############################# This is basically our type system ########################################################
-
 def result_type(*args: np.dtype):
     s = sorted(args, key=lambda x: x.itemsize)
 
@@ -104,7 +102,8 @@ def get_type_of_expression(expr,
                            # TODO: we shouldn't need to have default. AST leaves should have a type
                            default_int_type='int',
                            # TODO: we shouldn't need to have default. AST leaves should have a type
-                           symbol_type_dict=None):  # TODO: we shouldn't need to have default. AST leaves should have a type
+                           # TODO: we shouldn't need to have default. AST leaves should have a type
+                           symbol_type_dict=None):
     from pystencils.astnodes import ResolvedFieldAccess
     from pystencils.cpu.vectorization import vec_all, vec_any
 
@@ -181,9 +180,6 @@ def get_type_of_expression(expr,
     raise NotImplementedError("Could not determine type for", expr, type(expr))
 
 
-# ############################# End This is basically our type system ##################################################
-
-
 # TODO this seems quite wrong...
 sympy_version = sp.__version__.split('.')
 if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
@@ -191,7 +187,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
     sp.Number.__getstate__ = sp.Basic.__getstate__
     del sp.Basic.__getstate__
 
-
     class FunctorWithStoredKwargs:
         def __init__(self, func, **kwargs):
             self.func = func
@@ -200,7 +195,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
         def __call__(self, *args):
             return self.func(*args, **self.kwargs)
 
-
     # __reduce_ex__ would strip kwargs, so we override it
     def basic_reduce_ex(self, protocol):
         if hasattr(self, '__getnewargs_ex__'):
@@ -213,7 +207,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
             state = None
         return FunctorWithStoredKwargs(type(self), **kwargs), args, state
 
-
     sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
     sp.Basic.__reduce_ex__ = basic_reduce_ex
 
-- 
GitLab