Skip to content
Snippets Groups Projects
Commit 0ff425f3 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'KernelWrapper' into 'master'

Kernel wrapper

See merge request pycodegen/pystencils!61
parents 9344cb42 aff0c6b8
No related branches found
No related tags found
No related merge requests found
import os import os
from collections import Hashable from collections.abc import Hashable
from functools import partial from functools import partial
from itertools import chain from itertools import chain
......
...@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir ...@@ -60,6 +60,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
...@@ -482,16 +483,6 @@ class ExtensionModuleCode: ...@@ -482,16 +483,6 @@ class ExtensionModuleCode:
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
def __call__(self, **kwargs):
return self.kernel(**kwargs)
def compile_module(code, code_hash, base_dir): def compile_module(code, code_hash, base_dir):
compiler_config = get_compiler_config() compiler_config = get_compiler_config()
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
......
...@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional ...@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import sympy as sp import sympy as sp
from pystencils.astnodes import KernelFunction from pystencils.astnodes import KernelFunction
from pystencils.kernel_wrapper import KernelWrapper
def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
...@@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None): ...@@ -40,6 +41,10 @@ def show_code(ast: KernelFunction, custom_backend=None):
Can either be displayed as HTML in Jupyter notebooks or printed as normal string. Can either be displayed as HTML in Jupyter notebooks or printed as normal string.
""" """
from pystencils.backends.cbackend import generate_c from pystencils.backends.cbackend import generate_c
if isinstance(ast, KernelWrapper):
ast = ast.ast
dialect = 'cuda' if ast.backend == 'gpucuda' else 'c' dialect = 'cuda' if ast.backend == 'gpucuda' else 'c'
class CodeDisplay: class CodeDisplay:
......
...@@ -6,6 +6,7 @@ from pystencils.field import FieldType ...@@ -6,6 +6,7 @@ from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess from pystencils.interpolation_astnodes import TextureAccess
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True USE_FAST_MATH = True
...@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen ...@@ -93,8 +94,9 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
func(*args, **block_and_thread_numbers) func(*args, **block_and_thread_numbers)
# import pycuda.driver as cuda # import pycuda.driver as cuda
# cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called # cuda.Context.synchronize() # useful for debugging, to get errors right after kernel was called
wrapper.ast = kernel_function_node ast = kernel_function_node
wrapper.parameters = kernel_function_node.get_parameters() parameters = kernel_function_node.get_parameters()
wrapper = KernelWrapper(wrapper, parameters, ast)
wrapper.num_regs = func.num_regs wrapper.num_regs = func.num_regs
return wrapper return wrapper
......
"""
Light-weight wrapper around a compiled kernel
"""
import pystencils
class KernelWrapper:
def __init__(self, kernel, parameters, ast_node):
self.kernel = kernel
self.parameters = parameters
self.ast = ast_node
self.num_regs = None
def __call__(self, **kwargs):
return self.kernel(**kwargs)
@property
def code(self):
return str(pystencils.show_code(self.ast))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment