Skip to content
Snippets Groups Projects
Commit 83912565 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Gradient-check works again for PyTorch

parent 1972dcf8
No related branches found
No related tags found
No related merge requests found
from collections import OrderedDict from collections import OrderedDict
import sympy as sp
from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch
from pystencils_autodiff.backends.astnodes import TorchModule from pystencils_autodiff.backends.astnodes import TorchModule
from pystencils_autodiff.tensorflow_jit import _hash from pystencils_autodiff.tensorflow_jit import _hash
...@@ -33,31 +32,47 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -33,31 +32,47 @@ def create_autograd_function(autodiff_obj, use_cuda):
# backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0] # backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0]
# forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()] # forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()]
# backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()] # backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()]
class_kwargs = dict()
def forward(self, *args): def forward(self, *args, **kwargs):
kwargs.update(class_kwargs)
# TODO: drop contiguous requirement
if use_cuda: if use_cuda:
args = [a.contiguous().cuda() for a in args] args = [a.contiguous().cuda() for a in args]
kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()}
else: else:
args = [a.contiguous().cpu() for a in args] args = [a.contiguous().cpu() for a in args]
kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()}
# assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)
# if not any(isinstance(s, sp.Symbol) for s in args[i].shape))
# assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim))
# for i, f in enumerate(autodiff_obj.forward_input_fields))
# for field in autodiff_obj.forward_output_fields:
# field_to_tensor_dict[field] = torch.zeros(
# field.shape,
# dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
# device=args[0].device)
kwargs.update({f.name: args[i] for i, f in enumerate(
autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed if i < len(args)})
input_tensors = dict()
input_tensors.update({f.name: args[i] for i, f in enumerate(
autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed})
assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)
if not any(isinstance(s, sp.Symbol) for s in args[i].shape))
assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim))
for i, f in enumerate(autodiff_obj.forward_input_fields))
for field in autodiff_obj.forward_output_fields: for field in autodiff_obj.forward_output_fields:
field_to_tensor_dict[field] = torch.zeros( if field.name not in kwargs:
field.shape, kwargs[field.name] = torch.zeros(
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), field.shape,
device=args[0].device) dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields}) device=args[0].device)
output_tensors = OrderedDict({f.name:
field_to_tensor_dict.get(f, kwargs[f.name])
for f in autodiff_obj.forward_output_fields})
field_to_tensor_dict.update(kwargs)
kwargs.update(output_tensors)
self.save_for_backward(*args) self.saved_for_backward = kwargs
getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors) getattr(compiled_op, "call_" + forward_ast.function_name)(**kwargs)
return tuple(output_tensors.values()) return tuple(output_tensors.values())
...@@ -70,23 +85,24 @@ def create_autograd_function(autodiff_obj, use_cuda): ...@@ -70,23 +85,24 @@ def create_autograd_function(autodiff_obj, use_cuda):
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields)) assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(autodiff_obj.backward_input_fields)) for i, f in enumerate(autodiff_obj.backward_input_fields))
assert all(a.is_cuda == use_cuda for a in grad_outputs), f"Some of the tensors where on the wrong device. " \ assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. "
f"Op was compiled for CUDA: {str(use_cuda)}" f"Op was compiled for CUDA: {str(use_cuda)}"
saved = {f.name: self.saved_tensors[i] for i, f in enumerate(
autodiff_obj.forward_input_fields) if f in backward_ast.fields_accessed}
for field in autodiff_obj.backward_output_fields:
field_to_tensor_dict[field] = torch.zeros(
field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=grad_outputs[0].device)
backward_output_tensors = OrderedDict( for field in autodiff_obj.backward_output_fields:
{f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields}) backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape,
getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors) dtype=numpy_dtype_to_torch(
field.dtype.numpy_dtype),
device=grad_outputs[0].device)
for f in autodiff_obj.backward_output_fields})
field_names = [f.name for f in backward_ast.fields_accessed]
kwargs = {**gradients, **self.saved_for_backward, **backward_output_tensors}
kwargs = {k: v for k, v in kwargs.items() if k in field_names}
getattr(compiled_op, "call_" + backward_ast.function_name)(**kwargs)
return tuple(backward_output_tensors.values()) return tuple(backward_output_tensors.values())
cls = type(op_name, (torch.autograd.Function,), {}) cls = type(op_name, (torch.autograd.Function,), {})
cls.class_kwargs = class_kwargs
cls.forward = forward cls.forward = forward
cls.backward = backward cls.backward = backward
cls.kernel = forward cls.kernel = forward
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment