Skip to content
Snippets Groups Projects
Commit 8833346c authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent c378ca19
Branches
Tags
No related merge requests found
......@@ -707,7 +707,7 @@ class KernelConstraintsCheck:
new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs):
def process_expression(self, rhs, type_constants=True):
self._update_accesses_rhs(rhs)
if isinstance(rhs, Field.Access):
self.fields_read.add(rhs.field)
......@@ -716,19 +716,19 @@ class KernelConstraintsCheck:
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
elif isinstance(rhs, sp.Number):
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args]
new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1])
return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
else:
new_args = [self.process_expression(arg) for arg in rhs.args]
new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
@property
......@@ -796,7 +796,7 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(check.process_expression(obj.condition_expr),
return ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment