diff --git a/src/pystencils_autodiff/backends/python_bindings.py b/src/pystencils_autodiff/backends/python_bindings.py index fd6f02bd6b0d11451a007a8b75db1cfa9a0605c7..d86c4f1c1dfa8738a22518ff8462497ee8f0f77e 100644 --- a/src/pystencils_autodiff/backends/python_bindings.py +++ b/src/pystencils_autodiff/backends/python_bindings.py @@ -90,7 +90,7 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho """ # noqa ) - required_global_declarations = ["using namespace tensorflow;"] + required_global_declarations = [CustomCodeNode("using namespace tensorflow;", (), ())] headers = ['"tensorflow/core/framework/op.h"', '"tensorflow/core/framework/op_kernel.h"'] @@ -100,7 +100,11 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho input_field_names = [f.name for f in input_fields] output_field_names = [f.name for f in output_fields] parameters = function_node.get_parameters() - output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes + if not output_fields: + output_shape = str(next(iter(function_node.fields_accessed)).shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes + input_fields = function_node.fields_accessed + else: + output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes docstring = "TODO" # TODO