Skip to content
Snippets Groups Projects
Commit 20012153 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add _field_to_tensors

parent 9d08fa1d
No related branches found
No related tags found
No related merge requests found
import pystencils_autodiff.backends
# from pystencils_autodiff._field_to_tensors import (
# tf_constant_from_field, tf_placeholder_from_field,
# tf_scalar_variable_from_field, tf_variable_from_field,
# torch_tensor_from_field)
from pystencils_autodiff._field_to_tensors import (
tf_constant_from_field, tf_placeholder_from_field, tf_scalar_variable_from_field,
tf_variable_from_field, torch_tensor_from_field)
from pystencils_autodiff.adjoint_field import AdjointField
from pystencils_autodiff.autodiff import (AutoDiffAstPair, AutoDiffOp,
create_backward_assignments,
get_jacobian_of_assignments)
from pystencils_autodiff.autodiff import (
AutoDiffAstPair, AutoDiffOp, create_backward_assignments, get_jacobian_of_assignments)
__all__ = ['backends',
'AdjointField',
......@@ -14,7 +12,6 @@ __all__ = ['backends',
'create_backward_assignments',
'AutoDiffOp',
'AutoDiffAstPair',
# "tf_constant_from_field", " tf_placeholder_from_field",
# "tf_scalar_variable_from_field", " tf_variable_from_field",
# "torch_tensor_from_field"
]
"tf_constant_from_field", " tf_placeholder_from_field",
"tf_scalar_variable_from_field", " tf_variable_from_field",
"torch_tensor_from_field"]
import numpy as np
try:
import tensorflow as tf
except ImportError:
pass
try:
import torch
except ImportError:
pass
def tf_constant_from_field(field, init_val=0):
return tf.constant(init_val, dtype=field.dtype.numpy_dtype, shape=field.shape, name=field.name + '_constant')
def tf_scalar_variable_from_field(field, init_val, constraint=None):
var = tf.Variable(init_val, dtype=field.dtype.numpy_dtype, name=field.name + '_variable', constraint=constraint)
return var * tf_constant_from_field(field, 1)
def tf_variable_from_field(field, init_val=0, constraint=None):
if isinstance(init_val, (int, float)):
init_val *= np.ones(field.shape, field.dtype.numpy_dtype)
return tf.Variable(init_val, dtype=field.dtype.numpy_dtype, name=field.name + '_variable', constraint=constraint)
def tf_placeholder_from_field(field):
return tf.placeholder(dtype=field.dtype.numpy_dtype, name=field.name + '_placeholder', shape=field.shape)
def torch_tensor_from_field(field, init_val=0, cuda=True, requires_grad=False):
if isinstance(init_val, (int, float)):
init_val *= np.ones(field.shape, field.dtype.numpy_dtype)
device = torch.device('cuda' if cuda else 'cpu')
return torch.tensor(init_val, requires_grad=requires_grad, device=device)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment