Skip to content
Snippets Groups Projects
Commit e47d321f authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Make torch_native_cpu.tmpl.cpp compile

parent 397becd0
Branches
Tags
No related merge requests found
......@@ -18,16 +18,15 @@ std::vector<at::Tensor> {{ kernel_name }}_forward(
//auto {{tensor}} = at::zeros_like({{ forward_input_tensors[0] }});
//{% endfor %}
{% for i in dimensions -%}
int _size_{{ forward_tensors[0] }}_{{ i }} = {{ forward_tensors[0] }}.size({{ i }});
{% endfor %}
{% for tensor in forward_tensors -%}
{%- set last = loop.last -%}
scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>();
{% for i in dimensions -%}
int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}];
{% endfor -%}
{% for i in dimensions -%}
int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
{% endfor -%}
{% endfor -%}
{{forward_kernel}}
......@@ -54,6 +53,9 @@ std::vector<at::Tensor> {{ kernel_name }}_backward(
{% for i in dimensions -%}
int _stride_{{ tensor }}_{{i}} = {{ tensor }}.strides()[{{ i }}];
{% endfor -%}
{% for i in dimensions -%}
int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
{% endfor -%}
{% endfor -%}
{{backward_kernel}}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment