diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 241145bd251dd3e5461c5725cbf7cee9e7603c53..84d184c15bb30334c702a5cf785eb8e0020a4a9e 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -8,18 +8,15 @@ """ +from collections.abc import Iterable from os.path import dirname, join -import jinja2 - from pystencils.astnodes import ( - DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol, - FieldStrideSymbol, Node) -from pystencils.backends.cbackend import get_headers -from pystencils.framework_intergration_astnodes import ( - FrameworkIntegrationPrinter, WrapperFunction, generate_kernel_call) + DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol) from pystencils_autodiff._file_io import _read_template_from_file from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile +from pystencils_autodiff.framework_integration.astnodes import ( + WrapperFunction, generate_kernel_call) # Torch @@ -50,17 +47,34 @@ class TensorflowTensorDestructuring(DestructuringBindingsForFieldClass): ] -class TorchCudaModule(JinjaCppFile): - TEMPLATE = _read_template_from_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu')) +class PybindArrayDestructuring(DestructuringBindingsForFieldClass): + CLASS_TO_MEMBER_DICT = { + FieldPointerSymbol: "mutable_data()", + FieldShapeSymbol: "shape({dim})", + FieldStrideSymbol: "strides({dim})" + } + + CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>" + + headers = ["<pybind11/numpy.h>"] + + +class TorchModule(JinjaCppFile): + TEMPLATE = _read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp')) DESTRUCTURING_CLASS = TorchTensorDestructuring - def __init__(self, forward_kernel_ast, backward_kernel_ast): + def __init__(self, kernel_asts): + """Create a C++ module with forward and optional backward_kernels + + :param forward_kernel_ast: one or more kernel ASTs (can have any C dialect) + :param backward_kernel_ast: + """ + if not isinstance(kernel_asts, Iterable): + kernel_asts = [kernel_asts] ast_dict = { - 'forward_kernel': forward_kernel_ast, - 'backward_kernel': backward_kernel_ast, - 'forward_wrapper': self.generate_wrapper_function(forward_kernel_ast), - 'backward_wrapper': self.generate_wrapper_function(backward_kernel_ast), + 'kernels': kernel_asts, + 'kernel_wrappers': [self.generate_wrapper_function(k) for k in kernel_asts], } super().__init__(ast_dict) @@ -70,9 +84,9 @@ class TorchCudaModule(JinjaCppFile): function_name='call_' + kernel_ast.function_name) -class TensorCudaModule(TorchCudaModule): +class TensorflowModule(TorchModule): DESTRUCTURING_CLASS = TensorflowTensorDestructuring -# Generic -class MachineLearningBackend(FrameworkIntegrationPrinter): +class PybindModule(TorchModule): + DESTRUCTURING_CLASS = PybindArrayDestructuring diff --git a/src/pystencils_autodiff/backends/module.tmpl.cpp b/src/pystencils_autodiff/backends/module.tmpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2e9c2f0d0e14a48b2a9b672c6490ec1fab6425cc --- /dev/null +++ b/src/pystencils_autodiff/backends/module.tmpl.cpp @@ -0,0 +1,24 @@ +#include <cuda.h> +#include <vector> + +// Most compilers don't care whether it's __restrict or __restrict__ +#define RESTRICT __restrict__ + +{% for header in headers -%} +#include {{ header }} +{% endfor %} + +{% for global in globals -%} +{{ global }} +{% endfor %} + +{% for kernel in kernels %} +{{ kernel }} +{% endfor %} + +{% for wrapper in kernel_wrappers %} +{{ wrapper }} +{% endfor %} + +{{ python_bindings }} + diff --git a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp deleted file mode 100644 index 74eec9ee424e3955453dc52b3516d8cdab82d8da..0000000000000000000000000000000000000000 --- a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include <torch/extension.h> - -#include <vector> - -using namespace pybind11::literals; - -using scalar_t = {{ dtype }}; - - -#define RESTRICT __restrict - -std::vector<at::Tensor> {{ kernel_name }}_forward( -{%- for tensor in forward_tensors -%} - at::Tensor {{ tensor }} {{- ", " if not loop.last -}} -{%- endfor %}) -{ - //{% for tensor in forward_output_tensors -%} - //auto {{tensor}} = at::zeros_like({{ forward_input_tensors[0] }}); - //{% endfor %} - - {% for tensor in forward_tensors -%} - {%- set last = loop.last -%} - scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>(); - {% for i in dimensions -%} - int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}]; - {% endfor -%} - {% for i in dimensions -%} - int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }}); - {% endfor -%} - {% endfor -%} - - {{forward_kernel}} - - return { - {%- for tensor in forward_output_tensors -%} - {{ tensor }} {{- "," if not loop.last -}} - {% endfor -%} - }; -} - -std::vector<at::Tensor> {{ kernel_name }}_backward( -{%- for tensor in backward_tensors -%} - at::Tensor {{ tensor }} {{- ", " if not loop.last -}} -{% endfor %}) -{ - //{% for tensor in backward_output_tensors -%} - //auto {{tensor}} = at::zeros_like({{ backward_input_tensors[0] }}); - //{% endfor %} - - {% for tensor in backward_tensors -%} - {%- set last = loop.last -%} - scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>(); - {% for i in dimensions -%} - int _stride_{{ tensor }}_{{i}} = {{ tensor }}.strides()[{{ i }}]; - {% endfor -%} - {% for i in dimensions -%} - int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }}); - {% endfor -%} - {% endfor -%} - - {{backward_kernel}} - - return { - {%- for tensor in backward_output_tensors -%} - {{ tensor }} {{- "," if not loop.last -}} - {% endfor -%} - }; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &{{ kernel_name }}_forward, "{{ kernel_name }} forward (CPU)", -{%- for tensor in forward_tensors -%} - "{{ tensor }}"_a {{ ", " if not loop.last }} -{%- endfor -%} ); - m.def("backward", &{{ kernel_name }}_backward, "{{ kernel_name }} backward (CPU)", -{%- for tensor in backward_tensors -%} - "{{ tensor }}"_a {{ ", " if not loop.last }} -{%- endfor -%} ); -} diff --git a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp deleted file mode 100644 index 0662090d3320b045d41ac6ab11fb15c11c4d8d74..0000000000000000000000000000000000000000 --- a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include <torch/extension.h> -#include <vector> - -// CUDA forward declarations -using namespace pybind11::literals; - -void {{ kernel_name }}_cuda_forward( -{%- for tensor in forward_tensors %} - at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}} -{% endfor %}); - -std::vector<at::Tensor> {{ kernel_name }}_cuda_backward( -{%- for tensor in backward_tensors -%} - at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}} -{% endfor %}); - -// C++ interface - -// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. -#define CHECK_CUDA(x) \ - AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - //AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); - //CHECK_CONTIGUOUS(x) - -std::vector<at::Tensor> {{ kernel_name }}_forward( -{%- for tensor in forward_tensors -%} - at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}} -{%- endfor %}) -{ - {% for tensor in forward_tensors -%} - CHECK_INPUT({{ tensor.name }}); - {% endfor %} - - {{ kernel_name }}_cuda_forward( - {%- for tensor in forward_tensors %} - {{ tensor.name }} {{- ", " if not loop.last }} - {%- endfor %}); - - return std::vector<at::Tensor>{ - {%- for tensor in forward_output_tensors %} - {{ tensor.name }} {{- ", " if not loop.last }} - {%- endfor %} - } - ; -} - -std::vector<at::Tensor> {{ kernel_name }}_backward( -{%- for tensor in backward_tensors -%} - at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}} -{% endfor %}) -{ - {%- for tensor in forward_input_tensors + backward_input_tensors -%} - CHECK_INPUT({{ tensor }}); - {% endfor %} - {{ kernel_name }}_cuda_backward( - {%- for tensor in backward_tensors -%} - {{ tensor.name }} {{- ", " if not loop.last }} - {%- endfor %}); - - return std::vector<at::Tensor>{ - {%- for tensor in backward_output_tensors %} - {{ tensor.name }} {{- ", " if not loop.last }} - {%- endfor %} - } - ; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &{{ kernel_name }}_forward, "{{ kernel_name }} forward (CUDA)", -{%- for tensor in forward_tensors -%} - "{{ tensor.name }}"_a {{ ", " if not loop.last }} -{%- endfor -%} ); - m.def("backward", &{{ kernel_name }}_backward, "{{ kernel_name }} backward (CUDA)", -{%- for tensor in backward_tensors -%} - "{{ tensor.name }}"_a {{ ", " if not loop.last }} -{%- endfor -%} ); -} diff --git a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu b/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu deleted file mode 100644 index 4d706ea2d3b9c1c6d7fbb316678ada2771d6d7db..0000000000000000000000000000000000000000 --- a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu +++ /dev/null @@ -1,21 +0,0 @@ -#include <cuda.h> -#include <vector> - - -{% for header in headers -%} -#include {{ header }} -{% endfor %} - -#define RESTRICT __restrict__ -#define FUNC_PREFIX __global__ - - -{{ forward_kernel }} - -{{ backward_kernel }} - - -#define RESTRICT -{{ forward_wrapper }} - -{{ backward_wrapper }} diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 1e7d12827010efc584013b2a9c5d0131dc3141cf..db85b65031baa23153c2b2a9952289ebbe36ebc8 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -9,6 +9,7 @@ Astnodes useful for the generations of C modules for frameworks (apart from waLB waLBerla currently uses `pystencils-walberla <https://pypi.org/project/pystencils-walberla/>`_. """ +from collections.abc import Iterable from typing import Any, List, Set import jinja2 @@ -19,6 +20,7 @@ import pystencils from pystencils.astnodes import Node, NodeOrExpr, ResolvedFieldAccess from pystencils.data_types import TypedSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol +from pystencils_autodiff.framework_integration.printer import FrameworkIntegrationPrinter class DestructuringBindingsForFieldClass(Node): @@ -221,6 +223,7 @@ class JinjaCppFile(Node): def __init__(self, ast_dict): self.ast_dict = ast_dict + self.printer = FrameworkIntegrationPrinter() Node.__init__(self) @property @@ -246,8 +249,9 @@ class JinjaCppFile(Node): def __str__(self): assert self.TEMPLATE, f"Template of {self.__class__} must be set" - render_dict = {k: self._print(v) for k, v in self.ast_dict.items()} - render_dict.update({"headers": pystencils.backend.cbackend.get_headers(self)}) + render_dict = {k: self._print(v) if not isinstance(v, Iterable) else [self._print(a) for a in v] + for k, v in self.ast_dict.items()} + render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)}) return self.TEMPLATE.render(render_dict) diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index 81a182667500ee5de1c9a33a3ee8477c6383f7d2..3442ffaace5b8ff3ebbd5e567c2a7bd868d865b8 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -20,7 +20,13 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): return super_result.replace('FUNC_PREFIX ', '') def _print_KernelFunction(self, node): - return pystencils.backends.cbackend.generate_c(node, dialect=node.backend) + if node.backend == 'cuda': + prefix = '#define FUNC_PREFIX __global__\n' + kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda') + else: + prefix = '#define FUNC_PREFIX\n"' + kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c') + return prefix + kernel_code def _print_KernelFunctionCall(self, node): diff --git a/tests/test_module_printing.py b/tests/test_module_printing.py index a2d253e5b97097426719d8c1c96a0f6782c33e73..db0005fe5fc44ace94a553ed82a895e475966931 100644 --- a/tests/test_module_printing.py +++ b/tests/test_module_printing.py @@ -11,7 +11,7 @@ import sympy import pystencils from pystencils_autodiff import create_backward_assignments -from pystencils_autodiff.backends.astnodes import TensorCudaModule, TorchCudaModule +from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule TARGET_TO_DIALECT = { 'cpu': 'c', @@ -34,12 +34,18 @@ def test_module_printing(): forward_ast.function_name = 'forward' backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast.function_name = 'backward' - module = TorchCudaModule(forward_ast, backward_ast) + module = TorchModule([forward_ast, backward_ast]) print(module) - module = TensorCudaModule(forward_ast, backward_ast) + module = TensorflowModule({forward_ast: backward_ast}) print(module) + if target == 'cpu': + module = PybindModule([forward_ast, backward_ast]) + print(module) + module = PybindModule(forward_ast) + print(module) + def test_module_printing_parameter(): for target in ('cpu', 'gpu'): @@ -57,16 +63,22 @@ def test_module_printing_parameter(): forward_ast.function_name = 'forward' backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast.function_name = 'backward' - module = TorchCudaModule(forward_ast, backward_ast) + module = TorchModule([forward_ast, backward_ast]) print(module) - module = TensorCudaModule(forward_ast, backward_ast) + module = TensorflowModule({forward_ast: backward_ast}) print(module) + if target == 'cpu': + module = PybindModule([forward_ast, backward_ast]) + print(module) + module = PybindModule(forward_ast) + print(module) + def main(): test_module_printing() - test_module_printing_parameter() + # test_module_printing_parameter() if __name__ == '__main__':