diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 343dd1036e671447386d6eac660bde584080da49..3f3b0fa39be4fc27c54456714c9e649eb3c58ec8 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -13,9 +13,9 @@ import sys from collections.abc import Iterable from os.path import dirname, exists, join +import pystencils from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.cpu.cpujit import get_cache_config, get_compiler_config -from pystencils.gpucuda.cudajit import get_cubic_interpolation_include_paths from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils_autodiff._file_io import read_template_from_file, write_file from pystencils_autodiff.backends.python_bindings import ( @@ -26,6 +26,11 @@ from pystencils_autodiff.framework_integration.astnodes import ( from pystencils_autodiff.tensorflow_jit import _hash +def get_cubic_interpolation_include_paths(): + return [join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code'), + join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code', 'internal')] + + class TorchTensorDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "data_ptr<{dtype}>()",