From 865d22744e00a91954beb9d9cefa3b933699b677 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Tue, 21 Jan 2020 09:29:13 +0100 Subject: [PATCH] Add get_cubic_interpolation_include_paths --- src/pystencils_autodiff/backends/astnodes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 343dd10..3f3b0fa 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -13,9 +13,9 @@ import sys from collections.abc import Iterable from os.path import dirname, exists, join +import pystencils from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.cpu.cpujit import get_cache_config, get_compiler_config -from pystencils.gpucuda.cudajit import get_cubic_interpolation_include_paths from pystencils.include import get_pycuda_include_path, get_pystencils_include_path from pystencils_autodiff._file_io import read_template_from_file, write_file from pystencils_autodiff.backends.python_bindings import ( @@ -26,6 +26,11 @@ from pystencils_autodiff.framework_integration.astnodes import ( from pystencils_autodiff.tensorflow_jit import _hash +def get_cubic_interpolation_include_paths(): + return [join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code'), + join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code', 'internal')] + + class TorchTensorDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "data_ptr<{dtype}>()", -- GitLab