Skip to content
Snippets Groups Projects
Commit b132a731 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

WIP: add interpolation_mode (NEAREST_NEIGHBOR vs LINEAR)

parent 4e8932a1
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ from pystencils.astnodes import Node ...@@ -8,6 +8,7 @@ from pystencils.astnodes import Node
from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c from pystencils.backends.cbackend import CBackend, CustomSympyPrinter, generate_c
from pystencils.data_types import cast_func, create_type, get_type_of_expression from pystencils.data_types import cast_func, create_type, get_type_of_expression
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.interpolation_astnodes import InterpolationMode
with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f: with open(join(dirname(__file__), 'cuda_known_functions.txt')) as f:
lines = f.readlines() lines = f.readlines()
...@@ -77,7 +78,7 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -77,7 +78,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_TextureAccess(self, node): def _print_TextureAccess(self, node):
dtype = node.texture.field.dtype.numpy_dtype dtype = node.texture.field.dtype.numpy_dtype
if node.texture.cubic_bspline_interpolation: if node.texture.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple<%s>(%s, %s)" template = "cubicTex%iDSimple<%s>(%s, %s)"
else: else:
if dtype.itemsize > 4: if dtype.itemsize > 4:
......
...@@ -7,7 +7,7 @@ from pystencils.data_types import StructType ...@@ -7,7 +7,7 @@ from pystencils.data_types import StructType
from pystencils.field import FieldType from pystencils.field import FieldType
from pystencils.gpucuda.texture_utils import ndarray_to_tex from pystencils.gpucuda.texture_utils import ndarray_to_tex
from pystencils.include import get_pystencils_include_path from pystencils.include import get_pystencils_include_path
from pystencils.interpolation_astnodes import TextureAccess from pystencils.interpolation_astnodes import InterpolationMode, TextureAccess
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
USE_FAST_MATH = True USE_FAST_MATH = True
...@@ -46,7 +46,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen ...@@ -46,7 +46,7 @@ def make_python_function(kernel_function_node, argument_dict=None, custom_backen
if USE_FAST_MATH: if USE_FAST_MATH:
nvcc_options.append("-use_fast_math") nvcc_options.append("-use_fast_math")
if any(t.cubic_bspline_interpolation for t in textures): if any(t.interpolation_mode == InterpolationMode.CUBIC_SPLINE for t in textures):
assert isdir(join(dirname(__file__), "CubicInterpolationCUDA", "code")), \ assert isdir(join(dirname(__file__), "CubicInterpolationCUDA", "code")), \
"Submodule CubicInterpolationCUDA does not exist" "Submodule CubicInterpolationCUDA does not exist"
nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")] nvcc_options += ["-I" + join(dirname(__file__), "CubicInterpolationCUDA", "code")]
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
""" """
import itertools import itertools
from enum import Enum
from typing import Set from typing import Set
import sympy as sp import sympy as sp
...@@ -24,7 +25,14 @@ except Exception: ...@@ -24,7 +25,14 @@ except Exception:
pass pass
class LinearInterpolator(object): class InterpolationMode(str, Enum):
NEAREST_NEIGHBOR = "nearest_neighbour"
NN = NEAREST_NEIGHBOR
LINEAR = "linear"
CUBIC_SPLINE = "cubic_spline"
class Interpolator(object):
""" """
Implements non-integer accesses on fields using linear interpolation. Implements non-integer accesses on fields using linear interpolation.
...@@ -58,7 +66,9 @@ class LinearInterpolator(object): ...@@ -58,7 +66,9 @@ class LinearInterpolator(object):
required_global_declarations = [] required_global_declarations = []
def __init__(self, parent_field: pystencils.Field, def __init__(self,
parent_field: pystencils.Field,
interpolation_mode: InterpolationMode,
address_mode='BORDER', address_mode='BORDER',
use_normalized_coordinates=False): use_normalized_coordinates=False):
super().__init__() super().__init__()
...@@ -72,6 +82,7 @@ class LinearInterpolator(object): ...@@ -72,6 +82,7 @@ class LinearInterpolator(object):
'dummy_symbol_carrying_field' + self.field.name + hash_str) 'dummy_symbol_carrying_field' + self.field.name + hash_str)
self.symbol.field = self.field self.symbol.field = self.field
self.symbol.interpolator = self self.symbol.interpolator = self
self.interpolation_mode = interpolation_mode
def at(self, offset): def at(self, offset):
return InterpolatorAccess(self.symbol, *offset) return InterpolatorAccess(self.symbol, *offset)
...@@ -93,6 +104,30 @@ class LinearInterpolator(object): ...@@ -93,6 +104,30 @@ class LinearInterpolator(object):
self.use_normalized_coordinates)) self.use_normalized_coordinates))
class LinearInterpolator(Interpolator):
def __init__(self,
parent_field: pystencils.Field,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__(parent_field,
InterpolationMode.LINEAR,
address_mode,
use_normalized_coordinates)
class NearestNeightborInterpolator(Interpolator):
def __init__(self,
parent_field: pystencils.Field,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__(parent_field,
InterpolationMode.NN,
address_mode,
use_normalized_coordinates)
class InterpolatorAccess(TypedSymbol): class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, offsets, *args, **kwargs): def __new__(cls, field, offsets, *args, **kwargs):
obj = TextureAccess.__xnew_cached_(cls, field, offsets, *args, **kwargs) obj = TextureAccess.__xnew_cached_(cls, field, offsets, *args, **kwargs)
...@@ -149,6 +184,10 @@ class InterpolatorAccess(TypedSymbol): ...@@ -149,6 +184,10 @@ class InterpolatorAccess(TypedSymbol):
def symbols_defined(self) -> Set[sp.Symbol]: def symbols_defined(self) -> Set[sp.Symbol]:
return {self} return {self}
@property
def interpolation_mode(self):
return self.interpolator.interpolation_mode
def implementation_with_stencils(self): def implementation_with_stencils(self):
field = self.field field = self.field
...@@ -165,56 +204,66 @@ class InterpolatorAccess(TypedSymbol): ...@@ -165,56 +204,66 @@ class InterpolatorAccess(TypedSymbol):
offsets = self.offsets offsets = self.offsets
rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1) rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1)
# TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/ for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions): if self.interpolation_mode == InterpolationMode.NN:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
index = [f(offset) for (f, offset) in zip(c, offsets)]
for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
# Hardware boundary handling on GPU
if use_textures: if use_textures:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)]) sum[channel_idx] = self
sum[channel_idx] += \
weight * absolute_access(index, channel_idx if field.index_dimensions else ())
# else boundary handling using software
elif str(self.interpolator.address_mode).lower() == 'border':
is_inside_field = sp.And(
*itertools.chain([i >= 0 for i in index],
[idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
index = [cast_func(i, default_int_type) for i in index]
sum[channel_idx] += sp.Piecewise(
(weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
is_inside_field),
(sp.simplify(0), True)
)
elif str(self.interpolator.address_mode).lower() == 'clamp':
index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'wrap':
index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
(field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'mirror':
def triangle_fun(x, half_period):
saw_tooth = sp.Abs(x) % (2 * half_period)
return sp.Piecewise((saw_tooth, saw_tooth < half_period),
(2 * half_period - 1 - saw_tooth, True))
index = [cast_func(triangle_fun(i, field.shape[dim]),
default_int_type) for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * absolute_access(index, channel_idx if field.index_dimensions else ())
else: else:
raise NotImplementedError() sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx)
sum = [sp.factor(s) for s in sum]
if field.index_dimensions: elif self.interpolation_mode == InterpolationMode.LINEAR:
return sp.Matrix(sum) # TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
else: for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
return sum[0] weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
index = [f(offset) for (f, offset) in zip(c, offsets)]
# Hardware boundary handling on GPU
if use_textures:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
sum[channel_idx] += \
weight * absolute_access(index, channel_idx if field.index_dimensions else ())
# else boundary handling using software
elif str(self.interpolator.address_mode).lower() == 'border':
is_inside_field = sp.And(
*itertools.chain([i >= 0 for i in index],
[idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
index = [cast_func(i, default_int_type) for i in index]
sum[channel_idx] += sp.Piecewise(
(weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
is_inside_field),
(sp.simplify(0), True)
)
elif str(self.interpolator.address_mode).lower() == 'clamp':
index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'wrap':
index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
(field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'mirror':
def triangle_fun(x, half_period):
saw_tooth = sp.Abs(x) % (2 * half_period)
return sp.Piecewise((saw_tooth, saw_tooth < half_period),
(2 * half_period - 1 - saw_tooth, True))
index = [cast_func(triangle_fun(i, field.shape[dim]),
default_int_type) for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
else:
raise NotImplementedError()
elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
raise NotImplementedError("only works with HW interpolation for float32")
sum = [sp.factor(s) for s in sum]
if field.index_dimensions:
return sp.Matrix(sum)
else:
return sum[0]
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__) __xnew__ = staticmethod(__new_stage2__)
...@@ -231,9 +280,10 @@ class TextureCachedField: ...@@ -231,9 +280,10 @@ class TextureCachedField:
def __init__(self, parent_field, def __init__(self, parent_field,
address_mode=None, address_mode=None,
filter_mode=None, filter_mode=None,
interpolation_mode: InterpolationMode = InterpolationMode.LINEAR,
use_normalized_coordinates=False, use_normalized_coordinates=False,
read_as_integer=False, read_as_integer=False
cubic_bspline_interpolation=False): ):
if isinstance(address_mode, str): if isinstance(address_mode, str):
address_mode = getattr(pycuda.driver.address_mode, address_mode.upper()) address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
...@@ -252,14 +302,14 @@ class TextureCachedField: ...@@ -252,14 +302,14 @@ class TextureCachedField:
self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype) self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype)
self.symbol.interpolator = self self.symbol.interpolator = self
self.symbol.field = self.field self.symbol.field = self.field
self.cubic_bspline_interpolation = cubic_bspline_interpolation self.interpolation_mode = interpolation_mode
# assert str(self.field.dtype) != 'double', "CUDA does not support double textures!" # assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
# assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less" # assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
@classmethod @classmethod
def from_interpolator(cls, interpolator: LinearInterpolator): def from_interpolator(cls, interpolator: LinearInterpolator):
obj = cls(interpolator.field, interpolator.address_mode) obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode)
return obj return obj
def at(self, offset): def at(self, offset):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment