Skip to content
Snippets Groups Projects

[BugFix] Fix indexing with ghostlayers

Closed Markus Holzer requested to merge holzer/pystencils:FixMod into master
Files
6
from typing import Union
import sympy as sp
import pystencils.astnodes as ast
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.config import CreateKernelConfig
from pystencils.enums import Target, Backend
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
@@ -18,7 +15,7 @@ from pystencils.transformations import (
resolve_field_accesses, split_inner_loop)
def create_kernel(assignments: Union[NodeCollection],
def create_kernel(assignments: NodeCollection,
config: CreateKernelConfig) -> KernelFunction:
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
@@ -94,7 +91,7 @@ def create_kernel(assignments: Union[NodeCollection],
return ast_node
def create_indexed_kernel(assignments: Union[AssignmentCollection, NodeCollection],
def create_indexed_kernel(assignments: NodeCollection,
config: CreateKernelConfig) -> KernelFunction:
"""
Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
@@ -115,21 +112,24 @@ def create_indexed_kernel(assignments: Union[AssignmentCollection, NodeCollectio
fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields
all_fields = fields_read.union(fields_written)
# extract the index fields based on the name. The original index field might have been modified
index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
f"Non index fields are {non_index_fields}, spatial coordinates are " \
f"{spatial_coordinates}"
spatial_coordinates = list(spatial_coordinates)[0]
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written)
for index_field in index_fields:
index_field.field_type = FieldType.INDEXED
assert FieldType.is_indexed(index_field)
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
spatial_coordinates = list(spatial_coordinates)[0]
def get_coordinate_symbol_assignment(name):
for idx_field in index_fields:
assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
Loading