Skip to content
Snippets Groups Projects

Add autodiff

Closed Stephan Seitz requested to merge seitz/pystencils:autodiff into master
1 unresolved thread
1 file
+ 29
5
Compare changes
  • Side-by-side
  • Inline
  • Add kernelcreation.make_python_function (generic version of {cuda,cpu,llvm}jit.make_python_function)
    
    This is roughly equivalent to kernel_function_node.compile(backend=...)
    but it is easier to find and inspect.
    Since all backend have `make_python_function` it makes sense to add a
    generic one.
import itertools
from types import MappingProxyType
import sympy as sp
import itertools
import pystencils
from pystencils.assignment import Assignment
from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAssignment
from pystencils.astnodes import (Block, Conditional, LoopOverCoordinate,
SympyAssignment)
from pystencils.cpu.vectorization import vectorize
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.gpucuda.indexing import indexing_creator_from_params
from pystencils.transformations import remove_conditionals_in_staggered_kernel, loop_blocking, \
move_constants_before_loop
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.transformations import (
loop_blocking, move_constants_before_loop,
remove_conditionals_in_staggered_kernel)
def create_kernel(assignments, target='cpu', data_type="double", iteration_slice=None, ghost_layers=None,
@@ -265,3 +270,22 @@ def create_staggered_kernel(staggered_field, expressions, subexpressions=(), tar
elif isinstance(cpu_vectorize_info, dict):
vectorize(ast, **cpu_vectorize_info)
return ast
def make_python_function(kernel_function_node, target='cpu', argument_dict=None):
"""
A generic version of the {cuda,cpu,llvm}jit.make_python_function
Less confusing than kernel_function_node.compile(backend=...)
"""
if target == 'cpu':
kernel = pystencils.cpu.cpujit.make_python_function(kernel_function_node)
elif target == 'gpu':
kernel = pystencils.gpucuda.cudajit.make_python_function(kernel_function_node, argument_dict)
elif target == 'llvm':
kernel = pystencils.llvm.llvmjit.make_python_function(kernel_function_node, argument_dict)
else:
raise NotImplementedError('Unsupported target for make_python_function: %s' % target)
return kernel
Loading