diff --git a/tests/test_framework_printer.py b/tests/test_framework_printer.py index ffab27ea6530063a819eb90468c22fee46a8ee69..b7dac07dde6c4e85ff75c74d139b4960353d63eb 100644 --- a/tests/test_framework_printer.py +++ b/tests/test_framework_printer.py @@ -6,12 +6,12 @@ """ import pytest +import sympy as sp import pystencils -import sympy as sp from pystencils.astnodes import Block from pystencils_autodiff.framework_integration.astnodes import ( - DestructuringBindingsForFieldClass, KernelFunctionCall, WrapperFunction, generate_kernel_call) + DestructuringBindingsForFieldClass, FunctionCall, WrapperFunction, generate_kernel_call) from pystencils_autodiff.framework_integration.printer import FrameworkIntegrationPrinter # TODO @@ -31,7 +31,7 @@ def test_pure_call(): for target in ('cpu', 'gpu'): ast = pystencils.create_kernel(forward_assignments, target=target) - kernel_call_ast = KernelFunctionCall(ast) + kernel_call_ast = FunctionCall(ast) code = FrameworkIntegrationPrinter()(kernel_call_ast) print(code) @@ -44,7 +44,7 @@ def test_call_with_destructuring(): for target in ('cpu', 'gpu'): ast = pystencils.create_kernel(forward_assignments, target=target) - kernel_call_ast = KernelFunctionCall(ast) + kernel_call_ast = FunctionCall(ast) wrapper = DestructuringBindingsForFieldClass(kernel_call_ast) code = FrameworkIntegrationPrinter()(wrapper) print(code) @@ -59,7 +59,7 @@ def test_call_with_destructuring_fixed_size(): for target in ('cpu', 'gpu'): ast = pystencils.create_kernel(forward_assignments, target=target) - kernel_call_ast = KernelFunctionCall(ast) + kernel_call_ast = FunctionCall(ast) wrapper = DestructuringBindingsForFieldClass(kernel_call_ast) code = FrameworkIntegrationPrinter()(wrapper) print(code) @@ -73,14 +73,14 @@ def test_wrapper_function(): for target in ('cpu', 'gpu'): ast = pystencils.create_kernel(forward_assignments, target=target) - kernel_call_ast = KernelFunctionCall(ast) + kernel_call_ast = FunctionCall(ast) wrapper = WrapperFunction(DestructuringBindingsForFieldClass(kernel_call_ast)) code = FrameworkIntegrationPrinter()(wrapper) print(code) for target in ('cpu', 'gpu'): ast = pystencils.create_kernel(forward_assignments, target=target) - kernel_call_ast = KernelFunctionCall(ast) + kernel_call_ast = FunctionCall(ast) wrapper = WrapperFunction(Block([kernel_call_ast])) code = FrameworkIntegrationPrinter()(wrapper) print(code)