Skip to content
Snippets Groups Projects
Commit 0ff425f3 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'KernelWrapper' into 'master'

Kernel wrapper

See merge request !61
parents 9344cb42 aff0c6b8
Branches
No related tags found
No related merge requests found
import os
from collections import Hashable
from collections.abc import Hashable
from functools import partial
from itertools import chain
......
......@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
......@@ -482,16 +483,6 @@ class ExtensionModuleCode:
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
def __call__(self, **kwargs):
return self.kernel(**kwargs)
def compile_module(code, code_hash, base_dir):
compiler_config = get_compiler_config()
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
......
......@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import sympy as sp
from pystencils.astnodes import KernelFunction
from pystencils.kernel_wrapper import KernelWrapper
def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
......@@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
"""
from pystencils.backends.cbackend import generate_c
if isinstance(ast, KernelWrapper):
ast = ast.ast
dialect = 'cuda' if ast.backend == 'gpucuda' else 'c'
class CodeDisplay:
......
......@@ -6,6 +6,7 @@ from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True
......@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
func(*args, **block_and_thread_numbers)
# import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node
wrapper.parameters = kernel_function_node.get_parameters()
ast = kernel_function_node
parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs
return wrapper
......
"""
Light-weight wrapper around a compiled kernel
"""
import pystencils
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
self.num_regs = None
def __call__(self, **kwargs):
return self.kernel(**kwargs)
@property
def code(self):
return str(pystencils.show_code(self.ast))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment