From 91ac7c844c5027e697282476fd39b6a49779438b Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Thu, 27 Feb 2020 11:40:07 +0100 Subject: [PATCH] Add fix for Tensorflow python bindings if we cannot determine output_fields --- src/pystencils_autodiff/backends/python_bindings.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pystencils_autodiff/backends/python_bindings.py b/src/pystencils_autodiff/backends/python_bindings.py index fd6f02b..d86c4f1 100644 --- a/src/pystencils_autodiff/backends/python_bindings.py +++ b/src/pystencils_autodiff/backends/python_bindings.py @@ -90,7 +90,7 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho """ # noqa ) - required_global_declarations = ["using namespace tensorflow;"] + required_global_declarations = [CustomCodeNode("using namespace tensorflow;", (), ())] headers = ['"tensorflow/core/framework/op.h"', '"tensorflow/core/framework/op_kernel.h"'] @@ -100,7 +100,11 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho input_field_names = [f.name for f in input_fields] output_field_names = [f.name for f in output_fields] parameters = function_node.get_parameters() - output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes + if not output_fields: + output_shape = str(next(iter(function_node.fields_accessed)).shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes + input_fields = function_node.fields_accessed + else: + output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}') # noqa, TODO make work for flexible sizes docstring = "TODO" # TODO -- GitLab