From c062fc5c707f2548da5eb6302f1e5fc5cce5751d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20H=C3=B6nig?= <jan.hoenig@fau.de> Date: Wed, 2 Feb 2022 15:30:53 +0100 Subject: [PATCH] Removed default Vector width --- pystencils/cpu/vectorization.py | 6 +++--- pystencils/typing/types.py | 2 +- pystencils/typing/utilities.py | 10 ++-------- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 4069a2485..7996d6d3f 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -219,7 +219,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)}) rng._symbols_defined = set(new_result_symbols) fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase)) - insert_vector_casts(loop_node, default_float_type) + insert_vector_casts(loop_node, default_float_type, vector_width) def mask_conditionals(loop_body): @@ -248,7 +248,7 @@ def mask_conditionals(loop_body): visit_node(loop_body, mask=True) -def insert_vector_casts(ast_node, default_float_type='double'): +def insert_vector_casts(ast_node, default_float_type='double', vector_width=4): """Inserts necessary casts from scalar values to vector values.""" handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc, @@ -262,7 +262,7 @@ def insert_vector_casts(ast_node, default_float_type='double'): arg = visit_expr(expr.args[0]) assert cast_type in [BasicType('float32'), BasicType('float64')],\ f'Vectorization cannot vectorize type {cast_type}' - return expr.func(arg, VectorType(cast_type)) + return expr.func(arg, VectorType(cast_type, vector_width)) elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: new_arg = visit_expr(expr.args[0], default_type) base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ diff --git a/pystencils/typing/types.py b/pystencils/typing/types.py index 849c3318e..2f45ff4af 100644 --- a/pystencils/typing/types.py +++ b/pystencils/typing/types.py @@ -131,7 +131,7 @@ class VectorType(AbstractType): # TODO: check with rest instruction_set = None - def __init__(self, base_type: BasicType, width: int = 4): # TODO default vector length is dangerous + def __init__(self, base_type: BasicType, width: int): self._base_type = base_type self.width = width diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 4f67435bb..15d0beed1 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -114,6 +114,7 @@ def get_type_of_expression(expr, if not symbol_type_dict: symbol_type_dict = defaultdict(lambda: create_type('double')) + # TODO this line is quite hard to understand, if possible simpl get_type = partial(get_type_of_expression, default_float_type=default_float_type, default_int_type=default_int_type, @@ -170,13 +171,6 @@ def get_type_of_expression(expr, expr: sp.Expr if expr.args: types = tuple(get_type(a) for a in expr.args) - # collate_types checks numpy_dtype in the special cases - if any(not hasattr(t, 'numpy_dtype') for t in types): - forbid_collation_to_complex = False - forbid_collation_to_float = False - else: - forbid_collation_to_complex = expr.is_real is True - forbid_collation_to_float = expr.is_integer is True return collate_types(types) else: if expr.is_integer: @@ -187,7 +181,7 @@ def get_type_of_expression(expr, raise NotImplementedError("Could not determine type for", expr, type(expr)) -############################# End This is basically our type system ################################################## +# ############################# End This is basically our type system ################################################## # TODO this seems quite wrong... -- GitLab