From a7e61dc46d75c873223aa3d7a8ea8cd8644333b9 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 10 Jul 2019 19:43:16 +0200
Subject: [PATCH] Add `pystencils.make_python_function` used for
 KernelFunction.compile

`KernelFunction.compile = None` is currently set by the
`create_kernel` function of each respective backend as partial function
of `<backend>.make_python_function`.

The code would be clearer with a unified `make_python_function`.
`KernelFunction.compile` can then be implemented  as a call to this
function with the respective backend.
---
 pystencils/__init__.py               |  6 +++---
 pystencils/astnodes.py               |  3 +++
 pystencils/gpucuda/kernelcreation.py |  2 +-
 pystencils/kernelcreation.py         | 18 ++++++++++++++++++
 4 files changed, 25 insertions(+), 4 deletions(-)

diff --git a/pystencils/__init__.py b/pystencils/__init__.py
index a7e21703..273cebf0 100644
--- a/pystencils/__init__.py
+++ b/pystencils/__init__.py
@@ -5,18 +5,18 @@ from . import stencil as stencil
 from .assignment import Assignment, assignment_from_stencil
 from .data_types import TypedSymbol
 from .datahandling import create_data_handling
+from .slicing import make_slice
+from .kernelcreation import create_kernel, create_indexed_kernel, create_staggered_kernel, make_python_function
 from .display_utils import show_code, to_dot
 from .field import Field, FieldType, fields
 from .kernel_decorator import kernel
-from .kernelcreation import create_indexed_kernel, create_kernel, create_staggered_kernel
 from .simp import AssignmentCollection
-from .slicing import make_slice
 from .sympyextensions import SymbolCreator
 
 __all__ = ['Field', 'FieldType', 'fields',
            'TypedSymbol',
            'make_slice',
-           'create_kernel', 'create_indexed_kernel', 'create_staggered_kernel',
+           'create_kernel', 'create_indexed_kernel', 'create_staggered_kernel', 'make_python_function',
            'show_code', 'to_dot',
            'AssignmentCollection',
            'Assignment',
diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 441de76a..71ce9016 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
 import jinja2
 import sympy as sp
 
+import pystencils
 from pystencils.data_types import TypedSymbol, cast_func, create_type
 from pystencils.field import Field
 from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
@@ -246,6 +247,8 @@ class KernelFunction(Node):
         if self._compile_function is None:
             raise ValueError("No compile-function provided for this KernelFunction node")
         return self._compile_function(self, *args, **kwargs)
+        """
+        return pystencils.make_python_function(self, backend=self.backend)
 
 
 class SkipIteration(Node):
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index ff821070..ec5ad24b 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -1,6 +1,6 @@
+from pystencils import Field, FieldType
 from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment
 from pystencils.data_types import BasicType, StructType, TypedSymbol
-from pystencils.field import Field, FieldType
 from pystencils.gpucuda.cudajit import make_python_function
 from pystencils.gpucuda.indexing import BlockIndexing
 from pystencils.transformations import (
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index ade980f5..434b8dfb 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -3,6 +3,7 @@ from types import MappingProxyType
 
 import sympy as sp
 
+import pystencils
 from pystencils.assignment import Assignment
 from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
 from pystencils.cpu.vectorization import vectorize
@@ -267,3 +268,20 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
         elif isinstance(cpu_vectorize_info, dict):
             vectorize(ast, **cpu_vectorize_info)
     return ast
+
+
+def make_python_function(kernel_function_node, backend='cpu', argument_dict=None):
+    """
+    A generic version of the {cuda,cpu,llvm}jit.make_python_function
+
+    """
+    if backend == 'cpu' or not backend:
+        kernel = pystencils.cpu.cpujit.make_python_function(kernel_function_node)
+    elif backend == 'gpu' or backend == 'gpucuda':
+        kernel = pystencils.gpucuda.cudajit.make_python_function(kernel_function_node, argument_dict)
+    elif backend == 'llvm':
+        kernel = pystencils.llvm.llvmjit.make_python_function(kernel_function_node, argument_dict)
+    else:
+        raise NotImplementedError('Unsupported target for make_python_function %s' % backend)
+
+    return kernel
-- 
GitLab