Skip to content
Snippets Groups Projects
Commit 3fbfa910 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Allow field construction from Torch tensors

parent 74f49da9
No related branches found
No related tags found
No related merge requests found
......@@ -7,9 +7,6 @@ from pystencils_autodiff._adjoint_field import AdjointField
from pystencils_autodiff._autodiff import (
AutoDiffAstPair, AutoDiffBoundaryHandling, AutoDiffOp, DiffModes, create_backward_assignments,
get_jacobian_of_assignments)
from pystencils_autodiff._field_to_tensors import ( # NOQA
tf_constant_from_field, tf_placeholder_from_field, tf_scalar_variable_from_field,
tf_variable_from_field, torch_tensor_from_field)
__all__ = ['backends',
'AdjointField',
......@@ -18,11 +15,9 @@ __all__ = ['backends',
'AutoDiffOp',
'AutoDiffAstPair',
'tensorflow_jit',
'tf_constant_from_field', 'tf_placeholder_from_field',
'tf_scalar_variable_from_field', 'tf_variable_from_field',
'torch_tensor_from_field',
'DiffModes',
'AutoDiffBoundaryHandling']
sys.modules['pystencils.autodiff'] = pystencils_autodiff
sys.modules['pystencils.autodiff.backends'] = pystencils_autodiff.backends
......@@ -10,6 +10,7 @@
import itertools
import pystencils
from pystencils.astnodes import KernelFunction, ResolvedFieldAccess, SympyAssignment
......@@ -24,6 +25,18 @@ def compatibility_hacks():
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
_pystencils_fields = pystencils.fields
def fields(*args, **kwargs):
try:
import torch
from pystencils_autodiff.field_tensor_conversion import _torch_tensor_to_numpy_shim
kwargs = {k: _torch_tensor_to_numpy_shim(v) for k, v in kwargs.items() if isinstance(v, torch.Tensor)}
except ImportError:
torch = None
return _pystencils_fields(*args, **kwargs)
pystencils.fields = fields
KernelFunction.fields_read = property(fields_read)
KernelFunction.fields_written = property(fields_written)
......
import numpy as np
import sympy
from pystencils import Field
class _WhatEverClass:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def _torch_tensor_to_numpy_shim(tensor):
from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy
fake_array = _WhatEverClass(
strides=[tensor.stride(i) for i in range(len(tensor.shape))],
shape=tensor.shape,
dtype=torch_dtype_to_numpy(tensor.dtype))
return fake_array
def _create_field_from_array_like(field_name, maybe_array):
try:
import torch
except ImportError:
torch = None
if torch:
# Torch tensors don't have t.strides but t.stride(dim). Let's fix that!
if isinstance(maybe_array, torch.Tensor):
maybe_array = _torch_tensor_to_numpy_shim(maybe_array)
return Field.create_from_numpy_array(field_name, maybe_array)
def coerce_to_field(field_name, array_like):
if isinstance(array_like, Field):
return array_like.new_field_with_different_name(field_name, array_like)
return _create_field_from_array_like(field_name, array_like)
def is_array_like(a):
import pycuda.gpuarray
return (hasattr(a, '__array__') or isinstance(a, pycuda.gpuarray.GPUArray)) and not isinstance(a, sympy.Matrix)
def tf_constant_from_field(field, init_val=0):
import tensorflow as tf
......
......@@ -16,6 +16,7 @@ from os.path import exists, join
import p_tqdm
import pystencils
import pystencils.gpucuda
from pystencils.cpu.cpujit import get_cache_config, get_compiler_config, get_pystencils_include_path
from pystencils_autodiff._file_io import read_file, write_file
......
......@@ -9,9 +9,9 @@ from os.path import dirname, isfile, join
import numpy as np
import pytest
import sympy
import pystencils
import sympy
from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import write_cached_content
from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
......@@ -147,6 +147,7 @@ def test_torch_native_compilation_gpu():
@pytest.mark.parametrize('target', ('gpu', 'cpu'))
def test_execute_torch(target):
import pycuda.autoinit
module_name = "Ololol" + target
z, y, x = pystencils.fields("z, y, x: [20,40]")
......@@ -234,3 +235,15 @@ def test_reproducability():
output_0 = new_output
assert output_0 == new_output
def test_fields_from_torch_tensor():
torch = pytest.importorskip('torch')
import torch
a, b = torch.zeros((20, 10)), torch.zeros((6, 7))
x, y = pystencils.fields(x=a, y=b)
print(x)
print(y)
c = torch.zeros((20, 10)).cuda()
z = pystencils.fields(z=c)
print(z)
......@@ -98,69 +98,6 @@ def test_tfmad_gradient_check():
assert any(e < 1e-4 for e in gradient_error.values())
def check_tfmad_vector_input_data(args):
dtype = args.dtype
domain_shape = args.domain_shape
ndim = len(domain_shape)
# create arrays
c_arr = np.zeros(domain_shape)
v_arr = np.zeros(domain_shape + (ndim, ))
# create fields
c, v, c_next = ps.fields("c, v(2), c_next: % s[ % i, % i]" %
("float" if dtype == np.float32 else "double",
domain_shape[0], domain_shape[1]),
c=c_arr,
v=v_arr,
c_next=c_arr)
# write down advection diffusion pde
# the equation is represented by a single term and an implicit "=0" is assumed.
adv_diff_pde = ps.fd.transient(c) - ps.fd.diffusion(
c, sp.Symbol("D")) + ps.fd.advection(c, v)
discretize = ps.fd.Discretization2ndOrder(args.dx, args.dt)
discretization = discretize(adv_diff_pde)
discretization = discretization.subs(sp.Symbol("D"),
args.diffusion_coefficient)
forward_assignments = ps.AssignmentCollection(
[ps.Assignment(c_next.center(), discretization)], [])
autodiff = pystencils_autodiff.AutoDiffOp(
forward_assignments,
diff_mode='transposed-forward') # , constant_fields=[v]
print('Forward assignments:')
print(autodiff.forward_assignments)
print('Backward assignments:')
print(autodiff.backward_assignments)
def test_tfmad_vector_input_data():
parser = argparse.ArgumentParser()
parser.add_argument('--domain_shape',
default=(100, 30),
nargs=2,
type=int,
help="")
parser.add_argument('--dx', default=1, type=float, help="")
parser.add_argument('--dt', default=0.01, type=float, help="")
parser.add_argument('--diffusion_coefficient',
default=1,
type=float,
help="")
parser.add_argument('--num_total_time_steps', default=100, type=int)
parser.add_argument('--num_time_steps_for_op', default=1, type=int)
parser.add_argument('--learning_rate', default=1e-2, type=float)
parser.add_argument('--dtype', default=np.float64, type=np.dtype)
parser.add_argument('--num_optimization_steps', default=2000, type=int)
parser.add_argument('vargs', nargs='*')
args = parser.parse_args()
check_tfmad_vector_input_data(args)
def test_tfmad_gradient_check_torch():
torch = pytest.importorskip('torch')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment