diff --git a/lbmpy/creationfunctions.py b/lbmpy/creationfunctions.py index 63f4b2650a22745d962ad03595dc1b3c605eba2a..d3fab3cd8a143651dcd8335e3b4937d326c1a92b 100644 --- a/lbmpy/creationfunctions.py +++ b/lbmpy/creationfunctions.py @@ -236,11 +236,6 @@ def create_lb_update_rule(collision_rule=None, optimization={}, **kwargs): lb_method = collision_rule.method - if params['output'] and params['kernel_type'] == 'stream_pull_collide': - cqc = lb_method.conserved_quantity_computation - output_eqs = cqc.output_equations_from_pdfs(lb_method.pre_collision_pdf_symbols, params['output']) - collision_rule = collision_rule.new_merged(output_eqs) - field_data_type = 'float64' if opt_params['double_precision'] else 'float32' q = len(collision_rule.method.stencil) @@ -268,13 +263,11 @@ def create_lb_update_rule(collision_rule=None, optimization={}, **kwargs): if any(opt_params['builtin_periodicity']): accessor = PeriodicTwoFieldsAccessor(opt_params['builtin_periodicity'], ghost_layers=1) return create_lbm_kernel(collision_rule, src_field, dst_field, accessor) - elif params['kernel_type'] == 'collide_only': - result = create_lbm_kernel(collision_rule, src_field, src_field, CollideOnlyInplaceAccessor()) - return add_subexpressions_for_field_reads(result, subexpressions=False, main_assignments=True) elif params['kernel_type'] == 'stream_pull_only': return create_stream_pull_with_output_kernel(lb_method, src_field, dst_field, params['output']) else: kernel_type_to_accessor = { + 'collide_only': CollideOnlyInplaceAccessor, 'collide_stream_push': StreamPushTwoFieldsAccessor, 'esotwist_even': EsoTwistEvenTimeStepAccessor, 'esotwist_odd': EsoTwistOddTimeStepAccessor, @@ -350,6 +343,11 @@ def create_lb_collision_rule(lb_method=None, optimization={}, **kwargs): from pystencils.simp import sympy_cse collision_rule = sympy_cse(collision_rule) + if params['output'] and params['kernel_type'] == 'stream_pull_collide': + cqc = lb_method.conserved_quantity_computation + output_eqs = cqc.output_equations_from_pdfs(lb_method.pre_collision_pdf_symbols, params['output']) + collision_rule = collision_rule.new_merged(output_eqs) + return collision_rule