Skip to content
Snippets Groups Projects
Commit c062fc5c authored by Jan Hönig's avatar Jan Hönig Committed by Markus Holzer
Browse files

Removed default Vector width

parent 939241f2
No related branches found
No related tags found
1 merge request!292Rebase of pystencils Type System
...@@ -219,7 +219,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a ...@@ -219,7 +219,7 @@ def vectorize_inner_loops_and_adapt_load_stores(ast_node, vector_width, assume_a
substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)}) substitutions.update({s[0]: s[1] for s in zip(rng.result_symbols, new_result_symbols)})
rng._symbols_defined = set(new_result_symbols) rng._symbols_defined = set(new_result_symbols)
fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase)) fast_subs(loop_node, substitutions, skip=lambda e: isinstance(e, RNGBase))
insert_vector_casts(loop_node, default_float_type) insert_vector_casts(loop_node, default_float_type, vector_width)
def mask_conditionals(loop_body): def mask_conditionals(loop_body):
...@@ -248,7 +248,7 @@ def mask_conditionals(loop_body): ...@@ -248,7 +248,7 @@ def mask_conditionals(loop_body):
visit_node(loop_body, mask=True) visit_node(loop_body, mask=True)
def insert_vector_casts(ast_node, default_float_type='double'): def insert_vector_casts(ast_node, default_float_type='double', vector_width=4):
"""Inserts necessary casts from scalar values to vector values.""" """Inserts necessary casts from scalar values to vector values."""
handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc, handled_functions = (sp.Add, sp.Mul, fast_division, fast_sqrt, fast_inv_sqrt, vec_any, vec_all, DivFunc,
...@@ -262,7 +262,7 @@ def insert_vector_casts(ast_node, default_float_type='double'): ...@@ -262,7 +262,7 @@ def insert_vector_casts(ast_node, default_float_type='double'):
arg = visit_expr(expr.args[0]) arg = visit_expr(expr.args[0])
assert cast_type in [BasicType('float32'), BasicType('float64')],\ assert cast_type in [BasicType('float32'), BasicType('float64')],\
f'Vectorization cannot vectorize type {cast_type}' f'Vectorization cannot vectorize type {cast_type}'
return expr.func(arg, VectorType(cast_type)) return expr.func(arg, VectorType(cast_type, vector_width))
elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set: elif expr.func is sp.Abs and 'abs' not in ast_node.instruction_set:
new_arg = visit_expr(expr.args[0], default_type) new_arg = visit_expr(expr.args[0], default_type)
base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \ base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \
......
...@@ -131,7 +131,7 @@ class VectorType(AbstractType): ...@@ -131,7 +131,7 @@ class VectorType(AbstractType):
# TODO: check with rest # TODO: check with rest
instruction_set = None instruction_set = None
def __init__(self, base_type: BasicType, width: int = 4): # TODO default vector length is dangerous def __init__(self, base_type: BasicType, width: int):
self._base_type = base_type self._base_type = base_type
self.width = width self.width = width
......
...@@ -114,6 +114,7 @@ def get_type_of_expression(expr, ...@@ -114,6 +114,7 @@ def get_type_of_expression(expr,
if not symbol_type_dict: if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double')) symbol_type_dict = defaultdict(lambda: create_type('double'))
# TODO this line is quite hard to understand, if possible simpl
get_type = partial(get_type_of_expression, get_type = partial(get_type_of_expression,
default_float_type=default_float_type, default_float_type=default_float_type,
default_int_type=default_int_type, default_int_type=default_int_type,
...@@ -170,13 +171,6 @@ def get_type_of_expression(expr, ...@@ -170,13 +171,6 @@ def get_type_of_expression(expr,
expr: sp.Expr expr: sp.Expr
if expr.args: if expr.args:
types = tuple(get_type(a) for a in expr.args) types = tuple(get_type(a) for a in expr.args)
# collate_types checks numpy_dtype in the special cases
if any(not hasattr(t, 'numpy_dtype') for t in types):
forbid_collation_to_complex = False
forbid_collation_to_float = False
else:
forbid_collation_to_complex = expr.is_real is True
forbid_collation_to_float = expr.is_integer is True
return collate_types(types) return collate_types(types)
else: else:
if expr.is_integer: if expr.is_integer:
...@@ -187,7 +181,7 @@ def get_type_of_expression(expr, ...@@ -187,7 +181,7 @@ def get_type_of_expression(expr,
raise NotImplementedError("Could not determine type for", expr, type(expr)) raise NotImplementedError("Could not determine type for", expr, type(expr))
############################# End This is basically our type system ################################################## # ############################# End This is basically our type system ##################################################
# TODO this seems quite wrong... # TODO this seems quite wrong...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment