Skip to content
Snippets Groups Projects
Commit 91ac7c84 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add fix for Tensorflow python bindings if we cannot determine output_fields

parent d05067ad
Branches master
No related tags found
No related merge requests found
...@@ -90,7 +90,7 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho ...@@ -90,7 +90,7 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
""" # noqa """ # noqa
) )
required_global_declarations = ["using namespace tensorflow;"] required_global_declarations = [CustomCodeNode("using namespace tensorflow;", (), ())]
headers = ['"tensorflow/core/framework/op.h"', headers = ['"tensorflow/core/framework/op.h"',
'"tensorflow/core/framework/op_kernel.h"'] '"tensorflow/core/framework/op_kernel.h"']
...@@ -100,7 +100,11 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho ...@@ -100,7 +100,11 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
input_field_names = [f.name for f in input_fields] input_field_names = [f.name for f in input_fields]
output_field_names = [f.name for f in output_fields] output_field_names = [f.name for f in output_fields]
parameters = function_node.get_parameters() parameters = function_node.get_parameters()
output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes if not output_fields:
output_shape = str(next(iter(function_node.fields_accessed)).shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes
input_fields = function_node.fields_accessed
else:
output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes
docstring = "TODO" # TODO docstring = "TODO" # TODO
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment