Skip to content
Snippets Groups Projects
Commit d21f6f8c authored by Jan Hönig's avatar Jan Hönig
Browse files

Fixed kernel_decorator with config parameter

parent 157c4f11
No related branches found
No related tags found
2 merge requests!271Kernel decorator fix,!270Fixed kernel_decorator with config parameter
......@@ -7,11 +7,12 @@ import sympy as sp
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
from pystencils.kernelcreation import CreateKernelConfig
__all__ = ['kernel']
def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) -> Union[List[Assignment], Dict]:
def kernel(config: CreateKernelConfig = None, **kwargs) -> Callable[..., Union[List[Assignment], Dict]]:
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
......@@ -21,8 +22,8 @@ def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) ->
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
func: the decorated function
return_config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Examples:
>>> import pystencils as ps
......@@ -34,31 +35,34 @@ def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) ->
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
if return_config:
return {'assignments': assignments, 'function_name': func.__name__}
else:
return assignments
def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
if config:
config.function_name = func.__name__
return {'assignments': assignments, 'config': config}
else:
return assignments
return decorator
# noinspection PyMethodMayBeStatic
......
import numpy as np
import pystencils as ps
......@@ -15,3 +16,14 @@ def test_create_kernel_config():
c = ps.CreateKernelConfig(backend=ps.Backend.CUDA)
assert c.target == ps.Target.CPU
assert c.backend == ps.Backend.CUDA
def test_kernel_decorator_config():
config = ps.CreateKernelConfig()
a, b, c = ps.fields(a=np.ones(100), b=np.ones(100), c=np.ones(100))
@ps.kernel(config)
def test():
a[0] @= b[0] + c[0]
ps.create_kernel(**test)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment