diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index df11dabddac2702c1508f6d1a5f2bf124a2ddae3..4b0b8300e5665b5e9f55fd4a76312394bdc6f9ce 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -328,3 +328,20 @@ class DynamicFunction(sp.Function): def __repr__(self): return self.__str__() + + +class MeshNormalFunctor(DynamicFunction): + def __new__(cls, mesh_name, base_dtype, *args): + from pystencils.data_types import TypedMatrixSymbol + + A = TypedMatrixSymbol('A', 3, 1, base_dtype, 'Vector3<real_t>') + obj = DynamicFunction.__new__(cls, + TypedSymbol(str(mesh_name), + 'std::function<Vector3<real_t>(int, int, int)>'), + A.dtype, + *args) + obj.mesh_name = mesh_name + return obj + + def __getnewargs__(self): + return self.mesh_name, self.dtype.base_dtype, self.args[2:]