diff --git a/lbmpy/methods/momentbased/moment_transforms.py b/lbmpy/methods/momentbased/moment_transforms.py index 83e85805a8ab78b33d51a7b9da34897c67044028..e31d2e6e440f09258e3d7321c04e4fcf867d4040 100644 --- a/lbmpy/methods/momentbased/moment_transforms.py +++ b/lbmpy/methods/momentbased/moment_transforms.py @@ -142,7 +142,7 @@ class PdfsToCentralMomentsByMatrix(AbstractMomentTransform): central_moments = self.forward_matrix * f_vec main_assignments = [Assignment(sq_sym(moment_symbol_base, e), eq) for e, eq in zip(self.moment_exponents, central_moments)] - symbol_gen = SymbolGen(subexpression_base, dtype=float) + symbol_gen = SymbolGen(subexpression_base) ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen) if simplification: @@ -157,7 +157,7 @@ class PdfsToCentralMomentsByMatrix(AbstractMomentTransform): moment_vec = sp.Matrix(moments) pdfs_from_moments = self.backward_matrix * moment_vec main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, pdfs_from_moments)] - symbol_gen = SymbolGen(subexpression_base, dtype=float) + symbol_gen = SymbolGen(subexpression_base) ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen) if simplification: @@ -228,7 +228,7 @@ class FastCentralMomentTransform(AbstractMomentTransform): collect_partial_sums(e) subexpressions = [Assignment(lhs, rhs) for lhs, rhs in subexpressions_dict.items()] - symbol_gen = SymbolGen(subexpression_base, dtype=float) + symbol_gen = SymbolGen(subexpression_base) ac = AssignmentCollection(main_assignments, subexpressions=subexpressions, subexpression_symbol_generator=symbol_gen) if simplification: @@ -244,7 +244,7 @@ class FastCentralMomentTransform(AbstractMomentTransform): pdf_symbols, moment_symbol_base=POST_COLLISION_CENTRAL_MOMENT, simplification=False) raw_equations = raw_equations.new_without_subexpressions() - symbol_gen = SymbolGen(subexpression_base, dtype=float) + symbol_gen = SymbolGen(subexpression_base) ac = self._split_backward_equations(raw_equations, symbol_gen) if simplification: @@ -441,7 +441,7 @@ class PdfsToRawMomentsTransform(AbstractMomentTransform): collect_partial_sums(e) subexpressions += [Assignment(lhs, rhs) for lhs, rhs in partial_sums_dict.items()] - symbol_gen = SymbolGen(subexpression_base, dtype=float) + symbol_gen = SymbolGen(subexpression_base) ac = AssignmentCollection(main_assignments, subexpressions=subexpressions, subexpression_symbol_generator=symbol_gen) ac.add_simplification_hint('cq_symbols_to_moments', self.get_cq_to_moment_symbols_dict(moment_symbol_base)) @@ -457,7 +457,7 @@ class PdfsToRawMomentsTransform(AbstractMomentTransform): post_collision_moments = [sq_sym(moment_symbol_base, e) for e in self.moment_exponents] rm_to_f_vec = self.inv_moment_matrix * sp.Matrix(post_collision_moments) main_assignments = [Assignment(f, eq) for f, eq in zip(pdf_symbols, rm_to_f_vec)] - symbol_gen = SymbolGen(subexpression_base, dtype=float) + symbol_gen = SymbolGen(subexpression_base) ac = AssignmentCollection(main_assignments, subexpression_symbol_generator=symbol_gen) ac.add_simplification_hint('stencil', self.stencil) diff --git a/lbmpy/simplificationfactory.py b/lbmpy/simplificationfactory.py index cb62817e04b71e75059b47b857ba04437a2b6744..37599e8dc496be6fa2a43cd116d9ad0ffd677c61 100644 --- a/lbmpy/simplificationfactory.py +++ b/lbmpy/simplificationfactory.py @@ -3,7 +3,7 @@ import sympy as sp from lbmpy.innerloopsplit import create_lbm_split_groups from lbmpy.methods.momentbased.momentbasedmethod import MomentBasedLbMethod from lbmpy.methods.centeredcumulant import CenteredCumulantBasedLbMethod -from lbmpy.methods.centeredcumulant.simplification import insert_aliases, insert_zeros +from lbmpy.methods.centeredcumulant.simplification import insert_aliases, insert_zeros, insert_constants from lbmpy.methods.momentbased.momentbasedsimplifications import ( factor_density_after_factoring_relaxation_times, factor_relaxation_rates, replace_common_quadratic_and_constant_term, replace_density_and_velocity, replace_second_order_velocity_products) @@ -36,5 +36,6 @@ def create_simplification_strategy(lb_method, split_inner_loop=False): elif isinstance(lb_method, CenteredCumulantBasedLbMethod): s.add(insert_zeros) s.add(insert_aliases) + s.add(insert_constants) s.add(lambda ac: ac.new_without_unused_subexpressions()) return s