You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/07 02:03:27 UTC

[GitHub] threeleafzerg closed pull request #10696: [MXNET-366]Extend MXNet Distributed Training by AllReduce

threeleafzerg closed pull request #10696: [MXNET-366]Extend MXNet Distributed Training by AllReduce
URL: https://github.com/apache/incubator-mxnet/pull/10696
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitmodules b/.gitmodules
index 9aeb1c75498..07b873a629c 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -6,7 +6,7 @@
 	url = https://github.com/dmlc/dmlc-core.git
 [submodule "3rdparty/ps-lite"]
 	path = 3rdparty/ps-lite
-	url = https://github.com/dmlc/ps-lite
+	url = https://github.com/dmlc/ps-lite.git
 [submodule "3rdparty/dlpack"]
 	path = 3rdparty/dlpack
 	url = https://github.com/dmlc/dlpack
diff --git a/3rdparty/ps-lite b/3rdparty/ps-lite
index 8a763892a97..f45e2e78a74 160000
--- a/3rdparty/ps-lite
+++ b/3rdparty/ps-lite
@@ -1 +1 @@
-Subproject commit 8a763892a973afc1acd3d4b469d05bb338a83a6e
+Subproject commit f45e2e78a7430be09f76264d2f4073fb2b1d54a2
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 483108a6841..296b5287125 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -26,6 +26,7 @@ mxnet_option(USE_GPERFTOOLS       "Build with GPerfTools support (if found)" ON)
 mxnet_option(USE_JEMALLOC         "Build with Jemalloc support"   ON)
 mxnet_option(USE_PROFILER         "Build with Profiler support"   ON)
 mxnet_option(USE_DIST_KVSTORE     "Build with DIST_KVSTORE support" OFF)
+mxnet_option(USE_ALLREDUCE_DIST_KVSTORE     "Build with ALLREDUCE_DIST_KVSTORE support" OFF)
 mxnet_option(USE_PLUGINS_WARPCTC  "Use WARPCTC Plugins" OFF)
 mxnet_option(USE_PLUGIN_CAFFE     "Use Caffe Plugin" OFF)
 mxnet_option(USE_CPP_PACKAGE      "Build C++ Package" OFF)
@@ -638,6 +639,24 @@ if(USE_DIST_KVSTORE)
   add_definitions(-DMXNET_USE_DIST_KVSTORE)
   include_directories(SYSTEM ${pslite_INCLUDE_DIR})
   list(APPEND mxnet_LINKER_LIBS ${pslite_LINKER_LIBS})
+  if(USE_ALLREDUCE_DIST_KVSTORE)
+    include(cmake/AllReduce.cmake)
+    SET(MPI_ROOT "" CACHE PATH "MPI PATH which contain lib and header file")
+    if(MPI_ROOT STREQUAL "")
+      message(ERROR "Need to specify MPI_ROOT")
+    endif()
+    set(mpi_LINKER_LIB mpi)
+    add_definitions(-DMXNET_USE_ALLREDUCE_DIST_KVSTORE)
+    include_directories(SYSTEM "${MPI_ROOT}/include")
+    link_directories(${MPI_ROOT}/build/lib)
+    list(APPEND mxnet_LINKER_LIBS ${mpi_LINKER_LIB})
+    ## Generate proto files
+    set(allreduce_src "src/kvstore/collectives/src")
+    file(GLOB_RECURSE proto_files "${allreduce_src}/*.proto")
+    allreduce_protobuf_generate_cpp(${allreduce_src} proto_srcs proto_hdrs ${allreduce_src} ${allreduce_src} ${proto_files})
+    include_directories(SYSTEM ${allreduce_src})
+    list(APPEND SOURCE ${proto_srcs})
+  endif()
 endif()
 
 target_link_libraries(mxnet PUBLIC ${mxnet_LINKER_LIBS})
diff --git a/Makefile b/Makefile
index 2cf1ed1ece2..a4a8eb2288e 100644
--- a/Makefile
+++ b/Makefile
@@ -351,12 +351,37 @@ ifeq ($(USE_DIST_KVSTORE), 1)
 	LDFLAGS += $(PS_LDFLAGS_A)
 endif
 
+# for kvstore with type dist_sync_mpi
+PROTOBUF_DIR=$(ROOTDIR)/deps
+PROTOC=$(PROTOBUF_DIR)/bin/protoc
+COLL_PATH=$(ROOTDIR)/src/kvstore/collectives
+PROTO_GEN_FILE=
+DEF_MPI_PATH=$(ROOTDIR)/3rdparty/mpich
+ifeq ($(USE_DIST_KVSTORE), 1)
+ifeq ($(USE_ALLREDUCE_DIST_KVSTORE), 1)
+PROTO_GEN_FILE=src/kvstore/collectives/src/mpi_message.pb.cc src/kvstore/collectives/src/mpi_message.pb.h
+	ifeq ($(MPI_ROOT),)
+  	# Default mpi
+		MPI_ROOT := $(shell ./prepare_mpi.sh $(DEF_MPI_PATH))
+	endif
+ CFLAGS += -DMXNET_USE_ALLREDUCE_DIST_KVSTORE=1 -I$(MPI_ROOT)/include -I$(PROTOBUF_DIR)/include -I$(COLL_PATH)/include -I$(COLL_PATH)/src
+ LDFLAGS += -L$(MPI_ROOT)/lib -Wl,-rpath=$(MPI_ROOT)/lib -lmpi
+ LDFLAGS += $(PROTOBUF_DIR)/lib/libprotobuf.a
+endif
+endif
+
 .PHONY: clean all extra-packages test lint docs clean_all rcpplint rcppexport roxygen\
 	cython2 cython3 cython cyclean
 
 all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages
 
-SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
+ALLREDUCE_SRC = $(wildcard src/kvstore/collectives/src/*.cc)
+ALLREDUCE_SRC += $(PROTO_GEN_FILE)
+ALLREDUCE_OBJ = $(patsubst %.cc, build/%.o, $(ALLREDUCE_SRC))
+
+SRC_FILTER = $(ALLREDUCE_SRC)
+ORIGSRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
+SRC =	$(filter-out $(SRC_FILTER), $(ORIGSRC))
 OBJ = $(patsubst %.cc, build/%.o, $(SRC))
 CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
 CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC))
@@ -394,6 +419,7 @@ else
 	endif
 endif
 
+
 # all dep
 LIB_DEP += $(DMLC_CORE)/libdmlc.a $(NNVM_PATH)/lib/libnnvm.a
 ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(PLUGIN_OBJ) $(LIB_DEP)
@@ -436,10 +462,17 @@ else
 	CFLAGS += -DMXNET_USE_LIBJPEG_TURBO=0
 endif
 
+ifeq ($(USE_DIST_KVSTORE), 1)
+ifeq ($(USE_ALLREDUCE_DIST_KVSTORE), 1)
+ ALL_DEP += $(ALLREDUCE_OBJ)
+endif
+endif
+
 # For quick compile test, used smaller subset
 ALLX_DEP= $(ALL_DEP)
 
-build/src/%.o: src/%.cc | mkldnn
+
+build/src/%.o: src/%.cc $(PROTO_GEN_FILE) | mkldnn
 	@mkdir -p $(@D)
 	$(CXX) -std=c++11 -c $(CFLAGS) -MMD -c $< -o $@
 
@@ -491,6 +524,9 @@ $(PS_PATH)/build/libps.a: PSLITE
 PSLITE:
 	$(MAKE) CXX="$(CXX)" DEPS_PATH="$(DEPS_PATH)" -C $(PS_PATH) ps
 
+$(PROTO_GEN_FILE): PSLITE
+	$(PROTOC) --cpp_out=$(COLL_PATH)/src --proto_path=$(COLL_PATH)/src $(COLL_PATH)/src/mpi_message.proto
+
 $(DMLC_CORE)/libdmlc.a: DMLCCORE
 
 DMLCCORE:
@@ -624,6 +660,8 @@ clean: cyclean $(EXTRA_PACKAGES_CLEAN)
 	cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
 	$(RM) -r  $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS))
 	$(RM) -r  $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS))
+	$(RM) $(COLL_PATH)/src/mpi_message.pb.*
+	$(RM) -r $(DEF_MPI_PATH)
 else
 clean: mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN)
 	$(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ R-package/NAMESPACE R-package/man R-package/R/mxnet_generated.R \
@@ -632,6 +670,8 @@ clean: mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN)
 	cd $(PS_PATH); $(MAKE) clean; cd -
 	cd $(NNVM_PATH); $(MAKE) clean; cd -
 	cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
+	$(RM) $(COLL_PATH)/src/mpi_message.pb.*
+	$(RM) -r $(DEF_MPI_PATH)
 endif
 
 clean_all: clean
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 4f0b1464742..96b75b844b4 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -521,6 +521,20 @@ build_ubuntu_gpu_cmake() {
     report_ccache_usage
 }
 
+build_ubuntu_cpu_allreduce_kvstore() {
+    set -ex
+
+    build_ccache_wrappers
+
+    make  \
+        DEV=1                         \
+        USE_BLAS=openblas             \
+        USE_DIST_KVSTORE=1            \
+        USE_ALLREDUCE_DIST_KVSTORE=1  \
+        -j$(nproc)
+
+    report_ccache_usage
+}
 
 # Testing
 
diff --git a/cmake/AllReduce.cmake b/cmake/AllReduce.cmake
new file mode 100644
index 00000000000..beb72ed4302
--- /dev/null
+++ b/cmake/AllReduce.cmake
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Finds Google Protocol Buffers library and compilers and extends
+# the standard cmake script with version and python generation support
+
+find_package( Protobuf REQUIRED )
+include_directories(SYSTEM ${PROTOBUF_INCLUDE_DIR})
+
+# Make sure protoc version is greater than 3.0.0
+if(EXISTS ${PROTOBUF_PROTOC_EXECUTABLE})
+  message(STATUS "Found PROTOBUF Compiler: ${PROTOBUF_PROTOC_EXECUTABLE}")
+else()
+  message(FATAL_ERROR "Could not find PROTOBUF Compiler")
+endif()
+
+set(PROTOBUF_GENERATE_CPP_APPEND_PATH TRUE)
+
+################################################################################################
+# Usage:
+#   allreduce_protobuf_generate_cpp(<output_dir> <srcs_var> <hdrs_var> <work_path> <proto_files>)
+################################################################################################
+function(allreduce_protobuf_generate_cpp output_dir srcs_var hdrs_var work_path proto_path)
+  if(NOT ARGN)
+    message(SEND_ERROR "Error: allreduce_protobuf_generate_cpp() called without any proto files")
+    return()
+  endif()
+
+  if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
+    # Create an include path for each file specified
+    foreach(fil ${ARGN})
+      get_filename_component(abs_fil ${fil} ABSOLUTE)
+      get_filename_component(abs_path ${abs_fil} PATH)
+      list(FIND _protoc_include ${abs_path} _contains_already)
+      if(${_contains_already} EQUAL -1)
+        list(APPEND _protoc_include -I ${abs_path})
+      endif()
+    endforeach()
+  else()
+    set(_protoc_include -I ${CMAKE_CURRENT_SOURCE_DIR})
+  endif()
+
+  set(${srcs_var})
+  set(${hdrs_var})
+  foreach(fil ${ARGN})
+    get_filename_component(abs_fil ${fil} ABSOLUTE)
+    get_filename_component(fil_we ${fil} NAME_WE)
+	  string(REPLACE ${work_path}/ "" o_fil ${abs_fil})
+	  string(REPLACE "${fil_we}.proto" "" o_fil_path ${o_fil})
+
+    list(APPEND ${srcs_var} "${o_fil_path}/${fil_we}.pb.cc")
+    list(APPEND ${hdrs_var} "${o_fil_path}/${fil_we}.pb.h")
+
+    add_custom_command(
+      OUTPUT "${o_fil_path}/${fil_we}.pb.cc"
+             "${o_fil_path}/${fil_we}.pb.h"
+      COMMAND ${CMAKE_COMMAND} -E make_directory "${output_dir}"
+      COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --cpp_out ${output_dir} ${o_fil} --proto_path ${proto_path}
+      DEPENDS ${abs_fil}
+	    WORKING_DIRECTORY ${work_path}
+      COMMENT "Running C++ protocol buffer compiler on ${o_fil}" VERBATIM )
+  endforeach()
+
+  set_source_files_properties(${${srcs_var}} ${${hdrs_var}} PROPERTIES GENERATED TRUE)
+  set(${srcs_var} ${${srcs_var}} PARENT_SCOPE)
+  set(${hdrs_var} ${${hdrs_var}} PARENT_SCOPE)
+endfunction()
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 75147cfd706..e4b09538e5d 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1976,6 +1976,82 @@ MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
                               NDArrayHandle* vals,
                               int priority);
 
+/*!
+ * \brief aggregate and sum up a list of (key, value) pairs from from all nodes, the result is stored
+ *        in out_vals. It has the same syntax as allreduce. Note: Currently only kvstore with type
+ *        'dist_sync_allreduce' support this api.
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys
+ * \param in_vals the list of values to be aggregated
+ * \param out_vals the list of values to store the result
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStorePushPull(KVStoreHandle handle,
+                                mx_uint num,
+                                const int *keys,
+                                NDArrayHandle *in_vals,
+                                NDArrayHandle *out_vals,
+                                int priority);
+
+/*!
+ * \brief aggregate and sum up a list of (key, value) pairs from from all nodes, the result is stored
+ *        in out_vals. It has the same syntax as allreduce.Note: Currently only kvstore with type
+ *        'dist_sync_allreduce' support this api.
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys in string.
+ * \param in_vals the list of values to be aggregated
+ * \param out_vals the list of values to store the result
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStorePushPullEx(KVStoreHandle handle,
+                                  mx_uint num,
+                                  const char **keys,
+                                  NDArrayHandle *in_vals,
+                                  NDArrayHandle *out_vals,
+                                  int priority);
+
+/*!
+ * \brief broadcast a list of (key, value) pairs from root_rank to all other nodes. Note:
+ *        Currently only kvstore with type 'dist_sync_allreduce' support this api.
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys.
+ * \param vals on node with root rank, the list of (key, value) will be broadcast.
+ *             on other nodes, the list of (key, value) will be the place to store the result.
+ * \param root_rank the rank where the values will be broadcast.
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStoreBroadcast(KVStoreHandle handle,
+                                 mx_uint num,
+                                 const int *keys,
+                                 NDArrayHandle *vals,
+                                 int root_rank,
+                                 int priority);
+
+/*!
+ * \brief broadcast a list of (key, value) pairs from root_rank to all other nodes. Note:
+ *        Currently only kvstore with type 'dist_sync_allreduce' support this api.
+ * \param handle handle to the kvstore
+ * \param num the number of key-value pairs
+ * \param keys the list of keys in string.
+ * \param vals on node with root rank, the list of (key, value) will be broadcast.
+ *             on other nodes, the list of (key, value) will be the place to store the result.
+ * \param root_rank the rank where the values will be broadcast.
+ * \param priority the priority of the action
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStoreBroadcastEx(KVStoreHandle handle,
+                                   mx_uint num,
+                                   const char **keys,
+                                   NDArrayHandle *vals,
+                                   int root_rank,
+                                   int priority);
+
 /*!
  * \brief pull a list of (key, value) pairs from the kvstore, where each key is an integer.
  *        The NDArray pulled back will be in row_sparse storage with only the specified
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index e10bd213aa2..9ec2853a1e7 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -186,6 +186,68 @@ class KVStore {
                     const std::vector<NDArray*>& values,
                     int priority = 0, bool ignore_sparse = true) = 0;
 
+  /*!
+   * \brief push and pull a list of key-value pairs from the all the nodes
+   *        It will aggregate the values from all the nodes. It shared the same
+   *        syntax as allreduce
+   * \param keys the list of keys
+   * \param in_values the list of buffers to be allreduced
+   * \param out_values the list of buffers to store the result
+   * \param priority Priority of the action
+   */
+  virtual void PushPull(const std::vector<int> &keys,
+                        const std::vector<NDArray> &in_values,
+                        const std::vector<NDArray*> &out_values,
+                        int priority = 0) {
+      LOG(FATAL) << "The api is not supported in kvstore with type " << type_;
+  }
+
+  /*!
+   * \brief push and pull a list of key-value pairs from the all the nodes
+   *        It will aggregate the values from all the nodes. It shared the same
+   *        syntax as allreduce
+   * \param keys the list of keys in string format
+   * \param in_values the list of buffers to be allreduced
+   * \param out_values the list of buffers to store the result
+   * \param priority Priority of the action
+   */
+  virtual void PushPull(const std::vector<std::string> &str_keys,
+                        const std::vector<NDArray> &in_values,
+                        const std::vector<NDArray*> &out_values,
+                        int priority = 0) {
+      LOG(FATAL) << "The api is not supported in kvstore with type " << type_;
+  }
+
+  /*!
+   * \brief broadcast a list of key-value pairs from root_rank node to all other nodes
+   * \param keys the list of keys
+   * \param values the list of buffers to be broadcast in root_rank node, for other nodes
+   *        it's the list of bufferes to store the result
+   * \param root_rank indicates the data of which node will be broadcasted.
+   * \param priority Priority of the action
+   */
+  virtual void Broadcast(const std::vector<int> &keys,
+                         const std::vector<NDArray*> &values,
+                         int root_rank,
+                         int priority = 0) {
+      LOG(FATAL) << "The api is not supported in kvstore with type " << type_;
+    }
+
+  /*!
+   * \brief broadcast a list of key-value pairs from root_rank node to all other nodes
+   * \param keys the list of keys
+   * \param values the list of buffers to be broadcast in root_rank node, for other nodes
+   *        it's the list of bufferes to store the result
+   * \param root_rank indicates the data of which node will be broadcasted.
+   * \param priority Priority of the action
+   */
+  virtual void Broadcast(const std::vector<std::string> &str_keys,
+                         const std::vector<NDArray*> &values,
+                         int root_rank,
+                         int priority = 0) {
+      LOG(FATAL) << "The api is not supported in kvstore with type " << type_;
+    }
+
   /*!
    * \brief pull a list of key-value pairs from the store.
    *        The NDArray pulled back will be in row_sparse storage with only the
diff --git a/make/config.mk b/make/config.mk
index b65f77c605f..2d4af331669 100644
--- a/make/config.mk
+++ b/make/config.mk
@@ -158,6 +158,14 @@ USE_F16C =
 # whether or not to enable multi-machine supporting
 USE_DIST_KVSTORE = 0
 
+# whether or not to enable kvstore with type dist_sync_allreduce
+USE_ALLREDUCE_DIST_KVSTORE = 0
+
+# mpi library root directory, mpi_collectives will depend
+# upon $(MPI_ROOT)/include $(MPI_ROOT)/lib, user need to
+# set this path, otherwise we will use default mpi (mpich).
+MPI_ROOT =
+
 # whether or not allow to read and write HDFS directly. If yes, then hadoop is
 # required
 USE_HDFS = 0
diff --git a/prepare_mpi.sh b/prepare_mpi.sh
new file mode 100755
index 00000000000..53e8fd5cf72
--- /dev/null
+++ b/prepare_mpi.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+HOME_MPI_DIR=$1
+
+if [ ! -d "$HOME_MPI_DIR" ]; then
+    mkdir -p $HOME_MPI_DIR
+fi
+
+MXNET_ROOTDIR=`dirname $0`
+# Default MPI Vars
+DEF_MPI_URL=http://www.mpich.org/static/downloads/3.2.1/mpich-3.2.1.tar.gz
+DEF_MPI_TAR=mpich-3.2.1.tar.gz
+DEF_MPI_DIR=mpich-3.2.1
+DEF_MPI_BUILD=$HOME_MPI_DIR/build
+DEF_MPI_LIB=$DEF_MPI_BUILD/lib/libmpi.so
+
+if [ -e "$DEF_MPI_LIB" ]; then
+    echo "${DEF_MPI_BUILD}"
+    exit 0
+fi
+
+mkdir -p $DEF_MPI_BUILD
+##########################
+# Download MPI
+##########################
+echo "Downloading mpi ..." >&2
+cd $HOME_MPI_DIR && wget $DEF_MPI_URL && tar xvf $DEF_MPI_TAR >&2
+
+##########################
+# Build and Install MPI
+##########################
+echo "Congiure & Build & Install mpi ..." >&2
+cd $HOME_MPI_DIR/$DEF_MPI_DIR
+./configure --prefix=$DEF_MPI_BUILD >&2
+make -j >&2
+make install >&2
+
+cd $MXNET_ROOTDIR
+
+### Return MPI_ROOT
+echo "${DEF_MPI_BUILD}"
+
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 09ad96314d5..a58c793eccb 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -140,7 +140,10 @@ def _init_params(self):
                     idx = self._param2idx[param.name]
                     self._kvstore.init(idx, param_arrays[0])
                     if param._stype == 'default':
-                        self._kvstore.pull(idx, param_arrays, priority=-idx)
+                        if 'allreduce' not in self._kvstore.type:
+                            self._kvstore.pull(idx, param_arrays, priority=-idx)
+                        else:
+                            self._kvstore.broadcast(idx, param_arrays, 0, priority=-idx)
 
         self._params_to_init = params_to_init
 
@@ -289,11 +292,13 @@ def _allreduce_grads(self):
         if self._kvstore:
             for i, param in enumerate(self._params):
                 if param.grad_req != 'null':
+                    if 'allreduce' not in self._kvstore.type:
+                        self._kvstore.push(i, param.list_grad(), priority=-i)
+                        if not self._update_on_kvstore:
+                            self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False)
+                    else:
+                        self._kvstore.pushpull(i, param.list_grad(), param.list_grad(), priority=-i)
 
-                    self._kvstore.push(i, param.list_grad(), priority=-i)
-
-                    if not self._update_on_kvstore:
-                        self._kvstore.pull(i, param.list_grad(), priority=-i, ignore_sparse=False)
 
     def update(self, batch_size, ignore_stale_grad=False):
         """Makes one step of parameter update.
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 60973365975..7bfcbb892b8 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -166,6 +166,9 @@ def push(self, key, value, priority=0):
         There is no synchronization between workers.
         One can use ``_barrier()`` to sync all workers.
 
+        Note: This api is not supported for allreduce kvstore.
+        Use :py:meth:`pushpull` instead.
+
         Parameters
         ----------
         key : str, int, or sequence of str or int
@@ -226,14 +229,17 @@ def push(self, key, value, priority=0):
         >>> print b
         <RowSparseNDArray 2x3 @cpu(0)>
         """
-        ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
-        if use_str_keys:
-            check_call(_LIB.MXKVStorePushEx(
-                self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
+        if 'allreduce' not in self.type: # pylint: disable=unsupported-membership-test
+            ckeys, cvals, use_str_keys = _ctype_key_value(key, value)
+            if use_str_keys:
+                check_call(_LIB.MXKVStorePushEx(
+                    self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
+            else:
+                check_call(_LIB.MXKVStorePush(
+                    self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
         else:
-            check_call(_LIB.MXKVStorePush(
-                self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority)))
-
+            raise Exception("This api is not supported for kvstore with type %s. \
+                             Please use pushpull instead."%self.type)
 
     def pull(self, key, out=None, priority=0, ignore_sparse=True):
         """ Pulls a single value or a sequence of values from the store.
@@ -250,6 +256,9 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True):
         pull with `RowSparseNDArray` is not supported for dist kvstore.
         Please use ``row_sparse_pull`` instead.
 
+        Note: This api is not supported for allreduce kvstore.
+        Use :py:meth:`pushpull` instead.
+
         Parameters
         ----------
         key : str, int, or sequence of str or int
@@ -299,15 +308,125 @@ def pull(self, key, out=None, priority=0, ignore_sparse=True):
         [ 2.  2.  2.]]
         """
         assert(out is not None)
-        ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
-        if use_str_keys:
-            check_call(_LIB.MXKVStorePullWithSparseEx(self.handle, mx_uint(len(ckeys)), ckeys,
-                                                      cvals, ctypes.c_int(priority),
-                                                      ctypes.c_bool(ignore_sparse)))
+        if 'allreduce' not in self.type: # pylint: disable=unsupported-membership-test
+            ckeys, cvals, use_str_keys = _ctype_key_value(key, out)
+            if use_str_keys:
+                check_call(_LIB.MXKVStorePullWithSparseEx(self.handle, mx_uint(len(ckeys)), ckeys,
+                                                          cvals, ctypes.c_int(priority),
+                                                          ctypes.c_bool(ignore_sparse)))
+            else:
+                check_call(_LIB.MXKVStorePullWithSparse(self.handle, mx_uint(len(ckeys)), ckeys,
+                                                        cvals, ctypes.c_int(priority),
+                                                        ctypes.c_bool(ignore_sparse)))
+        else:
+            raise Exception("This api is not supported for kvstore with type %s. \
+                             Please use pushpull instead."%self.type)
+
+    def pushpull(self, keys, ins, outs, priority=0):
+        """ allreduce a single or a sequence of key-value pairs from all nodes.
+
+        This function returns immediately after sending an allreduce request to mpi background
+        thread. The rank 0 node will collect allreduce request info from all nodes and ensure
+        every all reduce execution order is the same across all nodes.
+
+        Note: This api is only supported for allreduce kvstore
+
+        Parameters
+        ----------
+        keys : str, int, or sequence of str or int
+              Keys.
+
+        ins : NDArray, or list of list of NDArray
+              Values corresponding to the keys to be allreduced.
+
+        outs : NDArray, or list of list of NDArray.
+               Values corresponding to the keys to store the result.
+
+        priority: int, optional
+            The priority of the push operation.
+            Higher priority push operations are likely to be executed before
+            other push actions.
+
+        Examples
+        --------
+        >>> # allreduce a single key-value pair on 2 nodes
+        >>> shape = (2, 3)
+        >>> in_ = mx.nd.ones(shape)
+        >>> out_ = mx.nd.zeros(shape)
+        >>> kv.pushpull('key', in_, out_, 0)
+        >>> print out_.asnumpy()
+        [[ 2.  2.  2.]
+         [ 2.  2.  2.]]
+        >>> # allreduce a list of key-value pairs
+        >>> keys = ['5', '7', '9']
+        >>> in_ = [mx.nd.ones(shape)]*len(keys)
+        >>> out_ = [mx.nd.zeros(shape)]*len(keys)
+        >>> print out_[1].asnumpy()
+        [[ 2.  2.  2.]
+        [ 2.  2.  2.]]
+        """
+        if 'allreduce' in self.type: # pylint: disable=unsupported-membership-test
+            ckeys, cinvals, use_str_keys = _ctype_key_value(keys, ins)
+            ckeys, coutvals, use_str_keys = _ctype_key_value(keys, outs)
+            if use_str_keys:
+                check_call(_LIB.MXKVStorePushPullEx(
+                    self.handle, mx_uint(len(ckeys)), ckeys, cinvals,
+                    coutvals, ctypes.c_int(priority)))
+            else:
+                check_call(_LIB.MXKVStorePushPull(
+                    self.handle, mx_uint(len(ckeys)), ckeys, cinvals,
+                    coutvals, ctypes.c_int(priority)))
         else:
-            check_call(_LIB.MXKVStorePullWithSparse(self.handle, mx_uint(len(ckeys)), ckeys,
-                                                    cvals, ctypes.c_int(priority),
-                                                    ctypes.c_bool(ignore_sparse)))
+            raise Exception("This api is not supported for kvstore with type %s. \
+                             Please use push and pull instead."%self.type)
+
+    def broadcast(self, keys, values, root_rank, priority=0):
+        """ Broadcast a single or a sequence of key-value pairs from root_rank to all other nodes
+
+        This function returns immediately after sending an broadcast request to mpi background
+        thread. In mpi background thread, it will invoke MPI_Bcast in every node.
+
+        Note: This api is only supported for allreduce kvstore
+
+        Parameters
+        ----------
+        keys : str, int, or sequence of str or int
+              Keys.
+
+        values : NDArray, or list of list of NDArray
+                 Values corresponding to the keys to be broadcast.
+
+        root_rank: Decides in which rank the value will be broadcast.
+
+        priority: int, optional
+            The priority of the push operation.
+            Higher priority push operations are likely to be executed before
+            other push actions.
+
+        Examples:
+        ---------
+        >>> if kv.rank == 0:
+        >>>   value = mx.nd.ones(shape)
+        >>> else:
+        >>>   value = mx.nd.zeros(shape)
+        >>> kv.broadcast('key', value, 0)
+        >>> if kv.rank != 0:
+        >>>   print value.asnumpy()
+        >>> [[ 1.  1.  1.]
+            [ 1.  1.  1.]]
+        """
+        if 'allreduce' in self.type: # pylint: disable=unsupported-membership-test
+            ckeys, cinvals, use_str_keys = _ctype_key_value(keys, values)
+            if use_str_keys:
+                check_call(_LIB.MXKVStoreBroadcastEx(
+                    self.handle, mx_uint(len(ckeys)), ckeys, cinvals,
+                    ctypes.c_int(root_rank), ctypes.c_int(priority)))
+            else:
+                check_call(_LIB.MXKVStoreBroadcast(
+                    self.handle, mx_uint(len(ckeys)), ckeys, cinvals,
+                    ctypes.c_int(root_rank), ctypes.c_int(priority)))
+        else:
+            raise Exception("This api is not supported for kvstore with type %s"%self.type)
 
     def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
         """ Pulls a single RowSparseNDArray value or a sequence of RowSparseNDArray values \
@@ -320,6 +439,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
 
         The returned values are guaranteed to be the latest values in the store.
 
+        Note: This api is not supported for allreduce kvstore
+
         Parameters
         ----------
         key : str, int, or sequence of str or int
@@ -363,6 +484,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
         """
         assert(out is not None)
         assert(row_ids is not None)
+        if 'allreduce' in self.type: # pylint: disable=unsupported-membership-test
+            raise Exception("This api is not supported for kvstore with type %s"%self.type)
         if isinstance(row_ids, NDArray):
             row_ids = [row_ids]
         assert(isinstance(row_ids, list)), \
@@ -437,7 +560,7 @@ def set_gradient_compression(self, compression_params):
             Other keys in this dictionary are optional and specific to the type
             of gradient compression.
         """
-        if ('device' in self.type) or ('dist' in self.type): # pylint: disable=unsupported-membership-test
+        if ('device' in self.type) or ('dist' in self.type) and ('allreduce' not in self.type): # pylint: disable=unsupported-membership-test
             ckeys, cvals = _ctype_dict(compression_params)
             check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
                                                             mx_uint(len(compression_params)),
@@ -452,6 +575,8 @@ def set_optimizer(self, optimizer):
         If using multiple machines and this operation is invoked from a worker node,
         it will serialized the optimizer with pickle and send it to all servers.
         The function returns after all servers have been updated.
+        In allreduce kvstore, this api only updates the local optimizer
+        same as single machine.
 
         Parameters
         ----------
@@ -479,7 +604,7 @@ def set_optimizer(self, optimizer):
         check_call(_LIB.MXKVStoreIsWorkerNode(ctypes.byref(is_worker)))
 
         # pylint: disable=invalid-name
-        if 'dist' in self.type and is_worker.value: # pylint: disable=unsupported-membership-test
+        if ('dist' in self.type) and ('allreduce' not in self.type) and is_worker.value: # pylint: disable=unsupported-membership-test
             # send the optimizer to server
             try:
                 # use ASCII protocol 0, might be slower, but not a big ideal
@@ -627,8 +752,11 @@ def _send_command_to_servers(self, head, body):
         body : str
             the body of the command.
         """
-        check_call(_LIB.MXKVStoreSendCommmandToServers(
-            self.handle, mx_uint(head), c_str(body)))
+        if 'allreduce' in self.type: # pylint: disable=unsupported-membership-test
+            raise Exception("This api is not supported for kvstore with type %s"%self.type)
+        else:
+            check_call(_LIB.MXKVStoreSendCommmandToServers(
+                self.handle, mx_uint(head), c_str(body)))
 
 def create(name='local'):
     """Creates a new KVStore.
@@ -656,9 +784,17 @@ def create(name='local'):
     No two updates happen on the same weight at the same time. However, the order is not
     guaranteed.
 
+    ``dist_sync_allreduce``: Behaves similarly to dist_sync but with some major difference.
+    With ``dist_sync_allreduce``, no parameter server configured, replace push and pull apis with
+    pushpull.
+
+    ``dist_device_sync_allreduce``: Behaves same as dist_sync_allreduce, but it will two levels
+    allreduce, firstly reduce across devices in single node, then reduce across machines.
+
     Parameters
     ----------
-    name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async'}
+    name : {'local', 'device', 'nccl', 'dist_sync', 'dist_device_sync', 'dist_async',
+            'dist_sync_allreduce' 'dist_device_sync_allreduce'}
         The type of KVStore.
     Returns
     -------
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index 3a50553a615..fd9812d733b 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -108,7 +108,7 @@ def _create_kvstore(kvstore, num_device, arg_params):
     else:
         raise TypeError('kvstore must be KVStore, str or None')
 
-    if kv is None:
+    if (kv is None) or ('allreduce' in kv.type):
         update_on_kvstore = False
 
     return (kv, update_on_kvstore)
@@ -118,6 +118,8 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_o
     for idx, param_on_devs in enumerate(param_arrays):
         name = param_names[idx]
         kvstore.init(name, arg_params[name])
+        if 'allreduce' in kvstore.type:
+            kvstore.broadcast(name, param_on_devs, 0, priority=-idx)
 
         if update_on_kvstore:
             kvstore.pull(name, param_on_devs, priority=-idx)
@@ -164,10 +166,13 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
         index = i
         if kvstore:
             name = param_names[index]
-            # push gradient, priority is negative index
-            kvstore.push(name, grad_list, priority=-index)
-            # pull back the sum gradients, to the same locations.
-            kvstore.pull(name, grad_list, priority=-index)
+            if 'allreduce' not in kvstore.type:
+                # push gradient, priority is negative index
+                kvstore.push(name, grad_list, priority=-index)
+                # pull back the sum gradients, to the same locations.
+                kvstore.pull(name, grad_list, priority=-index)
+            else:
+                kvstore.pushpull(name, grad_list, grad_list, priority=-index)
         for k, p in enumerate(zip(arg_list, grad_list)):
             # faked an index here, to make optimizer create diff
             # state for the same index but on diff devs, TODO(mli)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index efa7301d7ab..30dc2ec5969 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -923,6 +923,78 @@ int MXKVStorePullWithSparseEx(KVStoreHandle handle,
   API_END();
 }
 
+int MXKVStorePushPull(KVStoreHandle handle,
+                      mx_uint num,
+                      const int *keys,
+                      NDArrayHandle *in_vals,
+                      NDArrayHandle *out_vals,
+                      int priority) {
+  API_BEGIN();
+  std::vector<int> v_keys(num);
+  std::vector<NDArray> v_invals(num);
+  std::vector<NDArray*> v_outvals(num);
+  for (mx_uint i = 0; i < num; ++i) {
+    v_keys[i] = keys[i];
+    v_invals[i] = *static_cast<NDArray*>(in_vals[i]);
+    v_outvals[i] = static_cast<NDArray*>(out_vals[i]);
+  }
+  static_cast<KVStore*>(handle)->PushPull(v_keys, v_invals, v_outvals, priority);
+  API_END();
+}
+
+int MXKVStorePushPullEx(KVStoreHandle handle,
+                        mx_uint num,
+                        const char **keys,
+                        NDArrayHandle *in_vals,
+                        NDArrayHandle *out_vals,
+                        int priority) {
+  API_BEGIN();
+  std::vector<std::string> v_keys(num);
+  std::vector<NDArray> v_invals(num);
+  std::vector<NDArray*> v_outvals(num);
+  for (mx_uint i = 0; i < num; ++i) {
+    v_keys[i] = keys[i];
+    v_invals[i] = *static_cast<NDArray*>(in_vals[i]);
+    v_outvals[i] = static_cast<NDArray*>(out_vals[i]);
+  }
+  static_cast<KVStore*>(handle)->PushPull(v_keys, v_invals, v_outvals, priority);
+  API_END();
+}
+
+int MXKVStoreBroadcast(KVStoreHandle handle,
+                       mx_uint num,
+                       const int *keys,
+                       NDArrayHandle *vals,
+                       int root_rank,
+                       int priority) {
+  API_BEGIN();
+  std::vector<int> v_keys(num);
+  std::vector<NDArray*> v_vals(num);
+  for (mx_uint i = 0; i < num; ++i) {
+    v_keys[i] = keys[i];
+    v_vals[i] = static_cast<NDArray*>(vals[i]);
+  }
+  static_cast<KVStore*>(handle)->Broadcast(v_keys, v_vals, root_rank, priority);
+  API_END();
+}
+
+int MXKVStoreBroadcastEx(KVStoreHandle handle,
+                         mx_uint num,
+                         const char **keys,
+                         NDArrayHandle *vals,
+                         int root_rank,
+                         int priority) {
+  API_BEGIN();
+  std::vector<std::string> v_keys(num);
+  std::vector<NDArray*> v_vals(num);
+  for (mx_uint i = 0; i < num; ++i) {
+    v_keys[i] = keys[i];
+    v_vals[i] = static_cast<NDArray*>(vals[i]);
+  }
+  static_cast<KVStore*>(handle)->Broadcast(v_keys, v_vals, root_rank, priority);
+  API_END();
+}
+
 int MXKVStorePullRowSparse(KVStoreHandle handle,
                            mx_uint num,
                            const int* keys,
diff --git a/src/kvstore/collectives/include/coll_util.h b/src/kvstore/collectives/include/coll_util.h
new file mode 100644
index 00000000000..23bd40c12f9
--- /dev/null
+++ b/src/kvstore/collectives/include/coll_util.h
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Copyright (c) 2018 by Contributors
+ */
+
+#ifndef MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_UTIL_H_
+#define MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_UTIL_H_
+
+#if MXNET_USE_ALLREDUCE_DIST_KVSTORE
+
+
+#include <stdio.h>
+#include <vector>
+
+#define COLL_UTIL_DEBUG_ON 0
+
+#if COLL_UTIL_DEBUG_ON
+#define MXCOLL_DEBUG(rank, fmt, args...)  printf("rank[%d]:" fmt, rank, ## args)
+#else
+#define MXCOLL_DEBUG(fmt, args...)
+#endif
+
+
+#endif
+#endif  // MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_UTIL_H_
diff --git a/src/kvstore/collectives/include/coll_wrapper.h b/src/kvstore/collectives/include/coll_wrapper.h
new file mode 100644
index 00000000000..8f73964b42f
--- /dev/null
+++ b/src/kvstore/collectives/include/coll_wrapper.h
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Copyright (c) 2018 by Contributors
+ */
+
+#ifndef MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_WRAPPER_H_
+#define MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_WRAPPER_H_
+
+#if MXNET_USE_ALLREDUCE_DIST_KVSTORE
+
+#include <mpi.h>
+
+#include "mxnet/ndarray.h"
+#include "mxnet/base.h"
+#include "mpi_message.pb.h"
+
+template<typename DType>
+MPI_Datatype MPI_Data_Type_Cast(void);
+
+template<>
+MPI_Datatype MPI_Data_Type_Cast<int>(void) {
+  return MPI_INT;
+}
+
+template<>
+MPI_Datatype MPI_Data_Type_Cast<float>(void) {
+  return MPI_FLOAT;
+}
+
+template<>
+MPI_Datatype MPI_Data_Type_Cast<double>(void) {
+  return MPI_DOUBLE;
+}
+
+template <class xpu, class DType>
+struct COLL_Wrapper {
+  static int Broadcast(mxnet::NDArray *input_array,
+                       int root_rank) {
+    return 0; }
+
+  static int AllReduce(mxnet::NDArray *input_array,
+                       mxnet::NDArray *output_array) {
+    return 0; }
+};
+
+// CPU Implementation
+template <class DType>
+struct COLL_Wrapper<mxnet::cpu, DType> {
+  static int Broadcast(mxnet::NDArray *input_array,
+                       int root_rank) {
+    DType *buf = reinterpret_cast<DType *>(input_array->data().dptr<DType>());
+    unsigned int count = input_array->data().Size();
+    int ret = MPI_Bcast(buf, count, MPI_Data_Type_Cast<DType>(), root_rank, MPI_COMM_WORLD);
+    return ret;
+  }
+
+  static int AllReduce(mxnet::NDArray *input_array,
+                       mxnet::NDArray *output_array) {
+    DType *send_buf = reinterpret_cast<DType *>(input_array->data().dptr<DType>());
+    DType *recv_buf = reinterpret_cast<DType *>(output_array->data().dptr<DType>());
+    unsigned int count = input_array->data().Size();
+    int ret;
+    assert(input_array->data().Size() == output_array->data().Size());
+
+    if (send_buf != recv_buf) {
+      ret = MPI_Allreduce(reinterpret_cast<const void *>(send_buf),
+                          reinterpret_cast<void *>(recv_buf),
+                          count, MPI_Data_Type_Cast<DType>(), MPI_SUM, MPI_COMM_WORLD);
+    } else {
+      ret = MPI_Allreduce(MPI_IN_PLACE, reinterpret_cast<void *>(recv_buf),
+                         count, MPI_Data_Type_Cast<DType>(), MPI_SUM, MPI_COMM_WORLD);
+    }
+    return ret;
+  }
+};
+
+// GPU Implementation
+template <class DType>
+struct COLL_Wrapper<mxnet::gpu, DType> {
+  static int Broadcast(mxnet::NDArray *input_array,
+                       int root_rank) {
+    // TODO(zhouhaiy): implement gpu broadcast
+    LOG(FATAL) << "Collective For GPU version has not been implemented.";
+    return -1;
+  }
+
+  static int AllReduce(mxnet::NDArray *input_array,
+                       mxnet::NDArray *output_array) {
+    // TODO(zhouhaiy): implement gpu all reduce
+    LOG(FATAL) << "Collective For GPU version has not been implemented.";
+    return -1;
+  }
+};
+
+#endif
+#endif  // MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLL_WRAPPER_H_
diff --git a/src/kvstore/collectives/include/collectives.h b/src/kvstore/collectives/include/collectives.h
new file mode 100644
index 00000000000..0370a6754f5
--- /dev/null
+++ b/src/kvstore/collectives/include/collectives.h
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Copyright (c) 2018 by Contributors
+ */
+
+#ifndef MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLLECTIVES_H_
+#define MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLLECTIVES_H_
+
+#if MXNET_USE_ALLREDUCE_DIST_KVSTORE
+
+#include <mxnet/ndarray.h>
+
+#include <vector>
+#include <string>
+
+#include "../../comm.h"
+
+namespace mxnet {
+namespace kvstore {
+
+/*!
+ * \brief Get total node number.
+ * \param ret out param for total node number.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXGetMpiSize(int *ret);
+
+/*!
+ * \brief Get the rank of this node.
+ * \param ret out param for rank.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXGetMpiRank(int *ret);
+
+/*!
+ * \brief Initialize collective library.
+ * \param comm commDevice for reduce and broadcast
+ *        within single node.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXCOLLIBInit(Comm *comm);
+
+/*!
+ * \brief Get the local rank.
+ * \param ret out param for local rank.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXGetLocalRank(int *ret);
+
+/*!
+ * \brief Do Allreduce across the multi-node.
+ * \param key key.
+ * \param send_value value to be sent.
+ * \param recv_value value to hold the result.
+ * \param priority the priority of the action.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXAllReduce(int key,
+                mxnet::NDArray* send_value,
+                mxnet::NDArray* recv_value,
+                int priority);
+
+/*!
+ * \brief Broadcast values in root rank to all other nodes.
+ * \param key the key for the ndarray.
+ * \param value the value to be broadcast.
+ * \param root_rank the value in the rank to be broadcast.
+ * \param priority the priority of the action.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXBroadcast(int key,
+                mxnet::NDArray* value,
+                int root_rank,
+                int priority);
+
+/*!
+ * \brief All gather values in all nodes.
+ * \param key the key for the value.
+ * \param value the value to be gathered.
+ * \param priority the priority of the action.
+ * \return 0 when success, -1 when failure happens
+ */
+int MXAllGather(int key,
+                mxnet::NDArray* value,
+                int priority);
+
+/*!
+ * \brief Blocks until all rank reached this routine
+ * \return - when success, -1 when failure happens
+ */
+int MXBarrier();
+
+}  // namespace kvstore
+}  // namespace mxnet
+#endif
+#endif  // MXNET_KVSTORE_COLLECTIVES_INCLUDE_COLLECTIVES_H_
diff --git a/src/kvstore/collectives/src/collectives.cc b/src/kvstore/collectives/src/collectives.cc
new file mode 100644
index 00000000000..0afefddcadd
--- /dev/null
+++ b/src/kvstore/collectives/src/collectives.cc
@@ -0,0 +1,792 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Copyright (c) 2018 by Contributors
+ */
+
+#if MXNET_USE_ALLREDUCE_DIST_KVSTORE
+
+#include <mpi.h>
+#include <unordered_map>
+#include <queue>
+#include <thread>
+#include <functional>
+#include <mutex>
+#include <condition_variable>
+#include <atomic>
+#include <iostream>
+
+#include "mxnet/base.h"
+#include "mxnet/ndarray.h"
+#include "mxnet/engine.h"
+#include "dmlc/logging.h"
+#include "mpi_message.pb.h"
+#include "collectives.h"
+#include "coll_wrapper.h"
+#include "coll_util.h"
+
+using namespace mxnet::kvstore;
+
+const char INT_PREFIX[] = "INT";
+const char STR_PREFIX[] = "STR";
+const char IDX_PREFIX[] = "IDX";
+const char OPS_PREFIX[] = "OPS";
+const char OPS_ALLREDUCE[] = "ALLREDUCE";
+const char OPS_BROADCAST[] = "BROADCAST";
+const char DELIMITER[] = ":";
+
+namespace {
+
+struct CollectiveOpRecord {
+  int rank;
+
+  std::string key;
+
+  MPIDataType dtype;
+
+  mxnet::NDArray *val_in;
+
+  mxnet::NDArray *val_out;
+
+  int root_rank;
+
+  mxnet::engine::CallbackOnComplete callback;
+};
+
+typedef std::unordered_map<std::string, CollectiveOpRecord> NDArrayTable;
+
+typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
+
+/*
+ *  Collective_global var maintain a message table and a background thread.
+ *  In rank 0, message table is used to coordinate all reduce order
+ *  of ndarray in different nodes.The background thread is used
+ *  for doing collectives and  doing coordination between nodes
+ *  through mpi messages.
+ */
+struct CollectiveGlobalState {
+  std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
+
+  std::condition_variable cv;
+
+  bool initialization_done = false;
+
+  int init_status;
+
+  std::mutex mu;
+
+  NDArrayTable ndarray_table;
+
+  std::queue<MPIRequest> message_queue;
+
+  std::thread background_thread;
+
+  bool shut_down = false;
+
+  std::unique_ptr<MessageTable> message_table;
+
+  int rank = 0;
+
+  int local_rank = 0;
+
+  int size = 1;
+
+  int device = -1;
+
+  mxnet::Context pinned_ctx;
+
+  Comm *local_comm = NULL;
+
+  mxnet::NDArray sync_var1;
+
+  mxnet::NDArray sync_var2;
+
+  int sync_key;
+
+~CollectiveGlobalState() {
+  if (background_thread.joinable()) {
+    shut_down = true;
+    background_thread.join();
+  }
+}
+};
+
+static CollectiveGlobalState coll_global;
+
+// static std::unordered_map<std::string, mxnet::NDArray> mpi_comm_buf;
+
+#define RANK_ZERO 0
+
+#define TAG_NOTIFY 1
+
+bool IncrementNDArrayCount(
+  const std::unique_ptr<MessageTable>& message_table,
+  const MPIRequest &msg, int mpi_size) {
+  auto name = msg.key_name();
+  auto table_iter = message_table->find(name);
+  if (table_iter == message_table->end()) {
+    message_table->emplace(name, std::vector<MPIRequest>({msg}));
+    MXCOLL_DEBUG(coll_global.rank, "Insert new message key [%s] reqeust type [%d] from "
+                "rank[%d] into message table!\n", name.c_str(), msg.request_type(),
+                msg.request_rank());
+    table_iter = message_table->find(name);
+  } else {
+    MXCOLL_DEBUG(coll_global.rank, "Insert existing message key [%s] request type [%d]"
+                "from rank[%d] into message table!\n",
+                name.c_str(), msg.request_type(), msg.request_rank());
+    table_iter->second.push_back(msg);
+  }
+
+  int count = table_iter->second.size();
+  MXCOLL_DEBUG(coll_global.rank, "Message Key [%s] count [%d]\n", name.c_str(), count);
+  return count == mpi_size;
+}
+
+int DataTypeToMPIType(int ndarray_dtype, MPIDataType *mpi_dtype) {
+  if (ndarray_dtype == mshadow::kFloat32) {
+    *mpi_dtype = MX_MPI_FLOAT32;
+  } else if (ndarray_dtype == mshadow::kInt32) {
+    *mpi_dtype = MX_MPI_INT32;
+  } else if (ndarray_dtype == mshadow::kInt64) {
+    *mpi_dtype = MX_MPI_INT64;
+  } else {
+    return -1;
+  }
+  return 0;
+}
+
+MPIResponse ConstructMPIResponse(const std::unique_ptr<MessageTable>& message_table,
+                                 std::string name) {
+  bool error = false;
+  auto it = message_table->find(name);
+  assert(it != message_table->end());
+
+  std::vector<MPIRequest> requests = it->second;
+  assert(requests.size() > 0);
+
+  std::ostringstream error_message_stream;
+
+  auto data_type = requests[0].value_type();
+  for (unsigned int i = 1; i < requests.size(); i++) {
+    auto request_type = requests[i].value_type();
+    if (data_type != request_type) {
+      error = true;
+      error_message_stream
+        << "Mismatched data types: One rank had type "
+        << MPIDataType_Name(data_type)
+        << ", but another rank had type "
+        << MPIDataType_Name(request_type)
+        << ".";
+      break;
+    }
+  }
+
+  auto message_type = requests[0].request_type();
+  for (unsigned int i = 1; i < requests.size(); i++) {
+    if (error) {
+      break;
+    }
+    auto request_type = requests[i].request_type();
+    if (message_type != request_type) {
+      error = true;
+      error_message_stream
+        << "Mismatched Collective operations: One rank did op "
+        << message_type
+        << ", but another rank did op "
+        << request_type
+        << ".";
+      break;
+    }
+  }
+
+  // TODO(zhouhaiy): Check value shape for all reduce and all gather
+
+  MPIResponse response;
+  response.set_key_name(name);
+  if (error) {
+    std::string error_message = error_message_stream.str();
+    response.set_response_type(MPIResponse::ERROR);
+    response.set_error_message(error_message);
+    MXCOLL_DEBUG(coll_global.rank, "MPI Response Key [%s] error_message [%s].\n",
+                 name.c_str(), error_message.c_str());
+  } else {
+    auto response_type = MPIResponse::ERROR;
+    if (message_type == MPIRequest::ALLREDUCE) {
+      response_type = MPIResponse::ALLREDUCE;
+    } else if (message_type == MPIRequest::ALLGATHER) {
+      response_type = MPIResponse::ALLGATHER;
+    } else {
+      response_type = MPIResponse::BROADCAST;
+    }
+    response.set_response_type(response_type);
+  }
+
+  // Clear all queued up requests for this name. They are now taken care of
+  // by the constructed MPI response.
+  message_table->erase(it);
+
+  return response;
+}
+
+void PerformCollectiveOp(NDArrayTable *ndarray_table, MPIResponse response) {
+  mxnet::NDArray *input_array;
+  mxnet::NDArray *output_array;
+  mxnet::engine::CallbackOnComplete callback;
+  int root_rank;
+  {
+    std::lock_guard<std::mutex> guard(coll_global.mu);
+    auto name = response.key_name();
+    auto iter = ndarray_table->find(name);
+    assert(iter != ndarray_table->end());
+
+    assert(response.response_type() == MPIResponse::ALLREDUCE ||
+           response.response_type() == MPIResponse::ALLGATHER ||
+           response.response_type() == MPIResponse::BROADCAST ||
+           response.response_type() == MPIResponse::ERROR);
+
+    CollectiveOpRecord record = iter->second;
+    input_array = record.val_in;
+    output_array = record.val_out;
+    callback = record.callback;
+    root_rank = record.root_rank;
+    ndarray_table->erase(iter);
+  }
+
+  const int dev_in  = input_array->ctx().dev_mask();
+  if (response.response_type() == MPIResponse::ALLREDUCE) {
+    const int dev_out = output_array->ctx().dev_mask();
+    // We only support the case in ndarray and out ndarray
+    // share the same device type currently in dist_sync_allreduce.
+    if (dev_in != dev_out) {
+      LOG(FATAL) << "input and output ndarray with mixed device"
+                 << "(One CPU the other GPU or vice versa) "
+                 << "is not supported in kvstore with type dist_sync_allreduce.";
+    }
+  }
+
+  auto dtype = input_array->dtype();
+  int ret = 0;
+  std::string coll_ops;
+  if (response.response_type() == MPIResponse::ALLREDUCE) {
+    coll_ops = OPS_ALLREDUCE;
+    if (dtype == mshadow::kFloat32) {
+      switch (dev_in) {
+        case mshadow::cpu::kDevMask: {
+          ret = COLL_Wrapper<mxnet::cpu, float>::AllReduce(input_array, output_array);
+          break;
+        }
+        case mshadow::gpu::kDevMask: {
+#if MXNET_USE_CUDA
+          ret = COLL_Wrapper<mxnet::gpu, float>::AllReduce(input_array, output_array);
+          break;
+#else
+          LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+          break;
+#endif
+        }
+        default: {
+          LOG(FATAL) << "Unknown device type " << dev_in;
+        }
+      }
+    } else if (dtype == mshadow::kInt32) {
+      switch (dev_in) {
+        case mshadow::cpu::kDevMask: {
+          ret = COLL_Wrapper<mxnet::cpu, int>::AllReduce(input_array, output_array);
+          break;
+        }
+        case mshadow::gpu::kDevMask: {
+#if MXNET_USE_CUDA
+          ret = COLL_Wrapper<mxnet::gpu, int>::AllReduce(input_array, output_array);
+          break;
+#else
+          LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+          break;
+#endif
+        }
+        default: {
+          LOG(FATAL) << "Unknown device type " << dev_in;
+        }
+      }
+    } else {
+      LOG(FATAL) << "rank[" << coll_global.rank << "]:" << "Not supported datatype:"
+                 << dtype << " of ndarray with name " << response.key_name();
+    }
+    // Broadcast value to all devices
+  } else if (response.response_type() == MPIResponse::BROADCAST) {
+    coll_ops = OPS_BROADCAST;
+    if (dtype == mshadow::kFloat32) {
+      switch (dev_in) {
+        case mshadow::cpu::kDevMask: {
+          ret = COLL_Wrapper<mxnet::cpu, float>::Broadcast(input_array, root_rank);
+          break;
+        }
+        case mshadow::gpu::kDevMask: {
+#if MXNET_USE_CUDA
+          ret = COLL_Wrapper<mxnet::gpu, float>::Broadcast(input_array, root_rank);
+          break;
+#else
+          LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+          break;
+#endif
+        }
+        default: {
+          LOG(FATAL) << "Unknown device type " << dev_in;
+        }
+      }
+    } else if (dtype == mshadow::kInt32) {
+      switch (dev_in) {
+        case mshadow::cpu::kDevMask: {
+          ret = COLL_Wrapper<mxnet::cpu, int>::Broadcast(input_array, root_rank);
+          break;
+        }
+        case mshadow::gpu::kDevMask: {
+#if MXNET_USE_CUDA
+          ret = COLL_Wrapper<mxnet::gpu, int>::Broadcast(input_array, root_rank);
+          break;
+#else
+          LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+          break;
+#endif
+        }
+        default: {
+          LOG(FATAL) << "Unknown device type " << dev_in;
+        }
+      }
+    } else {
+      LOG(FATAL) << "rank[" << coll_global.rank << "]:" << "Not supported datatype:"
+                 << dtype << " of ndarray with name " << response.key_name();
+    }
+  } else {
+    LOG(FATAL) << "rank[" << coll_global.rank << "]:" << "Invalid MPI response type:"
+               << response.response_type();
+  }
+  if (ret != 0) {
+    LOG(FATAL) << "rank[" << coll_global.rank << "]:" << "Collective Operation " << coll_ops
+               << " failed at ndarray with name " << response.key_name();
+  }
+  callback();
+}
+
+void BackgroundThreadLoop() {
+  auto init_result = MPI_Init(NULL, NULL);
+  if (init_result != MPI_SUCCESS) {
+    coll_global.init_status = -1;
+    LOG(FATAL) << "MPI_Initialization Failure!";
+    coll_global.initialization_done = true;
+    coll_global.cv.notify_all();
+    return;
+  } else {
+    coll_global.init_status = 0;
+  }
+
+  int rank;
+  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+  bool is_coordinator = rank == 0;
+
+  int size;
+  MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+  MPI_Comm local_comm;
+  MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, &local_comm);
+  int local_rank;
+  MPI_Comm_rank(local_comm, &local_rank);
+
+  coll_global.rank = rank;
+  coll_global.local_rank = local_rank;
+  coll_global.size = size;
+  coll_global.initialization_done = true;
+
+  coll_global.cv.notify_all();
+
+  if (is_coordinator) {
+    coll_global.message_table =
+      std::unique_ptr<MessageTable>(new MessageTable());
+  }
+
+  bool should_shut_down = false;
+  do {
+    // TODO(zhouhaiy): Eliminate the need for thread sleep by making all activity
+    // depend on other activity (e.g. condition or MPI waits).
+    std::this_thread::sleep_for(std::chrono::milliseconds(1));
+
+    // Copy the data structures from global state under this lock.
+    // However, don't keep the lock for the rest of the loop, so that
+    // enqueued stream callbacks can continue.
+    std::queue<MPIRequest> message_queue;
+    {
+      std::lock_guard<std::mutex> guard(coll_global.mu);
+      while (!coll_global.message_queue.empty()) {
+        MPIRequest message = coll_global.message_queue.front();
+        coll_global.message_queue.pop();
+        message_queue.push(message);
+      }
+    }
+
+    // Collect all tensors that are ready to be reduced. Record them in the
+    // tensor count table (rank zero) or send them to rank zero to be
+    // recorded (everyone else).
+    std::vector<std::string> ready_to_reduce;
+    while (!message_queue.empty()) {
+      // Pop the first available message message
+      MPIRequest message = message_queue.front();
+      message_queue.pop();
+
+      if (is_coordinator) {
+        bool reduce = IncrementNDArrayCount(coll_global.message_table,
+                                           message, size);
+        if (reduce) {
+          MXCOLL_DEBUG(coll_global.rank, "Push back ndarray with key [%s] "
+                      "to ready_to_reduce!\n", message.key_name().c_str());
+          ready_to_reduce.push_back(message.key_name());
+        }
+      } else {
+        std::string encoded_message;
+        message.SerializeToString(&encoded_message);
+        MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
+                 MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+        MXCOLL_DEBUG(coll_global.rank, "MPI_Send message %s!\n", encoded_message.c_str());
+      }
+    }
+
+    // Rank zero has put all its own tensors in the tensor count table.
+    // Now, it should count all the tensors that are coming from other
+    // ranks at this tick. It should keep getting tensors until it gets a
+    // DONE message from all the other ranks.
+    if (is_coordinator) {
+      // Count of DONE messages. Keep receiving messages until the number
+      // of messages is equal to the number of processes. Initialize to
+      // one since the coordinator is effectively done.
+      int completed_ranks = 1;
+      while (completed_ranks != size) {
+        MPI_Status status;
+        MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+        // Find number of characters in message (including zero byte).
+        int source_rank = status.MPI_SOURCE;
+        int msg_length;
+        MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+        // If the length is zero, this is a DONE message.
+        if (msg_length == 0) {
+          completed_ranks++;
+          MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY,
+                   MPI_COMM_WORLD, &status);
+          continue;
+        }
+
+        // Get tensor name from MPI into an std::string.
+        char* buffer = new char[msg_length];
+        MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank,
+                 TAG_NOTIFY, MPI_COMM_WORLD, &status);
+        std::string received_data(buffer);
+        delete[] buffer;
+
+        MPIRequest received_message;
+        received_message.ParseFromString(received_data);
+        auto received_name = received_message.key_name();
+
+        bool reduce = IncrementNDArrayCount(
+                        coll_global.message_table, received_message, size);
+        if (reduce) {
+          MXCOLL_DEBUG(coll_global.rank, "Push back ndarray with key [%s] "
+                      "to ready_to_reduce!\n", received_name.c_str());
+          ready_to_reduce.push_back(received_name);
+        }
+      }
+
+      // At this point, rank zero should have a fully updated tensor
+      // count table and should know all the tensors that need to be
+      // reduced or gathered, and everyone else should have sent all
+      // their information to rank zero. We can now do reductions and
+      // gathers; rank zero will choose which ones and in what order,
+      // and will notify the other ranks before doing each reduction.
+      for (size_t i = 0; i < ready_to_reduce.size(); i++) {
+        // Notify all nodes which tensor we'd like to reduce now
+        auto name = ready_to_reduce[i];
+        MPIResponse response = ConstructMPIResponse(coll_global.message_table, name);
+        std::string encoded_response;
+        response.SerializeToString(&encoded_response);
+        for (int r = 1; r < size; r++) {
+          MPI_Send(encoded_response.c_str(),
+                   encoded_response.length() + 1,
+                   MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+        }
+
+        // Perform the reduction. All nodes should end up performing
+        // the same reduction.
+        PerformCollectiveOp(&(coll_global.ndarray_table), response);
+      }
+
+      // Notify all nodes that we are done with the reductions for this
+      // tick.
+      MPIResponse done_response;
+      done_response.set_response_type(coll_global.shut_down ?
+                                      MPIResponse::SHUTDOWN : MPIResponse::DONE);
+
+      std::string encoded_response;
+      done_response.SerializeToString(&encoded_response);
+
+      for (int r = 1; r < size; r++) {
+        MPI_Send(encoded_response.c_str(),
+                 encoded_response.length() + 1,
+                 MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+      }
+      if (coll_global.shut_down) {
+        should_shut_down = true;
+      }
+    } else {
+      // Notify the coordinator that this node is done sending messages.
+      // A DONE message is encoded as a zero-length message.
+      MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+
+      // Receive names for tensors to reduce from rank zero. Once we
+      // receive a empty DONE message, stop waiting for more names.
+      while (true) {
+        MPI_Status status;
+        MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+        // Find number of characters in message (including zero byte).
+        int msg_length;
+        MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+        // Get tensor name from MPI into an std::string.
+        char* buffer = new char[msg_length];
+        MPI_Recv(buffer, msg_length, MPI_BYTE, 0,
+                 TAG_NOTIFY, MPI_COMM_WORLD, &status);
+        std::string received_message(buffer);
+        delete[] buffer;
+
+        MPIResponse response;
+        response.ParseFromString(received_message);
+        if (response.response_type() == MPIResponse::DONE) {
+          // No more messages this tick
+          break;
+        } else if (response.response_type() == MPIResponse::SHUTDOWN) {
+        // No more messages this tick, and the background thread
+        // should shut down
+          should_shut_down = true;
+          break;
+        } else {
+          // Process the current message
+          PerformCollectiveOp(&(coll_global.ndarray_table), response);
+        }
+      }
+    }
+  } while (!should_shut_down);
+
+  MPI_Finalize();
+}
+
+int InitializeMPIOnce(Comm *comm) {
+  if (coll_global.initialized_flag.test_and_set())
+    return coll_global.init_status;
+
+  coll_global.device = -1;
+  coll_global.local_comm = comm;
+  coll_global.pinned_ctx = coll_global.local_comm->pinned_ctx();
+  coll_global.sync_var1 = mxnet::NDArray(mxnet::TShape({1, 1}), coll_global.pinned_ctx, true);
+  coll_global.sync_var2 = mxnet::NDArray(mxnet::TShape({1, 1}), coll_global.pinned_ctx, true);
+  coll_global.sync_key = 0xfeedbeaf;
+
+  coll_global.background_thread = std::thread(BackgroundThreadLoop);
+  std::unique_lock<std::mutex> lock(coll_global.mu);
+  coll_global.cv.wait(lock);
+  if (!coll_global.initialization_done) {
+    coll_global.init_status = -1;
+  }
+
+  MXCOLL_DEBUG(coll_global.rank, "MPI Initialization Done!\n");
+  return coll_global.init_status;
+}
+
+int IsMPIInitialized() {
+  if (!coll_global.initialization_done) {
+    return 0;
+  }
+  return 1;
+}
+
+void EnqueueCollective(CollectiveOpRecord record,
+                       MPIRequest::RequestType rtype,
+                       mxnet::Engine::CallbackOnComplete cb) {
+  record.callback = cb;
+  MPIRequest message;
+  MPIDataType mpiDataType;
+  message.set_request_rank(record.rank);
+  message.set_key_name(record.key);
+  int ret = DataTypeToMPIType(record.val_in->dtype(), &mpiDataType);
+  if (ret != 0) {
+    LOG(FATAL) << "Unknown ndarray type:" << record.val_in->dtype();
+    return;
+  }
+  message.set_value_type(mpiDataType);
+  message.set_request_type(rtype);
+  if (rtype == MPIRequest::BROADCAST) {
+    message.set_root_rank(record.root_rank);
+  }
+
+  std::lock_guard<std::mutex> guard(coll_global.mu);
+  coll_global.ndarray_table.emplace(record.key, record);
+  coll_global.message_queue.push(message);
+  MXCOLL_DEBUG(coll_global.rank, "Enqueue ndarray key [%s] to message queue!\n",
+               record.key.c_str());
+}
+};  // namespace
+
+namespace mxnet {
+namespace kvstore {
+
+int MXGetMpiSize(int *ret) {
+  if (IsMPIInitialized()) {
+    *ret = coll_global.size;
+    return 0;
+  }
+  return -1;
+}
+
+int MXGetMpiRank(int *ret) {
+  if (IsMPIInitialized()) {
+    *ret = coll_global.rank;
+    return 0;
+  }
+  return -1;
+}
+
+int MXCOLLIBInit(Comm *comm) {
+  return InitializeMPIOnce(comm);
+}
+
+int MXGetLocalRank(int *ret) {
+  if (IsMPIInitialized()) {
+    *ret = coll_global.local_rank;
+    return 0;
+  }
+  return -1;
+}
+
+int MXAllReduceImpl(const std::string& mpi_key,
+                    mxnet::NDArray* send_value,
+                    mxnet::NDArray* recv_value,
+                    int priority) {
+  CollectiveOpRecord record;
+  record.key = mpi_key;
+  record.rank = coll_global.rank;
+  record.val_in = send_value;
+  record.val_out = recv_value;
+  MXCOLL_DEBUG(coll_global.rank, "MXAllReduceImpl insert one record key [%s]!\n",
+              record.key.c_str());
+
+  auto all_reduce_async_fn = [record]
+  (mxnet::RunContext rctx, mxnet::Engine::CallbackOnComplete cb) {
+    EnqueueCollective(record, MPIRequest::ALLREDUCE, cb);
+  };
+  if (send_value->var() != recv_value->var()) {
+    CHECK_NOTNULL(mxnet::Engine::Get())->PushAsync(
+      all_reduce_async_fn,
+      coll_global.pinned_ctx,
+      {record.val_in->var()},
+      {record.val_out->var()},
+      mxnet::FnProperty::kNormal,
+      priority, "KVSTORE ALLREDUCE");
+  } else {
+    CHECK_NOTNULL(mxnet::Engine::Get())->PushAsync(
+      all_reduce_async_fn,
+      coll_global.pinned_ctx,
+      {},
+      {record.val_out->var()},
+      mxnet::FnProperty::kNormal,
+      priority, "KVSTORE ALLREDUCE");
+  }
+  return 0;
+}
+
+int MXAllReduce(int key,
+                mxnet::NDArray* send_value,
+                mxnet::NDArray* recv_value,
+                int priority) {
+  std::string key_prefix  = INT_PREFIX;
+  std::string delimiter   = DELIMITER;
+  std::string ops_prefix  = OPS_PREFIX;
+  std::string ops_allreduce   = OPS_ALLREDUCE;
+  std::string mpi_key = ops_prefix + delimiter + ops_allreduce + delimiter +
+                        key_prefix + delimiter + std::to_string(key);
+  int ret = MXAllReduceImpl(mpi_key, send_value, recv_value, priority);
+  return ret;
+}
+
+int MXBroadcastImpl(std::string mpi_key,
+                    mxnet::NDArray* value,
+                    int root_rank,
+                    int priority) {
+  CollectiveOpRecord record;
+  record.key = mpi_key;
+  record.rank = coll_global.rank;
+  record.root_rank = root_rank;
+  record.val_in = value;
+  MXCOLL_DEBUG(coll_global.rank, "MXBroadCastImpl insert one record key [%s]!\n",
+              record.key.c_str());
+
+  auto broadcast_async_fn = [record]
+  (mxnet::RunContext rctx, mxnet::Engine::CallbackOnComplete cb) {
+    EnqueueCollective(record, MPIRequest::BROADCAST, cb);
+  };
+  CHECK_NOTNULL(mxnet::Engine::Get())->PushAsync(
+    broadcast_async_fn,
+    coll_global.pinned_ctx,
+    {},
+    {record.val_in->var()},
+    mxnet::FnProperty::kNormal,
+    priority, "KVSTORE BROADCAST");
+  return 0;
+}
+
+int MXBroadcast(int key,
+                mxnet::NDArray* value,
+                int root_rank,
+                int priority) {
+  std::string key_prefix  = INT_PREFIX;
+  std::string delimiter   = DELIMITER;
+  std::string ops_prefix  = OPS_PREFIX;
+  std::string ops_broadcast   = OPS_BROADCAST;
+  std::string mpi_key = ops_prefix + delimiter + ops_broadcast + delimiter +
+                        key_prefix + delimiter + std::to_string(key);
+  return MXBroadcastImpl(mpi_key, value, root_rank, priority);
+}
+
+int MXAllGather(int key,
+                mxnet::NDArray* value,
+                int priority) {
+  // place holder
+  LOG(FATAL) << "Collective AllGather has not been implemented yet!";
+  return 0;
+}
+
+int MXBarrier() {
+  int ret = MXAllReduce(coll_global.sync_key,
+                        &coll_global.sync_var1,
+                        &coll_global.sync_var2,
+                        0);
+  mxnet::Engine::Get()->WaitForAll();
+  return ret;
+}
+
+}  // end of namespace kvstore
+}  // end of namespace mxnet
+#endif
diff --git a/src/kvstore/collectives/src/mpi_message.proto b/src/kvstore/collectives/src/mpi_message.proto
new file mode 100644
index 00000000000..265287587eb
--- /dev/null
+++ b/src/kvstore/collectives/src/mpi_message.proto
@@ -0,0 +1,89 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+
+/**
+ * Copyright (c) 2018 by Contributors
+ */
+
+syntax = "proto3";
+
+package mxnet.kvstore;
+
+// We would like to just use DataType here, but since this
+// is a contrib package, linking directly to MXNet protos seems to be
+// impossible. Doing so compiles, but fails with a cryptic error at runtime
+// about a pointer that was passed to free() but not created by malloc().
+//
+// Since using the mxnet/core protos seems to cause issues, we use our own,
+// which also has the benefit of supporting only the data types we want to support.
+enum MPIDataType {
+    MX_MPI_INVALID_TYPE = 0;
+    MX_MPI_FLOAT32 = 1;
+    MX_MPI_INT32 = 2;
+    MX_MPI_INT64 = 3;
+};
+
+// An MPIRequest is a message sent from a rank greater than zero to the
+// coordinator (rank zero), informing the coordinator of an operation that
+// the rank wants to do and the tensor that it wants to apply the operation to.
+message MPIRequest {
+  enum RequestType {
+    ALLREDUCE = 0;
+    ALLGATHER = 1;
+    BROADCAST = 2;
+  }
+
+  // The request rank is necessary to create a consistent ordering of results,
+  // for example in the allgather where the order of outputs should be sorted
+  // by rank.
+  int32 request_rank = 1;
+  string key_name = 2;
+  RequestType request_type = 3;
+  MPIDataType value_type = 4;
+  int32 root_rank = 5;
+
+  // We use a repeated integer instead of a TensorShapeProto because linking directly
+  // to MXNet protos causes issues. See the comment for MPIDataType.
+  repeated int64 value_shape = 6;
+};
+
+// An MPIResponse is a message sent from the coordinator (rank zero) to a rank
+// greater than zero, informing the rank of an operation should be performed
+// now. If the operation requested would result in an error (for example, due
+// to a type or shape mismatch), then the MPIResponse can contain an error and
+// an error message instead. Finally, an MPIResponse can be a DONE message (if
+// there are no more tensors to reduce on this tick of the background loop) or
+// SHUTDOWN if all MPI processes should shut down.
+message MPIResponse {
+  enum ResponseType {
+    ALLREDUCE = 0;
+    ALLGATHER = 1;
+    BROADCAST = 2;
+    ERROR = 3;
+    DONE = 4;
+    SHUTDOWN = 5;
+  }
+
+  // Empty if the type is DONE or SHUTDOWN.
+  ResponseType response_type = 1;
+  string key_name = 2;
+
+  // Empty unless response_type is ERROR.
+  string error_message = 3;
+};
diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc
index 4d17fffa1a3..dc3591d970f 100644
--- a/src/kvstore/kvstore.cc
+++ b/src/kvstore/kvstore.cc
@@ -35,6 +35,10 @@ std::atomic<int> mxnet::kvstore::KVStoreDist::customer_id_{0};
 #include "./kvstore_nccl.h"
 #endif  // MXNET_USE_NCCL
 
+#if MXNET_USE_ALLREDUCE_DIST_KVSTORE
+#include "./kvstore_dist_sync_allreduce.h"
+#endif
+
 namespace mxnet {
 
 KVStore* KVStore::Create(const char *type_name) {
@@ -50,6 +54,19 @@ KVStore* KVStore::Create(const char *type_name) {
   }
 
   if (has("dist")) {
+#if defined(MXNET_USE_ALLREDUCE_DIST_KVSTORE) && defined(MXNET_USE_DIST_KVSTORE)
+    if (has("allreduce")) {
+      kv = new kvstore::KVStoreDistSyncAllReduce(use_device_comm);
+      kv->type_ = tname;
+      return kv;
+    }
+#else
+    if (has("allreduce")) {
+      LOG(FATAL) << "compile with USE_ALLREDUCE_DIST_KVSTORE=1 to use " << tname;
+      return nullptr;
+    }
+#endif
+
 #if MXNET_USE_DIST_KVSTORE
     kv = new kvstore::KVStoreDist(use_device_comm);
     if (!has("_async") && kv->IsWorkerNode() && kv->get_rank() == 0) {
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 7e2f5cb5faa..c8a04cd7af0 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -503,16 +503,6 @@ class KVStoreDist : public KVStoreLocal {
       "KVStoreDistRowSparsePull");
   }
 
-  /**
-   * \brief check if the keys are all unique
-   */
-  void CheckUnique(const std::vector<int>& keys) {
-    auto keys_copy = keys;
-    auto last = std::unique(keys_copy.begin(), keys_copy.end());
-    CHECK_EQ(static_cast<size_t>(std::distance(keys_copy.begin(), last)),
-             static_cast<size_t>(keys.size()));
-  }
-
   /**
    * \brief convert to pskv for parameter server
    * \param key
diff --git a/src/kvstore/kvstore_dist_sync_allreduce.h b/src/kvstore/kvstore_dist_sync_allreduce.h
new file mode 100644
index 00000000000..345fb8e5512
--- /dev/null
+++ b/src/kvstore/kvstore_dist_sync_allreduce.h
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/**
+ * Copyright (c) 2018 by Contributors
+ * @file   kvstore_dist_sync_allreduce.h
+ * @brief  distributed implementation based on allreduce
+ */
+#ifndef MXNET_KVSTORE_KVSTORE_DIST_SYNC_ALLREDUCE_H_
+#define MXNET_KVSTORE_KVSTORE_DIST_SYNC_ALLREDUCE_H_
+
+#include <mxnet/kvstore.h>
+#include <unordered_map>
+#include <bitset>
+#include <vector>
+#include <string>
+#include <utility>
+#include <functional>
+#include <algorithm>
+#include "./comm.h"
+#include "./kvstore_utils.h"
+#include "./kvstore_local.h"
+
+#if MXNET_USE_ALLREDUCE_DIST_KVSTORE
+#include "collectives/include/collectives.h"
+
+namespace mxnet {
+namespace kvstore {
+
+/**
+ * \brief store data in local machine
+ */
+class KVStoreDistSyncAllReduce : public KVStoreLocal {
+ public:
+  explicit KVStoreDistSyncAllReduce(bool use_device_comm)
+  : KVStoreLocal(use_device_comm) {
+    int ret = MXCOLLIBInit(comm_);
+    if (ret != 0) {
+      LOG(FATAL) << "kvstore with type [" << type_ << "] failed with collective library init";
+    }
+  }
+
+  virtual ~KVStoreDistSyncAllReduce() {
+  }
+
+  void Push(const std::vector<int>& keys,
+            const std::vector<NDArray>& values,
+            int priority) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void Pull(const std::vector<int>& keys,
+            const std::vector<NDArray*>& values,
+            int priority,
+            bool ignore_sparse) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void PullRowSparse(const std::vector<int>& keys,
+                     const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+                     int priority = 0) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void Push(const std::vector<std::string>& str_keys,
+            const std::vector<NDArray>& values,
+            int priority) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void Pull(const std::vector<std::string>& str_keys,
+            const std::vector<NDArray*>& values,
+            int priority,
+            bool ignore_sparse) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void PullRowSparse(const std::vector<std::string>& str_keys,
+                     const std::vector<std::pair<NDArray*, NDArray>>& val_rowids,
+                     int priority = 0) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
+                              & kwargs) override {
+    LOG(FATAL) << "Not supported in KVStore with type " << type_ << ".";
+  }
+
+  void PushPull(const std::vector<int> &keys,
+                const std::vector<NDArray> &in_values,
+                const std::vector<NDArray*> &out_values,
+                int priority) override {
+    SetKeyType(kIntKey);
+    PushPullImpl(keys, in_values, out_values, priority);
+  }
+
+  void PushPull(const std::vector<std::string> &str_keys,
+                const std::vector<NDArray> &in_values,
+                const std::vector<NDArray*> &out_values,
+                int priority) override {
+    SetKeyType(kStringKey);
+    std::vector<int> keys(str_keys.size());
+    LookupKeys(str_keys, &keys);
+    PushPullImpl(keys, in_values, out_values, priority);
+  }
+
+  void Broadcast(const std::vector<int> &keys,
+                 const std::vector<NDArray*> &values,
+                 int root_rank,
+                 int priority) override {
+    SetKeyType(kIntKey);
+    BroadcastImpl(keys, values, root_rank, priority);
+  }
+
+  void Broadcast(const std::vector<std::string> &str_keys,
+                 const std::vector<NDArray*> &values,
+                 int root_rank,
+                 int priority) override {
+    SetKeyType(kStringKey);
+    std::vector<int> keys(str_keys.size());
+    LookupKeys(str_keys, &keys);
+    BroadcastImpl(keys, values, root_rank, priority);
+  }
+
+  void Barrier() override {
+    int ret = MXBarrier();
+    if (ret != 0) {
+      LOG(FATAL) << "MXBarrier is not successful. ret: " << ret;
+    }
+  }
+
+  int get_rank() const override {
+    int ret, rank;
+    ret = MXGetMpiRank(&rank);
+    if (ret != 0) {
+      LOG(FATAL) << "MXGetMpiRank is not successful. ret: " << ret;
+      rank = -1;
+    }
+    return rank;
+  }
+
+  int get_group_size() const override {
+    int ret, size;
+    ret = MXGetMpiSize(&size);
+    if (ret != 0) {
+      LOG(FATAL) << "MXGetMpiSize is not successful. ret: " << ret;
+      size = -1;
+    }
+    return size;
+  }
+
+ private:
+  void InitImpl(const std::vector<int>& keys,
+                const std::vector<NDArray>& values) override {
+    CheckUnique(keys);
+    for (size_t i = 0; i < keys.size(); ++i) {
+      comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
+    }
+  }
+
+  void PushPullImpl(const std::vector<int> &keys,
+                    const std::vector<NDArray> &in_values,
+                    const std::vector<NDArray*> &out_values,
+                    int priority) {
+    std::vector<int> uniq_keys;
+    std::vector<std::vector<NDArray> > grouped_invals;
+    std::vector<std::vector<NDArray*> > grouped_outvals;
+
+    CHECK_EQ(in_values.size(), out_values.size());
+    GroupKVPairsPush(keys, in_values, &uniq_keys, &grouped_invals, false);
+    uniq_keys.clear();
+    GroupKVPairsPull(keys, out_values, &uniq_keys, &grouped_outvals, true);
+
+    for (size_t i = 0; i < uniq_keys.size(); ++i) {
+      // reduce over devices
+      int key = uniq_keys[i];
+      const auto& invals = grouped_invals[i];
+      NDArray reduced = comm_->Reduce(key, invals, priority);
+      const auto storage_type = reduced.storage_type();
+      auto &comm_buf = comm_buf_[key];
+      if (reduced.ctx().dev_mask() == cpu::kDevMask) {
+        comm_buf = reduced;  // avoid memory copy
+      } else {
+         if (comm_buf.is_none()) {
+          if (storage_type == kDefaultStorage) {
+            comm_buf = NDArray(reduced.shape(), pinned_ctx_, true, reduced.dtype());
+          } else {
+            comm_buf = NDArray(storage_type, reduced.shape(), pinned_ctx_, true, reduced.dtype());
+          }
+        }
+        CopyFromTo(reduced, &comm_buf);
+      }
+      int ret = MXAllReduce(key, &comm_buf, &comm_buf, priority);
+      if (ret != 0) {
+        LOG(FATAL) << "MXAllReduce is not successful. ret:" << ret;
+      }
+      comm_->Broadcast(key, comm_buf, grouped_outvals[i], priority);
+    }
+  }
+
+  void BroadcastImpl(const std::vector<int> &keys,
+                     const std::vector<NDArray*> &values,
+                     int root_rank,
+                     int priority) {
+    std::vector<int> uniq_keys;
+    std::vector<std::vector<NDArray*> > grouped_vals;
+    GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true);
+
+    for (size_t i = 0; i < uniq_keys.size(); ++i) {
+      int key = uniq_keys[i];
+      auto& comm_buf = comm_buf_[key];
+      const auto storage_type = grouped_vals[i][0]->storage_type();
+      CHECK_EQ(storage_type, kDefaultStorage)
+              << "Expected stype of value to be kDefaultStorage";
+      if (comm_buf.is_none()) {
+        comm_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
+                          true, grouped_vals[i][0]->dtype());
+      }
+
+      if (get_rank() == 0) {
+        CopyFromTo(*grouped_vals[i][0], &comm_buf);
+      }
+      int ret = MXBroadcast(key, &comm_buf, root_rank, priority);
+      if (ret != 0) {
+        LOG(FATAL) << "MXBroadcast is not successful. ret:" << ret;
+      }
+      comm_->Broadcast(key, comm_buf, grouped_vals[i], priority);
+    }
+  }
+
+ private:
+  std::unordered_map<int, NDArray> comm_buf_;
+};
+}  // namespace kvstore
+}  // namespace mxnet
+
+#endif  // MXNET USE ALLREDUCE DIST KVSTORE
+#endif  // MXNET_KVSTORE_KVSTORE_DIST_SYNC_ALLREDUCE_H_
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 84e2700a20d..1df488ee0fa 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -416,6 +416,16 @@ class KVStoreLocal : public KVStore {
     return out;
   }
 
+  /**
+   * \brief check if the keys are all unique
+   */
+  void CheckUnique(const std::vector<int>& keys) {
+    auto keys_copy = keys;
+    auto last = std::unique(keys_copy.begin(), keys_copy.end());
+    CHECK_EQ(static_cast<size_t>(std::distance(keys_copy.begin(), last)),
+             static_cast<size_t>(keys.size()));
+  }
+
   /// reducer and broadcaster
   Comm* comm_;
   /// pinned context
diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries
index c0c14b26667..28bf0a7a21d 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -78,6 +78,15 @@ try {
         }
       }
     }
+    parallel 'CPU: Build allreduce KVStore': {
+      node('mxnetlinux-cpu') {
+        ws('workspace/build-cpu') {
+          init_git()
+          docker_run('ubuntu_base_cpu', 'build_ubuntu_cpu_allreduce_kvstore', false)
+          pack_lib('cpu', mx_lib)
+        }
+      }
+    }
   }
 
   stage('NightlyTests'){
@@ -95,7 +104,7 @@ try {
         ws('workspace/nt-KVStoreTest') {
           init_git()
           unpack_lib('gpu', mx_lib)
-          docker_run('ubuntu_nightly_gpu', 'nightly_test_KVStore_singleNode', true) 
+          docker_run('ubuntu_nightly_gpu', 'nightly_test_KVStore_singleNode', true)
         }
       }
     }
diff --git a/tests/nightly/dist_allreduce_sync_kvstore.py b/tests/nightly/dist_allreduce_sync_kvstore.py
new file mode 100644
index 00000000000..701fc54b9f7
--- /dev/null
+++ b/tests/nightly/dist_allreduce_sync_kvstore.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+sys.path.insert(0, "../../python/")
+import mxnet as mx
+import numpy as np
+import numpy.random as rnd
+import time
+
+def check_diff_to_scalar(A, x, rank=None):
+    """ assert A == x"""
+    assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
+
+# setup
+keys = ['3', '5', '7']
+shape = (2, 3)
+big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
+
+kv = mx.kv.create('dist_sync_allreduce')
+
+def init_kv():
+    my_rank = kv.rank
+    nworker = kv.num_workers
+    return kv, my_rank, nworker
+
+def test_sync_pushpull():
+    kv, my_rank, nworker = init_kv()
+    kv.init('3', mx.nd.ones(shape))
+    kv.init('99', mx.nd.ones(big_shape))
+    def check_pushpull(kv, my_rank, nworker):
+        nrepeat = 3
+        for i in range(nrepeat):
+            val = mx.nd.zeros(shape)
+            val2 = mx.nd.zeros(big_shape)
+            in_ = mx.nd.ones(shape)
+            in2_ = mx.nd.ones(big_shape)
+            kv.pushpull('3', in_, val)
+            kv.pushpull('99', in2_, val2)
+            num = nworker;
+            check_diff_to_scalar(val, num)
+            check_diff_to_scalar(val2, num)
+
+    check_pushpull(kv, my_rank, nworker)
+    print('worker ' + str(my_rank) + ' pushpull is done')
+
+def test_sync_broadcast():
+    kv, my_rank, nworker = init_kv()
+    kv.init('4', mx.nd.zeros(shape))
+    kv.init('98', mx.nd.zeros(big_shape))
+    def check_broadcast(kv, my_rank, nworker):
+        nrepeat = 3
+        for i in range(nrepeat):
+            if my_rank == 0:
+                val = mx.nd.ones(shape)
+                val2 = mx.nd.ones(big_shape)
+            else:
+                val = mx.nd.zeros(shape)
+                val2 = mx.nd.zeros(big_shape)
+            kv.broadcast('4', val, 0)
+            kv.broadcast('98', val2, 0)
+            num = 1
+            check_diff_to_scalar(val, num)
+            check_diff_to_scalar(val2, num)
+    check_broadcast(kv, my_rank, nworker)
+    print('worker ' + str(my_rank) + ' broadcast is done')
+if __name__ == "__main__":
+    test_sync_pushpull()
+    test_sync_broadcast()
diff --git a/tests/nightly/test_all.sh b/tests/nightly/test_all.sh
index 04d895fecf2..36240ef0076 100755
--- a/tests/nightly/test_all.sh
+++ b/tests/nightly/test_all.sh
@@ -44,6 +44,7 @@ USE_CUDA=1
 USE_CUDA_PATH=/usr/local/cuda
 USE_CUDNN=1
 USE_DIST_KVSTORE=1
+USE_ALLREDUCE_DIST_KVSTORE=1
 EOF
 
 juLog -name=Build -error=Error build
@@ -54,6 +55,9 @@ juLog -name=Python.Local.KVStore -error=Error python test_kvstore.py
 # python: distributed kvstore
 juLog -name=Python.Distributed.KVStore -error=Error ../../tools/launch.py -n 4 python dist_sync_kvstore.py
 
+#python: distributed kvstore with allreduce type
+juLog -name=Python.Distributed.KVStore.AllReduce -error=Error ../../3rdparty/mpich/build/bin/mpirun -n 4 python dist_allreduce_sync_kvstore.py
+
 # download data
 juLog -name=DownloadData bash ./download.sh
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services