Skip to content
Snippets Groups Projects
Commit 4179a44c authored by Martin Bauer's avatar Martin Bauer
Browse files

In-kernel boundaries using push

parent 183729f9
Branches
Tags
No related merge requests found
......@@ -6,8 +6,10 @@ from pystencils.astnodes import LoopOverCoordinate
from pystencils.data_types import cast_func
from pystencils.field import Field
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.simp.simplifications import sympy_cse_on_assignment_list
from pystencils.stencil import inverse_direction
from pystencils.sympyextensions import fast_subs
from pystencils.astnodes import Block, Conditional
def direction_indices_in_direction(direction, stencil):
......@@ -32,29 +34,6 @@ def boundary_substitutions(lb_method):
return replacements
def transformed_boundary_rule(boundary, accessor, field, direction_symbol, lb_method, **kwargs):
tmp_field = field.new_field_with_different_name("_tmp")
rule = boundary(tmp_field, direction_symbol, lb_method, **kwargs)
bsubs = boundary_substitutions(lb_method)
rule = [a.subs(bsubs) for a in rule]
accessor_writes = accessor.write(tmp_field, lb_method.stencil)
to_replace = set()
for assignment in rule:
to_replace.update({fa for fa in assignment.atoms(Field.Access) if fa.field == tmp_field})
def compute_replacement(fa):
f = fa.index[0]
shift = accessor_writes[f].offsets
new_index = tuple(a + b for a, b in zip(fa.offsets, shift))
return field[new_index](accessor_writes[f].index[0])
substitutions = {fa: compute_replacement(fa) for fa in to_replace}
all_assignments = [assignment.subs(substitutions) for assignment in rule]
main_assignments = [a for a in all_assignments if isinstance(a.lhs, Field.Access)]
sub_expressions = [a for a in all_assignments if not isinstance(a.lhs, Field.Access)]
return AssignmentCollection(main_assignments, sub_expressions)
def type_all_numbers(expr, dtype):
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
......@@ -83,41 +62,55 @@ def border_conditions(direction, field, ghost_layers=1):
return type_all_numbers(result, loop_ctr.dtype)
def read_assignments_with_boundaries(collision_rule, pdf_field,
boundary_spec,
prev_timestep_accessor,
current_timestep_accessor):
method = collision_rule.method
result = {a: [b, a] for a, b in zip(current_timestep_accessor.read(pdf_field, method.stencil),
method.pre_collision_pdf_symbols)}
def transformed_boundary_rule(boundary, accessor_func, field, direction_symbol, lb_method, **kwargs):
tmp_field = field.new_field_with_different_name("t")
rule = boundary(tmp_field, direction_symbol, lb_method, **kwargs)
bsubs = boundary_substitutions(lb_method)
rule = [a.subs(bsubs) for a in rule]
accessor_writes = accessor_func(tmp_field, lb_method.stencil)
to_replace = set()
for assignment in rule:
to_replace.update({fa for fa in assignment.rhs.atoms(Field.Access) if fa.field == tmp_field})
def compute_replacement(fa):
f = fa.index[0]
shift = accessor_writes[f].offsets
new_index = tuple(a + b for a, b in zip(fa.offsets, shift))
return field[new_index](accessor_writes[f].index[0])
substitutions = {fa: compute_replacement(fa) for fa in to_replace}
all_assignments = [assignment.subs(substitutions) for assignment in rule]
main_assignments = [a for a in all_assignments if isinstance(a.lhs, Field.Access)]
sub_expressions = [a for a in all_assignments if not isinstance(a.lhs, Field.Access)]
assert len(main_assignments) == 1
ac = AssignmentCollection(main_assignments, sub_expressions).new_without_subexpressions()
return ac.main_assignments[0].rhs
def read_assignments_with_boundaries(method, pdf_field, boundary_spec, pre_stream_access, read_access):
stencil = method.stencil
reads = [Assignment(*v) for v in zip(method.pre_collision_pdf_symbols,
read_access(pdf_field, method.stencil))]
for direction, boundary in boundary_spec.items():
dir_indices = direction_indices_in_direction(direction, method.stencil)
border_cond = border_conditions(direction, pdf_field, ghost_layers=1)
for dir_index in dir_indices:
ac = transformed_boundary_rule(boundary, prev_timestep_accessor, pdf_field, dir_index,
method, index_field=None)
assignments = ac.new_without_subexpressions().main_assignments
assert len(assignments) == 1
assignment = assignments[0]
assert assignment.lhs in result
value_without_boundary = result[assignment.lhs][1]
result[assignment.lhs][1] = sp.Piecewise((assignment.rhs, border_cond),
(value_without_boundary, True))
inv_index = stencil.index(inverse_direction(stencil[dir_index]))
value_from_boundary = transformed_boundary_rule(boundary, pre_stream_access, pdf_field, dir_index,
method, index_field=None)
value_without_boundary = reads[inv_index].rhs
new_rhs = sp.Piecewise((value_from_boundary, border_cond),
(value_without_boundary, True))
reads[inv_index] = Assignment(reads[inv_index].lhs, new_rhs)
return AssignmentCollection([Assignment(*e) for e in result.values()])
return AssignmentCollection(reads)
def update_rule_with_boundaries(collision_rule, input_field, output_field,
boundaries,
accessor, prev_accessor=None):
if prev_accessor is None:
prev_accessor = accessor
reads = read_assignments_with_boundaries(collision_rule, input_field, boundaries,
prev_timestep_accessor=prev_accessor,
current_timestep_accessor=accessor)
boundaries, accessor, pre_stream_access):
reads = read_assignments_with_boundaries(collision_rule.method, input_field, boundaries,
pre_stream_access, accessor.read)
write_substitutions = {}
method = collision_rule.method
......@@ -144,3 +137,47 @@ def update_rule_with_boundaries(collision_rule, input_field, output_field,
result.simplification_hints['split_groups'] = new_split_groups
return result
def boundary_conditional(boundary, direction, read_of_next_accessor, lb_method, output_field, cse=False, **kwargs):
stencil = lb_method.stencil
tmp_field = output_field.new_field_with_different_name("t")
dir_indices = direction_indices_in_direction(direction, stencil)
assignments = []
for direction_idx in dir_indices:
rule = boundary(tmp_field, direction_idx, lb_method, **kwargs)
boundary_subs = boundary_substitutions(lb_method)
rule = [a.subs(boundary_subs) for a in rule]
rhs_substitutions = {tmp_field(i): sym for i, sym in enumerate(lb_method.post_collision_pdf_symbols)}
offset = stencil[direction_idx]
inv_offset = inverse_direction(offset)
inv_idx = stencil.index(inv_offset)
lhs_substitutions = {
tmp_field[offset](inv_idx): read_of_next_accessor(output_field, stencil)[inv_idx]}
rule = [Assignment(a.lhs.subs(lhs_substitutions), a.rhs.subs(rhs_substitutions)) for a in rule]
ac = AssignmentCollection([rule[-1]], rule[:-1]).new_without_subexpressions()
assignments += ac.main_assignments
border_cond = border_conditions(direction, output_field, ghost_layers=1)
if cse:
assignments = sympy_cse_on_assignment_list(assignments)
return Conditional(border_cond, Block(assignments))
def update_rule_with_push_boundaries(collision_rule, field, boundary_spec, accessor, read_of_next_accessor):
if 'split_groups' in collision_rule.simplification_hints:
raise NotImplementedError("Split is not supported yet")
method = collision_rule.method
loads = [Assignment(a, b) for a, b in zip(method.pre_collision_pdf_symbols, accessor.read(field, method.stencil))]
stores = [Assignment(a, b) for a, b in
zip(accessor.write(field, method.stencil), method.post_collision_pdf_symbols)]
result = loads + collision_rule.all_assignments + stores
for direction, boundary in boundary_spec.items():
cond = boundary_conditional(boundary, direction, read_of_next_accessor, method, field)
result.append(cond)
return result
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment