diff --git a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
index d7ee774e03061e4d878d93cc1dfca8f7306beb4b..74eec9ee424e3955453dc52b3516d8cdab82d8da 100644
--- a/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
+++ b/src/pystencils_autodiff/backends/torch_native_cpu.tmpl.cpp
@@ -18,16 +18,15 @@ std::vector<at::Tensor> {{ kernel_name }}_forward(
     //auto {{tensor}} = at::zeros_like({{ forward_input_tensors[0] }});
     //{% endfor %}
 
-    {% for i in dimensions -%}
-    int _size_{{ forward_tensors[0] }}_{{ i }} = {{ forward_tensors[0] }}.size({{ i }});
-    {% endfor %}
-
     {% for tensor in forward_tensors -%}
     {%- set last = loop.last -%}
     scalar_t* _data_{{ tensor }} = {{ tensor }}.data<scalar_t>();
     {% for i in dimensions -%}
     int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}];
     {% endfor -%}
+    {% for i in dimensions -%}
+    int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
+    {% endfor -%}
     {% endfor -%}
 
     {{forward_kernel}}
@@ -54,6 +53,9 @@ std::vector<at::Tensor> {{ kernel_name }}_backward(
     {% for i in dimensions -%}
     int _stride_{{ tensor }}_{{i}} = {{ tensor }}.strides()[{{ i }}];
     {% endfor -%}
+    {% for i in dimensions -%}
+    int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
+    {% endfor -%}
     {% endfor -%}
 
     {{backward_kernel}}