diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 0c49160c8ba6beb4e9aa58707ec2697d0f47663a..f50ef65a938f48c9102b97b8969fe99019b59b70 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -303,7 +303,7 @@ class SkipIteration(Node): class Block(Node): - def __init__(self, nodes: List[Node]): + def __init__(self, nodes: Union[Node, List[Node]]): super(Block, self).__init__() if not isinstance(nodes, list): nodes = [nodes] diff --git a/pystencils/gpu/kernelcreation.py b/pystencils/gpu/kernelcreation.py index c0d6e71d05e3a2250249224e3a12e3daa30978d0..e3ad451bb154c2cee279713fdfd2167a4268b3fb 100644 --- a/pystencils/gpu/kernelcreation.py +++ b/pystencils/gpu/kernelcreation.py @@ -1,5 +1,3 @@ -from typing import Union - from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.config import CreateKernelConfig from pystencils.typing import StructType, TypedSymbol @@ -9,15 +7,13 @@ from pystencils.enums import Target, Backend from pystencils.gpu.gpujit import make_python_function from pystencils.node_collection import NodeCollection from pystencils.gpu.indexing import indexing_creator_from_params -from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.slicing import normalize_slice from pystencils.transformations import ( get_base_buffer_index, get_common_field, parse_base_pointer_info, resolve_buffer_accesses, resolve_field_accesses, unify_shape_symbols) -def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], - config: CreateKernelConfig): +def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig): function_name = config.function_name indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params) @@ -114,31 +110,24 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], return ast -def created_indexed_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], - config: CreateKernelConfig): +def created_indexed_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig): index_fields = config.index_fields function_name = config.function_name coordinate_names = config.coordinate_names indexing_creator = indexing_creator_from_params(config.gpu_indexing, config.gpu_indexing_params) - fields_written = assignments.bound_fields fields_read = assignments.rhs_fields - assignments = assignments.all_assignments - - assignments = add_types(assignments, config) all_fields = fields_read.union(fields_written) read_only_fields = set([f.name for f in fields_read - fields_written]) - - for index_field in index_fields: - index_field.field_type = FieldType.INDEXED - assert FieldType.is_indexed(index_field) - assert index_field.spatial_dimensions == 1, "Index fields have to be 1D" - + # extract the index fields based on the name. The original index field might have been modified + index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]] non_index_fields = [f for f in all_fields if f not in index_fields] spatial_coordinates = {f.spatial_dimensions for f in non_index_fields} - assert len(spatial_coordinates) == 1, "Non-index fields do not have the same number of spatial coordinates" + assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \ + f"Non index fields are {non_index_fields}, spatial coordinates are " \ + f"{spatial_coordinates}" spatial_coordinates = list(spatial_coordinates)[0] def get_coordinate_symbol_assignment(name): diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py index 352406566e47c9279bd486410d34f6b6b9bfff53..e0af05fd055bf8df0443c7f9401e352138f6b303 100644 --- a/pystencils/node_collection.py +++ b/pystencils/node_collection.py @@ -1,34 +1,42 @@ -from collections.abc import Iterable from typing import Any, Dict, List, Union, Optional, Set import sympy import sympy as sp -from sympy.codegen.ast import Assignment, AddAugmentedAssignment from sympy.codegen.rewriting import ReplaceOptim, optimize -from pystencils.astnodes import Block, Node, SympyAssignment +from pystencils.assignment import Assignment, AddAugmentedAssignment +import pystencils.astnodes as ast from pystencils.backends.cbackend import CustomCodeNode from pystencils.functions import DivFunc from pystencils.simp import AssignmentCollection class NodeCollection: - def __init__(self, assignments: List[Union[Node, Assignment]], + def __init__(self, assignments: List[Union[ast.Node, Assignment]], simplification_hints: Optional[Dict[str, Any]] = None, bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None): - nodes = list() - assignments = [assignments, ] if not isinstance(assignments, Iterable) else assignments - for assignment in assignments: - if isinstance(assignment, Assignment): - nodes.append(SympyAssignment(assignment.lhs, assignment.rhs)) - elif isinstance(assignment, AddAugmentedAssignment): - nodes.append(SympyAssignment(assignment.lhs, assignment.lhs + assignment.rhs)) - elif isinstance(assignment, Node): - nodes.append(assignment) + + def visit(obj): + if isinstance(obj, (list, tuple)): + return [visit(e) for e in obj] + if isinstance(obj, Assignment): + return ast.SympyAssignment(obj.lhs, obj.rhs) + elif isinstance(obj, AddAugmentedAssignment): + return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs) + elif isinstance(obj, ast.SympyAssignment): + return obj + elif isinstance(obj, ast.Conditional): + true_block = visit(obj.true_block) + false_block = None if obj.false_block is None else visit(obj.false_block) + return ast.Conditional(obj.condition_expr, true_block=true_block, false_block=false_block) + elif isinstance(obj, ast.Block): + return ast.Block([visit(e) for e in obj.args]) + elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate): + return obj else: - raise ValueError(f"Unknown node in the AssignmentCollection: {assignment}") + raise ValueError("Invalid object in the List of Assignments " + str(type(obj))) - self.all_assignments = nodes + self.all_assignments = visit(assignments) self.simplification_hints = simplification_hints if simplification_hints else {} self.bound_fields = bound_fields if bound_fields else {} self.rhs_fields = rhs_fields if rhs_fields else {} @@ -57,13 +65,13 @@ class NodeCollection: def visitor(node): if isinstance(node, CustomCodeNode): return node - elif isinstance(node, Block): + elif isinstance(node, ast.Block): return node.func([visitor(child) for child in node.args]) - elif isinstance(node, SympyAssignment): + elif isinstance(node, ast.SympyAssignment): new_lhs = visitor(node.lhs) new_rhs = visitor(node.rhs) return node.func(new_lhs, new_rhs, node.is_const, node.use_auto) - elif isinstance(node, Node): + elif isinstance(node, ast.Node): return node.func(*[visitor(child) for child in node.args]) elif isinstance(node, sympy.Basic): return optimize(node, sympy_optimisations)