diff --git a/pystencils/boundaries/boundaryconditions.py b/pystencils/boundaries/boundaryconditions.py index 39338634a51cae5bfc782efaae2ac69ca523152e..c1b8f4927f1dd74bd03a9855bed0d7c683d4ab5f 100644 --- a/pystencils/boundaries/boundaryconditions.py +++ b/pystencils/boundaries/boundaryconditions.py @@ -1,5 +1,8 @@ from typing import Any, List, Tuple +import pystencils as ps +import sympy as sp +import numpy as np from pystencils import Assignment from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo from pystencils.data_types import create_type @@ -111,3 +114,42 @@ class Dirichlet(Boundary): assert len(self._value) == field.index_shape[0], "Dirichlet value does not match index shape of field" return [Assignment(field(i), self._value[i]) for i in range(field.index_shape[0])] raise NotImplementedError("Dirichlet boundary not implemented for fields with more than one index dimension") + + +class NoFlux(Boundary): + inner_or_boundary = True # call the boundary condition with the fluid cell + single_link = False # needs to be called for all directional fluxes + + def __init__(self, stencil): + self.stencil = stencil + + def __call__(self, field, direction_symbol, **kwargs): + assert ps.FieldType.is_staggered(field) + + assert all([s == 0 for s in self.stencil[0]]) + accesses = [field.staggered_vector_access(ps.stencil.offset_to_direction_string(d)) + for d in self.stencil[1:]] + conds = [sp.Equality(direction_symbol, d+1) for d in range(len(accesses))] + + val = sp.Matrix(np.zeros(accesses[0].shape, dtype=int)) + + # use conditional + conditional = None + for a, c in zip(accesses, conds): + assignments = [] + for i in range(len(a)): + if type(a[i]) is sp.Mul and a[i].args[0] == -1: + continue # this will be written by the neighboring cell + else: + assignments.append(ps.Assignment(a[i], val[i])) + if len(assignments) > 0: + conditional = ps.astnodes.Conditional(ps.data_types.type_all_numbers(c, "int"), + ps.astnodes.Block(assignments), + conditional) + return [conditional] + + def __hash__(self): + return hash((NoFlux, self.stencil)) + + def __eq__(self, other): + return type(other) == NoFlux and other.stencil == self.stencil