Skip to content
Snippets Groups Projects

Fix integration pipeline

3 files
+ 34
37
Compare changes
  • Side-by-side
  • Inline
Files
3
from typing import Union
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.config import CreateKernelConfig
from pystencils.config import CreateKernelConfig
from pystencils.typing import StructType, TypedSymbol
from pystencils.typing import StructType, TypedSymbol
@@ -9,15 +7,13 @@ from pystencils.enums import Target, Backend
@@ -9,15 +7,13 @@ from pystencils.enums import Target, Backend
from pystencils.gpu.gpujit import make_python_function
from pystencils.gpu.gpujit import make_python_function
from pystencils.node_collection import NodeCollection
from pystencils.node_collection import NodeCollection
from pystencils.gpu.indexing import indexing_creator_from_params
from pystencils.gpu.indexing import indexing_creator_from_params
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.slicing import normalize_slice
from pystencils.transformations import (
from pystencils.transformations import (
get_base_buffer_index, get_common_field, parse_base_pointer_info,
get_base_buffer_index, get_common_field, parse_base_pointer_info,
resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols)
def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
config: CreateKernelConfig):
function_name = config.function_name
function_name = config.function_name
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
@@ -114,31 +110,24 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
@@ -114,31 +110,24 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
return ast
return ast
def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
def created_indexed_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
config: CreateKernelConfig):
index_fields = config.index_fields
index_fields = config.index_fields
function_name = config.function_name
function_name = config.function_name
coordinate_names = config.coordinate_names
coordinate_names = config.coordinate_names
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params)
fields_written = assignments.bound_fields
fields_written = assignments.bound_fields
fields_read = assignments.rhs_fields
fields_read = assignments.rhs_fields
assignments = assignments.all_assignments
assignments = add_types(assignments, config)
all_fields = fields_read.union(fields_written)
all_fields = fields_read.union(fields_written)
read_only_fields = set([f.name for f in fields_read - fields_written])
read_only_fields = set([f.name for f in fields_read - fields_written])
# extract the index fields based on the name. The original index field might have been modified
for index_field in index_fields:
index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
index_field.field_type = FieldType.INDEXED
assert FieldType.is_indexed(index_field)
assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"
non_index_fields = [f for f in all_fields if f not in index_fields]
non_index_fields = [f for f in all_fields if f not in index_fields]
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates"
assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
 
f"Non index fields are {non_index_fields}, spatial coordinates are " \
 
f"{spatial_coordinates}"
spatial_coordinates = list(spatial_coordinates)[0]
spatial_coordinates = list(spatial_coordinates)[0]
def get_coordinate_symbol_assignment(name):
def get_coordinate_symbol_assignment(name):
Loading