Skip to content
Snippets Groups Projects
Select Git revision
  • 5ee0ec34770f83c622be08e0cb9709f9f4e9e22b
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

test_jacobi_cbackend.py

Blame
  • kernel_decorator.py 2.72 KiB
    import ast
    import inspect
    import sympy as sp
    import textwrap
    from pystencils.sympyextensions import SymbolCreator
    from pystencils.assignment import Assignment
    
    __all__ = ['kernel']
    
    
    def kernel(func, **kwargs):
        """Decorator to simplify generation of pystencils Assignments.
    
        Changes the meaning of the '@=' operator. Each line containing this operator gives a symbolic assignment
        in the result list. Furthermore the meaning of the ternary inline 'if-else' changes meaning to denote a
        sympy Piecewise.
    
        The decorated function may not receive any arguments, with exception of an argument called 's' that specifies
        a SymbolCreator()
    
        Examples:
            >>> import pystencils as ps
            >>> @kernel
            ... def my_kernel(s):
            ...     f, g = ps.fields('f, g: [2D]')
            ...     s.neighbors @= f[0,1] + f[1,0]
            ...     g[0,0]      @= s.neighbors + f[0,0] if f[0,0] > 0 else 0
            >>> f, g = ps.fields('f, g: [2D]')
            >>> assert my_kernel[0].rhs == f[0,1] + f[1,0]
        """
        source = inspect.getsource(func)
        source = textwrap.dedent(source)
        a = ast.parse(source)
        KernelFunctionRewrite().visit(a)
        ast.fix_missing_locations(a)
        gl = func.__globals__.copy()
    
        assignments = []
    
        def assignment_adder(lhs, rhs):
            assignments.append(Assignment(lhs, rhs))
    
        gl['_add_assignment'] = assignment_adder
        gl['_Piecewise'] = sp.Piecewise
        gl.update(inspect.getclosurevars(func).nonlocals)
        exec(compile(a, filename="<ast>", mode="exec"), gl)
        func = gl[func.__name__]
        args = inspect.getfullargspec(func).args
        if 's' in args and 's' not in kwargs:
            kwargs['s'] = SymbolCreator()
        func(**kwargs)
        return assignments
    
    
    # noinspection PyMethodMayBeStatic
    class KernelFunctionRewrite(ast.NodeTransformer):
    
        def visit_IfExp(self, node):
            piecewise_func = ast.Name(id='_Piecewise', ctx=ast.Load())
            piecewise_func = ast.copy_location(piecewise_func, node)
            piecewise_args = [ast.Tuple(elts=[node.body, node.test], ctx=ast.Load()),
                              ast.Tuple(elts=[node.orelse, ast.NameConstant(value=True)], ctx=ast.Load())]
            result = ast.Call(func=piecewise_func, args=piecewise_args, keywords=[])
    
            return ast.copy_location(result, node)
    
        def visit_AugAssign(self, node):
            self.generic_visit(node)
            node.target.ctx = ast.Load()
            new_node = ast.Expr(ast.Call(func=ast.Name(id='_add_assignment', ctx=ast.Load()),
                                         args=[node.target, node.value],
                                         keywords=[]))
            return ast.copy_location(new_node, node)
    
        def visit_FunctionDef(self, node):
            self.generic_visit(node)
            node.decorator_list = []
            return node