diff --git a/src/pystencils_autodiff/_autodiff.py b/src/pystencils_autodiff/_autodiff.py index 4edd071b606b6c733046306b85e048b992e37696..b9755a4aec8969b3da0282961824ed860489ddb3 100644 --- a/src/pystencils_autodiff/_autodiff.py +++ b/src/pystencils_autodiff/_autodiff.py @@ -542,8 +542,7 @@ Backward: self, inputfield_tensor_dict, forward_loop, backward_loop) elif backend == 'torch_native': import pystencils_autodiff.backends._torch_native - op = pystencils_autodiff.backends._torch_native.create_autograd_function( - self, inputfield_tensor_dict, None, None) + op = pystencils_autodiff.backends._torch_native.create_autograd_function(self, inputfield_tensor_dict) else: raise NotImplementedError() diff --git a/src/pystencils_autodiff/backends/_torch_native.py b/src/pystencils_autodiff/backends/_torch_native.py index 53619b5e2504b5926bcd2cf6defd026a5af14c73..61022bdc49717b86fd024de299212b9c3da246c4 100644 --- a/src/pystencils_autodiff/backends/_torch_native.py +++ b/src/pystencils_autodiff/backends/_torch_native.py @@ -1,20 +1,8 @@ -import os import uuid -from itertools import chain -from os.path import dirname, isdir, isfile, join - -import jinja2 -from appdirs import user_cache_dir - -import pystencils -import pystencils_autodiff -import pystencils_autodiff.backends._pytorch -from pystencils.astnodes import FieldShapeSymbol -from pystencils.backends.cbackend import generate_c -from pystencils.backends.cuda_backend import CudaSympyPrinter, generate_cuda -from pystencils.cpu.kernelcreation import create_kernel -from pystencils.gpucuda.kernelcreation import create_cuda_kernel +from collections import OrderedDict + from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch +from pystencils_autodiff.backends.astnodes import TorchModule try: import torch @@ -22,156 +10,50 @@ except ImportError: pass -def read_file(file): - with open(file, 'r') as f: - return f.read() - - -def write_file(filename, content): - with open(filename, 'w') as f: - return f.write(content) - - -def generate_torch(destination_folder, - autodiff: pystencils_autodiff.AutoDiffOp, - is_cuda, - dtype, - forward_ast=None, - backward_ast=None): - shape = autodiff.forward_output_fields[0].spatial_shape - operation_hash = abs(hash(autodiff) + hash(shape) + hash(str(dtype))) - operation_string = "{}_native_{}_{}_{:x}".format( - autodiff.op_name, 'cuda' if is_cuda else 'cpu', 'x'.join(str(s) for s in shape), operation_hash) - - cpp_file = join(destination_folder, operation_string + '.cpp') - cuda_kernel_file = join(destination_folder, operation_string + '.cu') - - required_files = [cpp_file, cuda_kernel_file] if is_cuda else [cpp_file] - - if not all(isfile(x) for x in required_files): - generate_ast = create_cuda_kernel if is_cuda else create_kernel - generate_code = generate_cuda if is_cuda else generate_c - - if not forward_ast: - forward_ast = generate_ast(autodiff.forward_assignments.all_assignments) - if not backward_ast: - backward_ast = generate_ast(autodiff.backward_assignments.all_assignments) - - forward_ast.subs({s: FieldShapeSymbol( - [autodiff.forward_output_fields[0].name], s.coordinate) for s in forward_ast.atoms(FieldShapeSymbol)}) - backward_ast.subs({s: FieldShapeSymbol( - [autodiff.backward_output_fields[0].name], s.coordinate) for s in backward_ast.atoms(FieldShapeSymbol)}) - # backward_ast.subs({s: FieldStrideSymbol( - # autodiff.forward_input_fields[0].name, s.coordinate) for s in forward_ast.atoms(FieldStrideSymbol)}) - - forward_code = generate_code(forward_ast.body).replace( - 'float *', 'scalar_t *').replace('double *', 'scalar_t *') - backward_code = generate_code(backward_ast.body).replace( - 'float *', 'scalar_t *').replace('double *', 'scalar_t *') - - if is_cuda: - printer = CudaSympyPrinter() - block_and_thread_numbers = forward_ast.indexing.call_parameters(shape) - forward_block = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['block']) - forward_grid = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['grid']) - backward_shape = autodiff.backward_output_fields[0].spatial_shape - block_and_thread_numbers = backward_ast.indexing.call_parameters(backward_shape) - backward_block = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['block']) - backward_grid = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['grid']) - cuda_globals = pystencils.backends.cbackend.get_global_declarations(forward_ast) | \ - pystencils.backends.cbackend.get_global_declarations(backward_ast) - cuda_globals = [generate_cuda(g) for g in cuda_globals] - else: - backward_block = forward_block = "INVALID" - backward_grid = forward_grid = "INVALID" - cuda_globals = "" - - render_dict = { - "forward_tensors": [f for f in autodiff.forward_fields], - "forward_input_tensors": [f for f in autodiff.forward_input_fields], - "forward_output_tensors": [f for f in autodiff.forward_output_fields], - "backward_tensors": [f for f in autodiff.backward_fields + autodiff.forward_input_fields], - "backward_input_tensors": [f for f in autodiff.backward_input_fields], - "backward_output_tensors": [f for f in autodiff.backward_output_fields], - "forward_kernel": forward_code, - "backward_kernel": backward_code, - "dimensions": range(autodiff.forward_fields[0].spatial_dimensions), - "kernel_name": operation_string, - "forward_threads": "{" + forward_block + "}", - "forward_blocks": "{" + forward_grid + "}", - "backward_threads": "{" + backward_block + "}", - "backward_blocks": "{" + backward_grid + "}", - "cuda_globals": cuda_globals, - "dtype": pystencils.data_types.BasicType(dtype) - } - - if is_cuda: - template_string_cpp = read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cpp')) - template = jinja2.Template(template_string_cpp) - output = template.render(render_dict) - write_file(join(destination_folder, operation_string + '.cpp'), output) - - template_string = read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu')) - template = jinja2.Template(template_string) - output = template.render(render_dict) - write_file(join(destination_folder, operation_string + '.cu'), output) - else: - template_string_cpp = read_file(join(dirname(__file__), 'torch_native_cpu.tmpl.cpp')) - template = jinja2.Template(template_string_cpp) - output = template.render(render_dict) - write_file(join(destination_folder, operation_string + '.cpp'), output) - - from torch.utils.cpp_extension import load - compiled_operation = load(operation_string, required_files, verbose=True, - extra_cuda_cflags=[] if is_cuda else []) - compiled_operation.code = output - return compiled_operation - - -def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict, forward_loop=None, backward_loop=None): - if forward_loop is None: - assert backward_loop is None - is_cuda = all(t.is_cuda for t in inputfield_to_tensor_dict.values()) - assert all(t.is_cuda for t in inputfield_to_tensor_dict.values()) or \ - all(not t.is_cuda for t in inputfield_to_tensor_dict.values()), "All tensor should be on GPU or all on CPU" - dtype = pystencils_autodiff.backends._pytorch.torch_dtype_to_numpy( - list(inputfield_to_tensor_dict.values())[0].dtype) - - cache_dir = user_cache_dir('pystencils') - if not isdir(cache_dir): - os.mkdir(cache_dir) - # TODO: create function and stuff - - compiled_operation = generate_torch(cache_dir, autodiff_obj, is_cuda, dtype) - field_to_tensor_dict = inputfield_to_tensor_dict - # Allocate output tensor for forward and backward pass - for field in chain(autodiff_obj.forward_output_fields, autodiff_obj.backward_output_fields): - field_to_tensor_dict[field] = torch.zeros( - *field.shape, - dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), - device=list(inputfield_to_tensor_dict.values())[0].device) - - def forward(self): - self.saved = {f: field_to_tensor_dict[f] for f in chain( - autodiff_obj.forward_input_fields, autodiff_obj.backward_output_fields)} - compiled_operation.forward(**{f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_fields}) - return tuple(field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields) - - def backward(self, *grad_outputs): - self.saved.update({f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}) - compiled_operation.backward(**{f.name: t for f, t in self.saved.items()}) - return tuple(self.saved[f] for f in autodiff_obj.backward_output_fields) - - cls = type(str(uuid.uuid4()), (torch.autograd.Function,), {}) - cls.saved = None - cls.forward = forward - cls.backward = backward - cls.code = compiled_operation.code - return cls() +def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict): + field_to_tensor_dict = inputfield_to_tensor_dict + + # Allocate output tensor for forward and backward pass + for field in autodiff_obj.forward_output_fields + autodiff_obj.backward_output_fields: + field_to_tensor_dict[field] = torch.zeros( + *field.shape, + dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype), + device=list(inputfield_to_tensor_dict.values())[0].device) + + all_tensors = field_to_tensor_dict.values() + is_cuda = all(a.is_cuda for a in all_tensors) + + if is_cuda: + forward_ast = autodiff_obj.forward_ast_gpu + backward_ast = autodiff_obj.backward_ast_gpu else: - op = pystencils_autodiff.backends._pytorch.create_autograd_function(autodiff_obj, - inputfield_to_tensor_dict, - forward_loop, - backward_loop, - convert_tensors_to_arrays=False) - return op + forward_ast = autodiff_obj.forward_ast_cpu + backward_ast = autodiff_obj.backward_ast_cpu + + op_name = autodiff_obj.op_name + str(uuid.uuid4()) + compiled_op = TorchModule(op_name, [forward_ast, backward_ast]) + + output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields}) + backward_output_tensors = OrderedDict( + {f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields}) + + def forward(self, **input_tensors): + + self.save_for_backward(**input_tensors) + + getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors) + + return output_tensors.values() + + def backward(self, *grad_outputs): + gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)} + saved = self.saved_tensors + + getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors) + + return backward_output_tensors.values() + + cls = type(op_name, (torch.autograd.Function,), {}) + cls.forward = forward + cls.backward = backward + return cls