Skip to content
Snippets Groups Projects
Commit 6be6ba63 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Making cool stuff with interpolators

parent 064147b6
No related branches found
No related tags found
1 merge request!129Interpolation refactoring
......@@ -75,7 +75,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if type(node) == DiffInterpolatorAccess:
# cubicTex3D_1st_derivative_x(texture tex, float3 coord)
template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)" # noqa
elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
template = "cubicTex%iDSimple(%s, %s)"
else:
......
......@@ -109,7 +109,9 @@ class Discretization2ndOrder:
return self._discretize_advection(e)
elif isinstance(e, Diff):
arg, *indices = diff_args(e)
if not isinstance(arg, Field.Access):
from pystencils.interpolation_astnodes import InterpolatorAccess
if not isinstance(arg, (Field.Access, InterpolatorAccess)):
raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
return self.spatial_stencil(indices, self.dx, arg)
else:
......
......@@ -170,6 +170,14 @@ class InterpolatorAccess(TypedSymbol):
def __repr__(self):
return self.__str__()
@property
def ndim(self):
return len(self.offsets)
@property
def is_texture(self):
return isinstance(self.interpolator, TextureCachedField)
def atoms(self, *types):
if self.offsets:
offsets = set(o for o in self.offsets if isinstance(o, types))
......@@ -182,6 +190,11 @@ class InterpolatorAccess(TypedSymbol):
else:
return set()
def neighbor(self, coord_id, offset):
offset_list = list(self.offsets)
offset_list[coord_id] += offset
return self.interpolator.at(tuple(offset_list))
@property
def free_symbols(self):
symbols = set()
......@@ -318,6 +331,9 @@ class InterpolatorAccess(TypedSymbol):
class DiffInterpolatorAccess(InterpolatorAccess):
def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
from pystencils.fd import Diff, Discretization2ndOrder
return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
return obj
......@@ -363,7 +379,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
##########################################################################################
class TextureCachedField:
class TextureCachedField(Interpolator):
def __init__(self, parent_field,
address_mode=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment