From a3f92cb47e421779f6a7020380264bab4bfe40b5 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 14 Jan 2020 19:00:51 +0100
Subject: [PATCH] Implement interpolation as an optimization

---
 pystencils/transformations.py | 26 +++++++++++---------------
 1 file changed, 11 insertions(+), 15 deletions(-)

diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index eb68d53ef..a81d0fb0e 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -1322,6 +1322,8 @@ def implement_interpolations(ast_node: ast.Node,
     FLOAT32_T = create_type('float32')
 
     interpolation_accesses = ast_node.atoms(InterpolatorAccess)
+    if not interpolation_accesses:
+        return ast_node
 
     def can_use_hw_interpolation(i):
         return (use_hardware_interpolation_for_f32
@@ -1346,22 +1348,16 @@ def implement_interpolations(ast_node: ast.Node,
                 pass
             ast_node.subs({old_i: i})
 
-    if vectorize:
-        # TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather
-        raise NotImplementedError()
-    else:
-        substitutions = {i: i.implementation_with_stencils()
-                         for i in interpolation_accesses if not can_use_hw_interpolation(i)}
-        if isinstance(ast_node, AssignmentCollection):
-            ast_node = ast_node.subs(substitutions)
-        else:
-            ast_node.subs(substitutions)
+    from pystencils.math_optimizations import ReplaceOptim, optimize_ast
 
-    # from pystencils.math_optimizations import ReplaceOptim, optimize_ast
+    ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess)
+                                                    and not can_use_hw_interpolation(i),
+                                                    lambda e: e.implementation_with_stencils()
+                                                    )
 
-    # RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
-            # lambda e: e.args[0]
-            # )
-    # optimize_ast(ast_node, [RemoveConjugate])
+    RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate),
+                                   lambda e: e.args[0]
+                                   )
+    optimize_ast(ast_node, [RemoveConjugate, ImplementInterpolationByStencils])
 
     return ast_node
-- 
GitLab