Skip to content
Snippets Groups Projects

WIP: Astnodes for interpolation

1 file
+ 103
53
Compare changes
  • Side-by-side
  • Inline
@@ -9,6 +9,7 @@
"""
import itertools
from enum import Enum
from typing import Set
import sympy as sp
@@ -24,7 +25,14 @@ except Exception:
pass
class LinearInterpolator(object):
class InterpolationMode(str, Enum):
NEAREST_NEIGHBOR = "nearest_neighbour"
NN = NEAREST_NEIGHBOR
LINEAR = "linear"
CUBIC_SPLINE = "cubic_spline"
class Interpolator(object):
"""
Implements non-integer accesses on fields using linear interpolation.
@@ -58,7 +66,9 @@ class LinearInterpolator(object):
required_global_declarations = []
def __init__(self, parent_field: pystencils.Field,
def __init__(self,
parent_field: pystencils.Field,
interpolation_mode: InterpolationMode,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__()
@@ -72,6 +82,7 @@ class LinearInterpolator(object):
'dummy_symbol_carrying_field' + self.field.name + hash_str)
self.symbol.field = self.field
self.symbol.interpolator = self
self.interpolation_mode = interpolation_mode
def at(self, offset):
return InterpolatorAccess(self.symbol, *offset)
@@ -93,6 +104,30 @@ class LinearInterpolator(object):
self.use_normalized_coordinates))
class LinearInterpolator(Interpolator):
def __init__(self,
parent_field: pystencils.Field,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__(parent_field,
InterpolationMode.LINEAR,
address_mode,
use_normalized_coordinates)
class NearestNeightborInterpolator(Interpolator):
def __init__(self,
parent_field: pystencils.Field,
address_mode='BORDER',
use_normalized_coordinates=False):
super().__init__(parent_field,
InterpolationMode.NN,
address_mode,
use_normalized_coordinates)
class InterpolatorAccess(TypedSymbol):
def __new__(cls, field, offsets, *args, **kwargs):
obj = TextureAccess.__xnew_cached_(cls, field, offsets, *args, **kwargs)
@@ -149,6 +184,10 @@ class InterpolatorAccess(TypedSymbol):
def symbols_defined(self) -> Set[sp.Symbol]:
return {self}
@property
def interpolation_mode(self):
return self.interpolator.interpolation_mode
def implementation_with_stencils(self):
field = self.field
@@ -165,56 +204,66 @@ class InterpolatorAccess(TypedSymbol):
offsets = self.offsets
rounding_functions = (sp.floor, lambda x: sp.floor(x) + 1)
# TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
index = [f(offset) for (f, offset) in zip(c, offsets)]
for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
# Hardware boundary handling on GPU
for channel_idx in range(field.shape[0] if field.index_dimensions else 1):
if self.interpolation_mode == InterpolationMode.NN:
if use_textures:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
sum[channel_idx] += \
weight * absolute_access(index, channel_idx if field.index_dimensions else ())
# else boundary handling using software
elif str(self.interpolator.address_mode).lower() == 'border':
is_inside_field = sp.And(
*itertools.chain([i >= 0 for i in index],
[idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
index = [cast_func(i, default_int_type) for i in index]
sum[channel_idx] += sp.Piecewise(
(weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
is_inside_field),
(sp.simplify(0), True)
)
elif str(self.interpolator.address_mode).lower() == 'clamp':
index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'wrap':
index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
(field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'mirror':
def triangle_fun(x, half_period):
saw_tooth = sp.Abs(x) % (2 * half_period)
return sp.Piecewise((saw_tooth, saw_tooth < half_period),
(2 * half_period - 1 - saw_tooth, True))
index = [cast_func(triangle_fun(i, field.shape[dim]),
default_int_type) for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * absolute_access(index, channel_idx if field.index_dimensions else ())
sum[channel_idx] = self
else:
raise NotImplementedError()
sum = [sp.factor(s) for s in sum]
sum[channel_idx] = absolute_access([sp.floor(i + 0.5) for i in offsets], channel_idx)
if field.index_dimensions:
return sp.Matrix(sum)
else:
return sum[0]
elif self.interpolation_mode == InterpolationMode.LINEAR:
# TODO optimization: implement via lerp: https://devblogs.nvidia.com/lerp-faster-cuda/
for c in itertools.product(rounding_functions, repeat=field.spatial_dimensions):
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
index = [f(offset) for (f, offset) in zip(c, offsets)]
# Hardware boundary handling on GPU
if use_textures:
weight = sp.Mul(*[1 - sp.Abs(f(offset) - offset) for (f, offset) in zip(c, offsets)])
sum[channel_idx] += \
weight * absolute_access(index, channel_idx if field.index_dimensions else ())
# else boundary handling using software
elif str(self.interpolator.address_mode).lower() == 'border':
is_inside_field = sp.And(
*itertools.chain([i >= 0 for i in index],
[idx < field.shape[dim] for (dim, idx) in enumerate(index)]))
index = [cast_func(i, default_int_type) for i in index]
sum[channel_idx] += sp.Piecewise(
(weight * absolute_access(index, channel_idx if field.index_dimensions else ()),
is_inside_field),
(sp.simplify(0), True)
)
elif str(self.interpolator.address_mode).lower() == 'clamp':
index = [cast_func(sp.Min(sp.Max(0, i), field.shape[dim] - 1), default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'wrap':
index = [cast_func(sp.Piecewise((sp.Mod(i, field.shape[dim]), i >= 0),
(field.shape[dim] + sp.Mod(i, field.shape[dim]), True)),
default_int_type)
for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
elif str(self.interpolator.address_mode).lower() == 'mirror':
def triangle_fun(x, half_period):
saw_tooth = sp.Abs(x) % (2 * half_period)
return sp.Piecewise((saw_tooth, saw_tooth < half_period),
(2 * half_period - 1 - saw_tooth, True))
index = [cast_func(triangle_fun(i, field.shape[dim]),
default_int_type) for (dim, i) in enumerate(index)]
sum[channel_idx] += weight * \
absolute_access(index, channel_idx if field.index_dimensions else ())
else:
raise NotImplementedError()
elif self.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
raise NotImplementedError("only works with HW interpolation for float32")
sum = [sp.factor(s) for s in sum]
if field.index_dimensions:
return sp.Matrix(sum)
else:
return sum[0]
# noinspection SpellCheckingInspection
__xnew__ = staticmethod(__new_stage2__)
@@ -231,9 +280,10 @@ class TextureCachedField:
def __init__(self, parent_field,
address_mode=None,
filter_mode=None,
interpolation_mode: InterpolationMode = InterpolationMode.LINEAR,
use_normalized_coordinates=False,
read_as_integer=False,
cubic_bspline_interpolation=False):
read_as_integer=False
):
if isinstance(address_mode, str):
address_mode = getattr(pycuda.driver.address_mode, address_mode.upper())
@@ -252,14 +302,14 @@ class TextureCachedField:
self.symbol = TypedSymbol(str(self), self.field.dtype.numpy_dtype)
self.symbol.interpolator = self
self.symbol.field = self.field
self.cubic_bspline_interpolation = cubic_bspline_interpolation
self.interpolation_mode = interpolation_mode
# assert str(self.field.dtype) != 'double', "CUDA does not support double textures!"
# assert dtype_supports_textures(self.field.dtype), "CUDA only supports texture types with 32 bits or less"
@classmethod
def from_interpolator(cls, interpolator: LinearInterpolator):
obj = cls(interpolator.field, interpolator.address_mode)
obj = cls(interpolator.field, interpolator.address_mode, interpolation_mode=interpolator.interpolation_mode)
return obj
def at(self, offset):
Loading