diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 60599e4e0c094df78058592856c8cd19e78d7db9..2fd3c9eae02900dda019b80c985bdf3ac87e77fb 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -74,8 +74,8 @@ def generate_torch(destination_folder, block_and_thread_numbers = backward_ast.indexing.call_parameters(backward_shape) backward_block = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['block']) backward_grid = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['grid']) - cuda_globals = pystencils.backends.cuda_backend.get_global_declarations(forward_ast) | \ - pystencils.backends.cuda_backend.get_global_declarations(backward_ast) + cuda_globals = pystencils.backends.cbackend.get_global_declarations(forward_ast) | \ + pystencils.backends.cbackend.get_global_declarations(backward_ast) cuda_globals = [generate_cuda(g) for g in cuda_globals] else: backward_block = forward_block = "INVALID"