diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 95627c5b0550bc506fba5305ca09805badac55de..e51eb1b3e6a58293557c2628a82c2021ab792031 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -141,8 +141,10 @@ class TensorflowModule(TorchModule): def __init__(self, module_name, kernel_asts): """Create a C++ module with forward and optional backward_kernels - :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) - :param backward_kernel_ast: + Args: + module_name (str): Module name + kernel_asts (pystencils.kernel_wrappers.KernelWrapper): + ASTs as generated by `:func:pystencils.create_kernel` """ self.compiled_file = None