Skip to content
Snippets Groups Projects
Commit 41ee7b86 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

fix flake8 lints

parent 3c863610
No related branches found
No related tags found
No related merge requests found
...@@ -241,7 +241,7 @@ class CBackend: ...@@ -241,7 +241,7 @@ class CBackend:
return func_declaration + "\n" + body return func_declaration + "\n" + body
def _print_Block(self, node): def _print_Block(self, node):
if node == None: if node is None:
return "\n" return "\n"
block_contents = "\n".join([self._print(child) for child in node.args]) block_contents = "\n".join([self._print(child) for child in node.args])
...@@ -797,7 +797,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -797,7 +797,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if exp.is_integer and exp.is_number and 0 < exp < 8: if exp.is_integer and exp.is_number and 0 < exp < 8:
return self._print(sp.Mul(*[expr.base] * exp, evaluate=False)) return self._print(sp.Mul(*[expr.base] * exp, evaluate=False))
elif exp.is_integer and exp.is_number and -8 < exp < 0: elif exp.is_integer and exp.is_number and -8 < exp < 0:
return self._print(sp.Mul(*[DivFunc(CastFunc(1.0, expr.base.dtype),expr.base)] * (-exp), evaluate=False)) return self._print(sp.Mul(*[DivFunc(CastFunc(1.0, expr.base.dtype), expr.base)] * (-exp), evaluate=False))
elif exp == 0.5: elif exp == 0.5:
return root return root
elif exp == -0.5: elif exp == -0.5:
......
...@@ -76,7 +76,7 @@ class CachelineSize(ast.Node): ...@@ -76,7 +76,7 @@ class CachelineSize(ast.Node):
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False, assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False,
assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True, assume_inner_stride_one: bool = False, assume_sufficient_line_padding: bool = True,
vectorized_loop_substitutions : dict = {}, moved_constants : set = set()): vectorized_loop_substitutions: dict = {}, moved_constants: set = set()):
# TODO Vectorization Revamp we first introduce the remainder loop and then check if we can even vectorise. # TODO Vectorization Revamp we first introduce the remainder loop and then check if we can even vectorise.
# Maybe first copy the ast and return the copied version on failure # Maybe first copy the ast and return the copied version on failure
"""Explicit vectorization using SIMD vectorization via intrinsics. """Explicit vectorization using SIMD vectorization via intrinsics.
...@@ -100,11 +100,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', ...@@ -100,11 +100,11 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
assumes that at the end of each line there is enough padding with dummy data assumes that at the end of each line there is enough padding with dummy data
depending on the access pattern there might be additional padding depending on the access pattern there might be additional padding
required at the end of the array required at the end of the array
vectorized_loop_substitutions: a dictionary of symbols and substitutions to be applied only to the vectorized loop, vectorized_loop_substitutions: a dictionary of symbols and substitutions to be applied only to the vectorized
with the purpose of declaring certain scalar constants entering the kernel from outside loop, with the purpose of declaring certain scalar constants entering the kernel
as vectorized variables in front of the loop. from outside as vectorized variables in front of the loop.
moved_nodes: enables cooperation with moved constants: a set of TypedSymbol of symbols that were previously moved out moved_nodes: enables cooperation with moved constants: a set of TypedSymbol of symbols that were previously
of loops. They have to be adapted with a CastFunc to the vector datatype. moved out of loops. They have to be adapted with a CastFunc to the vector datatype.
""" """
if instruction_set == 'best': if instruction_set == 'best':
if get_supported_instruction_sets(): if get_supported_instruction_sets():
...@@ -262,12 +262,14 @@ def mask_conditionals(loop_body): ...@@ -262,12 +262,14 @@ def mask_conditionals(loop_body):
def insert_vector_casts(ast_node, instruction_set, default_float_type='double', moved_nodes: set = set()): def insert_vector_casts(ast_node, instruction_set, default_float_type='double', moved_nodes: set = set()):
"""Inserts necessary casts from scalar values to vector values. Casts are normally omitted for TypedSymbols, except they are in """Inserts necessary casts from scalar values to vector values. Casts are normally omitted for TypedSymbols, except
the set of symbols that were moved out of the loop previously, are still scalar and must therefore be vectorized.""" they are in the set of symbols that were moved out of the loop previously, are still scalar and must therefore
be vectorized."""
handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs) handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs)
def visit_expr(expr, default_type='double', _moved_nodes: set = moved_nodes): # TODO Vectorization Revamp: get rid of default_type # TODO Vectorization Revamp: get rid of default_type
def visit_expr(expr, default_type='double', _moved_nodes: set = moved_nodes):
if isinstance(expr, VectorMemoryAccess): if isinstance(expr, VectorMemoryAccess):
return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:]) return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
elif isinstance(expr, CastFunc): elif isinstance(expr, CastFunc):
...@@ -307,7 +309,10 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double', ...@@ -307,7 +309,10 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double',
target_type = collate_types(arg_types) target_type = collate_types(arg_types)
# insert cast function to target type (e.g. vectorType) if it's missing # insert cast function to target type (e.g. vectorType) if it's missing
casted_args = [ casted_args = [
CastFunc(a, target_type) if t != target_type and not isinstance(a, VectorMemoryAccess) and not all(isinstance(f, TypedSymbol) for f in a.free_symbols) else a CastFunc(a, target_type) if (
t != target_type and not isinstance(a, VectorMemoryAccess) and
not all(isinstance(f, TypedSymbol) for f in a.free_symbols)
) else a
for a, t in zip(new_args, arg_types)] for a, t in zip(new_args, arg_types)]
return expr.func(*casted_args) return expr.func(*casted_args)
elif expr.func is sp.UnevaluatedExpr: elif expr.func is sp.UnevaluatedExpr:
......
...@@ -568,7 +568,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None, ...@@ -568,7 +568,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None,
return visit_node(ast_node) return visit_node(ast_node)
def move_constants_before_loop(ast_node, rename_moved_consts = False): def move_constants_before_loop(ast_node, rename_moved_consts=False):
"""Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent. """Moves :class:`pystencils.ast.SympyAssignment` nodes out of loop body if they are iteration independent.
Call this after creating the loop structure with :func:`make_loop_over_domain` Call this after creating the loop structure with :func:`make_loop_over_domain`
...@@ -630,7 +630,8 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False): ...@@ -630,7 +630,8 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False):
block.append(child) block.append(child)
continue continue
if isinstance(child, ast.SympyAssignment) and isinstance(child.lhs, ast.ResolvedFieldAccess): # don't move field accesses # don't move field accesses
if isinstance(child, ast.SympyAssignment) and isinstance(child.lhs, ast.ResolvedFieldAccess):
block.append(child) block.append(child)
continue continue
...@@ -648,7 +649,7 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False): ...@@ -648,7 +649,7 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False):
if rename_moved_consts: if rename_moved_consts:
moved_symbol = TypedSymbol(old_symbol.name + "_mov" + str(moves_ctr), old_symbol.dtype) moved_symbol = TypedSymbol(old_symbol.name + "_mov" + str(moves_ctr), old_symbol.dtype)
moved_nodes.add(moved_symbol) moved_nodes.add(moved_symbol)
consts_moved_subs.update({old_symbol : moved_symbol}) consts_moved_subs.update({old_symbol: moved_symbol})
else: else:
moved_nodes.add(child.lhs) moved_nodes.add(child.lhs)
...@@ -667,8 +668,8 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False): ...@@ -667,8 +668,8 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False):
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const), target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
child_to_insert_before) child_to_insert_before)
#block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const)) # block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
consts_moved_subs.update({old_symbol : new_symbol}) consts_moved_subs.update({old_symbol: new_symbol})
if bool(consts_moved_subs): if bool(consts_moved_subs):
for child in children: for child in children:
...@@ -676,6 +677,7 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False): ...@@ -676,6 +677,7 @@ def move_constants_before_loop(ast_node, rename_moved_consts = False):
moves_ctr += 1 moves_ctr += 1
return moved_nodes return moved_nodes
def split_inner_loop(ast_node: ast.Node, symbol_groups): def split_inner_loop(ast_node: ast.Node, symbol_groups):
""" """
Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams Splits inner loop into multiple loops to minimize the amount of simultaneous load/store streams
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment