From 865d22744e00a91954beb9d9cefa3b933699b677 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 21 Jan 2020 09:29:13 +0100
Subject: [PATCH] Add get_cubic_interpolation_include_paths

---
 src/pystencils_autodiff/backends/astnodes.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py
index 343dd10..3f3b0fa 100644
--- a/src/pystencils_autodiff/backends/astnodes.py
+++ b/src/pystencils_autodiff/backends/astnodes.py
@@ -13,9 +13,9 @@ import sys
 from collections.abc import Iterable
 from os.path import dirname, exists, join
 
+import pystencils
 from pystencils.astnodes import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
 from pystencils.cpu.cpujit import get_cache_config, get_compiler_config
-from pystencils.gpucuda.cudajit import get_cubic_interpolation_include_paths
 from pystencils.include import get_pycuda_include_path, get_pystencils_include_path
 from pystencils_autodiff._file_io import read_template_from_file, write_file
 from pystencils_autodiff.backends.python_bindings import (
@@ -26,6 +26,11 @@ from pystencils_autodiff.framework_integration.astnodes import (
 from pystencils_autodiff.tensorflow_jit import _hash
 
 
+def get_cubic_interpolation_include_paths():
+    return [join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code'),
+            join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code', 'internal')]
+
+
 class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
     CLASS_TO_MEMBER_DICT = {
         FieldPointerSymbol: "data_ptr<{dtype}>()",
-- 
GitLab