diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index 915d5601b5fe7bc05b9833c21da62915b22552e5..cfef01e9052a35db0af0019cca830ac6ae7db1c7 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -8,6 +8,7 @@ from pystencils.astnodes import Node from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.data_types import cast_func, create_type, get_type_of_expression from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.interpolation_astnodes import InterpolationMode with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: lines = f.readlines() @@ -77,7 +78,7 @@ class CudaSympyPrinter(CustomSympyPrinter): def _print_TextureAccess(self, node): dtype = node.texture.field.dtype.numpy_dtype - if node.texture.cubic_bspline_interpolation: + if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple<%s>(%s, %s)" else: if dtype.itemsize > 4: diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index ee78b4663c4f40ebf3de0f96ebd405046a75de3c..8a403fd76828620e471dc07161a33faf1a5f0c6b 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -7,7 +7,7 @@ from pystencils.data_types import StructType from pystencils.field import FieldType from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.include import get_pystencils_include_path -from pystencils.interpolation_astnodes import TextureAccess +from pystencils.interpolation_astnodes import InterpolationMode, TextureAccess from pystencils.kernelparameters import FieldPointerSymbol USE_FAST_MATH = True @@ -46,7 +46,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen if USE_FAST_MATH: nvcc_options.append("-use_fast_math") - if any(t.cubic_bspline_interpolation for t in textures): + if any(t.interpolation_mode == InterpolationMode.CUBIC_SPLINE for t in textures): assert isdir(join(dirname(__file__), "CubicInterpolationCUDA", "code")), \ "Submodule CubicInterpolationCUDA does not exist" nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")]