From a8c5cc85ea6b5ed3923beae766a968f7c409900c Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Wed, 2 Feb 2022 20:05:03 +0100 Subject: [PATCH] Fix linter --- pystencils/cpu/vectorization.py | 13 +++++++------ pystencils/gpucuda/kernelcreation.py | 2 +- pystencils/simp/assignment_collection.py | 12 +++++------- pystencils/transformations.py | 2 +- pystencils/typing/__init__.py | 6 ------ pystencils/typing/cast_functions.py | 7 +++---- pystencils/typing/leaf_typing.py | 2 +- pystencils/typing/transformations.py | 1 - pystencils/typing/types.py | 1 - pystencils/typing/utilities.py | 11 ++--------- 10 files changed, 20 insertions(+), 37 deletions(-) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 4d609a11..ac25639b 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -7,8 +7,8 @@ from sympy.logic.boolalg import BooleanFunction, BooleanAtom import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set -from pystencils.typing import ( BasicType, PointerType, TypedSymbol, VectorType, CastFunc, collate_types, - get_type_of_expression, VectorMemoryAccess) +from pystencils.typing import (BasicType, PointerType, TypedSymbol, VectorType, CastFunc, collate_types, + get_type_of_expression, VectorMemoryAccess) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.functions import DivFunc from pystencils.field import Field @@ -203,9 +203,10 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, assume_aligned, nontem loop_node.step = vector_width loop_node.subs(substitutions) vector_int_width = ast_node.instruction_set['intwidth'] - vector_loop_counter = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) \ - + CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)), - VectorType(loop_counter_symbol.dtype, vector_int_width)) + arg_1 = CastFunc(loop_counter_symbol, VectorType(loop_counter_symbol.dtype, vector_int_width)) + arg_2 = CastFunc(tuple(range(vector_int_width if type(vector_int_width) is int else 2)), + VectorType(loop_counter_symbol.dtype, vector_int_width)) + vector_loop_counter = arg_1 + arg_2 fast_subs(loop_node, {loop_counter_symbol: vector_loop_counter}, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess) or isinstance(e, VectorMemoryAccess)) @@ -333,7 +334,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): assignment = arg # If there is a remainder loop we do not vectorise it, thus lhs will indicate this # if isinstance(assignment.lhs, ast.ResolvedFieldAccess): - # continue + # continue subs_expr = fast_subs(assignment.rhs, substitution_dict, skip=lambda e: isinstance(e, ast.ResolvedFieldAccess)) assignment.rhs = visit_expr(subs_expr, default_type) diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index 21721bb7..a50953b6 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -10,7 +10,7 @@ from pystencils.field import Field, FieldType from pystencils.enums import Target, Backend from pystencils.gpucuda.cudajit import make_python_function from pystencils.node_collection import NodeCollection -from pystencils.gpucuda.indexing import BlockIndexing, indexing_creator_from_params +from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.transformations import ( get_base_buffer_index, get_common_shape, parse_base_pointer_info, diff --git a/pystencils/simp/assignment_collection.py b/pystencils/simp/assignment_collection.py index 69dcf956..b3324e42 100644 --- a/pystencils/simp/assignment_collection.py +++ b/pystencils/simp/assignment_collection.py @@ -136,8 +136,7 @@ class AssignmentCollection: bound_symbols_set = bound_symbols_set.union(*[ assignment.symbols_defined for assignment in self.all_assignments if isinstance(assignment, pystencils.astnodes.Node) - ] - ) + ]) return bound_symbols_set @@ -159,11 +158,9 @@ class AssignmentCollection: @property def defined_symbols(self) -> Set[sp.Symbol]: """All symbols which occur as left-hand-sides of one of the main equations""" - return (set( - [assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)] - ).union(*[assignment.symbols_defined for assignment in self.main_assignments if isinstance( - assignment, pystencils.astnodes.Node)] - )) + lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]) + return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments + if isinstance(assignment, pystencils.astnodes.Node)])) @property def operation_count(self): @@ -365,6 +362,7 @@ class AssignmentCollection: new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] return self.copy(new_assignment, kept_subexpressions) + # ----------------------------------------- Display and Printing ------------------------------------------------- def _repr_html_(self): diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 7f864f9a..2e885904 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -11,7 +11,7 @@ import pystencils.astnodes as ast from pystencils.assignment import Assignment from pystencils.typing import ( PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type) -from pystencils.field import Field, Field, FieldType +from pystencils.field import Field, FieldType from pystencils.typing import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.slicing import normalize_slice diff --git a/pystencils/typing/__init__.py b/pystencils/typing/__init__.py index 2221b812..e69de29b 100644 --- a/pystencils/typing/__init__.py +++ b/pystencils/typing/__init__.py @@ -1,6 +0,0 @@ - - -from pystencils.typing.types import * -from pystencils.typing.typed_sympy import * -from pystencils.typing.cast_functions import * -from pystencils.typing.utilities import * diff --git a/pystencils/typing/cast_functions.py b/pystencils/typing/cast_functions.py index 8200e969..76686c21 100644 --- a/pystencils/typing/cast_functions.py +++ b/pystencils/typing/cast_functions.py @@ -2,7 +2,7 @@ import numpy as np import sympy as sp from sympy.logic.boolalg import Boolean -from pystencils.typing.types import AbstractType, BasicType, create_type +from pystencils.typing.types import AbstractType, BasicType from pystencils.typing.typed_sympy import TypedSymbol @@ -93,9 +93,8 @@ class CastFunc(sp.Function): See :func:`.TypedSymbol.is_integer` """ if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \ - np.issubdtype(self.dtype.numpy_dtype, np.floating) or \ - super().is_real + return np.issubdtype(self.dtype.numpy_dtype, np.integer) or np.issubdtype(self.dtype.numpy_dtype, + np.floating) or super().is_real else: return super().is_real diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index 6ccd864e..c6282489 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -185,7 +185,7 @@ class TypeAdder: collated_type = collate_types([t for _, t in args_types]) new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type - #elif isinstance(expr, sp.Mul): + # elif isinstance(expr, sp.Mul): # raise NotImplementedError('sp.Mul') # # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? # # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] diff --git a/pystencils/typing/transformations.py b/pystencils/typing/transformations.py index f5ddcfa4..74ecf19f 100644 --- a/pystencils/typing/transformations.py +++ b/pystencils/typing/transformations.py @@ -2,7 +2,6 @@ from typing import List from pystencils.config import CreateKernelConfig from pystencils.typing.leaf_typing import TypeAdder -from pystencils.typing import BasicType from sympy.codegen import Assignment diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py index 2f45ff4a..dbe28449 100644 --- a/pystencils/typing/types.py +++ b/pystencils/typing/types.py @@ -293,4 +293,3 @@ def create_type(specification: Union[np.dtype, AbstractType, str]) -> AbstractTy return BasicType(numpy_dtype, const=False) else: return StructType(numpy_dtype, const=False) - diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 15d0beed..6a43c798 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -34,8 +34,6 @@ def get_base_type(data_type): return data_type -############################# This is basically our type system ######################################################## - def result_type(*args: np.dtype): s = sorted(args, key=lambda x: x.itemsize) @@ -104,7 +102,8 @@ def get_type_of_expression(expr, # TODO: we shouldn't need to have default. AST leaves should have a type default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type - symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type + # TODO: we shouldn't need to have default. AST leaves should have a type + symbol_type_dict=None): from pystencils.astnodes import ResolvedFieldAccess from pystencils.cpu.vectorization import vec_all, vec_any @@ -181,9 +180,6 @@ def get_type_of_expression(expr, raise NotImplementedError("Could not determine type for", expr, type(expr)) -# ############################# End This is basically our type system ################################################## - - # TODO this seems quite wrong... sympy_version = sp.__version__.split('.') if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: @@ -191,7 +187,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: sp.Number.__getstate__ = sp.Basic.__getstate__ del sp.Basic.__getstate__ - class FunctorWithStoredKwargs: def __init__(self, func, **kwargs): self.func = func @@ -200,7 +195,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: def __call__(self, *args): return self.func(*args, **self.kwargs) - # __reduce_ex__ would strip kwargs, so we override it def basic_reduce_ex(self, protocol): if hasattr(self, '__getnewargs_ex__'): @@ -213,7 +207,6 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109: state = None return FunctorWithStoredKwargs(type(self), **kwargs), args, state - sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__ sp.Basic.__reduce_ex__ = basic_reduce_ex -- GitLab