Skip to content
Snippets Groups Projects

[BugFix] Fix indexing with ghostlayers

Closed Markus Holzer requested to merge holzer/pystencils:FixMod into master
Files
6
from typing import Union
import sympy as sp
import sympy as sp
import pystencils.astnodes as ast
import pystencils.astnodes as ast
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.config import CreateKernelConfig
from pystencils.config import CreateKernelConfig
from pystencils.enums import Target, Backend
from pystencils.enums import Target, Backend
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
@@ -18,7 +15,7 @@ from pystencils.transformations import (
@@ -18,7 +15,7 @@ from pystencils.transformations import (
resolve_field_accesses, split_inner_loop)
resolve_field_accesses, split_inner_loop)
def create_kernel(assignments: Union[NodeCollection],
def create_kernel(assignments: NodeCollection,
config: CreateKernelConfig) -> KernelFunction:
config: CreateKernelConfig) -> KernelFunction:
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
@@ -94,7 +91,7 @@ def create_kernel(assignments: Union[NodeCollection],
@@ -94,7 +91,7 @@ def create_kernel(assignments: Union[NodeCollection],
return ast_node
return ast_node
def create_indexed_kernel(assignments: Union[AssignmentCollection, NodeCollection],
def create_indexed_kernel(assignments: NodeCollection,
config: CreateKernelConfig) -> KernelFunction:
config: CreateKernelConfig) -> KernelFunction:
"""
"""
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
@@ -115,21 +112,24 @@ def create_indexed_kernel(assignments: Union[AssignmentCollection, NodeCollectio
@@ -115,21 +112,24 @@ def create_indexed_kernel(assignments: Union[AssignmentCollection, NodeCollectio
fields_written = assignments.bound_fields
fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields
fields_read = assignments.rhs_fields
 
all_fields = fields_read.union(fields_written)
 
# extract the index fields based on the name. The original index field might have been modified
 
index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
 
non_index_fields = [f for f in all_fields if f not in index_fields]
 
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
 
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
 
f"Non index fields are {non_index_fields}, spatial coordinates are " \
 
f"{spatial_coordinates}"
 
spatial_coordinates = list(spatial_coordinates)[0]
 
assignments = assignments.all_assignments
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written)
for index_field in index_fields:
for index_field in index_fields:
index_field.field_type = FieldType.INDEXED
index_field.field_type = FieldType.INDEXED
assert FieldType.is_indexed(index_field)
assert FieldType.is_indexed(index_field)
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
spatial_coordinates = list(spatial_coordinates)[0]
def get_coordinate_symbol_assignment(name):
def get_coordinate_symbol_assignment(name):
for idx_field in index_fields:
for idx_field in index_fields:
assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
Loading