From 3e930f35565af4cd669e63fff149dfce8297fc04 Mon Sep 17 00:00:00 2001
From: Christoph Alt <christoph.alt@fau.de>
Date: Tue, 8 Aug 2023 13:41:53 +0200
Subject: [PATCH] removed some code duplication between benchmark and
 benchmark_gpu move shared function to a extra file

---
 pystencils_benchmark/benchmark.py     | 80 +++-------------------
 pystencils_benchmark/benchmark_gpu.py | 96 ++++++--------------------
 pystencils_benchmark/common.py        | 97 +++++++++++++++++++++++++++
 3 files changed, 126 insertions(+), 147 deletions(-)
 create mode 100644 pystencils_benchmark/common.py

diff --git a/pystencils_benchmark/benchmark.py b/pystencils_benchmark/benchmark.py
index 0cc7b11..4258ee5 100644
--- a/pystencils_benchmark/benchmark.py
+++ b/pystencils_benchmark/benchmark.py
@@ -1,24 +1,24 @@
 from typing import Union, List
 from collections import namedtuple
 from pathlib import Path
-from jinja2 import Environment, PackageLoader, StrictUndefined
 
 import numpy as np
 
-from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.astnodes import KernelFunction, PragmaBlock
 from pystencils.enums import Backend
 from pystencils.typing import get_base_type
 from pystencils.sympyextensions import prod
 from pystencils.integer_functions import modulo_ceil
 
+from pystencils_benchmark.common import (_env,
+                                         _kernel_source,
+                                         _kernel_header,
+                                         compiler_toolchain,
+                                         copy_static_files,
+                                         setup_directories)
 from pystencils_benchmark.enums import Compiler
 
 
-_env = Environment(loader=PackageLoader('pystencils_benchmark'), undefined=StrictUndefined, keep_trailing_newline=True,
-                   trim_blocks=True, lstrip_blocks=True)
-
-
 def generate_benchmark(kernel_asts: Union[KernelFunction, List[KernelFunction]],
                        path: Path = None,
                        *,
@@ -26,14 +26,8 @@ def generate_benchmark(kernel_asts: Union[KernelFunction, List[KernelFunction]],
                        timing: bool = True,
                        likwid: bool = False
                        ) -> None:
-    if path is None:
-        path = Path('.')
-    else:
-        path.mkdir(parents=True, exist_ok=True)
-    src_path = path / 'src'
-    src_path.mkdir(parents=True, exist_ok=True)
-    include_path = path / 'include'
-    include_path.mkdir(parents=True, exist_ok=True)
+
+    src_path, include_path = setup_directories(path)
 
     if isinstance(kernel_asts, KernelFunction):
         kernel_asts = [kernel_asts]
@@ -56,39 +50,6 @@ def generate_benchmark(kernel_asts: Union[KernelFunction, List[KernelFunction]],
     compiler_toolchain(path, compiler, likwid)
 
 
-def compiler_toolchain(path: Path, compiler: Compiler, likwid: bool) -> None:
-    name = compiler.name
-    jinja_context = {
-        'compiler': name,
-        'likwid': likwid,
-    }
-
-    files = ['Makefile', f'{name}.mk']
-    for file_name in files:
-        with open(path / file_name, 'w+') as f:
-            template = _env.get_template(file_name).render(**jinja_context)
-            f.write(template)
-
-
-def copy_static_files(path: Path) -> None:
-    src_path = path / 'src'
-    src_path.mkdir(parents=True, exist_ok=True)
-    include_path = path / 'include'
-    include_path.mkdir(parents=True, exist_ok=True)
-
-    files = ['timing.h', 'timing.c']
-    for file_name in files:
-        template = _env.get_template(file_name).render()
-        if file_name[-1] == 'h':
-            target_path = include_path / file_name
-        elif file_name[-1] == 'c':
-            target_path = src_path / file_name
-        else:
-            target_path = path / file_name
-        with open(target_path, 'w+') as f:
-            f.write(template)
-
-
 def kernel_main(kernels_ast: List[KernelFunction], *,
                 timing: bool = True, likwid: bool = False) -> str:
     """
@@ -164,29 +125,8 @@ def kernel_main(kernels_ast: List[KernelFunction], *,
 
 
 def kernel_header(kernel_ast: KernelFunction, dialect: Backend = Backend.C) -> str:
-    function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
-    header_guard = f'_{kernel_ast.function_name.upper()}_H'
-
-    jinja_context = {
-        'header_guard': header_guard,
-        'function_signature': function_signature,
-    }
-
-    header = _env.get_template('cpu/kernel.h').render(**jinja_context)
-    return header
+    return _kernel_header(kernel_ast, dialect=dialect, template_file='cpu/kernel.h')
 
 
 def kernel_source(kernel_ast: KernelFunction, dialect: Backend = Backend.C) -> str:
-    kernel_name = kernel_ast.function_name
-    function_source = generate_c(kernel_ast, dialect=dialect)
-    headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
-    headers.update(get_headers(kernel_ast))
-
-    jinja_context = {
-        'function_source': function_source,
-        'headers': sorted(headers),
-        'timing': True,
-    }
-
-    source = _env.get_template('cpu/kernel.c').render(**jinja_context)
-    return source
+    return _kernel_source(kernel_ast, dialect=dialect, template_file='cpu/kernel.c')
diff --git a/pystencils_benchmark/benchmark_gpu.py b/pystencils_benchmark/benchmark_gpu.py
index befd83d..d0ccbe1 100644
--- a/pystencils_benchmark/benchmark_gpu.py
+++ b/pystencils_benchmark/benchmark_gpu.py
@@ -1,37 +1,31 @@
 from typing import Union, List
 from collections import namedtuple
 from pathlib import Path
-from jinja2 import Environment, PackageLoader, StrictUndefined
 
-from pystencils.backends.cbackend import generate_c, get_headers
 from pystencils.astnodes import KernelFunction
 from pystencils.enums import Backend
 from pystencils.typing import get_base_type
 from pystencils.sympyextensions import prod
 from pystencils.transformations import get_common_field
-# from pystencils.gpucuda import BlockIndexing
 
+from pystencils_benchmark.common import (_env,
+                                         _kernel_source,
+                                         _kernel_header,
+                                         compiler_toolchain,
+                                         copy_static_files,
+                                         setup_directories)
 from pystencils_benchmark.enums import Compiler
 
-_env = Environment(loader=PackageLoader('pystencils_benchmark'), undefined=StrictUndefined, keep_trailing_newline=True,
-                   trim_blocks=True, lstrip_blocks=True)
-
 
 def generate_benchmark_gpu(kernel_asts: Union[KernelFunction, List[KernelFunction]],
                            path: Path = None,
                            *,
-                           compiler: Compiler = Compiler.GCC,
+                           compiler: Compiler = Compiler.NVCC,
                            timing: bool = True,
                            cuda_block_size: tuple = (32, 1, 1)
                            ) -> None:
-    if path is None:
-        path = Path('.')
-    else:
-        path.mkdir(parents=True, exist_ok=True)
-    src_path = path / 'src'
-    src_path.mkdir(parents=True, exist_ok=True)
-    include_path = path / 'include'
-    include_path.mkdir(parents=True, exist_ok=True)
+
+    src_path, include_path = setup_directories(path)
 
     if isinstance(kernel_asts, KernelFunction):
         kernel_asts = [kernel_asts]
@@ -53,43 +47,8 @@ def generate_benchmark_gpu(kernel_asts: Union[KernelFunction, List[KernelFunctio
                             timing=timing,
                             cuda_block_size=cuda_block_size))
 
-    copy_static_files(path)
-    compiler_toolchain(path, compiler)
-
-
-def compiler_toolchain(path: Path, compiler: Compiler) -> None:
-    name = compiler.name
-    jinja_context = {
-        'compiler': name,
-        'likwid': False,
-    }
-
-    files = ['Makefile', f'{name}.mk']
-    for file_name in files:
-        with open(path / file_name, 'w+') as f:
-            template = _env.get_template(file_name).render(**jinja_context)
-            f.write(template)
-
-
-def copy_static_files(path: Path) -> None:
-    src_path = path / 'src'
-    src_path.mkdir(parents=True, exist_ok=True)
-    include_path = path / 'include'
-    include_path.mkdir(parents=True, exist_ok=True)
-
-    files = ['timing.h', 'timing.c']
-    for file_name in files:
-        template = _env.get_template(file_name).render()
-        if file_name[-1] == 'h':
-            target_path = include_path / file_name
-        elif file_name[-1] == 'c':
-            target_path = src_path / file_name
-            # TODO CUDA specific suffix:
-            target_path = target_path.with_suffix('.cu')
-        else:
-            target_path = path / file_name
-        with open(target_path, 'w+') as f:
-            f.write(template)
+    copy_static_files(path, source_file_suffix='.cu')
+    compiler_toolchain(path, compiler, likwid=False)
 
 
 def kernel_main(kernels_ast: List[KernelFunction], *, timing: bool = True, cuda_block_size: tuple):
@@ -149,31 +108,14 @@ def kernel_main(kernels_ast: List[KernelFunction], *, timing: bool = True, cuda_
 
 
 def kernel_header(kernel_ast: KernelFunction, dialect: Backend = Backend.C) -> str:
-    function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
-    header_guard = f'_{kernel_ast.function_name.upper()}_H'
-
-    jinja_context = {
-        'header_guard': header_guard,
-        'function_signature': function_signature,
-        'target': 'gpu'
-    }
-
-    header = _env.get_template('gpu/kernel.h').render(**jinja_context)
-    return header
+    return _kernel_header(kernel_ast,
+                          dialect=dialect,
+                          template_file='gpu/kernel.h',
+                          additional_jinja_context={'target': 'gpu'})
 
 
 def kernel_source(kernel_ast: KernelFunction, dialect: Backend = Backend.C) -> str:
-    kernel_name = kernel_ast.function_name
-    function_source = generate_c(kernel_ast, dialect=dialect)
-    headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
-    headers.update(get_headers(kernel_ast))
-
-    jinja_context = {
-        'function_source': function_source,
-        'headers': sorted(headers),
-        'timing': True,
-        'target': 'gpu'
-    }
-
-    source = _env.get_template('gpu/kernel.cu').render(**jinja_context)
-    return source
+    return _kernel_source(kernel_ast,
+                          dialect=dialect,
+                          template_file='gpu/kernel.cu',
+                          additional_jinja_context={'target': 'gpu'})
diff --git a/pystencils_benchmark/common.py b/pystencils_benchmark/common.py
new file mode 100644
index 0000000..beeeed6
--- /dev/null
+++ b/pystencils_benchmark/common.py
@@ -0,0 +1,97 @@
+from pystencils.backends.cbackend import generate_c, get_headers
+from pystencils.astnodes import KernelFunction
+from pystencils.enums import Backend
+from jinja2 import Environment, PackageLoader, StrictUndefined
+
+from pystencils_benchmark.enums import Compiler
+from pathlib import Path
+
+_env = Environment(loader=PackageLoader('pystencils_benchmark'),
+                   undefined=StrictUndefined,
+                   keep_trailing_newline=True,
+                   trim_blocks=True, lstrip_blocks=True)
+
+
+def _kernel_header(kernel_ast: KernelFunction,
+                   dialect: Backend = Backend.C,
+                   *,
+                   template_file: str,
+                   additional_jinja_context: dict = {}) -> str:
+    function_signature = generate_c(kernel_ast, dialect=dialect, signature_only=True)
+    header_guard = f'_{kernel_ast.function_name.upper()}_H'
+
+    jinja_context = {
+        'header_guard': header_guard,
+        'function_signature': function_signature,
+        **additional_jinja_context
+    }
+
+    header = _env.get_template(template_file).render(**jinja_context)
+    return header
+
+
+def _kernel_source(kernel_ast: KernelFunction,
+                   dialect: Backend = Backend.C,
+                   *,
+                   template_file: str,
+                   additional_jinja_context: dict = {}) -> str:
+    kernel_name = kernel_ast.function_name
+    function_source = generate_c(kernel_ast, dialect=dialect)
+    headers = {f'"{kernel_name}.h"', '<math.h>', '<stdint.h>'}
+    headers.update(get_headers(kernel_ast))
+
+    jinja_context = {
+        'function_source': function_source,
+        'headers': sorted(headers),
+        'timing': True,
+        **additional_jinja_context,
+    }
+
+    source = _env.get_template(template_file).render(**jinja_context)
+    return source
+
+
+def compiler_toolchain(path: Path, compiler: Compiler, likwid: bool) -> None:
+    name = compiler.name
+    jinja_context = {
+        'compiler': name,
+        'likwid': likwid,
+    }
+
+    files = ['Makefile', f'{name}.mk']
+    for file_name in files:
+        with open(path / file_name, 'w+') as f:
+            template = _env.get_template(file_name).render(**jinja_context)
+            f.write(template)
+
+
+def copy_static_files(path: Path, *, source_file_suffix='.c') -> None:
+    src_path = path / 'src'
+    src_path.mkdir(parents=True, exist_ok=True)
+    include_path = path / 'include'
+    include_path.mkdir(parents=True, exist_ok=True)
+
+    files = ['timing.h', 'timing.c']
+    for file_name in files:
+        template = _env.get_template(file_name).render()
+        if file_name[-1] == 'h':
+            target_path = include_path / file_name
+        elif file_name[-1] == 'c':
+            target_path = src_path / file_name
+            target_path = target_path.with_suffix(source_file_suffix)
+        else:
+            target_path = path / file_name
+        with open(target_path, 'w+') as f:
+            f.write(template)
+
+
+def setup_directories(path: Path):
+    if path is None:
+        path = Path('.')
+    else:
+        path.mkdir(parents=True, exist_ok=True)
+    src_path = path / 'src'
+    src_path.mkdir(parents=True, exist_ok=True)
+    include_path = path / 'include'
+    include_path.mkdir(parents=True, exist_ok=True)
+    return src_path, include_path
-- 
GitLab