diff --git a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp index d7ee774e03061e4d878d93cc1dfca8f7306beb4b..74eec9ee424e3955453dc52b3516d8cdab82d8da 100644 --- a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp +++ b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp @@ -18,16 +18,15 @@ std::vector<at::Tensor> {{ kernel_name }}_forward( //auto {{tensor}} = at::zeros_like({{ forward_input_tensors[0] }}); //{% endfor %} - {% for i in dimensions -%} - int _size_{{ forward_tensors[0] }}_{{ i }} = {{ forward_tensors[0] }}.size({{ i }}); - {% endfor %} - {% for tensor in forward_tensors -%} {%- set last = loop.last -%} scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>(); {% for i in dimensions -%} int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}]; {% endfor -%} + {% for i in dimensions -%} + int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }}); + {% endfor -%} {% endfor -%} {{forward_kernel}} @@ -54,6 +53,9 @@ std::vector<at::Tensor> {{ kernel_name }}_backward( {% for i in dimensions -%} int _stride_{{ tensor }}_{{i}} = {{ tensor }}.strides()[{{ i }}]; {% endfor -%} + {% for i in dimensions -%} + int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }}); + {% endfor -%} {% endfor -%} {{backward_kernel}}