Skip to content
Snippets Groups Projects
Commit 91ac7c84 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add fix for Tensorflow python bindings if we cannot determine output_fields

parent d05067ad
No related branches found
No related tags found
No related merge requests found
......@@ -90,7 +90,7 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
""" # noqa
)
required_global_declarations = ["using namespace tensorflow;"]
required_global_declarations = [CustomCodeNode("using namespace tensorflow;", (), ())]
headers = ['"tensorflow/core/framework/op.h"',
'"tensorflow/core/framework/op_kernel.h"']
......@@ -100,7 +100,11 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
input_field_names = [f.name for f in input_fields]
output_field_names = [f.name for f in output_fields]
parameters = function_node.get_parameters()
output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes
if not output_fields:
output_shape = str(next(iter(function_node.fields_accessed)).shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes
input_fields = function_node.fields_accessed
else:
output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes
docstring = "TODO" # TODO
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment