Skip to content
Snippets Groups Projects
Commit dc5023b5 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add test for ugly hack

parent f280ee36
Branches
Tags
No related merge requests found
Pipeline #18973 failed
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
import sympy import sympy
from pystencils import Field from pystencils import Field
from pystencils.field import FieldType
class _WhatEverClass: class _WhatEverClass:
...@@ -10,9 +11,13 @@ class _WhatEverClass: ...@@ -10,9 +11,13 @@ class _WhatEverClass:
class ArrayWithIndexDimensions: class ArrayWithIndexDimensions:
def __init__(self, array, index_dimensions): def __init__(self,
array,
index_dimensions,
field_type=FieldType.GENERIC):
self.array = array self.array = array
self.index_dimensions = index_dimensions self.index_dimensions = index_dimensions
self.field_type = field_type
def __array__(self): def __array__(self):
return self.array return self.array
...@@ -34,9 +39,11 @@ def _torch_tensor_to_numpy_shim(tensor): ...@@ -34,9 +39,11 @@ def _torch_tensor_to_numpy_shim(tensor):
def create_field_from_array_like(field_name, maybe_array): def create_field_from_array_like(field_name, maybe_array):
if isinstance(maybe_array, ArrayWithIndexDimensions): if isinstance(maybe_array, ArrayWithIndexDimensions):
index_dimensions = maybe_array.index_dimensions index_dimensions = maybe_array.index_dimensions
field_type = maybe_array.field_type
maybe_array = maybe_array.array maybe_array = maybe_array.array
else: else:
index_dimensions = 0 index_dimensions = 0
field_type = FieldType.GENERIC
try: try:
import torch import torch
...@@ -48,7 +55,9 @@ def create_field_from_array_like(field_name, maybe_array): ...@@ -48,7 +55,9 @@ def create_field_from_array_like(field_name, maybe_array):
if isinstance(maybe_array, torch.Tensor): if isinstance(maybe_array, torch.Tensor):
maybe_array = _torch_tensor_to_numpy_shim(maybe_array) maybe_array = _torch_tensor_to_numpy_shim(maybe_array)
return Field.create_from_numpy_array(field_name, maybe_array, index_dimensions) field = Field.create_from_numpy_array(field_name, maybe_array, index_dimensions)
field.field_type = field_type
return field
def coerce_to_field(field_name, array_like): def coerce_to_field(field_name, array_like):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment