diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 61022bdc49717b86fd024de299212b9c3da246c4..b574c301ee4e41bac8b035f8b378f8b19e7a25b3 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -1,8 +1,8 @@ -import uuid from collections import OrderedDict from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch from pystencils_autodiff.backends.astnodes import TorchModule +from pystencils_autodiff.tensorflow_jit import _hash try: import torch @@ -30,28 +30,55 @@ def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict): forward_ast = autodiff_obj.forward_ast_cpu backward_ast = autodiff_obj.backward_ast_cpu - op_name = autodiff_obj.op_name + str(uuid.uuid4()) - compiled_op = TorchModule(op_name, [forward_ast, backward_ast]) - - output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields}) - backward_output_tensors = OrderedDict( - {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields}) - - def forward(self, **input_tensors): - - self.save_for_backward(**input_tensors) + op_name = f'{autodiff_obj.op_name}_{_hash(str(autodiff_obj).encode()).hexdigest()}' + module = TorchModule(op_name, [forward_ast, backward_ast]) + compiled_op = module.compile() + + # print(TorchModule(op_name, [forward_ast, backward_ast])) + # wrapper = module.atoms(WrapperFunction) + # forward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + forward_ast.function_name][0] + # backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0] + # forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()] + # backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()] + + def forward(self, *args): + input_tensors = dict() + input_tensors.update({f.name: args[i] for i, f in enumerate( + autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed}) + assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields)) + assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim)) + for i, f in enumerate(autodiff_obj.forward_input_fields)) + for field in autodiff_obj.forward_output_fields: + field_to_tensor_dict[field] = torch.zeros( + field.shape, + dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + device=list(inputfield_to_tensor_dict.values())[0].device) + output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields}) + + self.save_for_backward(*args) getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors) - return output_tensors.values() + return tuple(output_tensors.values()) def backward(self, *grad_outputs): gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)} - saved = self.saved_tensors - + assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields)) + assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) + for i, f in enumerate(autodiff_obj.backward_input_fields)) + saved = {f.name: self.saved_tensors[i] for i, f in enumerate( + autodiff_obj.forward_input_fields) if f in backward_ast.fields_accessed} + for field in autodiff_obj.backward_output_fields: + field_to_tensor_dict[field] = torch.zeros( + field.shape, + dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + device=list(inputfield_to_tensor_dict.values())[0].device) + + backward_output_tensors = OrderedDict( + {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields}) getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors) - return backward_output_tensors.values() + return tuple(backward_output_tensors.values()) cls = type(op_name, (torch.autograd.Function,), {}) cls.forward = forward diff --git a/tests/backends/test_torch_native_compilation.py b/tests/backends/test_torch_native_compilation.py index 2d01db730754f57f15013431a342232ae87478d0..5814ca50eb93dd1af056909636360f73fcbfcb02 100644 --- a/tests/backends/test_torch_native_compilation.py +++ b/tests/backends/test_torch_native_compilation.py @@ -13,7 +13,7 @@ import sympy import pystencils from pystencils_autodiff import create_backward_assignments -from pystencils_autodiff._file_io import write_cached_content, write_file +from pystencils_autodiff._file_io import write_cached_content from pystencils_autodiff.backends.astnodes import PybindModule, TorchModule torch = pytest.importorskip('torch') diff --git a/tests/test_tfmad.py b/tests/test_tfmad.py index 7dd10e1d6a642b127744f07593fa02332e3cccb4..7065d9c0c3e65f14f5b02f074f99153566367304 100644 --- a/tests/test_tfmad.py +++ b/tests/test_tfmad.py @@ -165,10 +165,10 @@ def test_tfmad_gradient_check_torch(): a, b, out = ps.fields("a, b, out: float[21,13]") - cont = ps.fd.Diff(a, 0) - ps.fd.Diff(a, 1) - \ - ps.fd.Diff(b, 0) + ps.fd.Diff(b, 1) + cont = 2*ps.fd.Diff(a, 0) - 1.5*ps.fd.Diff(a, 1) - \ + ps.fd.Diff(b, 0) + 3 * ps.fd.Diff(b, 1) discretize = ps.fd.Discretization2ndOrder(dx=1) - discretization = discretize(cont) + discretization = discretize(cont) + 1.2*a.center assignment = ps.Assignment(out.center(), discretization) assignment_collection = ps.AssignmentCollection([assignment], []) @@ -189,12 +189,51 @@ def test_tfmad_gradient_check_torch(): function = auto_diff.create_tensorflow_op({ a: a_tensor, b: b_tensor - }, - backend='torch') + }, backend='torch') torch.autograd.gradcheck(function.apply, [a_tensor, b_tensor]) +@pytest.mark.parametrize('with_offsets', (True, False)) +def test_tfmad_gradient_check_torch_native(with_offsets): + torch = pytest.importorskip('torch') + + a, b, out = ps.fields("a, b, out: float64[21,13]") + + if with_offsets: + cont = 2*ps.fd.Diff(a, 0) - 1.5*ps.fd.Diff(a, 1) - ps.fd.Diff(b, 0) + 3 * ps.fd.Diff(b, 1) + discretize = ps.fd.Discretization2ndOrder(dx=1) + discretization = discretize(cont) + + assignment = ps.Assignment(out.center(), discretization + 1.2*a.center()) + else: + assignment = ps.Assignment(out.center(), 1.2*a.center + 0.1*b.center) + assignment_collection = ps.AssignmentCollection([assignment], []) + print('Forward') + print(assignment_collection) + + print('Backward') + auto_diff = pystencils_autodiff.AutoDiffOp(assignment_collection, + diff_mode='transposed-forward') + backward = auto_diff.backward_assignments + print(backward) + print('Forward output fields (to check order)') + print(auto_diff.forward_input_fields) + + a_tensor = torch.zeros(*a.shape, dtype=torch.float64, requires_grad=True).contiguous() + b_tensor = torch.zeros(*b.shape, dtype=torch.float64, requires_grad=True).contiguous() + + dict = { + a: a_tensor, + b: b_tensor + } + function = auto_diff.create_tensorflow_op(dict, backend='torch_native') + + import torch + torch.autograd.gradcheck(function.apply, tuple( + [dict[f] for f in auto_diff.forward_input_fields]), atol=1e-4, raise_exception=True) + + def get_curl(input_field: ps.Field, curl_field: ps.Field): """Return a ps.AssignmentCollection describing the calculation of the curl given a 2d or 3d vector field [z,y,x](f) or [y,x](f) @@ -256,3 +295,7 @@ def test_tfmad_two_outputs(): print('Backward') print(curl_op.backward_assignments) + + +if __name__ == '__main__': + test_tfmad_gradient_check_torch_native()