From 374d0fd5a54cd84b7f862e79b17a7e4e4b12badb Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 7 Aug 2019 15:20:22 +0200
Subject: [PATCH] Make Torch CUDA compilation pass

---
 .../backends/torch_native_cuda.tmpl.cu        | 52 ++++++++++++-------
 1 file changed, 34 insertions(+), 18 deletions(-)

diff --git a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu b/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu
index 5edd28b..bca5d19 100644
--- a/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu
+++ b/src/pystencils_autodiff/backends/torch_native_cuda.tmpl.cu
@@ -20,10 +20,12 @@ __global__ void {{ kernel_name }}_cuda_forward_kernel(
         {% for i in range(tensor.spatial_dimensions )-%}
         int _stride_{{ tensor.name }}_{{ i }} {{- ", " }} 
         {% endfor -%} 
+        {% for i in range(tensor.spatial_dimensions )-%}
+        int _size_{{ tensor.name }}_{{ i }} {{- ", " }} 
+        {% endfor -%} 
         {% endfor -%}
-        {% for i in range(forward_output_tensors[0].spatial_dimensions )-%}
-        int _size_{{ forward_output_tensors[0] }}_{{ i }} {{- "," if not loop.last }}
-        {% endfor %})
+        int _unused
+        )
 {
     {{forward_kernel}}
 }
@@ -36,10 +38,12 @@ __global__ void {{ kernel_name }}_cuda_backward_kernel(
         {% for i in range(tensor.spatial_dimensions )-%}
         int _stride_{{ tensor.name }}_{{ i }} {{- ", " }}
         {% endfor -%}
+        {% for i in range(tensor.spatial_dimensions )-%}
+        int _size_{{ tensor.name }}_{{ i }} {{- ", " }}
+        {% endfor -%}
         {% endfor -%}
-        {% for i in range(forward_output_tensors[0].spatial_dimensions )-%}
-        int _size_{{ forward_output_tensors[0].name }}_{{ i }} {{- "," if not loop.last }}
-        {% endfor %})
+        int _unused
+      )
 {
     {{backward_kernel}}
 }
@@ -50,9 +54,14 @@ void {{ kernel_name }}_cuda_forward(
     {%- endfor -%})
 {
 
-    {% for i in range(forward_output_tensors[0].spatial_dimensions )-%}
-    int _size_{{ forward_output_tensors[0].name }}_{{ i }} = {{ forward_output_tensors[0].name }}.size({{ i }});
-    {% endfor %}
+    {% for tensor in forward_tensors -%}
+    {% for i in dimensions -%}
+    int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}];
+    {% endfor -%}
+    {% for i in dimensions -%}
+    int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
+    {% endfor -%}
+    {% endfor -%}
 
 /*at:: at::device(at::kCUDA).dtype(k{{ dtype }})*/
     AT_DISPATCH_FLOATING_TYPES({{ forward_input_tensors[0].name }}.type(), "{{ kernel_name }}_forward_cuda", ([&] {
@@ -63,10 +72,11 @@ void {{ kernel_name }}_cuda_forward(
                         {% for i in range(tensor.spatial_dimensions) -%}
                         {{tensor.name}}.strides()[{{ i }}] {{- "," }}
                         {% endfor -%}
+                        {% for i in range(tensor.spatial_dimensions) -%}
+                        {{tensor.name}}.size({{ i }}) {{- "," }}
+                        {% endfor -%}
                         {% endfor -%}
-                        {% for i in range(forward_output_tensors[0].spatial_dimensions) -%}
-                        {{ forward_output_tensors[0].name }}.size({{ i }}) {{- "," if not loop.last }}
-                        {% endfor %}
+                        0
                         );
                 }));
      cudaError_t err = cudaGetLastError();
@@ -82,9 +92,14 @@ void {{ kernel_name }}_cuda_backward(
     {%- endfor %})
 {
 
-    {% for i in range(backward_output_tensors[0].spatial_dimensions )-%}
-    int _size_{{ backward_output_tensors[0].name }}_{{ i }} = {{ backward_output_tensors[0].name }}.size({{ i }});
-    {% endfor %}
+    {% for tensor in backward_tensors -%}
+    {% for i in dimensions -%}
+    int _stride_{{tensor}}_{{i}} = {{tensor}}.strides()[{{ i }}];
+    {% endfor -%}
+    {% for i in dimensions -%}
+    int _size_{{tensor}}_{{i}} = {{tensor}}.size({{ i }});
+    {% endfor -%}
+    {% endfor -%}
 
 /*at:: at::device(at::kCUDA).dtype(k{{ dtype }})*/
     AT_DISPATCH_FLOATING_TYPES({{ backward_input_tensors[0].name }}.type(), "{{ kernel_name }}_backward_cuda", ([&] {
@@ -95,10 +110,11 @@ void {{ kernel_name }}_cuda_backward(
                         {% for i in range(tensor.spatial_dimensions )-%}
                         {{tensor.name}}.strides()[{{ i }}]{{- ", " }}
                         {% endfor -%}
+                        {% for i in range(tensor.spatial_dimensions )-%}
+                        {{tensor.name}}.size({{ i }}){{- ", " }}
+                        {% endfor -%}
                         {% endfor -%}
-                        {% for i in range(backward_output_tensors[0].spatial_dimensions )-%}
-                        {{ backward_output_tensors[0].name }}.size({{ i }}) {{- "," if not loop.last }}
-                        {% endfor %}
+                        0
                         );
                 }));
 
-- 
GitLab