diff --git a/src/pystencils_autodiff/graph_datahandling.py b/src/pystencils_autodiff/graph_datahandling.py index b0c6283f6dabce8d64ee19e3bdf1cb160624035c..bc04f5e4a52b283f5ffcad7cb72b4e43a7db8993 100644 --- a/src/pystencils_autodiff/graph_datahandling.py +++ b/src/pystencils_autodiff/graph_datahandling.py @@ -15,6 +15,7 @@ import numpy as np import pystencils.datahandling import pystencils.kernel_wrapper import pystencils.timeloop +from pystencils.data_types import create_type from pystencils.field import FieldType @@ -149,7 +150,11 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): def run(self, time_steps=1): former_call_queue = copy(self.parent.call_queue) self.parent.call_queue = [] - super().run(time_steps) + try: + super().run(time_steps) + except Exception as e: + import warnings + warnings.warn(e) self.parent.call_queue = former_call_queue former_call_queue.append(TimeloopRun(self, time_steps)) @@ -181,16 +186,8 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): if layout is None: layout = self.default_layout - rtn = super().add_array(name, - values_per_cell, - dtype, - latex_name, - ghost_layers, - layout, - cpu, - gpu, - alignment, - field_type) + if gpu is None: + gpu = self.default_target in self._GPU_LIKE_TARGETS # Weird code happening in super class if not hasattr(values_per_cell, '__len__'): @@ -198,21 +195,43 @@ class GraphDataHandling(pystencils.datahandling.SerialDataHandling): if len(values_per_cell) == 1 and values_per_cell[0] == 1: values_per_cell = () - if shape: - rtn = self._fields[name] = pystencils.Field.create_fixed_size(name, - shape, - index_dimensions=len(values_per_cell), - layout=layout, - dtype=dtype, - field_type=field_type) + if isinstance(name, pystencils.Field): + rtn = name + name = name.name + super().add_array(rtn.name, + rtn.values_per_cell(), + rtn.dtype.numpy_dtype, + rtn.latex_name, + 1, + cpu=cpu, + gpu=gpu, + field_type=rtn.field_type) else: - rtn = self._fields[name] = pystencils.Field.create_generic(name, - self.dim, - dtype, - index_dimensions=len(values_per_cell), - layout=layout, - index_shape=values_per_cell, - field_type=field_type) + rtn = super().add_array(name, + values_per_cell, + dtype, + latex_name, + ghost_layers, + layout, + cpu, + gpu, + alignment, + field_type) + if shape: + rtn = self._fields[name] = pystencils.Field.create_fixed_size(name, + shape, + index_dimensions=len(values_per_cell), + layout=layout, + dtype=dtype, + field_type=field_type) + else: + rtn = self._fields[name] = pystencils.Field.create_generic(name, + self.dim, + dtype, + index_dimensions=len(values_per_cell), + layout=layout, + index_shape=values_per_cell, + field_type=field_type) rtn.latex_name = latex_name