Skip to content
Snippets Groups Projects

Fixed kernel_decorator with config parameter

Closed Jan Hönig requested to merge hoenig/pystencils:master into master
Compare and
2 files
+ 44
28
Preferences
Compare changes
Files
2
@@ -7,11 +7,12 @@ import sympy as sp
from pystencils.assignment import Assignment
from pystencils.sympyextensions import SymbolCreator
from pystencils.kernelcreation import CreateKernelConfig
__all__ = ['kernel']
def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) -> Union[List[Assignment], Dict]:
def kernel(config: CreateKernelConfig = None, **kwargs) -> Callable[..., Union[List[Assignment], Dict]]:
"""Decorator to simplify generation of pystencils Assignments.
Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
@@ -21,8 +22,8 @@ def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) ->
The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
a SymbolCreator()
func: the decorated function
return_config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
config: Specify whether to return the list with assignments, or a dictionary containing additional settings
like func_name
Examples:
>>> import pystencils as ps
@@ -34,31 +35,34 @@ def kernel(func: Callable[..., None], return_config: bool = False, **kwargs) ->
>>> f, g = ps.fields('f, g: [2D]')
>>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
"""
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
if return_config:
return {'assignments': assignments, 'function_name': func.__name__}
else:
return assignments
def decorator(func: Callable[..., None]) -> Union[List[Assignment], Dict]:
source = inspect.getsource(func)
source = textwrap.dedent(source)
a = ast.parse(source)
KernelFunctionRewrite().visit(a)
ast.fix_missing_locations(a)
gl = func.__globals__.copy()
assignments = []
def assignment_adder(lhs, rhs):
assignments.append(Assignment(lhs, rhs))
gl['_add_assignment'] = assignment_adder
gl['_Piecewise'] = sp.Piecewise
gl.update(inspect.getclosurevars(func).nonlocals)
exec(compile(a, filename="<ast>", mode="exec"), gl)
func = gl[func.__name__]
args = inspect.getfullargspec(func).args
if 's' in args and 's' not in kwargs:
kwargs['s'] = SymbolCreator()
func(**kwargs)
if config:
config.function_name = func.__name__
return {'assignments': assignments, 'config': config}
else:
return assignments
return decorator
# noinspection PyMethodMayBeStatic