Skip to content
Snippets Groups Projects
Commit 545a4d03 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Unify all module templates

parent b14f8614
No related branches found
No related tags found
No related merge requests found
Pipeline #17256 failed
...@@ -8,18 +8,15 @@ ...@@ -8,18 +8,15 @@
""" """
from collections.abc import Iterable
from os.path import dirname, join from os.path import dirname, join
import jinja2
from pystencils.astnodes import ( from pystencils.astnodes import (
DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol, DestructuringBindingsForFieldClass, FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol)
FieldStrideSymbol, Node)
from pystencils.backends.cbackend import get_headers
from pystencils.framework_intergration_astnodes import (
FrameworkIntegrationPrinter, WrapperFunction, generate_kernel_call)
from pystencils_autodiff._file_io import _read_template_from_file from pystencils_autodiff._file_io import _read_template_from_file
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
from pystencils_autodiff.framework_integration.astnodes import (
WrapperFunction, generate_kernel_call)
# Torch # Torch
...@@ -50,17 +47,34 @@ class TensorflowTensorDestructuring(DestructuringBindingsForFieldClass): ...@@ -50,17 +47,34 @@ class TensorflowTensorDestructuring(DestructuringBindingsForFieldClass):
] ]
class TorchCudaModule(JinjaCppFile): class PybindArrayDestructuring(DestructuringBindingsForFieldClass):
TEMPLATE = _read_template_from_file(join(dirname(__file__), 'torch_native_cuda.tmpl.cu')) CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "mutable_data()",
FieldShapeSymbol: "shape({dim})",
FieldStrideSymbol: "strides({dim})"
}
CLASS_NAME_TEMPLATE = "pybind11::array_t<{dtype}>"
headers = ["<pybind11/numpy.h>"]
class TorchModule(JinjaCppFile):
TEMPLATE = _read_template_from_file(join(dirname(__file__), 'module.tmpl.cpp'))
DESTRUCTURING_CLASS = TorchTensorDestructuring DESTRUCTURING_CLASS = TorchTensorDestructuring
def __init__(self, forward_kernel_ast, backward_kernel_ast): def __init__(self, kernel_asts):
"""Create a C++ module with forward and optional backward_kernels
:param forward_kernel_ast: one or more kernel ASTs (can have any C dialect)
:param backward_kernel_ast:
"""
if not isinstance(kernel_asts, Iterable):
kernel_asts = [kernel_asts]
ast_dict = { ast_dict = {
'forward_kernel': forward_kernel_ast, 'kernels': kernel_asts,
'backward_kernel': backward_kernel_ast, 'kernel_wrappers': [self.generate_wrapper_function(k) for k in kernel_asts],
'forward_wrapper': self.generate_wrapper_function(forward_kernel_ast),
'backward_wrapper': self.generate_wrapper_function(backward_kernel_ast),
} }
super().__init__(ast_dict) super().__init__(ast_dict)
...@@ -70,9 +84,9 @@ class TorchCudaModule(JinjaCppFile): ...@@ -70,9 +84,9 @@ class TorchCudaModule(JinjaCppFile):
function_name='call_' + kernel_ast.function_name) function_name='call_' + kernel_ast.function_name)
class TensorCudaModule(TorchCudaModule): class TensorflowModule(TorchModule):
DESTRUCTURING_CLASS = TensorflowTensorDestructuring DESTRUCTURING_CLASS = TensorflowTensorDestructuring
# Generic class PybindModule(TorchModule):
class MachineLearningBackend(FrameworkIntegrationPrinter): DESTRUCTURING_CLASS = PybindArrayDestructuring
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
// Most compilers don't care whether it's __restrict or __restrict__
#define RESTRICT __restrict__
{% for header in headers -%} {% for header in headers -%}
#include {{ header }} #include {{ header }}
{% endfor %} {% endfor %}
#define RESTRICT __restrict__ {% for global in globals -%}
#define FUNC_PREFIX __global__ {{ global }}
{% endfor %}
{{ forward_kernel }}
{{ backward_kernel }} {% for kernel in kernels %}
{{ kernel }}
{% endfor %}
{% for wrapper in kernel_wrappers %}
{{ wrapper }}
{% endfor %}
#define RESTRICT {{ python_bindings }}
{{ forward_wrapper }}
{{ backward_wrapper }}
#include <torch/extension.h>
#include <vector>
using namespace pybind11::literals;
using scalar_t = {{ dtype }};
#define RESTRICT __restrict
std::vector<at::Tensor> {{ kernel_name }}_forward(
{%- for tensor in forward_tensors -%}
at::Tensor {{ tensor }} {{- ", " if not loop.last -}}
{%- endfor %})
{
//{% for tensor in forward_output_tensors -%}
//auto {{tensor}} = at::zeros_like({{ forward_input_tensors[0] }});
//{% endfor %}
{% for tensor in forward_tensors -%}
{%- set last = loop.last -%}
scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>();
{% for i in dimensions -%}
int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}];
{% endfor -%}
{% for i in dimensions -%}
int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
{% endfor -%}
{% endfor -%}
{{forward_kernel}}
return {
{%- for tensor in forward_output_tensors -%}
{{ tensor }} {{- "," if not loop.last -}}
{% endfor -%}
};
}
std::vector<at::Tensor> {{ kernel_name }}_backward(
{%- for tensor in backward_tensors -%}
at::Tensor {{ tensor }} {{- ", " if not loop.last -}}
{% endfor %})
{
//{% for tensor in backward_output_tensors -%}
//auto {{tensor}} = at::zeros_like({{ backward_input_tensors[0] }});
//{% endfor %}
{% for tensor in backward_tensors -%}
{%- set last = loop.last -%}
scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>();
{% for i in dimensions -%}
int _stride_{{ tensor }}_{{i}} = {{ tensor }}.strides()[{{ i }}];
{% endfor -%}
{% for i in dimensions -%}
int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
{% endfor -%}
{% endfor -%}
{{backward_kernel}}
return {
{%- for tensor in backward_output_tensors -%}
{{ tensor }} {{- "," if not loop.last -}}
{% endfor -%}
};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &{{ kernel_name }}_forward, "{{ kernel_name }} forward (CPU)",
{%- for tensor in forward_tensors -%}
"{{ tensor }}"_a {{ ", " if not loop.last }}
{%- endfor -%} );
m.def("backward", &{{ kernel_name }}_backward, "{{ kernel_name }} backward (CPU)",
{%- for tensor in backward_tensors -%}
"{{ tensor }}"_a {{ ", " if not loop.last }}
{%- endfor -%} );
}
#include <torch/extension.h>
#include <vector>
// CUDA forward declarations
using namespace pybind11::literals;
void {{ kernel_name }}_cuda_forward(
{%- for tensor in forward_tensors %}
at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
{% endfor %});
std::vector<at::Tensor> {{ kernel_name }}_cuda_backward(
{%- for tensor in backward_tensors -%}
at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
{% endfor %});
// C++ interface
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
//AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x);
//CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> {{ kernel_name }}_forward(
{%- for tensor in forward_tensors -%}
at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
{%- endfor %})
{
{% for tensor in forward_tensors -%}
CHECK_INPUT({{ tensor.name }});
{% endfor %}
{{ kernel_name }}_cuda_forward(
{%- for tensor in forward_tensors %}
{{ tensor.name }} {{- ", " if not loop.last }}
{%- endfor %});
return std::vector<at::Tensor>{
{%- for tensor in forward_output_tensors %}
{{ tensor.name }} {{- ", " if not loop.last }}
{%- endfor %}
}
;
}
std::vector<at::Tensor> {{ kernel_name }}_backward(
{%- for tensor in backward_tensors -%}
at::Tensor {{ tensor.name }} {{- ", " if not loop.last -}}
{% endfor %})
{
{%- for tensor in forward_input_tensors + backward_input_tensors -%}
CHECK_INPUT({{ tensor }});
{% endfor %}
{{ kernel_name }}_cuda_backward(
{%- for tensor in backward_tensors -%}
{{ tensor.name }} {{- ", " if not loop.last }}
{%- endfor %});
return std::vector<at::Tensor>{
{%- for tensor in backward_output_tensors %}
{{ tensor.name }} {{- ", " if not loop.last }}
{%- endfor %}
}
;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &{{ kernel_name }}_forward, "{{ kernel_name }} forward (CUDA)",
{%- for tensor in forward_tensors -%}
"{{ tensor.name }}"_a {{ ", " if not loop.last }}
{%- endfor -%} );
m.def("backward", &{{ kernel_name }}_backward, "{{ kernel_name }} backward (CUDA)",
{%- for tensor in backward_tensors -%}
"{{ tensor.name }}"_a {{ ", " if not loop.last }}
{%- endfor -%} );
}
...@@ -9,6 +9,7 @@ Astnodes useful for the generations of C modules for frameworks (apart from waLB ...@@ -9,6 +9,7 @@ Astnodes useful for the generations of C modules for frameworks (apart from waLB
waLBerla currently uses `pystencils-walberla <https://pypi.org/project/pystencils-walberla/>`_. waLBerla currently uses `pystencils-walberla <https://pypi.org/project/pystencils-walberla/>`_.
""" """
from collections.abc import Iterable
from typing import Any, List, Set from typing import Any, List, Set
import jinja2 import jinja2
...@@ -19,6 +20,7 @@ import pystencils ...@@ -19,6 +20,7 @@ import pystencils
from pystencils.astnodes import Node, NodeOrExpr, ResolvedFieldAccess from pystencils.astnodes import Node, NodeOrExpr, ResolvedFieldAccess
from pystencils.data_types import TypedSymbol from pystencils.data_types import TypedSymbol
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils_autodiff.framework_integration.printer import FrameworkIntegrationPrinter
class DestructuringBindingsForFieldClass(Node): class DestructuringBindingsForFieldClass(Node):
...@@ -221,6 +223,7 @@ class JinjaCppFile(Node): ...@@ -221,6 +223,7 @@ class JinjaCppFile(Node):
def __init__(self, ast_dict): def __init__(self, ast_dict):
self.ast_dict = ast_dict self.ast_dict = ast_dict
self.printer = FrameworkIntegrationPrinter()
Node.__init__(self) Node.__init__(self)
@property @property
...@@ -246,8 +249,9 @@ class JinjaCppFile(Node): ...@@ -246,8 +249,9 @@ class JinjaCppFile(Node):
def __str__(self): def __str__(self):
assert self.TEMPLATE, f"Template of {self.__class__} must be set" assert self.TEMPLATE, f"Template of {self.__class__} must be set"
render_dict = {k: self._print(v) for k, v in self.ast_dict.items()} render_dict = {k: self._print(v) if not isinstance(v, Iterable) else [self._print(a) for a in v]
render_dict.update({"headers": pystencils.backend.cbackend.get_headers(self)}) for k, v in self.ast_dict.items()}
render_dict.update({"headers": pystencils.backends.cbackend.get_headers(self)})
return self.TEMPLATE.render(render_dict) return self.TEMPLATE.render(render_dict)
......
...@@ -20,7 +20,13 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): ...@@ -20,7 +20,13 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
return super_result.replace('FUNC_PREFIX ', '') return super_result.replace('FUNC_PREFIX ', '')
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
return pystencils.backends.cbackend.generate_c(node, dialect=node.backend) if node.backend == 'cuda':
prefix = '#define FUNC_PREFIX __global__\n'
kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='cuda')
else:
prefix = '#define FUNC_PREFIX\n"'
kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c')
return prefix + kernel_code
def _print_KernelFunctionCall(self, node): def _print_KernelFunctionCall(self, node):
......
...@@ -11,7 +11,7 @@ import sympy ...@@ -11,7 +11,7 @@ import sympy
import pystencils import pystencils
from pystencils_autodiff import create_backward_assignments from pystencils_autodiff import create_backward_assignments
from pystencils_autodiff.backends.astnodes import TensorCudaModule, TorchCudaModule from pystencils_autodiff.backends.astnodes import PybindModule, TensorflowModule, TorchModule
TARGET_TO_DIALECT = { TARGET_TO_DIALECT = {
'cpu': 'c', 'cpu': 'c',
...@@ -34,12 +34,18 @@ def test_module_printing(): ...@@ -34,12 +34,18 @@ def test_module_printing():
forward_ast.function_name = 'forward' forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward' backward_ast.function_name = 'backward'
module = TorchCudaModule(forward_ast, backward_ast) module = TorchModule([forward_ast, backward_ast])
print(module) print(module)
module = TensorCudaModule(forward_ast, backward_ast) module = TensorflowModule({forward_ast: backward_ast})
print(module) print(module)
if target == 'cpu':
module = PybindModule([forward_ast, backward_ast])
print(module)
module = PybindModule(forward_ast)
print(module)
def test_module_printing_parameter(): def test_module_printing_parameter():
for target in ('cpu', 'gpu'): for target in ('cpu', 'gpu'):
...@@ -57,16 +63,22 @@ def test_module_printing_parameter(): ...@@ -57,16 +63,22 @@ def test_module_printing_parameter():
forward_ast.function_name = 'forward' forward_ast.function_name = 'forward'
backward_ast = pystencils.create_kernel(backward_assignments, target) backward_ast = pystencils.create_kernel(backward_assignments, target)
backward_ast.function_name = 'backward' backward_ast.function_name = 'backward'
module = TorchCudaModule(forward_ast, backward_ast) module = TorchModule([forward_ast, backward_ast])
print(module) print(module)
module = TensorCudaModule(forward_ast, backward_ast) module = TensorflowModule({forward_ast: backward_ast})
print(module) print(module)
if target == 'cpu':
module = PybindModule([forward_ast, backward_ast])
print(module)
module = PybindModule(forward_ast)
print(module)
def main(): def main():
test_module_printing() test_module_printing()
test_module_printing_parameter() # test_module_printing_parameter()
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment