Skip to content
Snippets Groups Projects
Commit a5284d6e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add torch support to @crazy decorator

parent 67f71033
No related branches found
No related tags found
No related merge requests found
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
import inspect import inspect
from pystencils.autodiff.backends._pytorch import torch_dtype_to_numpy
from pystencils.field import Field from pystencils.field import Field
try: try:
...@@ -23,13 +24,28 @@ except ImportError: ...@@ -23,13 +24,28 @@ except ImportError:
tf = None tf = None
def _create_field_from_array_like(field_name, maybe_array):
if torch:
# Torch tensors don't have t.strides but t.stride(dim). Let's fix that!
if isinstance(maybe_array, torch.Tensor):
fake_array = object()
fake_array.strides = [maybe_array.stride(i) for i in range(len(maybe_array.shape))]
fake_array.shape = maybe_array.shape
fake_array.dtype = torch_dtype_to_numpy(maybe_array.dtype)
field = Field.create_from_numpy_array(field_name, fake_array)
return field
return Field.create_from_numpy_array(field_name, maybe_array)
def crazy(function): def crazy(function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
arg_names = inspect.getfullargspec(function).args arg_names = inspect.getfullargspec(function).args
compile_args = [Field.create_from_numpy_array( compile_args = [_create_field_from_array_like(
arg_names[i], a) if hasattr(a, '__array__') else a for i, a in enumerate(args)] arg_names[i], a) if hasattr(a, '__array__') else a for i, a in enumerate(args)]
compile_kwargs = {k: Field.create_from_numpy_array(str(k), a) if hasattr( compile_kwargs = {k: _create_field_from_array_like(str(k), a) if hasattr(
a, '__array__') else a for (k, a) in kwargs.items()} a, '__array__') else a for (k, a) in kwargs.items()}
assignments = function(*compile_args, **compile_kwargs) assignments = function(*compile_args, **compile_kwargs)
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
""" """
""" """
import torch
import pystencils import pystencils
from pystencils.autodiff import torch_tensor_from_field from pystencils.autodiff import torch_tensor_from_field
from pystencils_reco.filters import mean_filter from pystencils_reco.filters import mean_filter
...@@ -43,7 +45,22 @@ def test_pytorch_gpu(): ...@@ -43,7 +45,22 @@ def test_pytorch_gpu():
print(torch_op) print(torch_op)
def test_pytorch_from_tensors():
block_stencil = BallStencil(1, ndim=2)
x, y = pystencils.fields('x,y: float32[100,100]')
x_tensor = torch_tensor_from_field(x, requires_grad=True, cuda=True)
y_tensor = torch_tensor_from_field(y, cuda=True)
filter = mean_filter(x_tensor, y_tensor, block_stencil)
print(filter)
print(filter.backward())
torch_op = filter.create_pytorch_op(x=x_tensor+1, y=y_tensor)
print(torch_op)
def main(): def main():
# test_pytorch()
test_pytorch() test_pytorch()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment