From 6be6ba63ba21fcc5d713cdcb81add75fea7154cd Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Wed, 15 Jan 2020 19:29:39 +0100
Subject: [PATCH] Making cool stuff with interpolators

---
 pystencils/backends/cuda_backend.py  |  2 +-
 pystencils/fd/finitedifferences.py   |  4 +++-
 pystencils/interpolation_astnodes.py | 18 +++++++++++++++++-
 3 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/pystencils/backends/cuda_backend.py b/pystencils/backends/cuda_backend.py
index 9797bc7da..d590d65b4 100644
--- a/pystencils/backends/cuda_backend.py
+++ b/pystencils/backends/cuda_backend.py
@@ -75,7 +75,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
 
         if type(node) == DiffInterpolatorAccess:
             # cubicTex3D_1st_derivative_x(texture tex, float3 coord)
-            template = f"cubicTex%iD_1st_derivative_{'xyz'[node.diff_coordinate_idx]}(%s, %s)"
+            template = f"cubicTex%iD_1st_derivative_{list(reversed('xyz'[:node.ndim]))[node.diff_coordinate_idx]}(%s, %s)"  # noqa
         elif node.interpolator.interpolation_mode == InterpolationMode.CUBIC_SPLINE:
             template = "cubicTex%iDSimple(%s, %s)"
         else:
diff --git a/pystencils/fd/finitedifferences.py b/pystencils/fd/finitedifferences.py
index d5bce66e9..5b6b15f95 100644
--- a/pystencils/fd/finitedifferences.py
+++ b/pystencils/fd/finitedifferences.py
@@ -109,7 +109,9 @@ class Discretization2ndOrder:
             return self._discretize_advection(e)
         elif isinstance(e, Diff):
             arg, *indices = diff_args(e)
-            if not isinstance(arg, Field.Access):
+            from pystencils.interpolation_astnodes import InterpolatorAccess
+
+            if not isinstance(arg, (Field.Access, InterpolatorAccess)):
                 raise ValueError("Only derivatives with field or field accesses as arguments can be discretized")
             return self.spatial_stencil(indices, self.dx, arg)
         else:
diff --git a/pystencils/interpolation_astnodes.py b/pystencils/interpolation_astnodes.py
index 28d45c2dc..3ecc2a70a 100644
--- a/pystencils/interpolation_astnodes.py
+++ b/pystencils/interpolation_astnodes.py
@@ -170,6 +170,14 @@ class InterpolatorAccess(TypedSymbol):
     def __repr__(self):
         return self.__str__()
 
+    @property
+    def ndim(self):
+        return len(self.offsets)
+
+    @property
+    def is_texture(self):
+        return isinstance(self.interpolator, TextureCachedField)
+
     def atoms(self, *types):
         if self.offsets:
             offsets = set(o for o in self.offsets if isinstance(o, types))
@@ -182,6 +190,11 @@ class InterpolatorAccess(TypedSymbol):
         else:
             return set()
 
+    def neighbor(self, coord_id, offset):
+        offset_list = list(self.offsets)
+        offset_list[coord_id] += offset
+        return self.interpolator.at(tuple(offset_list))
+
     @property
     def free_symbols(self):
         symbols = set()
@@ -318,6 +331,9 @@ class InterpolatorAccess(TypedSymbol):
 
 class DiffInterpolatorAccess(InterpolatorAccess):
     def __new__(cls, symbol, diff_coordinate_idx, *offsets, **kwargs):
+        if symbol.interpolator.interpolation_mode == InterpolationMode.LINEAR:
+            from pystencils.fd import Diff, Discretization2ndOrder
+            return Discretization2ndOrder(1)(Diff(symbol.interpolator.at(offsets), diff_coordinate_idx))
         obj = DiffInterpolatorAccess.__xnew_cached_(cls, symbol, diff_coordinate_idx, *offsets, **kwargs)
         return obj
 
@@ -363,7 +379,7 @@ class DiffInterpolatorAccess(InterpolatorAccess):
 ##########################################################################################
 
 
-class TextureCachedField:
+class TextureCachedField(Interpolator):
 
     def __init__(self, parent_field,
                  address_mode=None,
-- 
GitLab