Skip to content
Snippets Groups Projects
Commit fc285c14 authored by Markus Holzer's avatar Markus Holzer
Browse files

codegen.py fix: bug in the communication of corner directions

In the function comm_directions the corner directions were treated wrongly.
This fix also generalized comm_directions so it will also work for higher dimensions
parent 772fdd77
No related branches found
No related tags found
No related merge requests found
0.2.5.dev3+772fdd77df
\ No newline at end of file
...@@ -188,7 +188,7 @@ def generate_pack_info(generation_context, class_name: str, ...@@ -188,7 +188,7 @@ def generate_pack_info(generation_context, class_name: str,
fields_accessed = set() fields_accessed = set()
for terms in directions_to_pack_terms.values(): for terms in directions_to_pack_terms.values():
for term in terms: for term in terms:
assert isinstance(term, Field.Access) #and all(e == 0 for e in term.offsets) assert isinstance(term, Field.Access) # and all(e == 0 for e in term.offsets)
fields_accessed.add(term) fields_accessed.add(term)
field_names = {fa.field.name for fa in fields_accessed} field_names = {fa.field.name for fa in fields_accessed}
...@@ -252,7 +252,7 @@ def generate_pack_info(generation_context, class_name: str, ...@@ -252,7 +252,7 @@ def generate_pack_info(generation_context, class_name: str,
def generate_mpidtype_info_from_kernel(generation_context, class_name: str, def generate_mpidtype_info_from_kernel(generation_context, class_name: str,
assignments: Sequence[Assignment], kind='pull', namespace='pystencils',): assignments: Sequence[Assignment], kind='pull', namespace='pystencils', ):
assert kind in ('push', 'pull') assert kind in ('push', 'pull')
reads = set() reads = set()
writes = set() writes = set()
...@@ -345,10 +345,23 @@ def default_create_kernel_parameters(generation_context, params): ...@@ -345,10 +345,23 @@ def default_create_kernel_parameters(generation_context, params):
def comm_directions(direction): def comm_directions(direction):
yield direction if all(e == 0 for e in direction):
for i in range(len(direction)): yield direction
if direction[i] != 0: binary_numbers_list = binary_numbers(len(direction))
dir_as_list = list(direction) for comm_direction in binary_numbers_list:
dir_as_list[i] = 0 for i in range(len(direction)):
if not all(e == 0 for e in dir_as_list): if direction[i] == 0:
yield tuple(dir_as_list) comm_direction[i] = 0
if direction[i] == -1 and comm_direction[i] == 1:
comm_direction[i] = -1
if not all(e == 0 for e in comm_direction):
yield tuple(comm_direction)
def binary_numbers(n):
result = list()
for i in range(1 << n):
binary_number = bin(i)[2:]
binary_number = '0' * (n - len(binary_number)) + binary_number
result.append((list(map(int, binary_number))))
return result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment