From 59caa88afd5e04ee9d338a8ee5b9d4db8e86c1e4 Mon Sep 17 00:00:00 2001
From: Alexander Reinauer <areinauer@icp.uni-stuttgart.de>
Date: Thu, 2 Apr 2020 12:13:50 +0200
Subject: [PATCH] NoFlux Boundary condition of Michael

---
 pystencils/boundaries/boundaryconditions.py | 42 +++++++++++++++++++++
 1 file changed, 42 insertions(+)

diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py
index 39338634..c1b8f492 100644
--- a/pystencils/boundaries/boundaryconditions.py
+++ b/pystencils/boundaries/boundaryconditions.py
@@ -1,5 +1,8 @@
 from typing import Any, List, Tuple
 
+import pystencils as ps
+import sympy as sp
+import numpy as np
 from pystencils import Assignment
 from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
 from pystencils.data_types import create_type
@@ -111,3 +114,42 @@ class Dirichlet(Boundary):
             assert len(self._value) == field.index_shape[0], "Dirichlet value does not match index shape of field"
             return [Assignment(field(i), self._value[i]) for i in range(field.index_shape[0])]
         raise NotImplementedError("Dirichlet boundary not implemented for fields with more than one index dimension")
+
+
+class NoFlux(Boundary):
+    inner_or_boundary = True  # call the boundary condition with the fluid cell
+    single_link = False  # needs to be called for all directional fluxes
+
+    def __init__(self, stencil):
+        self.stencil = stencil
+
+    def __call__(self, field, direction_symbol, **kwargs):
+        assert ps.FieldType.is_staggered(field)
+
+        assert all([s == 0 for s in self.stencil[0]])
+        accesses = [field.staggered_vector_access(ps.stencil.offset_to_direction_string(d))
+                    for d in self.stencil[1:]]
+        conds = [sp.Equality(direction_symbol, d+1) for d in range(len(accesses))]
+
+        val = sp.Matrix(np.zeros(accesses[0].shape, dtype=int))
+
+        # use conditional
+        conditional = None
+        for a, c in zip(accesses, conds):
+            assignments = []
+            for i in range(len(a)):
+                if type(a[i]) is sp.Mul and a[i].args[0] == -1:
+                    continue  # this will be written by the neighboring cell
+                else:
+                    assignments.append(ps.Assignment(a[i], val[i]))
+            if len(assignments) > 0:
+                conditional = ps.astnodes.Conditional(ps.data_types.type_all_numbers(c, "int"),
+                                                      ps.astnodes.Block(assignments),
+                                                      conditional)
+        return [conditional]
+
+    def __hash__(self):
+        return hash((NoFlux, self.stencil))
+
+    def __eq__(self, other):
+        return type(other) == NoFlux and other.stencil == self.stencil
-- 
GitLab