diff --git a/lbmpy_walberla/templates/LatticeModel.tmpl.h b/lbmpy_walberla/templates/LatticeModel.tmpl.h index 42241363c2124bf56df86f30da220646c110396c..be160e99e89d596f86511a97db314e35aeaaca18 100644 --- a/lbmpy_walberla/templates/LatticeModel.tmpl.h +++ b/lbmpy_walberla/templates/LatticeModel.tmpl.h @@ -146,7 +146,7 @@ private: {% endif %} blockId = &block->getId(); - {% if refinement_scaling -%} + {% if refinement_scaling_info -%} const uint_t targetLevel = block->getBlockStorage().getLevel(*block); if( targetLevel != currentLevel ) @@ -157,7 +157,7 @@ private: else // currentLevel > targetLevel level_scale_factor = real_t(1) / real_c( uint_t(1) << ( currentLevel - targetLevel ) ); - {% for scalingType, name, expression in refinement_scaling.scaling_info -%} + {% for scalingType, name, expression in refinement_scaling_info -%} {% if scalingType == 'normal' %} {{name}} = {{expression}}; {% elif scalingType in ('field_with_f', 'field_xyz') %} @@ -278,7 +278,7 @@ struct AdaptVelocityToForce<{{class_name}}, void> auto y = it.y(); auto z = it.z(); {% if macroscopic_velocity_shift %} - return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, 0.0 {%endif %} ); + return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, real_t(0.0) {%endif %} ); {% else %} return velocity; {% endif %} @@ -289,7 +289,7 @@ struct AdaptVelocityToForce<{{class_name}}, void> { {% if macroscopic_velocity_shift %} - return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, 0.0 {%endif %} ); + return velocity - Vector3<real_t>({{macroscopic_velocity_shift | join(",") }} {% if D == 2 %}, real_t(0.0) {%endif %} ); {% else %} return velocity; {% endif %} diff --git a/lbmpy_walberla/walberla_lbm_generation.py b/lbmpy_walberla/walberla_lbm_generation.py index b78c84a078d845025c937dc7dc6b178a99d5b75d..05e98a96fe31c1fd001c400f696b4e7c143a762c 100644 --- a/lbmpy_walberla/walberla_lbm_generation.py +++ b/lbmpy_walberla/walberla_lbm_generation.py @@ -14,7 +14,7 @@ from lbmpy.updatekernels import create_lbm_kernel, create_stream_pull_only_kerne from pystencils import AssignmentCollection, create_kernel from pystencils.astnodes import SympyAssignment from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, get_headers -from pystencils.data_types import TypedSymbol +from pystencils.data_types import TypedSymbol, type_all_numbers from pystencils.field import Field from pystencils.stencil import have_same_entries, offset_to_direction_string from pystencils.sympyextensions import get_symmetric_part @@ -52,7 +52,7 @@ def __lattice_model(generation_context, class_name, lb_method, stream_collide_as macroscopic_velocity_shift = None if force_model: if hasattr(force_model, 'macroscopic_velocity_shift'): - macroscopic_velocity_shift = [expression_to_code(e, "lm.", ["rho"]) + macroscopic_velocity_shift = [expression_to_code(e, "lm.", ["rho"],dtype=dtype) for e in force_model.macroscopic_velocity_shift(rho_sym)] cqc = lb_method.conserved_quantity_computation @@ -66,6 +66,11 @@ def __lattice_model(generation_context, class_name, lb_method, stream_collide_as required_headers = get_headers(stream_collide_ast) + if refinement_scaling: + refinement_scaling_info = [ (e0,e1,expression_to_code(e2, '', dtype=dtype)) for e0,e1,e2 in refinement_scaling.scaling_info ] + else: + refinement_scaling_info = None + jinja_context = { 'class_name': class_name, 'stencil_name': stencil_name, @@ -87,7 +92,7 @@ def __lattice_model(generation_context, class_name, lb_method, stream_collide_as dtype=dtype), 'density_velocity_setter_macroscopic_values': density_velocity_setter_macroscopic_values, - 'refinement_scaling': refinement_scaling, + 'refinement_scaling_info': refinement_scaling_info, 'stream_collide_kernel': KernelInfo(stream_collide_ast, ['pdfs_tmp'], [('pdfs', 'pdfs_tmp')], []), 'collide_kernel': KernelInfo(collide_ast, [], [], []), @@ -174,15 +179,15 @@ class RefinementScaling: scaling_type = 'field_xyz' field_access = field.center expr = scaling_rule(field_access, self.level_scale_factor) - self.scaling_info.append((scaling_type, name, expression_to_code(expr, ''))) + self.scaling_info.append((scaling_type, name, expr)) elif isinstance(parameter, Field.Access): field_access = parameter expr = scaling_rule(field_access, self.level_scale_factor) name = field_access.field.name - self.scaling_info.append(('field_xyz', name, expression_to_code(expr, ''))) + self.scaling_info.append(('field_xyz', name, expr)) elif isinstance(parameter, sp.Symbol): expr = scaling_rule(parameter, self.level_scale_factor) - self.scaling_info.append(('normal', parameter.name, expression_to_code(expr, ''))) + self.scaling_info.append(('normal', parameter.name, expr)) elif isinstance(parameter, list) or isinstance(parameter, tuple): for p in parameter: self.add_scaling(p, scaling_rule) @@ -227,7 +232,7 @@ def field_and_symbol_substitute(expr, variable_prefix="lm.", variables_without_p return expr.subs(substitutions) -def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=[]): +def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=[],dtype="double"): """ Takes a sympy expression and creates a C code string from it. Replaces field accesses by walberla field accesses i.e. field_W^1 -> field->get(-1, 0, 0, 1) @@ -237,12 +242,13 @@ def expression_to_code(expr, variable_prefix="lm.", variables_without_prefix=[]) :param variables_without_prefix: this variables are not prefixed :return: code string """ - return cpp_printer.doprint(field_and_symbol_substitute(expr, variable_prefix, variables_without_prefix)) + return cpp_printer.doprint(type_expr(field_and_symbol_substitute(expr, variable_prefix, variables_without_prefix),dtype=dtype)) +def type_expr(eq, dtype): + eq=type_all_numbers(eq,dtype=dtype) + return eq.subs({s: TypedSymbol(s.name, dtype) for s in eq.atoms(sp.Symbol)}) def equations_to_code(equations, variable_prefix="lm.", variables_without_prefix=[], dtype="double"): - def type_eq(eq): - return eq.subs({s: TypedSymbol(s.name, dtype) for s in eq.atoms(sp.Symbol)}) if isinstance(equations, AssignmentCollection): equations = equations.all_assignments @@ -252,9 +258,9 @@ def equations_to_code(equations, variable_prefix="lm.", variables_without_prefix result = [] left_hand_side_names = [e.lhs.name for e in equations] for eq in equations: - assignment = SympyAssignment(type_eq(eq.lhs), - field_and_symbol_substitute(eq.rhs, variable_prefix, - variables_without_prefix + left_hand_side_names)) + assignment = SympyAssignment(type_expr(eq.lhs,dtype=dtype), + type_expr(field_and_symbol_substitute(eq.rhs, variable_prefix, + variables_without_prefix + left_hand_side_names),dtype=dtype)) result.append(c_backend(assignment)) return "\n".join(result)