Skip to content
Snippets Groups Projects
Forked from pycodegen / pystencils
65 commits behind the upstream repository.
kernelcreation.py 10.99 KiB
import sympy as sp

import pystencils.astnodes as ast
from pystencils.config import CreateKernelConfig
from pystencils.enums import Target, Backend
from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
from pystencils.cpu.cpujit import make_python_function
from pystencils.typing import StructType, TypedSymbol, create_type
from pystencils.typing.transformations import add_types
from pystencils.field import Field, FieldType
from pystencils.node_collection import NodeCollection
from pystencils.transformations import (
    filtered_tree_iteration, iterate_loops_by_depth, get_base_buffer_index, get_optimal_loop_ordering,
    make_loop_over_domain, move_constants_before_loop, parse_base_pointer_info, resolve_buffer_accesses,
    resolve_field_accesses, split_inner_loop)


def create_kernel(assignments: NodeCollection,
                  config: CreateKernelConfig) -> KernelFunction:
    """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.

    Loops are created according to the field accesses in the equations.

    Args:
        assignments: list of sympy equations, containing accesses to :class:`pystencils.field.Field`.
        Defining the update rules of the kernel
        config: create kernel config

    Returns:
        AST node representing a function, that can be printed as C or CUDA code
    """
    function_name = config.function_name
    iteration_slice = config.iteration_slice
    ghost_layers = config.ghost_layers
    fields_written = assignments.bound_fields
    fields_read = assignments.rhs_fields

    split_groups = ()
    if 'split_groups' in assignments.simplification_hints:
        split_groups = assignments.simplification_hints['split_groups']
    assignments = assignments.all_assignments

    # TODO Cleanup: move add_types to create_domain_kernel or create_kernel
    assignments = add_types(assignments, config)

    all_fields = fields_read.union(fields_written)
    read_only_fields = set([f.name for f in fields_read - fields_written])

    buffers = set([f for f in all_fields if FieldType.is_buffer(f)])
    fields_without_buffers = all_fields - buffers

    body = ast.Block(assignments)
    loop_order = get_optimal_loop_ordering(fields_without_buffers)
    loop_node, ghost_layer_info = make_loop_over_domain(body, iteration_slice=iteration_slice,
                                                        ghost_layers=ghost_layers, loop_order=loop_order)
    ast_node = KernelFunction(loop_node, Target.CPU, Backend.C, compile_function=make_python_function,
                              ghost_layers=ghost_layer_info, function_name=function_name, assignments=assignments)

    if split_groups:
        type_info = config.data_type

        def type_symbol(term):
            if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
                return term
            elif isinstance(term, sp.Symbol):
                if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
                    return TypedSymbol(term.name, create_type(type_info))
                else:
                    return TypedSymbol(term.name, type_info[term.name])
            else:
                raise ValueError("Term has to be field access or symbol")

        typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
        split_inner_loop(ast_node, typed_split_groups)

    base_pointer_spec = config.base_pointer_specification
    if base_pointer_spec is None:
        base_pointer_spec = []
    base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
                                                             field.spatial_dimensions, field.index_dimensions)
                         for field in fields_without_buffers}

    buffer_base_pointer_info = {field.name: parse_base_pointer_info([['spatialInner0']], [0],
                                                                    field.spatial_dimensions, field.index_dimensions)
                                for field in buffers}
    base_pointer_info.update(buffer_base_pointer_info)

    if any(FieldType.is_buffer(f) for f in all_fields):
        resolve_buffer_accesses(ast_node, get_base_buffer_index(ast_node), read_only_fields)
    # TODO think about typing
    resolve_field_accesses(ast_node, read_only_fields, field_to_base_pointer_info=base_pointer_info)
    move_constants_before_loop(ast_node)
    return ast_node


def create_indexed_kernel(assignments: NodeCollection,
                          config: CreateKernelConfig) -> KernelFunction:
    """
    Similar to :func:`create_kernel`, but here not all cells of a field are updated but only cells with
    coordinates which are stored in an index field. This traversal method can e.g. be used for boundary handling.

    The coordinates are stored in a separate index_field, which is a one dimensional array with struct data type.
    This struct has to contain fields named 'x', 'y' and for 3D fields ('z'). These names are configurable with the
    'coordinate_names' parameter. The struct can have also other fields that can be read and written in the kernel, for
    example boundary parameters.

    Args:
        assignments: list of assignments
        config: Kernel configuration
    """
    function_name = config.function_name
    index_fields = config.index_fields
    coordinate_names = config.coordinate_names
    fields_written = assignments.bound_fields
    fields_read = assignments.rhs_fields

    all_fields = fields_read.union(fields_written)
    # extract the index fields based on the name. The original index field might have been modified
    index_fields = [idx_field for idx_field in index_fields if idx_field.name in [f.name for f in all_fields]]
    non_index_fields = [f for f in all_fields if f not in index_fields]
    spatial_coordinates = {f.spatial_dimensions for f in non_index_fields}
    assert len(spatial_coordinates) == 1, f"Non-index fields do not have the same number of spatial coordinates " \
                                          f"Non index fields are {non_index_fields}, spatial coordinates are " \
                                          f"{spatial_coordinates}"
    spatial_coordinates = list(spatial_coordinates)[0]

    assignments = assignments.all_assignments
    assignments = add_types(assignments, config)

    for index_field in index_fields:
        index_field.field_type = FieldType.INDEXED
        assert FieldType.is_indexed(index_field)
        assert index_field.spatial_dimensions == 1, "Index fields have to be 1D"

    def get_coordinate_symbol_assignment(name):
        for idx_field in index_fields:
            assert isinstance(idx_field.dtype, StructType), "Index fields have to have a struct data type"
            data_type = idx_field.dtype
            if data_type.has_element(name):
                rhs = idx_field[0](name)
                lhs = TypedSymbol(name, data_type.get_element_type(name))
                return SympyAssignment(lhs, rhs)
        raise ValueError(f"Index {name} not found in any of the passed index fields")

    coordinate_symbol_assignments = [get_coordinate_symbol_assignment(n)
                                     for n in coordinate_names[:spatial_coordinates]]
    coordinate_typed_symbols = [eq.lhs for eq in coordinate_symbol_assignments]
    assignments = coordinate_symbol_assignments + assignments

    # make 1D loop over index fields
    loop_body = Block([])
    loop_node = LoopOverCoordinate(loop_body, coordinate_to_loop_over=0, start=0, stop=index_fields[0].shape[0])

    for assignment in assignments:
        loop_body.append(assignment)

    function_body = Block([loop_node])
    ast_node = KernelFunction(function_body, Target.CPU, Backend.C, make_python_function,
                              ghost_layers=None, function_name=function_name, assignments=assignments)

    fixed_coordinate_mapping = {f.name: coordinate_typed_symbols for f in non_index_fields}

    read_only_fields = set([f.name for f in fields_read - fields_written])
    resolve_field_accesses(ast_node, read_only_fields, field_to_fixed_coordinates=fixed_coordinate_mapping)
    move_constants_before_loop(ast_node)
    return ast_node


def add_openmp(ast_node, schedule="static", num_threads=True, collapse=None, assume_single_outer_loop=True):
    """Parallelize the outer loop with OpenMP.

    Args:
        ast_node: abstract syntax tree created e.g. by :func:`create_kernel`
        schedule: OpenMP scheduling policy e.g. 'static' or 'dynamic'
        num_threads: explicitly specify number of threads
        collapse: number of nested loops to include in parallel region (see OpenMP collapse)
        assume_single_outer_loop: if True an exception is raised if multiple outer loops are detected for all but
                                  optimized staggered kernels the single-outer-loop assumption should be true
    """
    if not num_threads:
        return

    assert type(ast_node) is ast.KernelFunction
    body = ast_node.body
    threads_clause = "" if num_threads and isinstance(num_threads, bool) else f" num_threads({num_threads})"
    wrapper_block = ast.PragmaBlock('#pragma omp parallel' + threads_clause, body.take_child_nodes())
    body.append(wrapper_block)

    outer_loops = [l for l in filtered_tree_iteration(body, LoopOverCoordinate, stop_type=SympyAssignment)
                   if l.is_outermost_loop]
    assert outer_loops, "No outer loop found"
    if assume_single_outer_loop and len(outer_loops) > 1:
        raise ValueError("More than one outer loop found, only one outer loop expected")

    for loop_to_parallelize in outer_loops:
        try:
            loop_range = int(loop_to_parallelize.stop - loop_to_parallelize.start)
        except TypeError:
            loop_range = None

        if loop_range is not None and loop_range < num_threads and not collapse:
            contained_loops = [l for l in loop_to_parallelize.body.args if isinstance(l, LoopOverCoordinate)]
            if len(contained_loops) == 1:
                contained_loop = contained_loops[0]
                try:
                    contained_loop_range = int(contained_loop.stop - contained_loop.start)
                    if contained_loop_range > loop_range:
                        loop_to_parallelize = contained_loop
                except TypeError:
                    pass

        prefix = f"#pragma omp for schedule({schedule})"
        if collapse:
            prefix += f" collapse({collapse})"
        loop_to_parallelize.prefix_lines.append(prefix)


def add_pragmas(ast_node, pragma_lines, nesting_depth=-1):
    """Prepends given pragma lines to all loops of specified nesting depth.
    
    Args:
        ast: pystencils abstract syntax tree
        pragma_lines: Iterable of strings containing the pragma lines
        nesting_depth: Nesting depth of the loops the pragmas should be applied to.
                       Outermost loop has depth 0.
                       A depth of -1 indicates the innermost loops.
    """
    loop_nodes = iterate_loops_by_depth(ast_node, nesting_depth)
    for n in loop_nodes:
        n.prefix_lines += list(pragma_lines)