Skip to content
Snippets Groups Projects
Commit 865d2274 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add get_cubic_interpolation_include_paths

parent 9d500fb9
No related branches found
No related tags found
No related merge requests found
...@@ -13,9 +13,9 @@ import sys ...@@ -13,9 +13,9 @@ import sys
from collections.abc import Iterable from collections.abc import Iterable
from os.path import dirname, exists, join from os.path import dirname, exists, join
import pystencils
from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.cpu.cpujit import get_cache_config, get_compiler_config from pystencils.cpu.cpujit import get_cache_config, get_compiler_config
from pystencils.gpucuda.cudajit import get_cubic_interpolation_include_paths
from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
from pystencils_autodiff._file_io import read_template_from_file, write_file from pystencils_autodiff._file_io import read_template_from_file, write_file
from pystencils_autodiff.backends.python_bindings import ( from pystencils_autodiff.backends.python_bindings import (
...@@ -26,6 +26,11 @@ from pystencils_autodiff.framework_integration.astnodes import ( ...@@ -26,6 +26,11 @@ from pystencils_autodiff.framework_integration.astnodes import (
from pystencils_autodiff.tensorflow_jit import _hash from pystencils_autodiff.tensorflow_jit import _hash
def get_cubic_interpolation_include_paths():
return [join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code'),
join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code', 'internal')]
class TorchTensorDestructuring(DestructuringBindingsForFieldClass): class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = { CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data_ptr<{dtype}>()", FieldPointerSymbol: "data_ptr<{dtype}>()",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment