diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index b158754c8716a2d0bdae495ceda59341f945c8d8..5cf43b73c858e6b81ad4adfe93702af40a6c105a 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -12,7 +12,8 @@ from pystencils.gpucuda.indexing import indexing_creator_from_params from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.stencil import direction_string_to_offset, inverse_direction_string from pystencils.transformations import ( - loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel) + loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel, + replace_data_type_of_typed_symbols) def create_kernel(assignments, @@ -88,6 +89,8 @@ def create_kernel(assignments, split_groups = assignments.simplification_hints['split_groups'] assignments = assignments.all_assignments + assignments = replace_data_type_of_typed_symbols(assignments, data_type) + # ---- Creating ast if target == 'cpu': from pystencils.cpu import create_kernel diff --git a/pystencils/transformations.py b/pystencils/transformations.py index fc09f34e439c51c57b6579781619522eb166b510..da243f54b6ea6dca2de267068cfa89ee2a65b010 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -1101,6 +1101,15 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, i move_constants_before_loop(function_node.body) cleanup_blocks(function_node.body) +def replace_data_type_of_typed_symbols(assignments, data_type): + """changes the data types of the lhs of assignments which are already specified as TypedSymbol. This is needed + if the Assignments are already typed to double but the kernel is created for single precision""" + for i, assignment in enumerate(assignments): + if type(assignment.lhs) is TypedSymbol and assignment.lhs.dtype != data_type: + assignments[i] = Assignment(TypedSymbol(assignments[i].lhs.name, data_type), assignments[i].rhs) + + return assignments + # --------------------------------------- Helper Functions -------------------------------------------------------------