diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 4822425f1165aec4cd45cb35a83c986287c5555a..4bef57312a4ca5c8e3e4d0f52330e3a3b60cf530 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -1,6 +1,5 @@ from collections import OrderedDict -import sympy as sp from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch from pystencils_autodiff.backends.astnodes import TorchModule from pystencils_autodiff.tensorflow_jit import _hash @@ -33,31 +32,47 @@ def create_autograd_function(autodiff_obj, use_cuda): # backward_wrapper_ast = [w for w in wrapper if w.function_name == "call_" + backward_ast.function_name][0] # forward_parameters = [str(p.symbol) for p in forward_wrapper_ast.get_parameters()] # backward_parameters = [str(p.symbol) for p in backward_wrapper_ast.get_parameters()] + class_kwargs = dict() - def forward(self, *args): + def forward(self, *args, **kwargs): + kwargs.update(class_kwargs) + # TODO: drop contiguous requirement if use_cuda: args = [a.contiguous().cuda() for a in args] + kwargs = {k: v.contiguous().cuda() for k, v in kwargs.items()} else: args = [a.contiguous().cpu() for a in args] + kwargs = {k: v.contiguous().cpu() for k, v in kwargs.items()} + + # assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields) + # if not any(isinstance(s, sp.Symbol) for s in args[i].shape)) + # assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim)) + # for i, f in enumerate(autodiff_obj.forward_input_fields)) + # for field in autodiff_obj.forward_output_fields: + # field_to_tensor_dict[field] = torch.zeros( + # field.shape, + # dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + # device=args[0].device) + + kwargs.update({f.name: args[i] for i, f in enumerate( + autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed if i < len(args)}) - input_tensors = dict() - input_tensors.update({f.name: args[i] for i, f in enumerate( - autodiff_obj.forward_input_fields) if f in forward_ast.fields_accessed}) - assert all(f.shape == args[i].shape for i, f in enumerate(autodiff_obj.forward_input_fields) - if not any(isinstance(s, sp.Symbol) for s in args[i].shape)) - assert all(f.strides == tuple(args[i].stride(j) for j in range(args[i].ndim)) - for i, f in enumerate(autodiff_obj.forward_input_fields)) for field in autodiff_obj.forward_output_fields: - field_to_tensor_dict[field] = torch.zeros( - field.shape, - dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), - device=args[0].device) - output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields}) + if field.name not in kwargs: + kwargs[field.name] = torch.zeros( + field.shape, + dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + device=args[0].device) + output_tensors = OrderedDict({f.name: + field_to_tensor_dict.get(f, kwargs[f.name]) + for f in autodiff_obj.forward_output_fields}) + field_to_tensor_dict.update(kwargs) + kwargs.update(output_tensors) - self.save_for_backward(*args) + self.saved_for_backward = kwargs - getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors) + getattr(compiled_op, "call_" + forward_ast.function_name)(**kwargs) return tuple(output_tensors.values()) @@ -70,23 +85,24 @@ def create_autograd_function(autodiff_obj, use_cuda): assert all(f.shape == grad_outputs[i].shape for i, f in enumerate(autodiff_obj.backward_input_fields)) assert all(f.strides == tuple(grad_outputs[i].stride(j) for j in range(grad_outputs[i].ndim)) for i, f in enumerate(autodiff_obj.backward_input_fields)) - assert all(a.is_cuda == use_cuda for a in grad_outputs), f"Some of the tensors where on the wrong device. " \ - f"Op was compiled for CUDA: {str(use_cuda)}" - saved = {f.name: self.saved_tensors[i] for i, f in enumerate( - autodiff_obj.forward_input_fields) if f in backward_ast.fields_accessed} - for field in autodiff_obj.backward_output_fields: - field_to_tensor_dict[field] = torch.zeros( - field.shape, - dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), - device=grad_outputs[0].device) + assert all(a.is_cuda == use_cuda for a in grad_outputs), "Some of the tensors where on the wrong device. " + f"Op was compiled for CUDA: {str(use_cuda)}" - backward_output_tensors = OrderedDict( - {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields}) - getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors) + for field in autodiff_obj.backward_output_fields: + backward_output_tensors = OrderedDict({f.name: torch.zeros(field.shape, + dtype=numpy_dtype_to_torch( + field.dtype.numpy_dtype), + device=grad_outputs[0].device) + for f in autodiff_obj.backward_output_fields}) + field_names = [f.name for f in backward_ast.fields_accessed] + kwargs = {**gradients, **self.saved_for_backward, **backward_output_tensors} + kwargs = {k: v for k, v in kwargs.items() if k in field_names} + getattr(compiled_op, "call_" + backward_ast.function_name)(**kwargs) return tuple(backward_output_tensors.values()) cls = type(op_name, (torch.autograd.Function,), {}) + cls.class_kwargs = class_kwargs cls.forward = forward cls.backward = backward cls.kernel = forward