From 91ac7c844c5027e697282476fd39b6a49779438b Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 27 Feb 2020 11:40:07 +0100
Subject: [PATCH] Add fix for Tensorflow python bindings if we cannot determine
 output_fields

---
 src/pystencils_autodiff/backends/python_bindings.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/pystencils_autodiff/backends/python_bindings.py b/src/pystencils_autodiff/backends/python_bindings.py
index fd6f02b..d86c4f1 100644
--- a/src/pystencils_autodiff/backends/python_bindings.py
+++ b/src/pystencils_autodiff/backends/python_bindings.py
@@ -90,7 +90,7 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
         """  # noqa
     )
 
-    required_global_declarations = ["using namespace tensorflow;"]
+    required_global_declarations = [CustomCodeNode("using namespace tensorflow;", (), ())]
     headers = ['"tensorflow/core/framework/op.h"',
                '"tensorflow/core/framework/op_kernel.h"']
 
@@ -100,7 +100,11 @@ REGISTER_KERNEL_BUILDER(Name("{{ python_name }}").Device({{ device }}), {{ pytho
         input_field_names = [f.name for f in input_fields]
         output_field_names = [f.name for f in output_fields]
         parameters = function_node.get_parameters()
-        output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}')  # noqa,  TODO make work for flexible sizes
+        if not output_fields:
+            output_shape = str(next(iter(function_node.fields_accessed)).shape).replace('(', '{').replace(')', '}')  # noqa,  TODO make work for flexible sizes
+            input_fields = function_node.fields_accessed
+        else:
+            output_shape = str(output_fields[0].shape).replace('(', '{').replace(')', '}')  # noqa,  TODO make work for flexible sizes
 
         docstring = "TODO"  # TODO
 
-- 
GitLab