Skip to content
Snippets Groups Projects
benchmark.py 6.97 KiB
from typing import Union, List
from collections import namedtuple
from pathlib import Path
from jinja2 import Environment, PackageLoader, StrictUndefined

import numpy as np

from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.astnodes import KernelFunction, PragmaBlock
from pystencils.enums import Backend
from pystencils.typing import get_base_type
from pystencils.sympyextensions import prod
from pystencils.integer_functions import modulo_ceil

from pystencils_benchmark.enums import Compiler


_env = Environment(loader=PackageLoader('pystencils_benchmark'), undefined=StrictUndefined, keep_trailing_newline=True,
                   trim_blocks=True, lstrip_blocks=True)


def generate_benchmark(kernel_asts: Union[KernelFunction, List[KernelFunction]],
                       path: Path = None,
                       *,
                       compiler: Compiler = Compiler.GCC,
                       timing: bool = True,
                       likwid: bool = False
                       ) -> None:
    if path is None:
        path = Path('.')
    else:
        path.mkdir(parents=True, exist_ok=True)
    src_path = path / 'src'
    src_path.mkdir(parents=True, exist_ok=True)
    include_path = path / 'include'
    include_path.mkdir(parents=True, exist_ok=True)

    if isinstance(kernel_asts, KernelFunction):
        kernel_asts = [kernel_asts]

    for kernel_ast in kernel_asts:
        kernel_name = kernel_ast.function_name

        header = kernel_header(kernel_ast)
        with open(include_path / f'{kernel_name}.h', 'w+') as f:
            f.write(header)

        source = kernel_source(kernel_ast)
        with open(src_path / f'{kernel_name}.c', 'w+') as f:
            f.write(source)

    with open(src_path / 'main.c', 'w+') as f:
        f.write(kernel_main(kernel_asts, timing=timing, likwid=likwid))

    copy_static_files(path)
    compiler_toolchain(path, compiler, likwid)


def compiler_toolchain(path: Path, compiler: Compiler, likwid: bool) -> None:
    name = compiler.name
    jinja_context = {
        'compiler': name,
        'likwid': likwid,
    }

    files = ['Makefile', f'{name}.mk']
    for file_name in files:
        with open(path / file_name, 'w+') as f:
            template = _env.get_template(file_name).render(**jinja_context)
            f.write(template)


def copy_static_files(path: Path) -> None:
    src_path = path / 'src'
    src_path.mkdir(parents=True, exist_ok=True)
    include_path = path / 'include'
    include_path.mkdir(parents=True, exist_ok=True)

    files = ['timing.h', 'timing.c']
    for file_name in files:
        template = _env.get_template(file_name).render()
        if file_name[-1] == 'h':
            target_path = include_path / file_name
        elif file_name[-1] == 'c':
            target_path = src_path / file_name
        else:
            target_path = path / file_name
        with open(target_path, 'w+') as f:
            f.write(template)


def kernel_main(kernels_ast: List[KernelFunction], *,
                timing: bool = True, likwid: bool = False) -> str:
    """
    Return C code of a benchmark program for the given kernel.

    Args:
        kernels_ast: A list of the pystencils AST object as returned by create_kernel for benchmarking
        timing: add timing output to the code, prints time per iteration to stdout
        likwid: add likwid marker to the code
    Returns:
        C code as string
    """
    Kernel = namedtuple('Kernel', ['name', 'constants', 'fields', 'call_parameters', 'call_argument_list', 'openmp'])
    kernels = []
    includes = set()
    for kernel in kernels_ast:
        name = kernel.function_name
        accessed_fields = {f.name: f for f in kernel.fields_accessed}
        constants = []
        fields = []
        call_parameters = []
        # TODO: Think about it maybe there is a better way to detect openmp
        openmp = isinstance(kernel.body.args[0], PragmaBlock)
        for p in kernel.get_parameters():
            if not p.is_field_parameter:
                constants.append((p.symbol.name, str(p.symbol.dtype)))
                call_parameters.append(p.symbol.name)
            else:
                assert p.is_field_pointer, "Benchmark implemented only for kernels with fixed loop size"
                field = accessed_fields[p.field_name]
                dtype = str(get_base_type(p.symbol.dtype))
                np_dtype = get_base_type(p.symbol.dtype).numpy_dtype
                size_data_type = np_dtype.itemsize

                # TODO double check the size computation
                dim0_size = field.shape[-1]
                dim1_size = np.prod(field.shape[:-1])
                elements = prod(field.shape)

                if kernel.instruction_set:
                    align = kernel.instruction_set['width'] * size_data_type
                    padding_elements = modulo_ceil(dim0_size, kernel.instruction_set['width']) - dim0_size
                    padding_bytes = padding_elements * size_data_type
                    ghost_layers = max(max(kernel.ghost_layers))

                    size = dim1_size * padding_bytes + np.prod(field.shape) * size_data_type

                    assert align % np_dtype.itemsize == 0
                    offset = ((dim0_size + padding_elements + ghost_layers) %
                              kernel.instruction_set['width']) * size_data_type
                else:
                    size = elements * size_data_type
                    offset = 0
                    align = 0
                fields.append((p.field_name, dtype, elements, size, offset, align))
                call_parameters.append(p.field_name)

        # TODO: Think about openmp detection again
        kernels.append(Kernel(name=name, fields=fields, constants=constants, call_parameters=call_parameters,
                              call_argument_list=",".join(call_parameters), openmp=openmp))

        includes.add(name)

    jinja_context = {
        'kernels': kernels,
        'includes': includes,
        'timing': timing,
        'likwid': likwid,
    }

    main = _env.get_template('cpu/main.c').render(**jinja_context)
    return main


def kernel_header(kernel_ast: KernelFunction, dialect: Backend = Backend.C) -> str:
    function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
    header_guard = f'_{kernel_ast.function_name.upper()}_H'

    jinja_context = {
        'header_guard': header_guard,
        'function_signature': function_signature,
    }

    header = _env.get_template('cpu/kernel.h').render(**jinja_context)
    return header


def kernel_source(kernel_ast: KernelFunction, dialect: Backend = Backend.C) -> str:
    kernel_name = kernel_ast.function_name
    function_source = generate_c(kernel_ast, dialect=dialect)
    headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
    headers.update(get_headers(kernel_ast))

    jinja_context = {
        'function_source': function_source,
        'headers': sorted(headers),
        'timing': True,
    }

    source = _env.get_template('cpu/kernel.c').render(**jinja_context)
    return source