Skip to content
Snippets Groups Projects
Commit 1c5814c2 authored by Markus Holzer's avatar Markus Holzer
Browse files

node collection replacements more specific

parent 404f5138
No related branches found
No related tags found
1 merge request!306Improve Vectorisation
...@@ -43,14 +43,15 @@ class NodeCollection: ...@@ -43,14 +43,15 @@ class NodeCollection:
def evaluate_terms(self): def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim( evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer, lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()) lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim( evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8, lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: ( lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else (DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))) DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)) )
sympy_optimisations = [evaluate_constant_terms, evaluate_pow] sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
if self.is_Nodes: if self.is_Nodes:
...@@ -65,6 +66,7 @@ class NodeCollection: ...@@ -65,6 +66,7 @@ class NodeCollection:
return optimize(node, sympy_optimisations) return optimize(node, sympy_optimisations)
else: else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor') raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments] self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
else: else:
self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
......
...@@ -282,6 +282,7 @@ def test_vectorised_pow(instruction_set=instruction_set): ...@@ -282,6 +282,7 @@ def test_vectorised_pow(instruction_set=instruction_set):
ast = ps.create_kernel(as6) ast = ps.create_kernel(as6)
vectorize(ast, instruction_set=instruction_set) vectorize(ast, instruction_set=instruction_set)
ps.show_code(ast)
ast.compile() ast.compile()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment