Skip to content
Snippets Groups Projects
Commit 64eda171 authored by Markus Holzer's avatar Markus Holzer
Browse files

test

parent 3887bb4f
No related branches found
No related tags found
No related merge requests found
Pipeline #55583 failed
from typing import Union
import sympy as sp
import pystencils.astnodes as ast
......@@ -17,7 +15,7 @@ from pystencils.transformations import (
resolve_field_accesses, split_inner_loop)
def create_kernel(assignments: Union[NodeCollection],
def create_kernel(assignments: NodeCollection,
config: CreateKernelConfig) -> KernelFunction:
"""Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
......@@ -114,40 +112,24 @@ def create_indexed_kernel(assignments: NodeCollection,
fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields
Index_IDBefore = ""
for index_field in index_fields:
Index_IDBefore += f"index field name: {index_field.name}: {str(id(index_field))}"
all_fields = fields_read.union(fields_written)
# extract the index fields based on the name. The original index field might have been modified
index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
f"Non index fields are {non_index_fields}, spatial coordinates are " \
f"{spatial_coordinates}"
spatial_coordinates = list(spatial_coordinates)[0]
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written)
AllFieldsIDSBefore = ""
for field in all_fields:
AllFieldsIDSBefore += f"field name: {field.name}: {str(id(field))}"
for index_field in index_fields:
index_field.field_type = FieldType.INDEXED
assert FieldType.is_indexed(index_field)
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
Index_IDAfter = ""
for index_field in index_fields:
Index_IDAfter += f"index field name: {index_field.name}: {str(id(index_field))}"
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
f"Non index fields are {non_index_fields}, spatial coordinates are " \
f"{spatial_coordinates} and len(spatial coordiantes) is " \
f"{len(spatial_coordinates)}, " \
f"index_fields in config: {config.index_fields}, " \
f"extracted index fields: {index_fields} " \
f"Index_IDBefore {Index_IDBefore} " \
f"Index_IDAfter {Index_IDAfter} " \
f"AllFieldsIDSBefore {AllFieldsIDSBefore}"
spatial_coordinates = list(spatial_coordinates)[0]
def get_coordinate_symbol_assignment(name):
for idx_field in index_fields:
assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment