From e354ca10df56225799aad08080706df277fd4a5e Mon Sep 17 00:00:00 2001
From: Martin Bauer <martin.bauer@fau.de>
Date: Sun, 6 Oct 2019 19:48:01 +0200
Subject: [PATCH] Support for MPI-dtype-based communication

---
 pystencils_walberla/__init__.py               |  5 +-
 pystencils_walberla/codegen.py                | 56 ++++++++++++-
 pystencils_walberla/jinja_filters.py          |  2 +-
 .../templates/CpuPackInfo.tmpl.cpp            |  1 -
 .../templates/MpiDtypeInfo.tmpl.h             | 84 +++++++++++++++++++
 5 files changed, 143 insertions(+), 5 deletions(-)
 create mode 100644 pystencils_walberla/templates/MpiDtypeInfo.tmpl.h

diff --git a/pystencils_walberla/__init__.py b/pystencils_walberla/__init__.py
index b8f06f3..767581a 100644
--- a/pystencils_walberla/__init__.py
+++ b/pystencils_walberla/__init__.py
@@ -1,7 +1,8 @@
 from .cmake_integration import CodeGeneration
 from .codegen import (
     generate_pack_info, generate_pack_info_for_field, generate_pack_info_from_kernel,
-    generate_sweep)
+    generate_mpidtype_info_from_kernel, generate_sweep)
 
 __all__ = ['CodeGeneration',
-           'generate_sweep', 'generate_pack_info_from_kernel', 'generate_pack_info_for_field', 'generate_pack_info']
+           'generate_sweep', 'generate_pack_info_from_kernel', 'generate_pack_info_for_field', 'generate_pack_info',
+           'generate_mpidtype_info_from_kernel']
diff --git a/pystencils_walberla/codegen.py b/pystencils_walberla/codegen.py
index e47861d..e1f9623 100644
--- a/pystencils_walberla/codegen.py
+++ b/pystencils_walberla/codegen.py
@@ -12,7 +12,7 @@ from pystencils.stencil import inverse_direction, offset_to_direction_string
 from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
 
 __all__ = ['generate_sweep', 'generate_pack_info', 'generate_pack_info_for_field', 'generate_pack_info_from_kernel',
-           'default_create_kernel_parameters', 'KernelInfo']
+           'generate_mpidtype_info_from_kernel', 'default_create_kernel_parameters', 'KernelInfo']
 
 
 def generate_sweep(generation_context, class_name, assignments,
@@ -249,6 +249,60 @@ def generate_pack_info(generation_context, class_name: str,
     generation_context.write_file("{}.{}".format(class_name, source_extension), source)
 
 
+def generate_mpidtype_info_from_kernel(generation_context, class_name: str,
+                                       assignments: Sequence[Assignment], kind='pull',  namespace='pystencils',):
+    assert kind in ('push', 'pull')
+    reads = set()
+    writes = set()
+
+    if isinstance(assignments, AssignmentCollection):
+        assignments = assignments.all_assignments
+
+    for a in assignments:
+        reads.update(a.rhs.atoms(Field.Access))
+        writes.update(a.lhs.atoms(Field.Access))
+
+    spec = defaultdict(set)
+    if kind == 'pull':
+        read_fields = set(fa.field for fa in reads)
+        assert len(read_fields) == 1, "Only scenarios where one fields neighbors are accessed"
+        field = read_fields.pop()
+        for fa in reads:
+            assert all(abs(e) <= 1 for e in fa.offsets)
+            if all(offset == 0 for offset in fa.offsets):
+                continue
+            comm_direction = inverse_direction(fa.offsets)
+            for comm_dir in comm_directions(comm_direction):
+                assert len(fa.index) == 1, "Supports only fields with a single index dimension"
+                spec[(offset_to_direction_string(comm_dir),)].add(fa.index[0])
+    elif kind == 'push':
+        written_fields = set(fa.field for fa in writes)
+        assert len(written_fields) == 1, "Only scenarios where one fields neighbors are accessed"
+        field = written_fields.pop()
+
+        for fa in writes:
+            assert all(abs(e) <= 1 for e in fa.offsets)
+            if all(offset == 0 for offset in fa.offsets):
+                continue
+            for comm_dir in comm_directions(fa.offsets):
+                assert len(fa.index) == 1, "Supports only fields with a single index dimension"
+                spec[(offset_to_direction_string(comm_dir),)].add(fa.index[0])
+    else:
+        raise ValueError("Invalid 'kind' parameter")
+
+    jinja_context = {
+        'class_name': class_name,
+        'namespace': namespace,
+        'kind': kind,
+        'field_name': field.name,
+        'f_size': field.index_shape[0],
+        'spec': spec,
+    }
+    env = Environment(loader=PackageLoader('pystencils_walberla'))
+    header = env.get_template("MpiDtypeInfo.tmpl.h").render(**jinja_context)
+    generation_context.write_file("{}.h".format(class_name), header)
+
+
 # ---------------------------------- Internal --------------------------------------------------------------------------
 
 
diff --git a/pystencils_walberla/jinja_filters.py b/pystencils_walberla/jinja_filters.py
index 7cc2d25..174b9df 100644
--- a/pystencils_walberla/jinja_filters.py
+++ b/pystencils_walberla/jinja_filters.py
@@ -117,7 +117,7 @@ def field_extraction_code(field, is_temporary, declaration_only=False,
             return "%s * %s;" % (field_type, field_name)
         else:
             prefix = "" if no_declaration else "auto "
-            return "%s%s = block->getData< %s >(%sID);" % (prefix, field_name, field_type, field_name)
+            return "%s%s = block->uncheckedFastGetData< %s >(%sID);" % (prefix, field_name, field_type, field_name)
     else:
         assert field_name.endswith('_tmp')
         original_field_name = field_name[:-len('_tmp')]
diff --git a/pystencils_walberla/templates/CpuPackInfo.tmpl.cpp b/pystencils_walberla/templates/CpuPackInfo.tmpl.cpp
index c47e913..24fd293 100644
--- a/pystencils_walberla/templates/CpuPackInfo.tmpl.cpp
+++ b/pystencils_walberla/templates/CpuPackInfo.tmpl.cpp
@@ -1,6 +1,5 @@
 #include "stencil/Directions.h"
 #include "core/cell/CellInterval.h"
-#include "cuda/GPUField.h"
 #include "core/DataTypes.h"
 #include "{{class_name}}.h"
 
diff --git a/pystencils_walberla/templates/MpiDtypeInfo.tmpl.h b/pystencils_walberla/templates/MpiDtypeInfo.tmpl.h
new file mode 100644
index 0000000..3f9cbb2
--- /dev/null
+++ b/pystencils_walberla/templates/MpiDtypeInfo.tmpl.h
@@ -0,0 +1,84 @@
+#pragma once
+
+#include "core/debug/Debug.h"
+#include "communication/UniformMPIDatatypeInfo.h"
+#include "field/communication/MPIDatatypes.h"
+
+#include <set>
+
+namespace walberla {
+namespace {{namespace}} {
+
+class {{class_name}} : public ::walberla::communication::UniformMPIDatatypeInfo
+{
+public:
+    using GhostLayerField_T = GhostLayerField<real_t, {{f_size}}>;
+
+    {{class_name}}( BlockDataID {{field_name}} )
+        :{{field_name}}_({{field_name}})
+    {}
+    virtual ~{{class_name}}() {}
+
+    virtual shared_ptr<mpi::Datatype> getSendDatatype ( IBlock * block, const stencil::Direction dir )
+    {
+       {% if kind == 'pull' %}
+        return make_shared<mpi::Datatype>( field::communication::mpiDatatypeSliceBeforeGhostlayerXYZ(
+                *getField( block ), dir, uint_t( 1 ), getOptimizedCommunicationIndices( dir ), false ) );
+       {% else %}
+        return make_shared<mpi::Datatype>( field::communication::mpiDatatypeGhostLayerOnlyXYZ(
+                *getField( block ), dir, false, getOptimizedCommunicationIndices( dir ) ) );
+       {% endif %}
+    }
+
+    virtual shared_ptr<mpi::Datatype> getRecvDatatype ( IBlock * block, const stencil::Direction dir )
+    {
+        {% if kind == 'pull' %}
+        return make_shared<mpi::Datatype>( field::communication::mpiDatatypeGhostLayerOnlyXYZ(
+                *getField( block ), dir, false, getOptimizedCommunicationIndices( stencil::inverseDir[dir] ) ) );
+        {% else %}
+        return make_shared<mpi::Datatype>( field::communication::mpiDatatypeSliceBeforeGhostlayerXYZ(
+                *getField( block ), dir, uint_t( 1 ), getOptimizedCommunicationIndices( stencil::inverseDir[dir] ), false ) );
+        {% endif %}
+    }
+
+    virtual void * getSendPointer( IBlock * block, const stencil::Direction ) {
+        return getField(block)->data();
+    }
+
+    virtual void * getRecvPointer( IBlock * block, const stencil::Direction ) {
+        return getField(block)->data();
+    }
+
+private:
+
+    inline static std::set< cell_idx_t > getOptimizedCommunicationIndices( const stencil::Direction dir )
+    {
+        switch(dir)
+        {
+            {%- for direction_set, index_set in spec.items()  %}
+            {%- for dir in direction_set %}
+            case stencil::{{dir}}:
+            {%- endfor %}
+               return {{index_set}};
+            {% endfor %}
+            default:
+                WALBERLA_ASSERT(false);
+                return {};
+        }
+    }
+
+    GhostLayerField_T * getField( IBlock * block )
+    {
+        GhostLayerField_T * const f = block->getData<GhostLayerField_T>( {{field_name}}_ );
+        WALBERLA_ASSERT_NOT_NULLPTR( f );
+        return f;
+    }
+
+    BlockDataID {{field_name}}_;
+};
+
+
+} // namespace {{namespace}}
+} // namespace walberla
+
+
-- 
GitLab