diff --git a/pystencils/transformations.py b/pystencils/transformations.py index eb68d53efcf406bec4ed6b604a6b5424990bd7b1..a81d0fb0e2dfc06f232eb2f234491dd4858e8f69 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1322,6 +1322,8 @@ def implement_interpolations(ast_node: ast.Node, FLOAT32_T = create_type('float32') interpolation_accesses = ast_node.atoms(InterpolatorAccess) + if not interpolation_accesses: + return ast_node def can_use_hw_interpolation(i): return (use_hardware_interpolation_for_f32 @@ -1346,22 +1348,16 @@ def implement_interpolations(ast_node: ast.Node, pass ast_node.subs({old_i: i}) - if vectorize: - # TODO can be done in _interpolator_access_to_stencils field.absolute_access == simd_gather - raise NotImplementedError() - else: - substitutions = {i: i.implementation_with_stencils() - for i in interpolation_accesses if not can_use_hw_interpolation(i)} - if isinstance(ast_node, AssignmentCollection): - ast_node = ast_node.subs(substitutions) - else: - ast_node.subs(substitutions) + from pystencils.math_optimizations import ReplaceOptim, optimize_ast - # from pystencils.math_optimizations import ReplaceOptim, optimize_ast + ImplementInterpolationByStencils = ReplaceOptim(lambda e: isinstance(e, InterpolatorAccess) + and not can_use_hw_interpolation(i), + lambda e: e.implementation_with_stencils() + ) - # RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate), - # lambda e: e.args[0] - # ) - # optimize_ast(ast_node, [RemoveConjugate]) + RemoveConjugate = ReplaceOptim(lambda e: isinstance(e, sp.conjugate), + lambda e: e.args[0] + ) + optimize_ast(ast_node, [RemoveConjugate, ImplementInterpolationByStencils]) return ast_node