diff --git a/lbmpy/boundaries_in_kernel.py b/lbmpy/boundaries/boundaries_in_kernel.py similarity index 63% rename from lbmpy/boundaries_in_kernel.py rename to lbmpy/boundaries/boundaries_in_kernel.py index a4719be4baa0b5691e8797986d534f5c871d3554..9be2062643d456f1eddbe7919244dae62300ab97 100644 --- a/lbmpy/boundaries_in_kernel.py +++ b/lbmpy/boundaries/boundaries_in_kernel.py @@ -2,14 +2,13 @@ import sympy as sp from lbmpy.boundaries.boundaryhandling import BoundaryOffsetInfo, LbmWeightInfo from pystencils.assignment import Assignment -from pystencils.astnodes import LoopOverCoordinate -from pystencils.data_types import cast_func +from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment +from pystencils.data_types import type_all_numbers from pystencils.field import Field from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.simplifications import sympy_cse_on_assignment_list from pystencils.stencil import inverse_direction from pystencils.sympyextensions import fast_subs -from pystencils.astnodes import Block, Conditional def direction_indices_in_direction(direction, stencil): @@ -34,11 +33,6 @@ def boundary_substitutions(lb_method): return replacements -def type_all_numbers(expr, dtype): - substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)} - return expr.subs(substitutions) - - def border_conditions(direction, field, ghost_layers=1): abs_direction = tuple(-e if e < 0 else e for e in direction) assert sum(abs_direction) == 1 @@ -87,59 +81,7 @@ def transformed_boundary_rule(boundary, accessor_func, field, direction_symbol, return ac.main_assignments[0].rhs -def read_assignments_with_boundaries(method, pdf_field, boundary_spec, pre_stream_access, read_access): - stencil = method.stencil - reads = [Assignment(*v) for v in zip(method.pre_collision_pdf_symbols, - read_access(pdf_field, method.stencil))] - - for direction, boundary in boundary_spec.items(): - dir_indices = direction_indices_in_direction(direction, method.stencil) - border_cond = border_conditions(direction, pdf_field, ghost_layers=1) - for dir_index in dir_indices: - inv_index = stencil.index(inverse_direction(stencil[dir_index])) - value_from_boundary = transformed_boundary_rule(boundary, pre_stream_access, pdf_field, dir_index, - method, index_field=None) - value_without_boundary = reads[inv_index].rhs - new_rhs = sp.Piecewise((value_from_boundary, border_cond), - (value_without_boundary, True)) - reads[inv_index] = Assignment(reads[inv_index].lhs, new_rhs) - - return AssignmentCollection(reads) - - -def update_rule_with_boundaries(collision_rule, input_field, output_field, - boundaries, accessor, pre_stream_access): - reads = read_assignments_with_boundaries(collision_rule.method, input_field, boundaries, - pre_stream_access, accessor.read) - - write_substitutions = {} - method = collision_rule.method - post_collision_symbols = method.post_collision_pdf_symbols - pre_collision_symbols = method.pre_collision_pdf_symbols - - output_accesses = accessor.write(output_field, method.stencil) - input_accesses = accessor.read(input_field, method.stencil) - - for (idx, offset), output_access in zip(enumerate(method.stencil), output_accesses): - write_substitutions[post_collision_symbols[idx]] = output_access - - result = collision_rule.new_with_substitutions(write_substitutions) - result.subexpressions = reads.all_assignments + result.subexpressions - - if 'split_groups' in result.simplification_hints: - all_substitutions = write_substitutions.copy() - for (idx, offset), input_access in zip(enumerate(method.stencil), input_accesses): - all_substitutions[pre_collision_symbols[idx]] = input_access - - new_split_groups = [] - for split_group in result.simplification_hints['split_groups']: - new_split_groups.append([fast_subs(e, all_substitutions) for e in split_group]) - result.simplification_hints['split_groups'] = new_split_groups - - return result - - -def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False, **kwargs): +def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False): stencil = lb_method.stencil tmp_field = output_field.new_field_with_different_name("t") @@ -147,7 +89,7 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, assignments = [] for direction_idx in dir_indices: - rule = boundary(tmp_field, direction_idx, lb_method, **kwargs) + rule = boundary(tmp_field, direction_idx, lb_method, index_field=None) boundary_subs = boundary_substitutions(lb_method) rule = [a.subs(boundary_subs) for a in rule] @@ -165,19 +107,28 @@ def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, border_cond = border_conditions(direction, output_field, ghost_layers=1) if cse: assignments = sympy_cse_on_assignment_list(assignments) + assignments = [SympyAssignment(a.lhs, a.rhs) for a in assignments] return Conditional(border_cond, Block(assignments)) def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, accessor, read_of_next_accessor): - if 'split_groups' in collision_rule.simplification_hints: - raise NotImplementedError("Split is not supported yet") method = collision_rule.method loads = [Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil))] stores = [Assignment(a, b) for a, b in zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)] - result = loads + collision_rule.all_assignments + stores + result = collision_rule.copy() + result.subexpressions = loads + result.subexpressions + result.main_assignments += stores for direction, boundary in boundary_spec.items(): cond = boundary_conditional(boundary, direction, read_of_next_accessor, method, field) - result.append(cond) - return result \ No newline at end of file + result.main_assignments.append(cond) + + if 'split_groups' in result.simplification_hints: + substitutions = {b: a for a, b in zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)} + new_split_groups = [] + for split_group in result.simplification_hints['split_groups']: + new_split_groups.append([fast_subs(e, substitutions) for e in split_group]) + result.simplification_hints['split_groups'] = new_split_groups + + return result