Skip to content
Snippets Groups Projects
Commit 2d73140a authored by Jan Hönig's avatar Jan Hönig
Browse files

Removed default Vector width

parent 7315fde9
Branches
No related tags found
No related merge requests found
Pipeline #37268 failed
......@@ -219,7 +219,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
rng._symbols_defined = set(new_result_symbols)
fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
insert_vector_casts(loop_node, default_float_type)
insert_vector_casts(loop_node, default_float_type, vector_width)
def mask_conditionals(loop_body):
......@@ -248,7 +248,7 @@ def mask_conditionals(loop_body):
visit_node(loop_body, mask=True)
def insert_vector_casts(ast_node, default_float_type='double'):
def insert_vector_casts(ast_node, default_float_type='double', vector_width=4):
"""Inserts necessary casts from scalar values to vector values."""
handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc,
......@@ -262,7 +262,7 @@ def insert_vector_casts(ast_node, default_float_type='double'):
arg = visit_expr(expr.args[0])
assert cast_type in [BasicType('float32'), BasicType('float64')],\
f'Vectorization cannot vectorize type {cast_type}'
return expr.func(arg, VectorType(cast_type))
return expr.func(arg, VectorType(cast_type, vector_width))
elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
new_arg = visit_expr(expr.args[0], default_type)
base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \
......
......@@ -131,7 +131,7 @@ class VectorType(AbstractType):
# TODO: check with rest
instruction_set = None
def __init__(self, base_type: BasicType, width: int = 4): # TODO default vector length is dangerous
def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type
self.width = width
......
......@@ -114,6 +114,7 @@ def get_type_of_expression(expr,
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
......@@ -170,13 +171,6 @@ def get_type_of_expression(expr,
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
# collate_types checks numpy_dtype in the special cases
if any(not hasattr(t, 'numpy_dtype') for t in types):
forbid_collation_to_complex = False
forbid_collation_to_float = False
else:
forbid_collation_to_complex = expr.is_real is True
forbid_collation_to_float = expr.is_integer is True
return collate_types(types)
else:
if expr.is_integer:
......@@ -187,7 +181,7 @@ def get_type_of_expression(expr,
raise NotImplementedError("Could not determine type for", expr, type(expr))
############################# End This is basically our type system ##################################################
# ############################# End This is basically our type system ##################################################
# TODO this seems quite wrong...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment