Skip to content
Snippets Groups Projects
Commit 0b31077f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Check gradient for native torch

parent 89562889
No related branches found
No related tags found
No related merge requests found
Pipeline #18136 failed
import uuid
from collections import OrderedDict
from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch
from pystencils_autodiff.backends.astnodes import TorchModule
from pystencils_autodiff.tensorflow_jit import _hash
try:
import torch
......@@ -30,28 +30,55 @@ def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict):
forward_ast = autodiff_obj.forward_ast_cpu
backward_ast = autodiff_obj.backward_ast_cpu
op_name = autodiff_obj.op_name + str(uuid.uuid4())
compiled_op = TorchModule(op_name, [forward_ast, backward_ast])
output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields})
backward_output_tensors = OrderedDict(
{f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields})
def forward(self, **input_tensors):
self.save_for_backward(**input_tensors)
op_name = f'{autodiff_obj.op_name}_{_hash(str(autodiff_obj).encode()).hexdigest()}'
module = TorchModule(op_name, [forward_ast, backward_ast])
compiled_op = module.compile()
# print(TorchModule(op_name, [forward_ast, backward_ast]))
# wrapper = module.atoms(WrapperFunction)
# forward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + forward_ast.function_name][0]
# backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0]
# forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()]
# backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()]
def forward(self, *args):
input_tensors = dict()
input_tensors.update({f.name: args[i] for i, f in enumerate(
autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed})
assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields))
assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim))
for i, f in enumerate(autodiff_obj.forward_input_fields))
for field in autodiff_obj.forward_output_fields:
field_to_tensor_dict[field] = torch.zeros(
field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=list(inputfield_to_tensor_dict.values())[0].device)
output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields})
self.save_for_backward(*args)
getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors)
return output_tensors.values()
return tuple(output_tensors.values())
def backward(self, *grad_outputs):
gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
saved = self.saved_tensors
assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields))
assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim))
for i, f in enumerate(autodiff_obj.backward_input_fields))
saved = {f.name: self.saved_tensors[i] for i, f in enumerate(
autodiff_obj.forward_input_fields) if f in backward_ast.fields_accessed}
for field in autodiff_obj.backward_output_fields:
field_to_tensor_dict[field] = torch.zeros(
field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=list(inputfield_to_tensor_dict.values())[0].device)
backward_output_tensors = OrderedDict(
{f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields})
getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors)
return backward_output_tensors.values()
return tuple(backward_output_tensors.values())
cls = type(op_name, (torch.autograd.Function,), {})
cls.forward = forward
......
......@@ -13,7 +13,7 @@ import sympy
import pystencils
from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff._file_io import write_cached_content, write_file
from pystencils_autodiff._file_io import write_cached_content
from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule
torch = pytest.importorskip('torch')
......
......@@ -165,10 +165,10 @@ def test_tfmad_gradient_check_torch():
a, b, out = ps.fields("a, b, out: float[21,13]")
cont = ps.fd.Diff(a, 0) - ps.fd.Diff(a, 1) - \
ps.fd.Diff(b, 0) + ps.fd.Diff(b, 1)
cont = 2*ps.fd.Diff(a, 0) - 1.5*ps.fd.Diff(a, 1) - \
ps.fd.Diff(b, 0) + 3 * ps.fd.Diff(b, 1)
discretize = ps.fd.Discretization2ndOrder(dx=1)
discretization = discretize(cont)
discretization = discretize(cont) + 1.2*a.center
assignment = ps.Assignment(out.center(), discretization)
assignment_collection = ps.AssignmentCollection([assignment], [])
......@@ -189,12 +189,51 @@ def test_tfmad_gradient_check_torch():
function = auto_diff.create_tensorflow_op({
a: a_tensor,
b: b_tensor
},
backend='torch')
}, backend='torch')
torch.autograd.gradcheck(function.apply, [a_tensor, b_tensor])
@pytest.mark.parametrize('with_offsets', (True, False))
def test_tfmad_gradient_check_torch_native(with_offsets):
torch = pytest.importorskip('torch')
a, b, out = ps.fields("a, b, out: float64[21,13]")
if with_offsets:
cont = 2*ps.fd.Diff(a, 0) - 1.5*ps.fd.Diff(a, 1) - ps.fd.Diff(b, 0) + 3 * ps.fd.Diff(b, 1)
discretize = ps.fd.Discretization2ndOrder(dx=1)
discretization = discretize(cont)
assignment = ps.Assignment(out.center(), discretization + 1.2*a.center())
else:
assignment = ps.Assignment(out.center(), 1.2*a.center + 0.1*b.center)
assignment_collection = ps.AssignmentCollection([assignment], [])
print('Forward')
print(assignment_collection)
print('Backward')
auto_diff = pystencils_autodiff.AutoDiffOp(assignment_collection,
diff_mode='transposed-forward')
backward = auto_diff.backward_assignments
print(backward)
print('Forward output fields (to check order)')
print(auto_diff.forward_input_fields)
a_tensor = torch.zeros(*a.shape, dtype=torch.float64, requires_grad=True).contiguous()
b_tensor = torch.zeros(*b.shape, dtype=torch.float64, requires_grad=True).contiguous()
dict = {
a: a_tensor,
b: b_tensor
}
function = auto_diff.create_tensorflow_op(dict, backend='torch_native')
import torch
torch.autograd.gradcheck(function.apply, tuple(
[dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True)
def get_curl(input_field: ps.Field, curl_field: ps.Field):
"""Return a ps.AssignmentCollection describing the calculation of
the curl given a 2d or 3d vector field [z,y,x](f) or [y,x](f)
......@@ -256,3 +295,7 @@ def test_tfmad_two_outputs():
print('Backward')
print(curl_op.backward_assignments)
if __name__ == '__main__':
test_tfmad_gradient_check_torch_native()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment