From 703816fe09f5fb59d2a923bb7b55f24eed942be7 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Thu, 13 Feb 2020 08:58:45 +0100
Subject: [PATCH] Allow adding arrays from fields, warn only if run fails

---
 src/pystencils_autodiff/graph_datahandling.py | 69 ++++++++++++-------
 1 file changed, 44 insertions(+), 25 deletions(-)

diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py
index b0c6283..bc04f5e 100644
--- a/src/pystencils_autodiff/graph_datahandling.py
+++ b/src/pystencils_autodiff/graph_datahandling.py
@@ -15,6 +15,7 @@ import numpy as np
 import pystencils.datahandling
 import pystencils.kernel_wrapper
 import pystencils.timeloop
+from pystencils.data_types import create_type
 from pystencils.field import FieldType
 
 
@@ -149,7 +150,11 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
         def run(self, time_steps=1):
             former_call_queue = copy(self.parent.call_queue)
             self.parent.call_queue = []
-            super().run(time_steps)
+            try:
+                super().run(time_steps)
+            except Exception as e:
+                import warnings
+                warnings.warn(e)
             self.parent.call_queue = former_call_queue
             former_call_queue.append(TimeloopRun(self, time_steps))
 
@@ -181,16 +186,8 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
         if layout is None:
             layout = self.default_layout
 
-        rtn = super().add_array(name,
-                                values_per_cell,
-                                dtype,
-                                latex_name,
-                                ghost_layers,
-                                layout,
-                                cpu,
-                                gpu,
-                                alignment,
-                                field_type)
+        if gpu is None:
+            gpu = self.default_target in self._GPU_LIKE_TARGETS
 
         # Weird code happening in super class
         if not hasattr(values_per_cell, '__len__'):
@@ -198,21 +195,43 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling):
         if len(values_per_cell) == 1 and values_per_cell[0] == 1:
             values_per_cell = ()
 
-        if shape:
-            rtn = self._fields[name] = pystencils.Field.create_fixed_size(name,
-                                                                          shape,
-                                                                          index_dimensions=len(values_per_cell),
-                                                                          layout=layout,
-                                                                          dtype=dtype,
-                                                                          field_type=field_type)
+        if isinstance(name, pystencils.Field):
+            rtn = name
+            name = name.name
+            super().add_array(rtn.name,
+                              rtn.values_per_cell(),
+                              rtn.dtype.numpy_dtype,
+                              rtn.latex_name,
+                              1,
+                              cpu=cpu,
+                              gpu=gpu,
+                              field_type=rtn.field_type)
         else:
-            rtn = self._fields[name] = pystencils.Field.create_generic(name,
-                                                                       self.dim,
-                                                                       dtype,
-                                                                       index_dimensions=len(values_per_cell),
-                                                                       layout=layout,
-                                                                       index_shape=values_per_cell,
-                                                                       field_type=field_type)
+            rtn = super().add_array(name,
+                                    values_per_cell,
+                                    dtype,
+                                    latex_name,
+                                    ghost_layers,
+                                    layout,
+                                    cpu,
+                                    gpu,
+                                    alignment,
+                                    field_type)
+            if shape:
+                rtn = self._fields[name] = pystencils.Field.create_fixed_size(name,
+                                                                              shape,
+                                                                              index_dimensions=len(values_per_cell),
+                                                                              layout=layout,
+                                                                              dtype=dtype,
+                                                                              field_type=field_type)
+            else:
+                rtn = self._fields[name] = pystencils.Field.create_generic(name,
+                                                                           self.dim,
+                                                                           dtype,
+                                                                           index_dimensions=len(values_per_cell),
+                                                                           layout=layout,
+                                                                           index_shape=values_per_cell,
+                                                                           field_type=field_type)
 
         rtn.latex_name = latex_name
 
-- 
GitLab