From dced5877b80043460e2d9754a435b4439dacb862 Mon Sep 17 00:00:00 2001
From: markus holzer <markus.holzer@fau.de>
Date: Wed, 24 Mar 2021 20:18:47 +0100
Subject: [PATCH] Add type conversion for SP types

---
 pystencils/kernelcreation.py  | 5 ++++-
 pystencils/transformations.py | 9 +++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index b158754c..5cf43b73 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -12,7 +12,8 @@ from pystencils.gpucuda.indexing import indexing_creator_from_params
 from pystencils.simp.assignment_collection import AssignmentCollection
 from pystencils.stencil import direction_string_to_offset, inverse_direction_string
 from pystencils.transformations import (
-    loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
+    loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel,
+    replace_data_type_of_typed_symbols)
 
 
 def create_kernel(assignments,
@@ -88,6 +89,8 @@ def create_kernel(assignments,
             split_groups = assignments.simplification_hints['split_groups']
         assignments = assignments.all_assignments
 
+    assignments = replace_data_type_of_typed_symbols(assignments, data_type)
+
     # ----  Creating ast
     if target == 'cpu':
         from pystencils.cpu import create_kernel
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index fc09f34e..da243f54 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -1101,6 +1101,15 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i
     move_constants_before_loop(function_node.body)
     cleanup_blocks(function_node.body)
 
+def replace_data_type_of_typed_symbols(assignments, data_type):
+    """changes the data types of the lhs of assignments which are already specified as TypedSymbol. This is needed
+       if the Assignments are already typed to double but the kernel is created for single precision"""
+    for i, assignment in enumerate(assignments):
+        if type(assignment.lhs) is TypedSymbol and assignment.lhs.dtype != data_type:
+            assignments[i] = Assignment(TypedSymbol(assignments[i].lhs.name, data_type), assignments[i].rhs)
+
+    return assignments
+
 
 # --------------------------------------- Helper Functions -------------------------------------------------------------
 
-- 
GitLab