Skip to content
Snippets Groups Projects
Commit 30763da4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Merge commit '70d68574'

parents 1c2c660d 70d68574
No related branches found
No related tags found
No related merge requests found
Pipeline #17172 failed
......@@ -4,8 +4,8 @@ import pystencils_autodiff.backends # NOQA
from pystencils_autodiff._field_to_tensors import ( # NOQA
tf_constant_from_field, tf_placeholder_from_field, tf_scalar_variable_from_field,
tf_variable_from_field, torch_tensor_from_field)
from pystencils_autodiff.adjoint_field import AdjointField
from pystencils_autodiff.autodiff import (
from pystencils_autodiff._adjoint_field import AdjointField
from pystencils_autodiff._autodiff import (
AutoDiffAstPair, AutoDiffOp, create_backward_assignments, get_jacobian_of_assignments)
__all__ = ['backends',
......
import numpy as np
try:
import tensorflow as tf
except ImportError:
pass
def tf_constant_from_field(field, init_val=0):
return tf.constant(init_val, dtype=field.dtype.numpy_dtype, shape=field.shape, name=field.name + '_constant')
def tf_scalar_variable_from_field(field, init_val, constraint=None):
var = tf.Variable(init_val, dtype=field.dtype.numpy_dtype, name=field.name + '_variable', constraint=constraint)
return var * tf_constant_from_field(field, 1)
def tf_variable_from_field(field, init_val=0, constraint=None):
if isinstance(init_val, (int, float)):
init_val *= np.ones(field.shape, field.dtype.numpy_dtype)
return tf.Variable(init_val, dtype=field.dtype.numpy_dtype, name=field.name + '_variable', constraint=constraint)
def tf_placeholder_from_field(field):
return tf.placeholder(dtype=field.dtype.numpy_dtype, name=field.name + '_placeholder', shape=field.shape)
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import numpy as np
try:
import torch
except ImportError:
pass
def torch_tensor_from_field(field, init_val=0, cuda=True, requires_grad=False):
if isinstance(init_val, (int, float)):
init_val *= np.ones(field.shape, field.dtype.numpy_dtype)
device = torch.device('cuda' if cuda else 'cpu')
return torch.tensor(init_val, requires_grad=requires_grad, device=device)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment