From cfd615c75d61b7cce63dccd257be8004208c80ae Mon Sep 17 00:00:00 2001
From: Christoph Rettinger <christoph.rettinger@fau.de>
Date: Thu, 20 Feb 2020 19:58:19 +0100
Subject: [PATCH] Fixed floating point precision generation

---
 lbmpy_walberla/templates/LatticeModel.tmpl.h |  8 ++---
 lbmpy_walberla/walberla_lbm_generation.py    | 32 ++++++++++++--------
 2 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/lbmpy_walberla/templates/LatticeModel.tmpl.h b/lbmpy_walberla/templates/LatticeModel.tmpl.h
index 4224136..be160e9 100644
--- a/lbmpy_walberla/templates/LatticeModel.tmpl.h
+++ b/lbmpy_walberla/templates/LatticeModel.tmpl.h
@@ -146,7 +146,7 @@ private:
         {% endif %}
         blockId = &block->getId();
 
-        {% if refinement_scaling -%}
+        {% if refinement_scaling_info -%}
         const uint_t targetLevel = block->getBlockStorage().getLevel(*block);
 
         if( targetLevel != currentLevel )
@@ -157,7 +157,7 @@ private:
             else // currentLevel > targetLevel
                level_scale_factor = real_t(1) / real_c( uint_t(1) << ( currentLevel - targetLevel ) );
 
-            {% for scalingType, name, expression in refinement_scaling.scaling_info -%}
+            {% for scalingType, name, expression in refinement_scaling_info -%}
             {% if scalingType == 'normal' %}
             {{name}} = {{expression}};
             {% elif scalingType in ('field_with_f', 'field_xyz') %}
@@ -278,7 +278,7 @@ struct AdaptVelocityToForce<{{class_name}}, void>
       auto y = it.y();
       auto z = it.z();
       {% if macroscopic_velocity_shift %}
-      return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, 0.0 {%endif %} );
+      return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, real_t(0.0) {%endif %} );
       {% else %}
       return velocity;
       {% endif %}
@@ -289,7 +289,7 @@ struct AdaptVelocityToForce<{{class_name}}, void>
    {
       {% if macroscopic_velocity_shift %}
 
-      return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, 0.0 {%endif %} );
+      return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, real_t(0.0) {%endif %} );
       {% else %}
       return velocity;
       {% endif %}
diff --git a/lbmpy_walberla/walberla_lbm_generation.py b/lbmpy_walberla/walberla_lbm_generation.py
index b78c84a..05e98a9 100644
--- a/lbmpy_walberla/walberla_lbm_generation.py
+++ b/lbmpy_walberla/walberla_lbm_generation.py
@@ -14,7 +14,7 @@ from lbmpy.updatekernels import create_lbm_kernel, create_stream_pull_only_kerne
 from pystencils import AssignmentCollection, create_kernel
 from pystencils.astnodes import SympyAssignment
 from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, get_headers
-from pystencils.data_types import TypedSymbol
+from pystencils.data_types import TypedSymbol, type_all_numbers
 from pystencils.field import Field
 from pystencils.stencil import have_same_entries, offset_to_direction_string
 from pystencils.sympyextensions import get_symmetric_part
@@ -52,7 +52,7 @@ def __lattice_model(generation_context, class_name, lb_method, stream_collide_as
     macroscopic_velocity_shift = None
     if force_model:
         if hasattr(force_model, 'macroscopic_velocity_shift'):
-            macroscopic_velocity_shift = [expression_to_code(e, "lm.", ["rho"])
+            macroscopic_velocity_shift = [expression_to_code(e, "lm.", ["rho"],dtype=dtype)
                                           for e in force_model.macroscopic_velocity_shift(rho_sym)]
 
     cqc = lb_method.conserved_quantity_computation
@@ -66,6 +66,11 @@ def __lattice_model(generation_context, class_name, lb_method, stream_collide_as
 
     required_headers = get_headers(stream_collide_ast)
 
+    if refinement_scaling:
+        refinement_scaling_info = [ (e0,e1,expression_to_code(e2, '', dtype=dtype)) for e0,e1,e2 in refinement_scaling.scaling_info ]
+    else:
+        refinement_scaling_info = None
+
     jinja_context = {
         'class_name': class_name,
         'stencil_name': stencil_name,
@@ -87,7 +92,7 @@ def __lattice_model(generation_context, class_name, lb_method, stream_collide_as
                                                      dtype=dtype),
         'density_velocity_setter_macroscopic_values': density_velocity_setter_macroscopic_values,
 
-        'refinement_scaling': refinement_scaling,
+        'refinement_scaling_info': refinement_scaling_info,
 
         'stream_collide_kernel': KernelInfo(stream_collide_ast, ['pdfs_tmp'], [('pdfs', 'pdfs_tmp')], []),
         'collide_kernel': KernelInfo(collide_ast, [], [], []),
@@ -174,15 +179,15 @@ class RefinementScaling:
                 scaling_type = 'field_xyz'
                 field_access = field.center
             expr = scaling_rule(field_access, self.level_scale_factor)
-            self.scaling_info.append((scaling_type, name, expression_to_code(expr, '')))
+            self.scaling_info.append((scaling_type, name, expr))
         elif isinstance(parameter, Field.Access):
             field_access = parameter
             expr = scaling_rule(field_access, self.level_scale_factor)
             name = field_access.field.name
-            self.scaling_info.append(('field_xyz', name, expression_to_code(expr, '')))
+            self.scaling_info.append(('field_xyz', name, expr))
         elif isinstance(parameter, sp.Symbol):
             expr = scaling_rule(parameter, self.level_scale_factor)
-            self.scaling_info.append(('normal', parameter.name, expression_to_code(expr, '')))
+            self.scaling_info.append(('normal', parameter.name, expr))
         elif isinstance(parameter, list) or isinstance(parameter, tuple):
             for p in parameter:
                 self.add_scaling(p, scaling_rule)
@@ -227,7 +232,7 @@ def field_and_symbol_substitute(expr, variable_prefix="lm.", variables_without_p
     return expr.subs(substitutions)
 
 
-def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=[]):
+def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=[],dtype="double"):
     """
     Takes a sympy expression and creates a C code string from it. Replaces field accesses by
     walberla field accesses i.e. field_W^1 -> field->get(-1, 0, 0, 1)
@@ -237,12 +242,13 @@ def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=[])
     :param variables_without_prefix: this variables are not prefixed
     :return: code string
     """
-    return cpp_printer.doprint(field_and_symbol_substitute(expr, variable_prefix, variables_without_prefix))
+    return cpp_printer.doprint(type_expr(field_and_symbol_substitute(expr, variable_prefix, variables_without_prefix),dtype=dtype))
 
+def type_expr(eq, dtype):
+    eq=type_all_numbers(eq,dtype=dtype)
+    return eq.subs({s: TypedSymbol(s.name, dtype) for s in eq.atoms(sp.Symbol)})
 
 def equations_to_code(equations, variable_prefix="lm.", variables_without_prefix=[], dtype="double"):
-    def type_eq(eq):
-        return eq.subs({s: TypedSymbol(s.name, dtype) for s in eq.atoms(sp.Symbol)})
 
     if isinstance(equations, AssignmentCollection):
         equations = equations.all_assignments
@@ -252,9 +258,9 @@ def equations_to_code(equations, variable_prefix="lm.", variables_without_prefix
     result = []
     left_hand_side_names = [e.lhs.name for e in equations]
     for eq in equations:
-        assignment = SympyAssignment(type_eq(eq.lhs),
-                                     field_and_symbol_substitute(eq.rhs, variable_prefix,
-                                                                 variables_without_prefix + left_hand_side_names))
+        assignment = SympyAssignment(type_expr(eq.lhs,dtype=dtype),
+                                     type_expr(field_and_symbol_substitute(eq.rhs, variable_prefix,
+                                                                 variables_without_prefix + left_hand_side_names),dtype=dtype))
         result.append(c_backend(assignment))
     return "\n".join(result)
 
-- 
GitLab