Skip to content
Snippets Groups Projects
Commit 1c5814c2 authored by Markus Holzer's avatar Markus Holzer
Browse files

node collection replacements more specific

parent 404f5138
No related branches found
No related tags found
No related merge requests found
Pipeline #47235 failed
......@@ -43,14 +43,15 @@ class NodeCollection:
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf())
lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: (
sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False)))
))
lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
(DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
if self.is_Nodes:
......@@ -65,6 +66,7 @@ class NodeCollection:
return optimize(node, sympy_optimisations)
else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
else:
self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
......
......@@ -282,6 +282,7 @@ def test_vectorised_pow(instruction_set=instruction_set):
ast = ps.create_kernel(as6)
vectorize(ast, instruction_set=instruction_set)
ps.show_code(ast)
ast.compile()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment