diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py index 915d5601b5fe7bc05b9833c21da62915b22552e5..cfef01e9052a35db0af0019cca830ac6ae7db1c7 100644 --- a/pystencils/backends/cuda_backend.py +++ b/pystencils/backends/cuda_backend.py @@ -8,6 +8,7 @@ from pystencils.astnodes import Node from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.data_types import cast_func, create_type, get_type_of_expression from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.interpolation_astnodes import InterpolationMode with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: lines = f.readlines() @@ -77,7 +78,7 @@ class CudaSympyPrinter(CustomSympyPrinter): def _print_TextureAccess(self, node): dtype = node.texture.field.dtype.numpy_dtype - if node.texture.cubic_bspline_interpolation: + if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE: template = "cubicTex%iDSimple<%s>(%s, %s)" else: if dtype.itemsize > 4: diff --git a/pystencils/gpucuda/cudajit.py b/pystencils/gpucuda/cudajit.py index ee78b4663c4f40ebf3de0f96ebd405046a75de3c..8a403fd76828620e471dc07161a33faf1a5f0c6b 100644 --- a/pystencils/gpucuda/cudajit.py +++ b/pystencils/gpucuda/cudajit.py @@ -7,7 +7,7 @@ from pystencils.data_types import StructType from pystencils.field import FieldType from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.include import get_pystencils_include_path -from pystencils.interpolation_astnodes import TextureAccess +from pystencils.interpolation_astnodes import InterpolationMode, TextureAccess from pystencils.kernelparameters import FieldPointerSymbol USE_FAST_MATH = True @@ -46,7 +46,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen if USE_FAST_MATH: nvcc_options.append("-use_fast_math") - if any(t.cubic_bspline_interpolation for t in textures): + if any(t.interpolation_mode == InterpolationMode.CUBIC_SPLINE for t in textures): assert isdir(join(dirname(__file__), "CubicInterpolationCUDA", "code")), \ "Submodule CubicInterpolationCUDA does not exist" nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")] diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py index b18c9c40e32c25668ef9d70fd6921f354bc1c261..46aef0bdb05e2875b5523d2e850a0421a64c3772 100644 --- a/pystencils/interpolation_astnodes.py +++ b/pystencils/interpolation_astnodes.py @@ -9,6 +9,7 @@ """ import itertools +from enum import Enum from typing import Set import sympy as sp @@ -24,7 +25,14 @@ except Exception: pass -class LinearInterpolator(object): +class InterpolationMode(str, Enum): + NEAREST_NEIGHBOR = "nearest_neighbour" + NN = NEAREST_NEIGHBOR + LINEAR = "linear" + CUBIC_SPLINE = "cubic_spline" + + +class Interpolator(object): """ Implements non-integer accesses on fields using linear interpolation. @@ -58,7 +66,9 @@ class LinearInterpolator(object): required_global_declarations = [] - def __init__(self, parent_field: pystencils.Field, + def __init__(self, + parent_field: pystencils.Field, + interpolation_mode: InterpolationMode, address_mode='BORDER', use_normalized_coordinates=False): super().__init__() @@ -72,6 +82,7 @@ class LinearInterpolator(object): 'dummy_symbol_carrying_field' + self.field.name + hash_str) self.symbol.field = self.field self.symbol.interpolator = self + self.interpolation_mode = interpolation_mode def at(self, offset): return InterpolatorAccess(self.symbol, *offset) @@ -93,6 +104,30 @@ class LinearInterpolator(object): self.use_normalized_coordinates)) +class LinearInterpolator(Interpolator): + + def __init__(self, + parent_field: pystencils.Field, + address_mode='BORDER', + use_normalized_coordinates=False): + super().__init__(parent_field, + InterpolationMode.LINEAR, + address_mode, + use_normalized_coordinates) + + +class NearestNeightborInterpolator(Interpolator): + + def __init__(self, + parent_field: pystencils.Field, + address_mode='BORDER', + use_normalized_coordinates=False): + super().__init__(parent_field, + InterpolationMode.NN, + address_mode, + use_normalized_coordinates) + + class InterpolatorAccess(TypedSymbol): def __new__(cls, field, offsets, *args, **kwargs): obj = TextureAccess.__xnew_cached_(cls, field, offsets, *args, **kwargs) @@ -149,6 +184,10 @@ class InterpolatorAccess(TypedSymbol): def symbols_defined(self) -> Set[sp.Symbol]: return {self} + @property + def interpolation_mode(self): + return self.interpolator.interpolation_mode + def implementation_with_stencils(self): field = self.field @@ -165,56 +204,66 @@ class InterpolatorAccess(TypedSymbol): offsets = self.offsets rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1) - # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/ - for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions): - weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) - index = [f(offset) for (f, offset) in zip(c, offsets)] - for channel_idx in range(field.shape[0] if field.index_dimensions else 1): - # Hardware boundary handling on GPU + for channel_idx in range(field.shape[0] if field.index_dimensions else 1): + if self.interpolation_mode == InterpolationMode.NN: if use_textures: - weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) - sum[channel_idx] += \ - weight * absolute_access(index, channel_idx if field.index_dimensions else ()) - # else boundary handling using software - elif str(self.interpolator.address_mode).lower() == 'border': - is_inside_field = sp.And( - *itertools.chain([i >= 0 for i in index], - [idx < field.shape[dim] for (dim, idx) in enumerate(index)])) - index = [cast_func(i, default_int_type) for i in index] - sum[channel_idx] += sp.Piecewise( - (weight * absolute_access(index, channel_idx if field.index_dimensions else ()), - is_inside_field), - (sp.simplify(0), True) - ) - elif str(self.interpolator.address_mode).lower() == 'clamp': - index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type) - for (dim, i) in enumerate(index)] - sum[channel_idx] += weight * \ - absolute_access(index, channel_idx if field.index_dimensions else ()) - elif str(self.interpolator.address_mode).lower() == 'wrap': - index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0), - (field.shape[dim] + sp.Mod(i, field.shape[dim]), True)), - default_int_type) - for (dim, i) in enumerate(index)] - sum[channel_idx] += weight * \ - absolute_access(index, channel_idx if field.index_dimensions else ()) - elif str(self.interpolator.address_mode).lower() == 'mirror': - def triangle_fun(x, half_period): - saw_tooth = sp.Abs(x) % (2 * half_period) - return sp.Piecewise((saw_tooth, saw_tooth < half_period), - (2 * half_period - 1 - saw_tooth, True)) - index = [cast_func(triangle_fun(i, field.shape[dim]), - default_int_type) for (dim, i) in enumerate(index)] - sum[channel_idx] += weight * absolute_access(index, channel_idx if field.index_dimensions else ()) + sum[channel_idx] = self else: - raise NotImplementedError() - - sum = [sp.factor(s) for s in sum] + sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx) - if field.index_dimensions: - return sp.Matrix(sum) - else: - return sum[0] + elif self.interpolation_mode == InterpolationMode.LINEAR: + # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/ + for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions): + weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) + index = [f(offset) for (f, offset) in zip(c, offsets)] + # Hardware boundary handling on GPU + if use_textures: + weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) + sum[channel_idx] += \ + weight * absolute_access(index, channel_idx if field.index_dimensions else ()) + # else boundary handling using software + elif str(self.interpolator.address_mode).lower() == 'border': + is_inside_field = sp.And( + *itertools.chain([i >= 0 for i in index], + [idx < field.shape[dim] for (dim, idx) in enumerate(index)])) + index = [cast_func(i, default_int_type) for i in index] + sum[channel_idx] += sp.Piecewise( + (weight * absolute_access(index, channel_idx if field.index_dimensions else ()), + is_inside_field), + (sp.simplify(0), True) + ) + elif str(self.interpolator.address_mode).lower() == 'clamp': + index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type) + for (dim, i) in enumerate(index)] + sum[channel_idx] += weight * \ + absolute_access(index, channel_idx if field.index_dimensions else ()) + elif str(self.interpolator.address_mode).lower() == 'wrap': + index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0), + (field.shape[dim] + sp.Mod(i, field.shape[dim]), True)), + default_int_type) + for (dim, i) in enumerate(index)] + sum[channel_idx] += weight * \ + absolute_access(index, channel_idx if field.index_dimensions else ()) + elif str(self.interpolator.address_mode).lower() == 'mirror': + def triangle_fun(x, half_period): + saw_tooth = sp.Abs(x) % (2 * half_period) + return sp.Piecewise((saw_tooth, saw_tooth < half_period), + (2 * half_period - 1 - saw_tooth, True)) + index = [cast_func(triangle_fun(i, field.shape[dim]), + default_int_type) for (dim, i) in enumerate(index)] + sum[channel_idx] += weight * \ + absolute_access(index, channel_idx if field.index_dimensions else ()) + else: + raise NotImplementedError() + elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE: + raise NotImplementedError("only works with HW interpolation for float32") + + sum = [sp.factor(s) for s in sum] + + if field.index_dimensions: + return sp.Matrix(sum) + else: + return sum[0] # noinspection SpellCheckingInspection __xnew__ = staticmethod(__new_stage2__) @@ -231,9 +280,10 @@ class TextureCachedField: def __init__(self, parent_field, address_mode=None, filter_mode=None, + interpolation_mode: InterpolationMode = InterpolationMode.LINEAR, use_normalized_coordinates=False, - read_as_integer=False, - cubic_bspline_interpolation=False): + read_as_integer=False + ): if isinstance(address_mode, str): address_mode = getattr(pycuda.driver.address_mode, address_mode.upper()) @@ -252,14 +302,14 @@ class TextureCachedField: self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype) self.symbol.interpolator = self self.symbol.field = self.field - self.cubic_bspline_interpolation = cubic_bspline_interpolation + self.interpolation_mode = interpolation_mode # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!" # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less" @classmethod def from_interpolator(cls, interpolator: LinearInterpolator): - obj = cls(interpolator.field, interpolator.address_mode) + obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode) return obj def at(self, offset):