diff --git a/lbmpy/boundaries_in_kernel.py b/lbmpy/boundaries_in_kernel.py index a1d44c72bde2e3081a6ee8d93bc5afc5eb7fe0ca..a4719be4baa0b5691e8797986d534f5c871d3554 100644 --- a/lbmpy/boundaries_in_kernel.py +++ b/lbmpy/boundaries_in_kernel.py @@ -6,8 +6,10 @@ from pystencils.astnodes import LoopOverCoordinate from pystencils.data_types import cast_func from pystencils.field import Field from pystencils.simp.assignment_collection import AssignmentCollection +from pystencils.simp.simplifications import sympy_cse_on_assignment_list from pystencils.stencil import inverse_direction from pystencils.sympyextensions import fast_subs +from pystencils.astnodes import Block, Conditional def direction_indices_in_direction(direction, stencil): @@ -32,29 +34,6 @@ def boundary_substitutions(lb_method): return replacements -def transformed_boundary_rule(boundary, accessor, field, direction_symbol, lb_method, **kwargs): - tmp_field = field.new_field_with_different_name("_tmp") - rule = boundary(tmp_field, direction_symbol, lb_method, **kwargs) - bsubs = boundary_substitutions(lb_method) - rule = [a.subs(bsubs) for a in rule] - accessor_writes = accessor.write(tmp_field, lb_method.stencil) - to_replace = set() - for assignment in rule: - to_replace.update({fa for fa in assignment.atoms(Field.Access) if fa.field == tmp_field}) - - def compute_replacement(fa): - f = fa.index[0] - shift = accessor_writes[f].offsets - new_index = tuple(a + b for a, b in zip(fa.offsets, shift)) - return field[new_index](accessor_writes[f].index[0]) - - substitutions = {fa: compute_replacement(fa) for fa in to_replace} - all_assignments = [assignment.subs(substitutions) for assignment in rule] - main_assignments = [a for a in all_assignments if isinstance(a.lhs, Field.Access)] - sub_expressions = [a for a in all_assignments if not isinstance(a.lhs, Field.Access)] - return AssignmentCollection(main_assignments, sub_expressions) - - def type_all_numbers(expr, dtype): substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)} return expr.subs(substitutions) @@ -83,41 +62,55 @@ def border_conditions(direction, field, ghost_layers=1): return type_all_numbers(result, loop_ctr.dtype) -def read_assignments_with_boundaries(collision_rule, pdf_field, - boundary_spec, - prev_timestep_accessor, - current_timestep_accessor): - method = collision_rule.method - result = {a: [b, a] for a, b in zip(current_timestep_accessor.read(pdf_field, method.stencil), - method.pre_collision_pdf_symbols)} +def transformed_boundary_rule(boundary, accessor_func, field, direction_symbol, lb_method, **kwargs): + tmp_field = field.new_field_with_different_name("t") + rule = boundary(tmp_field, direction_symbol, lb_method, **kwargs) + bsubs = boundary_substitutions(lb_method) + rule = [a.subs(bsubs) for a in rule] + accessor_writes = accessor_func(tmp_field, lb_method.stencil) + to_replace = set() + for assignment in rule: + to_replace.update({fa for fa in assignment.rhs.atoms(Field.Access) if fa.field == tmp_field}) + + def compute_replacement(fa): + f = fa.index[0] + shift = accessor_writes[f].offsets + new_index = tuple(a + b for a, b in zip(fa.offsets, shift)) + return field[new_index](accessor_writes[f].index[0]) + + substitutions = {fa: compute_replacement(fa) for fa in to_replace} + all_assignments = [assignment.subs(substitutions) for assignment in rule] + main_assignments = [a for a in all_assignments if isinstance(a.lhs, Field.Access)] + sub_expressions = [a for a in all_assignments if not isinstance(a.lhs, Field.Access)] + assert len(main_assignments) == 1 + ac = AssignmentCollection(main_assignments, sub_expressions).new_without_subexpressions() + return ac.main_assignments[0].rhs + + +def read_assignments_with_boundaries(method, pdf_field, boundary_spec, pre_stream_access, read_access): + stencil = method.stencil + reads = [Assignment(*v) for v in zip(method.pre_collision_pdf_symbols, + read_access(pdf_field, method.stencil))] for direction, boundary in boundary_spec.items(): dir_indices = direction_indices_in_direction(direction, method.stencil) border_cond = border_conditions(direction, pdf_field, ghost_layers=1) for dir_index in dir_indices: - ac = transformed_boundary_rule(boundary, prev_timestep_accessor, pdf_field, dir_index, - method, index_field=None) - assignments = ac.new_without_subexpressions().main_assignments - assert len(assignments) == 1 - assignment = assignments[0] - assert assignment.lhs in result - - value_without_boundary = result[assignment.lhs][1] - result[assignment.lhs][1] = sp.Piecewise((assignment.rhs, border_cond), - (value_without_boundary, True)) + inv_index = stencil.index(inverse_direction(stencil[dir_index])) + value_from_boundary = transformed_boundary_rule(boundary, pre_stream_access, pdf_field, dir_index, + method, index_field=None) + value_without_boundary = reads[inv_index].rhs + new_rhs = sp.Piecewise((value_from_boundary, border_cond), + (value_without_boundary, True)) + reads[inv_index] = Assignment(reads[inv_index].lhs, new_rhs) - return AssignmentCollection([Assignment(*e) for e in result.values()]) + return AssignmentCollection(reads) def update_rule_with_boundaries(collision_rule, input_field, output_field, - boundaries, - accessor, prev_accessor=None): - if prev_accessor is None: - prev_accessor = accessor - - reads = read_assignments_with_boundaries(collision_rule, input_field, boundaries, - prev_timestep_accessor=prev_accessor, - current_timestep_accessor=accessor) + boundaries, accessor, pre_stream_access): + reads = read_assignments_with_boundaries(collision_rule.method, input_field, boundaries, + pre_stream_access, accessor.read) write_substitutions = {} method = collision_rule.method @@ -144,3 +137,47 @@ def update_rule_with_boundaries(collision_rule, input_field, output_field, result.simplification_hints['split_groups'] = new_split_groups return result + + +def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False, **kwargs): + stencil = lb_method.stencil + tmp_field = output_field.new_field_with_different_name("t") + + dir_indices = direction_indices_in_direction(direction, stencil) + + assignments = [] + for direction_idx in dir_indices: + rule = boundary(tmp_field, direction_idx, lb_method, **kwargs) + boundary_subs = boundary_substitutions(lb_method) + rule = [a.subs(boundary_subs) for a in rule] + + rhs_substitutions = {tmp_field(i): sym for i, sym in enumerate(lb_method.post_collision_pdf_symbols)} + offset = stencil[direction_idx] + inv_offset = inverse_direction(offset) + inv_idx = stencil.index(inv_offset) + + lhs_substitutions = { + tmp_field[offset](inv_idx): read_of_next_accessor(output_field, stencil)[inv_idx]} + rule = [Assignment(a.lhs.subs(lhs_substitutions), a.rhs.subs(rhs_substitutions)) for a in rule] + ac = AssignmentCollection([rule[-1]], rule[:-1]).new_without_subexpressions() + assignments += ac.main_assignments + + border_cond = border_conditions(direction, output_field, ghost_layers=1) + if cse: + assignments = sympy_cse_on_assignment_list(assignments) + return Conditional(border_cond, Block(assignments)) + + +def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, accessor, read_of_next_accessor): + if 'split_groups' in collision_rule.simplification_hints: + raise NotImplementedError("Split is not supported yet") + method = collision_rule.method + loads = [Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil))] + stores = [Assignment(a, b) for a, b in + zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)] + + result = loads + collision_rule.all_assignments + stores + for direction, boundary in boundary_spec.items(): + cond = boundary_conditional(boundary, direction, read_of_next_accessor, method, field) + result.append(cond) + return result \ No newline at end of file