Skip to content
Snippets Groups Projects
astnodes.py 5.41 KiB
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.

"""

"""

import sys
from collections.abc import Iterable
from os.path import dirname, exists, join

import pystencils
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff._file_io import read_template_from_file, write_file
from pystencils_autodiff.backends.python_bindings import (
    PybindFunctionWrapping, PybindPythonBindings, TensorflowFunctionWrapping,
    TensorflowPythonBindings, TorchPythonBindings)
from pystencils_autodiff.framework_integration.astnodes import (
    DestructuringBindingsForFieldClass, JinjaCppFile, WrapperFunction, generate_kernel_call)
from pystencils_autodiff.tensorflow_jit import _hash


class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
    CLASS_TO_MEMBER_DICT = {
        FieldPointerSymbol: "data<{dtype}>()",
        FieldShapeSymbol: "size({dim})",
        FieldStrideSymbol: "strides()[{dim}]"
    }

    CLASS_NAME_TEMPLATE = "at::Tensor"

    headers = ["<ATen/ATen.h>"]


class TensorflowTensorDestructuring(DestructuringBindingsForFieldClass):
    CLASS_TO_MEMBER_DICT = {
        FieldPointerSymbol: "flat<{dtype}>().data()",
        FieldShapeSymbol: "dim_size({dim})",
        FieldStrideSymbol: "dim_size({dim}) * tensorflow::DataTypeSize({field_name}.dtype())"
    }

    CLASS_NAME_TEMPLATE = "tensorflow::Tensor"

    headers = ["<tensorflow/core/framework/tensor.h>",
               "<tensorflow/core/framework/types.h>",
               ]


class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
    CLASS_TO_MEMBER_DICT = {
        FieldPointerSymbol: "mutable_data()",
        FieldShapeSymbol: "shape({dim})",
        FieldStrideSymbol: "strides({dim})"
    }

    CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>"

    headers = ["<pybind11/numpy.h>"]


class TorchModule(JinjaCppFile):
    TEMPLATE = read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp'))
    DESTRUCTURING_CLASS = TorchTensorDestructuring
    PYTHON_BINDINGS_CLASS = TorchPythonBindings
    PYTHON_FUNCTION_WRAPPING_CLASS = PybindFunctionWrapping

    def __init__(self, module_name, kernel_asts):
        """Create a C++ module with forward and optional backward_kernels

        :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
        :param backward_kernel_ast:
        """
        if not isinstance(kernel_asts, Iterable):
            kernel_asts = [kernel_asts]
        wrapper_functions = [self.generate_wrapper_function(k) for k in kernel_asts]
        self.module_name = module_name

        ast_dict = {
            'kernels': kernel_asts,
            'kernel_wrappers': wrapper_functions,
            'python_bindings': self.PYTHON_BINDINGS_CLASS(module_name,
                                                          [self.PYTHON_FUNCTION_WRAPPING_CLASS(a)
                                                              for a in wrapper_functions])
        }

        super().__init__(ast_dict)

    def generate_wrapper_function(self, kernel_ast):
        return WrapperFunction(self.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
                               function_name='call_' + kernel_ast.function_name)

    def compile(self):
        from torch.utils.cpp_extension import load
        file_extension = '.cu' if self.is_cuda else '.cpp'
        source_code = str(self)
        hash = _hash(source_code.encode()).hexdigest()
        file_name = join(pystencils.cache.cache_dir, f'{hash}{file_extension}')

        if not exists(file_name):
            write_file(file_name, source_code)
        # TODO: propagate extra headers
        torch_extension = load(hash, [file_name], with_cuda=self.is_cuda)
        return torch_extension


class TensorflowModule(TorchModule):
    DESTRUCTURING_CLASS = TensorflowTensorDestructuring
    PYTHON_BINDINGS_CLASS = TensorflowPythonBindings
    PYTHON_FUNCTION_WRAPPING_CLASS = TensorflowFunctionWrapping

    def __init__(self, module_name, kernel_asts):
        """Create a C++ module with forward and optional backward_kernels

        :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
        :param backward_kernel_ast:
        """

        super().__init__(module_name, kernel_asts)

    def compile(self):
        from pystencils_autodiff.tensorflow_jit import compile_sources_and_load
        if self.is_cuda:
            return compile_sources_and_load([], cuda_sources=[str(self)])
        else:
            return compile_sources_and_load([str(self)])


class PybindModule(TorchModule):
    DESTRUCTURING_CLASS = PybindArrayDestructuring
    PYTHON_BINDINGS_CLASS = PybindPythonBindings

    CPP_IMPORT_PREFIX = """/*
<%
setup_pybind11(cfg)
%>
*/
"""

    def compile(self):
        try:
            import cppimport
        except ImportError:
            assert False, 'cppimport ist required for compiling pybind11 modules'

        assert not self.is_cuda

        source_code = self.CPP_IMPORT_PREFIX + str(self)
        file_name = join(pystencils.cache.cache_dir, f'{self.module_name}.cpp')

        if not exists(file_name):
            write_file(file_name, source_code)
        # TODO: propagate extra headers
        cache_dir = pystencils.cache.cache_dir
        if cache_dir not in sys.path:
            sys.path.append(cache_dir)
        torch_extension = cppimport.imp(f'{self.module_name}')
        return torch_extension