Skip to content
Snippets Groups Projects
Commit d7ac0ce9 authored by Markus Holzer's avatar Markus Holzer
Browse files

Remove cases for Pow

parent 7d9f3dfa
No related branches found
No related tags found
1 merge request!306Improve Vectorisation
......@@ -787,18 +787,14 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
else:
exp = expr.exp
# TODO the printer should not have any intelligence like this.
# TODO To remove all of these cases the vectoriser needs to be reworked. See loop cutting
if exp.is_integer and exp.is_number and 0 < exp < 8:
return self._print(sp.Mul(*[expr.base] * exp, evaluate=False))
elif exp == -1:
return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
elif exp == 0.5:
return root
elif exp == -0.5:
return self.instruction_set['/'].format(one, root, **self._kwargs)
elif exp.is_integer and exp.is_number and - 8 < exp < 0:
return self.instruction_set['/'].format(one,
self._print(sp.Mul(*[expr.base] * (-exp), evaluate=False)),
**self._kwargs)
else:
raise ValueError("Generic exponential not supported: " + str(expr))
......
......@@ -19,7 +19,7 @@ def test_vec_any(instruction_set, dtype):
width = 4 # we don't know the actual value
else:
width = get_vector_instruction_set(dtype, instruction_set)['width']
data_arr = np.zeros((4 * width, 4 * width), dtype=np.float64 if dtype == 'double' else np.float32)
data_arr = np.zeros((4 * width, 4 * width), dtype=dtype)
data_arr[3:9, 1:3 * width - 1] = 1.0
data = ps.fields(f"data: {dtype}[2D]", data=data_arr)
......
......@@ -261,7 +261,6 @@ def test_vectorised_pow(instruction_set=instruction_set):
ast = ps.create_kernel(as1)
vectorize(ast, instruction_set=instruction_set)
print(ast)
ast.compile()
ast = ps.create_kernel(as2)
......@@ -282,7 +281,6 @@ def test_vectorised_pow(instruction_set=instruction_set):
ast = ps.create_kernel(as6)
vectorize(ast, instruction_set=instruction_set)
ps.show_code(ast)
ast.compile()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment