Skip to content
Snippets Groups Projects
Commit 30763da4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Merge commit '70d68574'

parents 1c2c660d 70d68574
No related merge requests found
......@@ -4,8 +4,8 @@ import pystencils_autodiff.backends # NOQA
from pystencils_autodiff._field_to_tensors import ( # NOQA
tf_constant_from_field, tf_placeholder_from_field, tf_scalar_variable_from_field,
tf_variable_from_field, torch_tensor_from_field)
from pystencils_autodiff.adjoint_field import AdjointField
from pystencils_autodiff.autodiff import (
from pystencils_autodiff._adjoint_field import AdjointField
from pystencils_autodiff._autodiff import (
AutoDiffAstPair, AutoDiffOp, create_backward_assignments, get_jacobian_of_assignments)
__all__ = ['backends',
......
import numpy as np
try:
import tensorflow as tf
except ImportError:
pass
def tf_constant_from_field(field, init_val=0):
return tf.constant(init_val, dtype=field.dtype.numpy_dtype, shape=field.shape, name=field.name + '_constant')
def tf_scalar_variable_from_field(field, init_val, constraint=None):
var = tf.Variable(init_val, dtype=field.dtype.numpy_dtype, name=field.name + '_variable', constraint=constraint)
return var * tf_constant_from_field(field, 1)
def tf_variable_from_field(field, init_val=0, constraint=None):
if isinstance(init_val, (int, float)):
init_val *= np.ones(field.shape, field.dtype.numpy_dtype)
return tf.Variable(init_val, dtype=field.dtype.numpy_dtype, name=field.name + '_variable', constraint=constraint)
def tf_placeholder_from_field(field):
return tf.placeholder(dtype=field.dtype.numpy_dtype, name=field.name + '_placeholder', shape=field.shape)
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import numpy as np
try:
import torch
except ImportError:
pass
def torch_tensor_from_field(field, init_val=0, cuda=True, requires_grad=False):
if isinstance(init_val, (int, float)):
init_val *= np.ones(field.shape, field.dtype.numpy_dtype)
device = torch.device('cuda' if cuda else 'cpu')
return torch.tensor(init_val, requires_grad=requires_grad, device=device)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment