diff --git a/pystencils/__init__.py b/pystencils/__init__.py index a7e21703b48c7398da8be6f609f0f3123f7822d1..273cebf01705d4fec464ac9c96e43af9e0e4321d 100644 --- a/pystencils/__init__.py +++ b/pystencils/__init__.py @@ -5,18 +5,18 @@ from . import stencil as stencil from .assignment import Assignment, assignment_from_stencil from .data_types import TypedSymbol from .datahandling import create_data_handling +from .slicing import make_slice +from .kernelcreation import create_kernel, create_indexed_kernel, create_staggered_kernel, make_python_function from .display_utils import show_code, to_dot from .field import Field, FieldType, fields from .kernel_decorator import kernel -from .kernelcreation import create_indexed_kernel, create_kernel, create_staggered_kernel from .simp import AssignmentCollection -from .slicing import make_slice from .sympyextensions import SymbolCreator __all__ = ['Field', 'FieldType', 'fields', 'TypedSymbol', 'make_slice', - 'create_kernel', 'create_indexed_kernel', 'create_staggered_kernel', + 'create_kernel', 'create_indexed_kernel', 'create_staggered_kernel', 'make_python_function', 'show_code', 'to_dot', 'AssignmentCollection', 'Assignment', diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 441de76aa9b2c7a771ceba7ba4010e67023ac9e7..71ce9016fb15bd623e45f19c42607071422dc4b9 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Set, Union import jinja2 import sympy as sp +import pystencils from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol @@ -246,6 +247,8 @@ class KernelFunction(Node): if self._compile_function is None: raise ValueError("No compile-function provided for this KernelFunction node") return self._compile_function(self, *args, **kwargs) + """ + return pystencils.make_python_function(self, backend=self.backend) class SkipIteration(Node): diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index ff82107000b00595dbcca6699ed42b172c324353..ec5ad24b6f26126e9dd8d11b4a81d96df1d96e15 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -1,6 +1,6 @@ +from pystencils import Field, FieldType from pystencils.astnodes import Block, KernelFunction, LoopOverCoordinate, SympyAssignment from pystencils.data_types import BasicType, StructType, TypedSymbol -from pystencils.field import Field, FieldType from pystencils.gpucuda.cudajit import make_python_function from pystencils.gpucuda.indexing import BlockIndexing from pystencils.transformations import ( diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index ade980f55969a4005f2c5055ad27f861ab5524de..434b8dfb87bb73d45c2529e683f4f3a58a95e75e 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -3,6 +3,7 @@ from types import MappingProxyType import sympy as sp +import pystencils from pystencils.assignment import Assignment from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment from pystencils.cpu.vectorization import vectorize @@ -267,3 +268,20 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar elif isinstance(cpu_vectorize_info, dict): vectorize(ast, **cpu_vectorize_info) return ast + + +def make_python_function(kernel_function_node, backend='cpu', argument_dict=None): + """ + A generic version of the {cuda,cpu,llvm}jit.make_python_function + + """ + if backend == 'cpu' or not backend: + kernel = pystencils.cpu.cpujit.make_python_function(kernel_function_node) + elif backend == 'gpu' or backend == 'gpucuda': + kernel = pystencils.gpucuda.cudajit.make_python_function(kernel_function_node, argument_dict) + elif backend == 'llvm': + kernel = pystencils.llvm.llvmjit.make_python_function(kernel_function_node, argument_dict) + else: + raise NotImplementedError('Unsupported target for make_python_function %s' % backend) + + return kernel