Skip to content
Snippets Groups Projects
Commit 08c373f4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Rewrite _torch_native for AST generation of torch ops

parent b8ca3472
Branches
Tags
No related merge requests found
......@@ -542,8 +542,7 @@ Backward:
self, inputfield_tensor_dict, forward_loop, backward_loop)
elif backend == 'torch_native':
import pystencils_autodiff.backends._torch_native
op = pystencils_autodiff.backends._torch_native.create_autograd_function(
self, inputfield_tensor_dict, None, None)
op = pystencils_autodiff.backends._torch_native.create_autograd_function(self, inputfield_tensor_dict)
else:
raise NotImplementedError()
......
import os
import uuid
from itertools import chain
from os.path import dirname, isdir, isfile, join
import jinja2
from appdirs import user_cache_dir
import pystencils
import pystencils_autodiff
import pystencils_autodiff.backends._pytorch
from pystencils.astnodes import FieldShapeSymbol
from pystencils.backends.cbackend import generate_c
from pystencils.backends.cuda_backend import CudaSympyPrinter, generate_cuda
from pystencils.cpu.kernelcreation import create_kernel
from pystencils.gpucuda.kernelcreation import create_cuda_kernel
from collections import OrderedDict
from pystencils_autodiff.backends._pytorch import numpy_dtype_to_torch
from pystencils_autodiff.backends.astnodes import TorchModule
try:
import torch
......@@ -22,156 +10,50 @@ except ImportError:
pass
def read_file(file):
with open(file, 'r') as f:
return f.read()
def write_file(filename, content):
with open(filename, 'w') as f:
return f.write(content)
def generate_torch(destination_folder,
autodiff: pystencils_autodiff.AutoDiffOp,
is_cuda,
dtype,
forward_ast=None,
backward_ast=None):
shape = autodiff.forward_output_fields[0].spatial_shape
operation_hash = abs(hash(autodiff) + hash(shape) + hash(str(dtype)))
operation_string = "{}_native_{}_{}_{:x}".format(
autodiff.op_name, 'cuda' if is_cuda else 'cpu', 'x'.join(str(s) for s in shape), operation_hash)
cpp_file = join(destination_folder, operation_string + '.cpp')
cuda_kernel_file = join(destination_folder, operation_string + '.cu')
required_files = [cpp_file, cuda_kernel_file] if is_cuda else [cpp_file]
if not all(isfile(x) for x in required_files):
generate_ast = create_cuda_kernel if is_cuda else create_kernel
generate_code = generate_cuda if is_cuda else generate_c
if not forward_ast:
forward_ast = generate_ast(autodiff.forward_assignments.all_assignments)
if not backward_ast:
backward_ast = generate_ast(autodiff.backward_assignments.all_assignments)
forward_ast.subs({s: FieldShapeSymbol(
[autodiff.forward_output_fields[0].name], s.coordinate) for s in forward_ast.atoms(FieldShapeSymbol)})
backward_ast.subs({s: FieldShapeSymbol(
[autodiff.backward_output_fields[0].name], s.coordinate) for s in backward_ast.atoms(FieldShapeSymbol)})
# backward_ast.subs({s: FieldStrideSymbol(
# autodiff.forward_input_fields[0].name, s.coordinate) for s in forward_ast.atoms(FieldStrideSymbol)})
forward_code = generate_code(forward_ast.body).replace(
'float *', 'scalar_t *').replace('double *', 'scalar_t *')
backward_code = generate_code(backward_ast.body).replace(
'float *', 'scalar_t *').replace('double *', 'scalar_t *')
if is_cuda:
printer = CudaSympyPrinter()
block_and_thread_numbers = forward_ast.indexing.call_parameters(shape)
forward_block = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['block'])
forward_grid = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['grid'])
backward_shape = autodiff.backward_output_fields[0].spatial_shape
block_and_thread_numbers = backward_ast.indexing.call_parameters(backward_shape)
backward_block = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['block'])
backward_grid = ', '.join(printer.doprint(i) for i in block_and_thread_numbers['grid'])
cuda_globals = pystencils.backends.cbackend.get_global_declarations(forward_ast) | \
pystencils.backends.cbackend.get_global_declarations(backward_ast)
cuda_globals = [generate_cuda(g) for g in cuda_globals]
else:
backward_block = forward_block = "INVALID"
backward_grid = forward_grid = "INVALID"
cuda_globals = ""
render_dict = {
"forward_tensors": [f for f in autodiff.forward_fields],
"forward_input_tensors": [f for f in autodiff.forward_input_fields],
"forward_output_tensors": [f for f in autodiff.forward_output_fields],
"backward_tensors": [f for f in autodiff.backward_fields + autodiff.forward_input_fields],
"backward_input_tensors": [f for f in autodiff.backward_input_fields],
"backward_output_tensors": [f for f in autodiff.backward_output_fields],
"forward_kernel": forward_code,
"backward_kernel": backward_code,
"dimensions": range(autodiff.forward_fields[0].spatial_dimensions),
"kernel_name": operation_string,
"forward_threads": "{" + forward_block + "}",
"forward_blocks": "{" + forward_grid + "}",
"backward_threads": "{" + backward_block + "}",
"backward_blocks": "{" + backward_grid + "}",
"cuda_globals": cuda_globals,
"dtype": pystencils.data_types.BasicType(dtype)
}
if is_cuda:
template_string_cpp = read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cpp'))
template = jinja2.Template(template_string_cpp)
output = template.render(render_dict)
write_file(join(destination_folder, operation_string + '.cpp'), output)
template_string = read_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu'))
template = jinja2.Template(template_string)
output = template.render(render_dict)
write_file(join(destination_folder, operation_string + '.cu'), output)
else:
template_string_cpp = read_file(join(dirname(__file__), 'torch_native_cpu.tmpl.cpp'))
template = jinja2.Template(template_string_cpp)
output = template.render(render_dict)
write_file(join(destination_folder, operation_string + '.cpp'), output)
from torch.utils.cpp_extension import load
compiled_operation = load(operation_string, required_files, verbose=True,
extra_cuda_cflags=[] if is_cuda else [])
compiled_operation.code = output
return compiled_operation
def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict, forward_loop=None, backward_loop=None):
if forward_loop is None:
assert backward_loop is None
is_cuda = all(t.is_cuda for t in inputfield_to_tensor_dict.values())
assert all(t.is_cuda for t in inputfield_to_tensor_dict.values()) or \
all(not t.is_cuda for t in inputfield_to_tensor_dict.values()), "All tensor should be on GPU or all on CPU"
dtype = pystencils_autodiff.backends._pytorch.torch_dtype_to_numpy(
list(inputfield_to_tensor_dict.values())[0].dtype)
cache_dir = user_cache_dir('pystencils')
if not isdir(cache_dir):
os.mkdir(cache_dir)
# TODO: create function and stuff
compiled_operation = generate_torch(cache_dir, autodiff_obj, is_cuda, dtype)
field_to_tensor_dict = inputfield_to_tensor_dict
# Allocate output tensor for forward and backward pass
for field in chain(autodiff_obj.forward_output_fields, autodiff_obj.backward_output_fields):
field_to_tensor_dict[field] = torch.zeros(
*field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=list(inputfield_to_tensor_dict.values())[0].device)
def forward(self):
self.saved = {f: field_to_tensor_dict[f] for f in chain(
autodiff_obj.forward_input_fields, autodiff_obj.backward_output_fields)}
compiled_operation.forward(**{f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_fields})
return tuple(field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields)
def backward(self, *grad_outputs):
self.saved.update({f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)})
compiled_operation.backward(**{f.name: t for f, t in self.saved.items()})
return tuple(self.saved[f] for f in autodiff_obj.backward_output_fields)
cls = type(str(uuid.uuid4()), (torch.autograd.Function,), {})
cls.saved = None
cls.forward = forward
cls.backward = backward
cls.code = compiled_operation.code
return cls()
def create_autograd_function(autodiff_obj, inputfield_to_tensor_dict):
field_to_tensor_dict = inputfield_to_tensor_dict
# Allocate output tensor for forward and backward pass
for field in autodiff_obj.forward_output_fields + autodiff_obj.backward_output_fields:
field_to_tensor_dict[field] = torch.zeros(
*field.shape,
dtype=numpy_dtype_to_torch(field.dtype.numpy_dtype),
device=list(inputfield_to_tensor_dict.values())[0].device)
all_tensors = field_to_tensor_dict.values()
is_cuda = all(a.is_cuda for a in all_tensors)
if is_cuda:
forward_ast = autodiff_obj.forward_ast_gpu
backward_ast = autodiff_obj.backward_ast_gpu
else:
op = pystencils_autodiff.backends._pytorch.create_autograd_function(autodiff_obj,
inputfield_to_tensor_dict,
forward_loop,
backward_loop,
convert_tensors_to_arrays=False)
return op
forward_ast = autodiff_obj.forward_ast_cpu
backward_ast = autodiff_obj.backward_ast_cpu
op_name = autodiff_obj.op_name + str(uuid.uuid4())
compiled_op = TorchModule(op_name, [forward_ast, backward_ast])
output_tensors = OrderedDict({f.name: field_to_tensor_dict[f] for f in autodiff_obj.forward_output_fields})
backward_output_tensors = OrderedDict(
{f.name: field_to_tensor_dict[f] for f in autodiff_obj.backward_output_fields})
def forward(self, **input_tensors):
self.save_for_backward(**input_tensors)
getattr(compiled_op, "call_" + forward_ast.function_name)(**input_tensors, **output_tensors)
return output_tensors.values()
def backward(self, *grad_outputs):
gradients = {f.name: grad_outputs[i] for i, f in enumerate(autodiff_obj.backward_input_fields)}
saved = self.saved_tensors
getattr(compiled_op, "call_" + backward_ast.function_name)(**gradients, **saved, **backward_output_tensors)
return backward_output_tensors.values()
cls = type(op_name, (torch.autograd.Function,), {})
cls.forward = forward
cls.backward = backward
return cls
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment