You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/07 21:38:12 UTC

[GitHub] eric-haibin-lin closed pull request #8915: [MXNET-331] NVLink communication pattern updated

eric-haibin-lin closed pull request #8915: [MXNET-331] NVLink communication pattern updated 
URL: https://github.com/apache/incubator-mxnet/pull/8915
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index da2d03d519f..c007a78aefa 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -22,19 +22,20 @@
  */
 #ifndef MXNET_KVSTORE_COMM_H_
 #define MXNET_KVSTORE_COMM_H_
+#define NVLINK_SUPPORT 4
 #include <dmlc/omp.h>
-#include <string>
 #include <algorithm>
-#include <utility>
 #include <limits>
-#include <vector>
-#include <tuple>
+#include <string>
 #include <thread>
-#include "mxnet/ndarray.h"
-#include "gradient_compression.h"
+#include <tuple>
+#include <utility>
+#include <vector>
 #include "../ndarray/ndarray_function.h"
 #include "../operator/tensor/sparse_retain-inl.h"
 #include "./kvstore_utils.h"
+#include "gradient_compression.h"
+#include "mxnet/ndarray.h"
 namespace mxnet {
 namespace kvstore {
 /**
@@ -42,10 +43,8 @@ namespace kvstore {
  */
 class Comm {
  public:
-  Comm() {
-    pinned_ctx_ = Context::CPUPinned(0);
-  }
-  virtual ~Comm() { }
+  Comm() { pinned_ctx_ = Context::CPUPinned(0); }
+  virtual ~Comm() {}
   /**
    * \brief init key with the data shape and storage shape
    */
@@ -54,33 +53,32 @@ class Comm {
   /**
    * \brief returns src[0] + .. + src[src.size()-1]
    */
-  virtual const NDArray& Reduce(
-      int key, const std::vector<NDArray>& src, int priority) = 0;
+  virtual const NDArray& Reduce(int key, const std::vector<NDArray>& src,
+                                int priority) = 0;
   /**
    * \brief copy from src to dst[i] for every i
    */
-  virtual void Broadcast(
-      int key, const NDArray& src,
-      const std::vector<NDArray*> dst, int priority) = 0;
+  virtual void Broadcast(int key, const NDArray& src,
+                         const std::vector<NDArray*> dst, int priority) = 0;
 
   /**
    * \brief broadcast src to dst[i] with target row_ids for every i
-   * \param dst a list of destination row_sparse NDArray and its target row_ids to broadcast,
+   * \param dst a list of destination row_sparse NDArray and its target row_ids
+   to broadcast,
             where the row_ids are expected to be unique and sorted
-   * \param use_copy if set to true, directly copy src to dst[i] without looking up the
+   * \param use_copy if set to true, directly copy src to dst[i] without looking
+   up the
             provided row_ids
    */
-  virtual void BroadcastRowSparse(int key, const NDArray& src,
-                                  const std::vector<std::pair<NDArray*, NDArray>>& dst,
-                                  const bool use_copy,
-                                  const int priority) = 0;
+  virtual void BroadcastRowSparse(
+      int key, const NDArray& src,
+      const std::vector<std::pair<NDArray*, NDArray>>& dst, const bool use_copy,
+      const int priority) = 0;
 
   /**
    * \brief return a pinned contex
    */
-  Context pinned_ctx() const {
-    return pinned_ctx_;
-  }
+  Context pinned_ctx() const { return pinned_ctx_; }
 
   /**
    * \brief Sets gradient compression parameters to be able to
@@ -108,7 +106,7 @@ class CommCPU : public Comm {
     // TODO(junwu) delete the following data member, now for benchmark only
     is_serial_push_ = dmlc::GetEnv("MXNET_KVSTORE_SERIAL_PUSH", 0);
   }
-  virtual ~CommCPU() { }
+  virtual ~CommCPU() {}
 
   void Init(int key, const NDArrayStorageType stype, const TShape& shape,
             int type = mshadow::kFloat32) override {
@@ -121,7 +119,7 @@ class CommCPU : public Comm {
 
   const NDArray& Reduce(int key, const std::vector<NDArray>& src,
                         int priority) override {
-    auto& buf = merge_buf_[key];
+    BufferEntry& buf = merge_buf_[key];
     // avoid extra copy for single device, but it may bring problems for
     // abnormal usage of kvstore
     if (src.size() == 1) {
@@ -140,25 +138,28 @@ class CommCPU : public Comm {
       reduce[0] = buf.merged;
 
       if (buf.copy_buf.empty()) {
-        buf.copy_buf.resize(src.size()-1);
+        buf.copy_buf.resize(src.size() - 1);
         for (size_t j = 0; j < src.size() - 1; ++j) {
           // allocate NDArray based on storage type
-          buf.copy_buf[j] = NDArray(
-            src[0].shape(), pinned_ctx_, false, src[0].dtype());
+          buf.copy_buf[j] =
+              NDArray(src[0].shape(), pinned_ctx_, false, src[0].dtype());
         }
       }
       for (size_t i = 1; i < src.size(); ++i) {
-        CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority);
-        reduce[i] = buf.copy_buf[i-1];
-        const_vars[i-1] = reduce[i].var();
+        CopyFromTo(src[i], &(buf.copy_buf[i - 1]), priority);
+        reduce[i] = buf.copy_buf[i - 1];
+        const_vars[i - 1] = reduce[i].var();
       }
 
       Engine::Get()->PushAsync(
-        [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-          ReduceSumCPU(reduce);
-          on_complete();
-        }, Context::CPU(), const_vars, {reduce[0].var()},
-        FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
+          [reduce, this](RunContext rctx,
+                         Engine::CallbackOnComplete on_complete) {
+            ReduceSumCPU(reduce);
+            on_complete();
+          },
+          Context::CPU(), const_vars, {reduce[0].var()},
+          FnProperty::kCPUPrioritized, priority,
+          PROFILER_MESSAGE("KVStoreReduce"));
 
     } else {
       // buf.merged is a sparse ndarray.
@@ -168,8 +169,8 @@ class CommCPU : public Comm {
       if (buf.copy_buf.empty()) {
         buf.copy_buf.resize(src.size());
         for (size_t j = 0; j < src.size(); ++j) {
-          buf.copy_buf[j] = NDArray(
-            src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype());
+          buf.copy_buf[j] = NDArray(src[0].storage_type(), src[0].shape(),
+                                    pinned_ctx_, true, src[0].dtype());
         }
       }
       for (size_t i = 0; i < src.size(); ++i) {
@@ -178,44 +179,46 @@ class CommCPU : public Comm {
         const_vars[i] = reduce[i].var();
       }
       NDArray result = buf.merged;
-      Resource rsc = ResourceManager::Get()->Request(result.ctx(),
-          ResourceRequest(ResourceRequest::kTempSpace));
+      Resource rsc = ResourceManager::Get()->Request(
+          result.ctx(), ResourceRequest(ResourceRequest::kTempSpace));
       Engine::Get()->PushAsync(
-        [reduce, result, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-          NDArray out = result;
-          is_serial_push_?
-            ReduceSumCPUExSerial(reduce, &out)
-            : mxnet::ndarray::ElementwiseSum(rctx.get_stream<cpu>(), rsc, reduce, &out);
-          on_complete();
-        }, Context::CPU(), const_vars, {result.var(), rsc.var},
-        FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce"));
+          [reduce, result, rsc, this](RunContext rctx,
+                                      Engine::CallbackOnComplete on_complete) {
+            NDArray out = result;
+            is_serial_push_ ? ReduceSumCPUExSerial(reduce, &out)
+                            : mxnet::ndarray::ElementwiseSum(
+                                  rctx.get_stream<cpu>(), rsc, reduce, &out);
+            on_complete();
+          },
+          Context::CPU(), const_vars, {result.var(), rsc.var},
+          FnProperty::kCPUPrioritized, priority,
+          PROFILER_MESSAGE("KVStoreReduce"));
     }
 
     return buf.merged;
   }
 
-  void Broadcast(int key, const NDArray& src,
-                 const std::vector<NDArray*> dst, int priority) override {
+  void Broadcast(int key, const NDArray& src, const std::vector<NDArray*> dst,
+                 int priority) override {
     int mask = src.ctx().dev_mask();
     if (mask == Context::kCPU) {
-      for (auto d : dst) CopyFromTo(src, d, priority);
+      for (auto& d : dst) CopyFromTo(src, d, priority);
     } else {
       // first copy data to cpu, then broadcast
-      auto& buf = merge_buf_[key];
+      BufferEntry& buf = merge_buf_[key];
       CopyFromTo(src, &buf.merged, priority);
-      for (auto d : dst) CopyFromTo(buf.merged, d, priority);
+      for (auto& d : dst) CopyFromTo(buf.merged, d, priority);
     }
   }
 
   void BroadcastRowSparse(int key, const NDArray& src,
                           const std::vector<std::pair<NDArray*, NDArray>>& dst,
-                          const bool use_copy,
-                          const int priority) override {
+                          const bool use_copy, const int priority) override {
     using namespace mshadow;
     CHECK_EQ(src.storage_type(), kRowSparseStorage)
-      << "BroadcastRowSparse expects row-sparse src NDArray";
+        << "BroadcastRowSparse expects row-sparse src NDArray";
     CHECK_EQ(src.ctx().dev_mask(), Context::kCPU)
-      << "BroadcastRowSparse with src on gpu context not supported";
+        << "BroadcastRowSparse with src on gpu context not supported";
     for (size_t i = 0; i < dst.size(); ++i) {
       NDArray* out = dst[i].first;
       NDArray row_id = dst[i].second;
@@ -223,40 +226,47 @@ class CommCPU : public Comm {
         CopyFromTo(src, out, priority);
       } else {
         CHECK_EQ(out->storage_type(), kRowSparseStorage)
-                 << "BroadcastRowSparse expects row_sparse dst NDArray";
+            << "BroadcastRowSparse expects row_sparse dst NDArray";
         CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU)
-                 << "BroadcastRowSparse with row_indices on gpu context not supported";
+            << "BroadcastRowSparse with row_indices on gpu context not "
+               "supported";
         // retain according to unique indices
-        const bool use_sparse_retain = (src.shape()[0] != src.storage_shape()[0])
-          || (row_id.dtype() != out->aux_type(rowsparse::kIdx))
-          || (out->ctx().dev_mask() != Context::kGPU);
+        const bool use_sparse_retain =
+            (src.shape()[0] != src.storage_shape()[0]) ||
+            (row_id.dtype() != out->aux_type(rowsparse::kIdx)) ||
+            (out->ctx().dev_mask() != Context::kGPU);
         if (use_sparse_retain) {  // use sparse_retain op
           const bool is_to_gpu = out->ctx().dev_mask() == Context::kGPU;
-          NDArray out_cpu = is_to_gpu? NDArray(kRowSparseStorage, src.shape(),
-              src.ctx(), true, src.dtype(), src.aux_types()) : *out;
+          NDArray out_cpu =
+              is_to_gpu ? NDArray(kRowSparseStorage, src.shape(), src.ctx(),
+                                  true, src.dtype(), src.aux_types())
+                        : *out;
           Engine::Get()->PushAsync(
-            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-              const TBlob& indices = row_id.data();
-              NDArray temp = out_cpu;  // get rid of const qualifier
-              op::SparseRetainOpForwardRspImpl<cpu>(rctx.get_stream<cpu>(),
-                                                    src, indices, kWriteTo,
-                                                    &temp);
-              on_complete();
-            }, Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
-            FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+              [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+                const TBlob& indices = row_id.data();
+                NDArray temp = out_cpu;  // get rid of const qualifier
+                op::SparseRetainOpForwardRspImpl<cpu>(
+                    rctx.get_stream<cpu>(), src, indices, kWriteTo, &temp);
+                on_complete();
+              },
+              Context::CPU(), {src.var(), row_id.var()}, {out_cpu.var()},
+              FnProperty::kNormal, priority,
+              PROFILER_MESSAGE("KVStoreSparseRetain"));
           if (is_to_gpu) {
             CopyFromTo(out_cpu, out, priority);
           }
         } else {  // direct copy rows
           Engine::Get()->PushAsync(
-            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-              CopyRetainedRowsToGPU(rctx.get_stream<cpu>(), rctx.get_stream<gpu>(),
-                                    src, row_id, out);
-              // wait for GPU operations to complete
-              rctx.get_stream<gpu>()->Wait();
-              on_complete();
-            }, out->ctx(), {src.var(), row_id.var()}, {out->var()},
-            FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
+              [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+                CopyRetainedRowsToGPU(rctx.get_stream<cpu>(),
+                                      rctx.get_stream<gpu>(), src, row_id, out);
+                // wait for GPU operations to complete
+                rctx.get_stream<gpu>()->Wait();
+                on_complete();
+              },
+              out->ctx(), {src.var(), row_id.var()}, {out->var()},
+              FnProperty::kCopyToGPU, priority,
+              PROFILER_MESSAGE("KVStoreCopyRetainedRowsToGPU"));
         }
       }
     }
@@ -270,22 +280,22 @@ class CommCPU : public Comm {
    */
   void CopyRetainedRowsToGPU(mshadow::Stream<cpu>* cpu_stream,
                              mshadow::Stream<gpu>* gpu_stream,
-                             const NDArray& src,
-                             const NDArray& indices,
+                             const NDArray& src, const NDArray& indices,
                              NDArray* dst) {
 #if MXNET_USE_CUDA == 1
     CHECK_EQ(src.storage_type(), kRowSparseStorage)
-      << "CopyRetainedRowsToGPU expects row-sparse src NDArray";
+        << "CopyRetainedRowsToGPU expects row-sparse src NDArray";
     CHECK_EQ(src.ctx().dev_mask(), Context::kCPU)
-      << "CopyRetainedRowsToGPU with src on gpu context not supported";
+        << "CopyRetainedRowsToGPU with src on gpu context not supported";
     CHECK_EQ(src.storage_shape()[0], src.shape()[0])
-      << "CopyRetainedRowsToGPU only supports src rsp with full rows";
+        << "CopyRetainedRowsToGPU only supports src rsp with full rows";
     CHECK_EQ(indices.storage_type(), kDefaultStorage);
     CHECK_EQ(indices.ctx().dev_mask(), Context::kCPU);
     CHECK_EQ(dst->storage_type(), kRowSparseStorage);
     CHECK_EQ(dst->ctx().dev_mask(), Context::kGPU);
     CHECK_EQ(indices.dtype(), dst->aux_type(rowsparse::kIdx))
-      << "CopyRetainedRowsToGPU only supports same data type for idx array and dst aux_data(0)";
+        << "CopyRetainedRowsToGPU only supports same data type for idx array "
+           "and dst aux_data(0)";
     if (!src.storage_initialized() || indices.data().Size() == 0U) {
       op::FillZerosRspImpl(gpu_stream, *dst);
       return;
@@ -299,29 +309,33 @@ class CommCPU : public Comm {
     dst->CheckAndAlloc({Shape1(num_rows_retained)});
     TBlob dst_data = dst->data();
     TBlob dst_idx_data = dst->aux_data(rowsparse::kIdx);
-    MSHADOW_TYPE_SWITCH(src.dtype(), DType, {
-      MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, {
-        // copy idx array
-        Tensor<gpu, 1, IType> dst_idx_tensor = dst_idx_data.FlatTo1D<gpu, IType>(gpu_stream);
-        const Tensor<cpu, 1, IType> idx_tensor = idx_data.FlatTo1D<cpu, IType>(cpu_stream);
-        Copy(dst_idx_tensor, idx_tensor, gpu_stream);
-        // copy src data
-        const Tensor<cpu, 2, DType> src_data_tensor = src_data.get_with_shape<cpu, 2, DType>(
-            Shape2(src_data.shape_[0], row_length), cpu_stream);
-        Tensor<gpu, 2, DType> dst_data_tensor = dst_data.get_with_shape<gpu, 2, DType>(
-            Shape2(dst_data.shape_[0], row_length), gpu_stream);
-        for (size_t i = 0; i < num_rows_retained; ++i) {
-          Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]], gpu_stream);
-        }
-      })
-    })
+    MSHADOW_TYPE_SWITCH(
+        src.dtype(), DType, {MSHADOW_IDX_TYPE_SWITCH(indices.dtype(), IType, {
+          // copy idx array
+          Tensor<gpu, 1, IType> dst_idx_tensor =
+              dst_idx_data.FlatTo1D<gpu, IType>(gpu_stream);
+          const Tensor<cpu, 1, IType> idx_tensor =
+              idx_data.FlatTo1D<cpu, IType>(cpu_stream);
+          Copy(dst_idx_tensor, idx_tensor, gpu_stream);
+          // copy src data
+          const Tensor<cpu, 2, DType> src_data_tensor =
+              src_data.get_with_shape<cpu, 2, DType>(
+                  Shape2(src_data.shape_[0], row_length), cpu_stream);
+          Tensor<gpu, 2, DType> dst_data_tensor =
+              dst_data.get_with_shape<gpu, 2, DType>(
+                  Shape2(dst_data.shape_[0], row_length), gpu_stream);
+          for (size_t i = 0; i < num_rows_retained; ++i) {
+            Copy(dst_data_tensor[i], src_data_tensor[idx_tensor[i]],
+                 gpu_stream);
+          }
+        })})
 #else
     LOG(FATAL) << "GPU not enabled";
 #endif
   }
 
   // reduce sum into val[0]
-  inline void ReduceSumCPU(const std::vector<NDArray> &in_data) {
+  inline void ReduceSumCPU(const std::vector<NDArray>& in_data) {
     MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, {
       std::vector<DType*> dptr(in_data.size());
       for (size_t i = 0; i < in_data.size(); ++i) {
@@ -335,7 +349,8 @@ class CommCPU : public Comm {
   }
 
   // serial implementation of reduce sum for row sparse NDArray.
-  inline void ReduceSumCPUExSerial(const std::vector<NDArray> &in, NDArray *out) {
+  inline void ReduceSumCPUExSerial(const std::vector<NDArray>& in,
+                                   NDArray* out) {
     using namespace rowsparse;
     using namespace mshadow;
     auto stype = out->storage_type();
@@ -374,7 +389,8 @@ class CommCPU : public Comm {
         CHECK_EQ(indices.size(), total_num_rows);
         // dedup indices
         std::sort(indices.begin(), indices.end());
-        indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin());
+        indices.resize(std::unique(indices.begin(), indices.end()) -
+                       indices.begin());
         // the one left are unique non-zero rows
         size_t nnr = indices.size();
         // allocate memory for output
@@ -406,12 +422,12 @@ class CommCPU : public Comm {
     });
   }
 
-  template<typename DType>
-  inline static void ReduceSumCPU(
-      const std::vector<DType*> &dptr, size_t offset, index_t size) {
+  template <typename DType>
+  inline static void ReduceSumCPU(const std::vector<DType*>& dptr,
+                                  size_t offset, index_t size) {
     using namespace mshadow;  // NOLINT(*)
     Tensor<cpu, 1, DType> in_0(dptr[0] + offset, Shape1(size));
-    for (size_t i = 1; i < dptr.size(); i+=4) {
+    for (size_t i = 1; i < dptr.size(); i += 4) {
       switch (dptr.size() - i) {
         case 1: {
           Tensor<cpu, 1, DType> in_1(dptr[i] + offset, Shape1(size));
@@ -420,22 +436,22 @@ class CommCPU : public Comm {
         }
         case 2: {
           Tensor<cpu, 1, DType> in_1(dptr[i] + offset, Shape1(size));
-          Tensor<cpu, 1, DType> in_2(dptr[i+1] + offset, Shape1(size));
+          Tensor<cpu, 1, DType> in_2(dptr[i + 1] + offset, Shape1(size));
           in_0 += in_1 + in_2;
           break;
         }
         case 3: {
           Tensor<cpu, 1, DType> in_1(dptr[i] + offset, Shape1(size));
-          Tensor<cpu, 1, DType> in_2(dptr[i+1] + offset, Shape1(size));
-          Tensor<cpu, 1, DType> in_3(dptr[i+2] + offset, Shape1(size));
+          Tensor<cpu, 1, DType> in_2(dptr[i + 1] + offset, Shape1(size));
+          Tensor<cpu, 1, DType> in_3(dptr[i + 2] + offset, Shape1(size));
           in_0 += in_1 + in_2 + in_3;
           break;
         }
         default: {
           Tensor<cpu, 1, DType> in_1(dptr[i] + offset, Shape1(size));
-          Tensor<cpu, 1, DType> in_2(dptr[i+1] + offset, Shape1(size));
-          Tensor<cpu, 1, DType> in_3(dptr[i+2] + offset, Shape1(size));
-          Tensor<cpu, 1, DType> in_4(dptr[i+3] + offset, Shape1(size));
+          Tensor<cpu, 1, DType> in_2(dptr[i + 1] + offset, Shape1(size));
+          Tensor<cpu, 1, DType> in_3(dptr[i + 2] + offset, Shape1(size));
+          Tensor<cpu, 1, DType> in_4(dptr[i + 3] + offset, Shape1(size));
           in_0 += in_1 + in_2 + in_3 + in_4;
           break;
         }
@@ -443,15 +459,15 @@ class CommCPU : public Comm {
     }
   }
 
-  template<typename DType>
+  template <typename DType>
   inline void ReduceSumCPUImpl(std::vector<DType*> dptr, size_t total) {
     const size_t step = std::min(bigarray_bound_, static_cast<size_t>(4 << 10));
-    long ntask = (total + step - 1) / step; // NOLINT(*)
+    long ntask = (total + step - 1) / step;  // NOLINT(*)
     if (total < bigarray_bound_ || nthread_reduction_ <= 1) {
       ReduceSumCPU(dptr, 0, total);
     } else {
-      #pragma omp parallel for schedule(static) num_threads(nthread_reduction_)
-      for (long j = 0; j < ntask; ++j) { // NOLINT(*)
+#pragma omp parallel for schedule(static) num_threads(nthread_reduction_)
+      for (long j = 0; j < ntask; ++j) {  // NOLINT(*)
         size_t k = static_cast<size_t>(j);
         size_t begin = std::min(k * step, total);
         size_t end = std::min((k + 1) * step, total);
@@ -484,11 +500,9 @@ class CommCPU : public Comm {
  */
 class CommDevice : public Comm {
  public:
-  CommDevice() {
-    inited_ = false;
-  }
+  CommDevice() { inited_ = false; }
 
-  virtual ~CommDevice() { }
+  virtual ~CommDevice() {}
 
   void Init(int key, const NDArrayStorageType stype, const TShape& shape,
             int dtype = mshadow::kFloat32) override {
@@ -523,95 +537,236 @@ class CommDevice : public Comm {
     }
 
     InitBuffersAndComm(src);
-    auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-
-    const NDArrayStorageType stype = buf.merged.storage_type();
+    // merge buffer holds the first group of gpus
+    BufferEntry& buf = merge_buf_[key];
+    // stage buffer holds the data of the second group  or the first when merge
+    // buffer is empty
+    BufferEntry& stage = stage_buf_[key];
+    std::vector<NDArray> reduce_s;
+
+    const NDArrayStorageType stype = stage.merged.storage_type();
     if (stype == kDefaultStorage) {
-      CopyFromTo(src[0], &(buf.merged), priority);
-      reduce[0] = buf.merged;
-
-      if (buf.copy_buf.empty()) {
-        // TODO(mli) this results in large device memory usage for huge ndarray,
-        // such as the largest fullc in VGG. consider to do segment reduce with
-        // NDArray.Slice or gpu direct memory access. for the latter, we need to
-        // remove some ctx check, and also it reduces 20% perf
-        buf.copy_buf.resize(src.size()-1);
-        for (size_t i = 0; i < src.size()-1; ++i) {
-          buf.copy_buf[i] = NDArray(
-            buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
-        }
+      if (buf.merged.is_none() && stage.copy_buf.empty()) {
+        stage.copy_buf.resize(src.size() - 1);
+        for (size_t i = 0; i < src.size() - 1; ++i)
+          stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                      false, stage.merged.dtype());
       }
-      for (size_t i = 0; i < src.size()-1; ++i) {
-        CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
-        reduce[i+1] = buf.copy_buf[i];
+      reduce_s.resize(stage.copy_buf.size() + 1);
+      for (size_t i = 0, j = 0; i < src.size(); ++i) {
+        int id = src[i].ctx().dev_id;
+        if ((!buf.merged.is_none() && id == stage.merged.ctx().dev_id) ||
+            (buf.merged.is_none() && i == 0)) {
+          CopyFromTo(src[i], &stage.merged, priority);
+          reduce_s[0] = stage.merged;
+        } else if (id >= 4 || buf.merged.is_none()) {
+          CopyFromTo(src[i], &(stage.copy_buf[j]), priority);
+          reduce_s[j + 1] = stage.copy_buf[j];
+          j++;
+        }
       }
     } else {
-      if (buf.copy_buf.empty()) {
-        buf.copy_buf.resize(src.size());
-        for (size_t j = 0; j < src.size(); ++j) {
-          buf.copy_buf[j] = NDArray(
-            buf.merged.storage_type(), buf.merged.shape(), buf.merged.ctx(),
-            true, buf.merged.dtype());
+      if (buf.merged.is_none() && stage.copy_buf.empty()) {
+        stage.copy_buf.resize(src.size());
+        for (size_t j = 0; j < src.size(); ++j)
+          stage.copy_buf[j] =
+              NDArray(stage.merged.storage_type(), stage.merged.shape(),
+                      stage.merged.ctx(), true, stage.merged.dtype());
+      }
+      reduce_s.resize(stage.copy_buf.size());
+      for (size_t i = 0, j = 0; i < src.size(); ++i) {
+        int id = src[i].ctx().dev_id;
+        if (id >= 4 || buf.merged.is_none()) {
+          CopyFromTo(src[i], &(stage.copy_buf[j]), priority);
+          reduce_s[j] = stage.copy_buf[j];
+          j++;
         }
       }
-      for (size_t i = 0; i < src.size(); ++i) {
-        CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
-        reduce[i] = buf.copy_buf[i];
+    }
+    // Reducing either the second group of data or the second when merge buffer
+    // is empty
+    ElementwiseSum(reduce_s, &stage.merged, priority);
+    // Main reduce result on the first group of GPUs including the partial
+    // result from the second group
+    if (!buf.merged.is_none()) {
+      const NDArrayStorageType sstype = buf.merged.storage_type();
+      std::vector<NDArray> reduce;
+      if (sstype == kDefaultStorage) {
+        reduce.resize(buf.copy_buf.size() + 1);
+        for (size_t i = 0, j = 0; i < src.size(); ++i) {
+          int id = src[i].ctx().dev_id;
+          if (id == buf.merged.ctx().dev_id) {
+            reduce[0] = src[i];
+          } else if (id < 4) {
+            CopyFromTo(src[i], &(buf.copy_buf[j]), priority);
+            reduce[j + 1] = buf.copy_buf[j];
+            j++;
+          }
+        }
+      } else {
+        reduce.resize(buf.copy_buf.size());
+        for (size_t i = 0, j = 0; i < src.size(); ++i) {
+          int id = src[i].ctx().dev_id;
+          if (id < 4) {
+            CopyFromTo(src[i], &(buf.copy_buf[j]), priority);
+            reduce[j] = buf.copy_buf[j];
+            j++;
+          }
+        }
       }
+      // Copy the second group's reducing result to merge buffer
+      CopyFromTo(stage.merged, &(buf.copy_buf[buf.copy_buf.size() - 1]),
+                 priority);
+      reduce[reduce.size() - 1] = buf.copy_buf[buf.copy_buf.size() - 1];
+      ElementwiseSum(reduce, &buf.merged);
+    } else {
+      return stage.merged;
     }
-    ElementwiseSum(reduce, &buf.merged, priority);
+
     return buf.merged;
   }
 
   const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
                                   int priority) {
     InitBuffersAndComm(src);
-    auto& buf = merge_buf_[key];
-    std::vector<NDArray> reduce(src.size());
-    if (buf.copy_buf.empty()) {
+    BufferEntry& buf = merge_buf_[key];
+    BufferEntry& stage = stage_buf_[key];
+    if (buf.merged.is_none() && stage.copy_buf.empty()) {
       // one buf for each context
-      buf.copy_buf.resize(src.size());
-      buf.compressed_recv_buf.resize(src.size());
-      buf.compressed_send_buf.resize(src.size());
-      buf.residual.resize(src.size());
+      stage.copy_buf.resize(src.size());
+      stage.compressed_recv_buf.resize(src.size());
+      stage.compressed_send_buf.resize(src.size());
+      stage.residual.resize(src.size());
 
       for (size_t i = 0; i < src.size(); ++i) {
-        buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
-                                  false, buf.merged.dtype());
-        buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(),
-                                  false, buf.merged.dtype());
-        buf.residual[i] = 0;
+        stage.copy_buf[i] = NDArray(stage.merged.shape(), stage.merged.ctx(),
+                                    false, stage.merged.dtype());
+        stage.residual[i] = NDArray(stage.merged.shape(), src[i].ctx(), false,
+                                    stage.merged.dtype());
+        stage.residual[i] = 0;
+        int64_t small_size =
+            gc_->GetCompressedSize(stage.merged.shape().Size());
+        stage.compressed_recv_buf[i] =
+            NDArray(TShape{small_size}, stage.merged.ctx(), false,
+                    stage.merged.dtype());
+        stage.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(),
+                                               false, stage.merged.dtype());
+      }
+    } else if (!buf.merged.is_none()) {
+      if (buf.copy_buf.empty() && stage.copy_buf.empty()) {
+        buf.copy_buf.resize(group1.size() + 1);
+        buf.compressed_recv_buf.resize(group1.size() + 1);
+        buf.compressed_send_buf.resize(group1.size() + 1);
+        buf.residual.resize(group1.size());
+        stage.copy_buf.resize(group2.size());
+        stage.compressed_recv_buf.resize(group2.size());
+        stage.compressed_send_buf.resize(group2.size());
+        stage.residual.resize(group2.size());
+        for (size_t i = 0, j = 0, k = 0; i < src.size(); ++i) {
+          int id = src[i].ctx().dev_id;
+          if (id < NVLINK_SUPPORT) {
+            buf.copy_buf[j] = NDArray(buf.merged.shape(), buf.merged.ctx(),
+                                      false, buf.merged.dtype());
+            buf.residual[j] = NDArray(buf.merged.shape(), src[i].ctx(), false,
+                                      buf.merged.dtype());
+            buf.residual[j] = 0;
+            int64_t small_size =
+                gc_->GetCompressedSize(buf.merged.shape().Size());
+            buf.compressed_recv_buf[j] =
+                NDArray(TShape{small_size}, buf.merged.ctx(), false,
+                        buf.merged.dtype());
+            buf.compressed_send_buf[j] = NDArray(
+                TShape{small_size}, src[i].ctx(), false, buf.merged.dtype());
+            j++;
+          } else {
+            stage.copy_buf[k] =
+                NDArray(stage.merged.shape(), stage.merged.ctx(), false,
+                        stage.merged.dtype());
+            stage.residual[k] = NDArray(stage.merged.shape(), src[i].ctx(),
+                                        false, stage.merged.dtype());
+            stage.residual[k] = 0;
+            int64_t small_size =
+                gc_->GetCompressedSize(stage.merged.shape().Size());
+            stage.compressed_recv_buf[k] =
+                NDArray(TShape{small_size}, stage.merged.ctx(), false,
+                        stage.merged.dtype());
+            stage.compressed_send_buf[k] = NDArray(
+                TShape{small_size}, src[i].ctx(), false, stage.merged.dtype());
+            k++;
+          }
+        }
+        buf.copy_buf[group1.size()] = NDArray(
+            buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype());
         int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size());
-        buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(),
-                                        false, buf.merged.dtype());
-        buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(),
-                                        false, buf.merged.dtype());
+        buf.compressed_recv_buf[group1.size()] = NDArray(
+            TShape{small_size}, buf.merged.ctx(), false, buf.merged.dtype());
+        buf.compressed_send_buf[group1.size()] = NDArray(
+            TShape{small_size}, stage.merged.ctx(), false, buf.merged.dtype());
       }
     }
+    std::vector<NDArray> reduce_s(stage.copy_buf.size());
+    std::vector<NDArray> reduce(buf.copy_buf.size());
+
+    for (size_t i = 0, j = 0, k = 0; i < src.size(); ++i) {
+      int id = src[i].ctx().dev_id;
+      if (id >= NVLINK_SUPPORT || buf.merged.is_none()) {
+        // compress before copy
+        // this is done even if the data is on same context as copy_buf because
+        // we don't want the training to be biased towards data on this GPU
+        gc_->Quantize(src[i], &(stage.compressed_send_buf[j]),
+                      &(stage.residual[j]), priority);
+
+        if (stage.compressed_send_buf[j].ctx() !=
+            stage.compressed_recv_buf[j].ctx()) {
+          CopyFromTo(stage.compressed_send_buf[j],
+                     &(stage.compressed_recv_buf[j]), priority);
+        } else {
+          // avoid memory copy when they are on same context
+          stage.compressed_recv_buf[j] = stage.compressed_send_buf[j];
+        }
 
-    for (size_t i = 0; i < src.size(); ++i) {
-      // compress before copy
-      // this is done even if the data is on same context as copy_buf because
-      // we don't want the training to be biased towards data on this GPU
-      gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority);
-
-      if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) {
-        CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority);
+        gc_->Dequantize(stage.compressed_recv_buf[j], &(stage.copy_buf[j]),
+                        priority);
+        reduce_s[j] = stage.copy_buf[j];
+        j++;
       } else {
-        // avoid memory copy when they are on same context
-        buf.compressed_recv_buf[i] = buf.compressed_send_buf[i];
-      }
+        gc_->Quantize(src[i], &(buf.compressed_send_buf[k]), &(buf.residual[k]),
+                      priority);
+
+        if (buf.compressed_send_buf[k].ctx() !=
+            buf.compressed_recv_buf[k].ctx()) {
+          CopyFromTo(buf.compressed_send_buf[k], &(buf.compressed_recv_buf[k]),
+                     priority);
+        } else {
+          // avoid memory copy when they are on same context
+          buf.compressed_recv_buf[k] = buf.compressed_send_buf[k];
+        }
 
-      gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority);
-      reduce[i] = buf.copy_buf[i];
+        gc_->Dequantize(buf.compressed_recv_buf[k], &(buf.copy_buf[k]),
+                        priority);
+        reduce[k] = buf.copy_buf[k];
+        k++;
+      }
+    }
+    ElementwiseSum(reduce_s, &stage.merged);
+    if (buf.merged.is_none()) {
+      return stage.merged;
+    } else {
+      gc_->Quantize(stage.merged, &buf.compressed_send_buf[group1.size()],
+                    &(buf.residual[buf.residual.size() - 1]), priority);
+      CopyFromTo(buf.compressed_send_buf[group1.size()],
+                 &(buf.compressed_recv_buf[group1.size()]), priority);
+      gc_->Dequantize(buf.compressed_recv_buf[group1.size()],
+                      &(buf.copy_buf[group1.size()]), priority);
+      reduce[reduce.size() - 1] = buf.copy_buf[group1.size()];
+      ElementwiseSum(reduce, &buf.merged);
     }
-    ElementwiseSum(reduce, &buf.merged);
+
     return buf.merged;
   }
 
-  void Broadcast(int key, const NDArray& src,
-                 const std::vector<NDArray*> dst, int priority) override {
+  void Broadcast(int key, const NDArray& src, const std::vector<NDArray*> dst,
+                 int priority) override {
     if (!inited_) {
       // copy to a random device first
       int dev_id = key % dst.size();
@@ -622,20 +777,24 @@ class CommDevice : public Comm {
         }
       }
     } else {
-      auto& buf = merge_buf_[key];
-      CopyFromTo(src, &buf.merged, priority);
+      BufferEntry& buf = merge_buf_[key];
+      BufferEntry& stage = stage_buf_[key];
+      if (!buf.merged.is_none()) CopyFromTo(src, &buf.merged, priority);
+      CopyFromTo(src, &stage.merged, priority);
       for (auto d : dst) {
-        CopyFromTo(buf.merged, d, priority);
+        if (d->ctx().dev_id >= NVLINK_SUPPORT || buf.merged.is_none())
+          CopyFromTo(stage.merged, d, priority);
+        else
+          CopyFromTo(buf.merged, d, priority);
       }
     }
   }
 
   void BroadcastRowSparse(int key, const NDArray& src,
                           const std::vector<std::pair<NDArray*, NDArray>>& dst,
-                          const bool use_copy,
-                          const int priority) override {
+                          const bool use_copy, const int priority) override {
     CHECK_EQ(src.storage_type(), kRowSparseStorage)
-      << "BroadcastRowSparse expects row-sparse src NDArray";
+        << "BroadcastRowSparse expects row-sparse src NDArray";
 
     for (size_t i = 0; i < dst.size(); ++i) {
       NDArray* out = dst[i].first;
@@ -644,38 +803,44 @@ class CommDevice : public Comm {
         CopyFromTo(src, out, priority);
       } else {
         CHECK_EQ(out->storage_type(), kRowSparseStorage)
-                 << "BroadcastRowSparse expects row_sparse dst NDArray";
+            << "BroadcastRowSparse expects row_sparse dst NDArray";
 
         const bool is_diff_ctx = out->ctx() != src.ctx();
-        NDArray out_gpu = is_diff_ctx? NDArray(kRowSparseStorage, out->shape(),
-            src.ctx(), true, out->dtype(), out->aux_types()) : *out;
+        NDArray out_gpu =
+            is_diff_ctx ? NDArray(kRowSparseStorage, out->shape(), src.ctx(),
+                                  true, out->dtype(), out->aux_types())
+                        : *out;
 
         CHECK_EQ(row_id.ctx(), src.ctx())
-                << "row_id and src are expected to be on the same context";
-
-        Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
-            NDArray temp = out_gpu;
-            const TBlob& indices = row_id.data();
-            switch (temp.ctx().dev_mask()) {
-              case cpu::kDevMask: {
-                mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(rctx.get_stream<cpu>(),
-                    src, indices, kWriteTo, &temp);
-                break;
-              }
+            << "row_id and src are expected to be on the same context";
+
+        Engine::Get()->PushAsync(
+            [=](RunContext rctx, Engine::CallbackOnComplete on_complete) {
+              NDArray temp = out_gpu;
+              const TBlob& indices = row_id.data();
+              switch (temp.ctx().dev_mask()) {
+                case cpu::kDevMask: {
+                  mxnet::common::SparseRetainOpForwardRspWrapper<cpu>(
+                      rctx.get_stream<cpu>(), src, indices, kWriteTo, &temp);
+                  break;
+                }
 #if MXNET_USE_CUDA
-              case gpu::kDevMask: {
-                mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(rctx.get_stream<gpu>(),
-                    src, indices, kWriteTo, &temp);
-                // wait for GPU operations to complete
-                rctx.get_stream<gpu>()->Wait();
-                break;
-              }
+                case gpu::kDevMask: {
+                  mxnet::common::SparseRetainOpForwardRspWrapper<gpu>(
+                      rctx.get_stream<gpu>(), src, indices, kWriteTo, &temp);
+                  // wait for GPU operations to complete
+                  rctx.get_stream<gpu>()->Wait();
+                  break;
+                }
 #endif
-              default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
-            }
-            on_complete();
-          }, out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
-        FnProperty::kNormal, priority, PROFILER_MESSAGE("KVStoreSparseRetain"));
+                default:
+                  LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+              }
+              on_complete();
+            },
+            out_gpu.ctx(), {src.var(), row_id.var()}, {out_gpu.var()},
+            FnProperty::kNormal, priority,
+            PROFILER_MESSAGE("KVStoreSparseRetain"));
         if (is_diff_ctx) {
           CopyFromTo(out_gpu, out, priority);
         }
@@ -694,7 +859,7 @@ class CommDevice : public Comm {
     }
     int n = static_cast<int>(gpus.size());
     int enabled = 0;
-    std::vector<int> p2p(n*n);
+    std::vector<int> p2p(n * n);
     for (int i = 0; i < n; ++i) {
       cudaSetDevice(gpus[i]);
       for (int j = 0; j < n; j++) {
@@ -704,21 +869,21 @@ class CommDevice : public Comm {
           cudaError_t e = cudaDeviceEnablePeerAccess(gpus[j], 0);
           if (e == cudaSuccess || e == cudaErrorPeerAccessAlreadyEnabled) {
             ++enabled;
-            p2p[i*n+j] = 1;
+            p2p[i * n + j] = 1;
           }
         }
       }
     }
-    if (enabled != n*(n-1)) {
+    if (enabled != n * (n - 1)) {
       // print warning info if not fully enabled
-      LOG(WARNING) << "only " << enabled <<  " out of "
-                   << n*(n-1) << " GPU pairs are enabled direct access. "
+      LOG(WARNING) << "only " << enabled << " out of " << n * (n - 1)
+                   << " GPU pairs are enabled direct access. "
                    << "It may affect the performance. "
                    << "You can set MXNET_ENABLE_GPU_P2P=0 to turn it off";
       std::string access(n, '.');
       for (int i = 0; i < n; ++i) {
         for (int j = 0; j < n; ++j) {
-          access[j] = p2p[i*n+j] ? 'v' : '.';
+          access[j] = p2p[i * n + j] ? 'v' : '.';
         }
         LOG(WARNING) << access;
       }
@@ -729,40 +894,109 @@ class CommDevice : public Comm {
   using KeyAttrs = std::tuple<int, TShape, int, NDArrayStorageType>;
   // try to allocate buff on device evenly
   void InitMergeBuffer(const std::vector<Context>& devs) {
-    std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
-              const KeyAttrs& a, const KeyAttrs& b) {
-      return std::get<1>(a).Size() > std::get<1>(b).Size();
-    });
-
-    std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
-    for (auto d : devs) {
-      ctx_info[d.dev_id] = std::make_pair(d, 0);
+    std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(),
+              [](const KeyAttrs& a, const KeyAttrs& b) {
+                return std::get<1>(a).Size() > std::get<1>(b).Size();
+              });
+
+    for (auto& d : devs) {
+      if (d.dev_id < NVLINK_SUPPORT)
+        group1.push_back(d);
+      else
+        group2.push_back(d);
     }
-    for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
-      const int key  = std::get<0>(sorted_key_attrs_[i]);
-      const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
-      const int type = std::get<2>(sorted_key_attrs_[i]);
-      const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
-      auto& buf = merge_buf_[key];
-      Context ctx;
-      size_t min_size = std::numeric_limits<size_t>::max();
-      for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
-        size_t size = it->second.second;
-        if (size <= min_size) {
-          ctx = it->second.first;
-          min_size = size;
+    if (group1.empty() || group2.empty()) {
+      // all gpus are all connected by NVLinks: use all-to-all
+      std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
+      for (auto d : devs) {
+        ctx_info[d.dev_id] = std::make_pair(d, 0);
+      }
+      for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
+        const int key = std::get<0>(sorted_key_attrs_[i]);
+        const TShape shape = std::get<1>(sorted_key_attrs_[i]);
+        const int type = std::get<2>(sorted_key_attrs_[i]);
+        const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
+        BufferEntry& stage = stage_buf_[key];
+        Context ctx;
+        size_t min_size = std::numeric_limits<size_t>::max();
+        for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
+          size_t size = it->second.second;
+          if (size <= min_size) {
+            ctx = it->second.first;
+            min_size = size;
+          }
+        }
+        if (stype == kDefaultStorage) {
+          stage.merged = NDArray(shape, ctx, false, type);
+        } else {
+          stage.merged = NDArray(stype, shape, ctx, true, type);
         }
+        ctx_info[ctx.dev_id].second += shape.Size();
       }
-      if (stype == kDefaultStorage) {
-        buf.merged = NDArray(shape, ctx, false, type);
-      } else {
-        buf.merged = NDArray(stype, shape, ctx, true, type);
+    } else {
+      // QPI connections are included: use spanning tree
+      size_t gpu0, gpu1;
+      // gpu0 and gpu1 hold the gpu indexes connected by nvlink between group1
+      // and group2 groups accordingly
+      for (gpu0 = 0, gpu1 = 0; gpu0 < group1.size() && gpu1 < group2.size();) {
+        if (group2[gpu1].dev_id - group1[gpu0].dev_id == NVLINK_SUPPORT)
+          break;
+        else if (group2[gpu1].dev_id - group1[gpu0].dev_id > NVLINK_SUPPORT)
+          gpu0++;
+        else
+          gpu1++;
+      }
+      if (gpu0 == group1.size() || gpu1 == group2.size()) gpu0 = gpu1 = 0;
+      for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
+        const int key = std::get<0>(sorted_key_attrs_[i]);
+        const TShape shape = std::get<1>(sorted_key_attrs_[i]);
+        const int type = std::get<2>(sorted_key_attrs_[i]);
+        const NDArrayStorageType stype = std::get<3>(sorted_key_attrs_[i]);
+        BufferEntry& buf = merge_buf_[key];
+        BufferEntry& stage = stage_buf_[key];
+        if (stype == kDefaultStorage) {
+          buf.merged = NDArray(shape, group1[gpu0], false, type);
+          if (buf.copy_buf.empty()) {
+            buf.copy_buf.resize(group1.size());
+            for (size_t i = 0; i < group1.size(); ++i)
+              buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
+                                        false, buf.merged.dtype());
+          }
+
+          stage.merged = NDArray(shape, group2[gpu1], false, type);
+          if (stage.copy_buf.empty()) {
+            stage.copy_buf.resize(group2.size() - 1);
+            for (size_t i = 0; i < group2.size() - 1; ++i)
+              stage.copy_buf[i] =
+                  NDArray(stage.merged.shape(), stage.merged.ctx(), false,
+                          stage.merged.dtype());
+          }
+        } else {
+          buf.merged = NDArray(stype, shape, group1[gpu0], true, type);
+          if (buf.copy_buf.empty()) {
+            buf.copy_buf.resize(group1.size() + 1);
+            for (size_t i = 0; i < group1.size() + 1; ++i)
+              buf.copy_buf[i] =
+                  NDArray(stype, buf.merged.shape(), buf.merged.ctx(), true,
+                          buf.merged.dtype());
+          }
+
+          stage.merged = NDArray(stype, shape, group2[gpu1], true, type);
+          if (stage.copy_buf.empty()) {
+            stage.copy_buf.resize(group2.size());
+            for (size_t i = 0; i < group2.size(); ++i)
+              stage.copy_buf[i] =
+                  NDArray(stype, stage.merged.shape(), stage.merged.ctx(), true,
+                          stage.merged.dtype());
+          }
+        }
       }
-      ctx_info[ctx.dev_id].second += shape.Size();
     }
     inited_ = true;
   }
 
+  /// \brief the NVLinked connected gpu groups
+  std::vector<Context> group1, group2;
   std::vector<KeyAttrs> sorted_key_attrs_;
   /// \brief temporal space for pushing and pulling
   struct BufferEntry {
@@ -778,6 +1012,8 @@ class CommDevice : public Comm {
     std::vector<NDArray> compressed_recv_buf;
   };
   std::unordered_map<int, BufferEntry> merge_buf_;
+  /// \brief the small buffer for partially merged data
+  std::unordered_map<int, BufferEntry> stage_buf_;
   bool inited_;
 };
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services