You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2017/11/21 03:20:52 UTC

[incubator-mxnet] branch master updated: 2bit gradient compression (#8728)

This is an automated email from the ASF dual-hosted git repository.

cjolivier01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cd47962  2bit gradient compression (#8728)
cd47962 is described below

commit cd47962fc4a670f1058bf797d71544c34da32c62
Author: Rahul Huilgol <ra...@gmail.com>
AuthorDate: Mon Nov 20 19:20:50 2017 -0800

    2bit gradient compression (#8728)
    
    * 2bit gradient compression
    
    * trigger CI
---
 example/image-classification/common/fit.py |  44 ++--
 example/rnn/lstm_bucketing.py              |   1 -
 include/mxnet/c_api.h                      |  13 +
 include/mxnet/kvstore.h                    |  15 ++
 python/mxnet/gluon/trainer.py              |  12 +-
 python/mxnet/kvstore.py                    |  62 +++++
 python/mxnet/module/bucketing_module.py    |  17 +-
 python/mxnet/module/module.py              |  11 +-
 src/c_api/c_api.cc                         |  14 ++
 src/kvstore/comm.h                         |  87 ++++++-
 src/kvstore/gradient_compression-inl.h     | 155 ++++++++++++
 src/kvstore/gradient_compression.cc        | 193 ++++++++++++++
 src/kvstore/gradient_compression.cu        |  40 +++
 src/kvstore/gradient_compression.h         | 138 ++++++++++
 src/kvstore/kvstore.cc                     |   2 +-
 src/kvstore/kvstore_dist.h                 | 388 +++++++++++++++++++++--------
 src/kvstore/kvstore_dist_server.h          | 143 +++++++++--
 src/kvstore/kvstore_local.h                |   7 +
 tests/nightly/dist_sync_kvstore.py         | 120 ++++++++-
 tests/nightly/test_kvstore.py              | 200 +++++++++++++--
 tools/bandwidth/measure.py                 |   6 +-
 21 files changed, 1501 insertions(+), 167 deletions(-)

diff --git a/example/image-classification/common/fit.py b/example/image-classification/common/fit.py
index 51a1abe..2b002c7 100755
--- a/example/image-classification/common/fit.py
+++ b/example/image-classification/common/fit.py
@@ -103,6 +103,11 @@ def add_fit_args(parser):
                        help='1 means test reading speed without training')
     train.add_argument('--dtype', type=str, default='float32',
                        help='precision: float32 or float16')
+    train.add_argument('--gc-type', type=str, default='none',
+                       help='type of gradient compression to use, \
+                             takes `2bit` or `none` for now')
+    train.add_argument('--gc-threshold', type=float, default=0.5,
+                       help='threshold for 2bit gradient compression')
     return train
 
 def fit(args, network, data_loader, **kwargs):
@@ -114,6 +119,9 @@ def fit(args, network, data_loader, **kwargs):
     """
     # kvstore
     kv = mx.kvstore.create(args.kv_store)
+    if args.gc_type != 'none':
+        kv.set_gradient_compression({'type': args.gc_type,
+                                     'threshold': args.gc_threshold})
 
     # logging
     head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
@@ -162,10 +170,10 @@ def fit(args, network, data_loader, **kwargs):
 
     lr_scheduler  = lr_scheduler
     optimizer_params = {
-            'learning_rate': lr,
-            'wd' : args.wd,
-            'lr_scheduler': lr_scheduler,
-            'multi_precision': True}
+        'learning_rate': lr,
+        'wd' : args.wd,
+        'lr_scheduler': lr_scheduler,
+        'multi_precision': True}
 
     # Only a limited number of optimizers have 'momentum' property
     has_momentum = {'sgd', 'dcasgd', 'nag'}
@@ -195,17 +203,17 @@ def fit(args, network, data_loader, **kwargs):
 
     # run
     model.fit(train,
-        begin_epoch        = args.load_epoch if args.load_epoch else 0,
-        num_epoch          = args.num_epochs,
-        eval_data          = val,
-        eval_metric        = eval_metrics,
-        kvstore            = kv,
-        optimizer          = args.optimizer,
-        optimizer_params   = optimizer_params,
-        initializer        = initializer,
-        arg_params         = arg_params,
-        aux_params         = aux_params,
-        batch_end_callback = batch_end_callbacks,
-        epoch_end_callback = checkpoint,
-        allow_missing      = True,
-        monitor            = monitor)
+              begin_epoch        = args.load_epoch if args.load_epoch else 0,
+              num_epoch          = args.num_epochs,
+              eval_data          = val,
+              eval_metric        = eval_metrics,
+              kvstore            = kv,
+              optimizer          = args.optimizer,
+              optimizer_params   = optimizer_params,
+              initializer        = initializer,
+              arg_params         = arg_params,
+              aux_params         = aux_params,
+              batch_end_callback = batch_end_callbacks,
+              epoch_end_callback = checkpoint,
+              allow_missing      = True,
+              monitor            = monitor)
diff --git a/example/rnn/lstm_bucketing.py b/example/rnn/lstm_bucketing.py
index 2e7bc65..0e7f064 100644
--- a/example/rnn/lstm_bucketing.py
+++ b/example/rnn/lstm_bucketing.py
@@ -48,7 +48,6 @@ parser.add_argument('--batch-size', type=int, default=32,
 parser.add_argument('--disp-batches', type=int, default=50,
                     help='show progress for every n batches')
 
-
 def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
     if not os.path.isfile(fname):
         raise IOError("Please use get_ptb_data.sh to download requied file (data/ptb.train.txt)")
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index fa8d995..9815786 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -1544,6 +1544,19 @@ MXNET_DLL int MXInitPSEnv(mx_uint num_vars,
  */
 MXNET_DLL int MXKVStoreCreate(const char *type,
                               KVStoreHandle *out);
+
+/*!
+ * \brief Set parameters to use low-bit compressed gradients
+ * \param handle handle to the kvstore
+ * \param keys keys for compression parameters
+ * \param vals values for compression parameters
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXKVStoreSetGradientCompression(KVStoreHandle handle,
+                                              mx_uint num_params,
+                                              const char** keys,
+                                              const char** vals);
+
 /*!
  * \brief Delete a KVStore handle.
  * \param handle handle to the kvstore
diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h
index 1649c43..4e99a9c 100644
--- a/include/mxnet/kvstore.h
+++ b/include/mxnet/kvstore.h
@@ -31,6 +31,7 @@
 #include <string>
 #include <functional>
 #include <atomic>
+#include "../../src/kvstore/gradient_compression.h"
 #include "./ndarray.h"
 #if MXNET_USE_DIST_KVSTORE
 #include "ps/ps.h"
@@ -65,6 +66,14 @@ class KVStore {
    */
   inline const std::string& type() { return type_; }
 
+  /**
+   * \brief Set parameters to use low-bit compressed gradients
+   * \param compression_type type of compression
+   * \param threshold threshold for 2bit compression
+   */
+  virtual void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
+                                      & kwargs) = 0;
+
   /*!
    * \brief Initialize a list of key-value pair to the store.
    *
@@ -388,6 +397,12 @@ class KVStore {
    */
   std::string type_;
 
+  /** \brief Gradient compression object starts with GC_NONE mode
+   * Used if SetGradientCompression sets the type.
+   * Currently there is no support for un-setting gradient compression
+   */
+  std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
+
   /**
    * \brief whether to do barrier when finalize
    */
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index 115d1ff..f3a1460 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -44,6 +44,11 @@ class Trainer(object):
     kvstore : str or KVStore
         kvstore type for multi-gpu and distributed training. See help on
         :any:`mxnet.kvstore.create` for more information.
+    compression_params : dict
+        Specifies type of gradient compression and additional arguments depending
+        on the type of compression being used. For example, 2bit compression requires a threshold.
+        Arguments would then be {'type':'2bit', 'threshold':0.5}
+        See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
 
     Properties
     ----------
@@ -51,7 +56,8 @@ class Trainer(object):
         The current learning rate of the optimizer. Given an Optimizer object
         optimizer, its learning rate can be accessed as optimizer.learning_rate.
     """
-    def __init__(self, params, optimizer, optimizer_params=None, kvstore='device'):
+    def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
+                 compression_params=None):
         if isinstance(params, (dict, ParameterDict)):
             params = list(params.values())
         if not isinstance(params, (list, tuple)):
@@ -65,7 +71,7 @@ class Trainer(object):
                     "First argument must be a list or dict of Parameters, " \
                     "got list of %s."%(type(param)))
             self._params.append(param)
-
+        self._compression_params = compression_params
         optimizer_params = optimizer_params if optimizer_params else {}
         self._scale = optimizer_params.get('rescale_grad', 1.0)
         self._contexts = self._check_contexts()
@@ -104,6 +110,8 @@ class Trainer(object):
         kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts),
                                                      arg_arrays)
         if kvstore:
+            if self._compression_params:
+                kvstore.set_gradient_compression(self._compression_params)
             if 'dist' in kvstore.type:
                 update_on_kvstore = False
             for i, param in enumerate(self._params):
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index 8625303..bf42455 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -64,6 +64,16 @@ def _ctype_key_value(keys, vals):
                  else c_array_buf(ctypes.c_int, array('i', [keys] * len(vals)))
         return (c_keys, c_handle_array(vals), use_str_keys)
 
+def _ctype_dict(param_dict):
+    """
+    Returns ctype arrays for keys and values(converted to strings) in a dictionary
+    """
+    assert(isinstance(param_dict, dict)), \
+        "unexpected type for param_dict: " + str(type(param_dict))
+    c_keys = c_array(ctypes.c_char_p, [c_str(k) for k in param_dict.keys()])
+    c_vals = c_array(ctypes.c_char_p, [c_str(str(v)) for v in param_dict.values()])
+    return (c_keys, c_vals)
+
 def _updater_wrapper(updater):
     """A wrapper for the user-defined handle."""
     def updater_handle(key, lhs_handle, rhs_handle, _):
@@ -350,6 +360,58 @@ class KVStore(object):
             check_call(_LIB.MXKVStorePullRowSparse(
                 self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority)))
 
+    def set_gradient_compression(self, compression_params):
+        """ Specifies type of low-bit quantization for gradient compression \
+         and additional arguments depending on the type of compression being used.
+
+        2bit Gradient Compression takes a positive float `threshold`.
+        The technique works by thresholding values such that positive values in the
+        gradient above threshold will be set to threshold. Negative values whose absolute
+        values are higher than threshold, will be set to the negative of threshold.
+        Values whose absolute values are less than threshold will be set to 0.
+        By doing so, each value in the gradient is in one of three states. 2bits are
+        used to represent these states, and every 16 float values in the original
+        gradient can be represented using one float. This compressed representation
+        can reduce communication costs. The difference between these thresholded values and
+        original values is stored at the sender's end as residual and added to the
+        gradient in the next iteration.
+
+        When kvstore is 'local', gradient compression is used to reduce communication
+        between multiple devices (gpus). Gradient is quantized on each GPU which
+        computed the gradients, then sent to the GPU which merges the gradients. This
+        receiving GPU dequantizes the gradients and merges them. Note that this
+        increases memory usage on each GPU because of the residual array stored.
+
+        When kvstore is 'dist', gradient compression is used to reduce communication
+        from worker to sender. Gradient is quantized on each worker which
+        computed the gradients, then sent to the server which dequantizes
+        this data and merges the gradients from each worker. Note that this
+        increases CPU memory usage on each worker because of the residual array stored.
+        Only worker to server communication is compressed in this setting.
+        If each machine has multiple GPUs, currently this GPU to GPU or GPU to CPU communication
+        is not compressed. Server to worker communication (in the case of pull)
+        is also not compressed.
+
+        To use 2bit compression, we need to specify `type` as `2bit`.
+        Only specifying `type` would use default value for the threshold.
+        To completely specify the arguments for 2bit compression, we would need to pass
+        a dictionary which includes `threshold` like:
+        {'type': '2bit', 'threshold': 0.5}
+
+        Parameters
+        ----------
+        compression_params : dict
+            A dictionary specifying the type and parameters for gradient compression.
+            The key `type` in this dictionary is a
+            required string argument and specifies the type of gradient compression.
+            Currently `type` can be only `2bit`
+            Other keys in this dictionary are optional and specific to the type
+            of gradient compression.
+        """
+        ckeys, cvals = _ctype_dict(compression_params)
+        check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
+                                                        mx_uint(len(compression_params)),
+                                                        ckeys, cvals))
 
     def set_optimizer(self, optimizer):
         """ Registers an optimizer with the kvstore.
diff --git a/python/mxnet/module/bucketing_module.py b/python/mxnet/module/bucketing_module.py
index fa92c5d..0bea260 100644
--- a/python/mxnet/module/bucketing_module.py
+++ b/python/mxnet/module/bucketing_module.py
@@ -54,10 +54,16 @@ class BucketingModule(BaseModule):
         Instead they are initialized to 0 and can be set by set_states()
     group2ctxs : list of dict of str to context
         Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
+    compression_params : dict
+        Specifies type of gradient compression and additional arguments depending
+        on the type of compression being used. For example, 2bit compression requires a threshold.
+        Arguments would then be {'type':'2bit', 'threshold':0.5}
+        See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
     """
     def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
                  context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None, group2ctxs=None):
+                 fixed_param_names=None, state_names=None, group2ctxs=None,
+                 compression_params=None):
         super(BucketingModule, self).__init__(logger=logger)
 
         assert default_bucket_key is not None
@@ -75,6 +81,7 @@ class BucketingModule(BaseModule):
         _check_input_names(symbol, state_names, "state", True)
         _check_input_names(symbol, fixed_param_names, "fixed_param", True)
 
+        self._compression_params = compression_params
         self._fixed_param_names = fixed_param_names
         self._state_names = state_names
         self._context = context
@@ -323,7 +330,9 @@ class BucketingModule(BaseModule):
         module = Module(symbol, data_names, label_names, logger=self.logger,
                         context=self._context, work_load_list=self._work_load_list,
                         fixed_param_names=self._fixed_param_names,
-                        state_names=self._state_names, group2ctxs=self._group2ctxs)
+                        state_names=self._state_names,
+                        group2ctxs=self._group2ctxs,
+                        compression_params=self._compression_params)
         module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
                     force_rebind=False, shared_module=None, grad_req=grad_req)
         self._curr_module = module
@@ -353,7 +362,9 @@ class BucketingModule(BaseModule):
                             logger=self.logger, context=self._context,
                             work_load_list=self._work_load_list,
                             fixed_param_names=self._fixed_param_names,
-                            state_names=self._state_names, group2ctxs=self._group2ctxs)
+                            state_names=self._state_names,
+                            group2ctxs=self._group2ctxs,
+                            compression_params=self._compression_params)
             module.bind(data_shapes, label_shapes, self._curr_module.for_training,
                         self._curr_module.inputs_need_grad,
                         force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py
index 8301330..a9c6516 100644
--- a/python/mxnet/module/module.py
+++ b/python/mxnet/module/module.py
@@ -61,10 +61,16 @@ class Module(BaseModule):
         Instead they are initialized to 0 and can be set by `set_states()`.
     group2ctxs : list of dict of str to context
         Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
+    compression_params : dict
+        Specifies type of gradient compression and additional arguments depending
+        on the type of compression being used. For example, 2bit compression requires a threshold.
+        Arguments would then be {'type':'2bit', 'threshold':0.5}
+        See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
     """
     def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
                  logger=logging, context=ctx.cpu(), work_load_list=None,
-                 fixed_param_names=None, state_names=None, group2ctxs=None):
+                 fixed_param_names=None, state_names=None, group2ctxs=None,
+                 compression_params=None):
         super(Module, self).__init__(logger=logger)
 
         if isinstance(context, ctx.Context):
@@ -103,6 +109,7 @@ class Module(BaseModule):
         self._aux_params = None
         self._params_dirty = False
 
+        self._compression_params = compression_params
         self._optimizer = None
         self._kvstore = None
         self._update_on_kvstore = None
@@ -525,6 +532,8 @@ class Module(BaseModule):
         self._updater = None
 
         if kvstore:
+            if self._compression_params:
+                kvstore.set_gradient_compression(self._compression_params)
             # copy initialized local parameters to kvstore
             _initialize_kvstore(kvstore=kvstore,
                                 param_arrays=self._exec_group.param_arrays,
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 0dde004..027f00b 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -748,6 +748,20 @@ int MXKVStoreCreate(const char *type,
   API_END();
 }
 
+int MXKVStoreSetGradientCompression(KVStoreHandle handle, mx_uint num_params,
+                                    const char** keys, const char** vals) {
+  API_BEGIN();
+  std::vector<std::pair<std::string, std::string> > params;
+  for (mx_uint i = 0; i < num_params; ++i) {
+    std::pair<std::string, std::string> p;
+    p.first = keys[i];
+    p.second = vals[i];
+    params.push_back(p);
+  }
+  static_cast<KVStore*>(handle)->SetGradientCompression(params);
+  API_END();
+}
+
 int MXKVStoreFree(KVStoreHandle handle) {
   API_BEGIN();
   delete static_cast<KVStore*>(handle);
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index fcf1e6b..5e15c2a 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -31,6 +31,7 @@
 #include <tuple>
 #include <thread>
 #include "mxnet/ndarray.h"
+#include "gradient_compression.h"
 #include "../ndarray/ndarray_function.h"
 #include "../operator/tensor/sparse_retain-inl.h"
 namespace mxnet {
@@ -80,8 +81,18 @@ class Comm {
     return pinned_ctx_;
   }
 
+  /**
+   * \brief Sets gradient compression parameters to be able to
+   * perform reduce with compressed gradients
+   */
+  void SetGradientCompression(std::shared_ptr<GradientCompression> gc) {
+    gc_ = gc;
+  }
+
  protected:
   Context pinned_ctx_;
+
+  std::shared_ptr<GradientCompression> gc_;
 };
 
 /**
@@ -485,14 +496,7 @@ class CommDevice : public Comm {
     }
   }
 
-  const NDArray& Reduce(int key, const std::vector<NDArray>& src,
-                        int priority) override {
-    // avoid extra copy for single device, but it may bring problems for
-    // abnormal usage of kvstore
-    if (src.size() == 1) {
-      return src[0];
-    }
-
+  void InitBuffersAndComm(const std::vector<NDArray>& src) {
     if (!inited_) {
       std::vector<Context> devs;
       for (const auto& a : src) {
@@ -503,7 +507,23 @@ class CommDevice : public Comm {
         EnableP2P(devs);
       }
     }
+  }
+
+  const NDArray& Reduce(int key, const std::vector<NDArray>& src,
+                        int priority) override {
+    // when this reduce is called from kvstore_dist, gc is not set
+    // we don't do compression twice in dist_sync_device
+    if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) {
+      return ReduceCompressed(key, src, priority);
+    }
+
+    // avoid extra copy for single device, but it may bring problems for
+    // abnormal usage of kvstore
+    if (src.size() == 1) {
+      return src[0];
+    }
 
+    InitBuffersAndComm(src);
     auto& buf = merge_buf_[key];
     std::vector<NDArray> reduce(src.size());
     CopyFromTo(src[0], &(buf.merged), priority);
@@ -526,7 +546,52 @@ class CommDevice : public Comm {
     }
 
     ElementwiseSum(reduce, &buf.merged);
+    return buf.merged;
+  }
+
+  const NDArray& ReduceCompressed(int key, const std::vector<NDArray>& src,
+                                  int priority) {
+    InitBuffersAndComm(src);
+    auto& buf = merge_buf_[key];
+    std::vector<NDArray> reduce(src.size());
+    if (buf.copy_buf.empty()) {
+      // one buf for each context
+      buf.copy_buf.resize(src.size());
+      buf.compressed_recv_buf.resize(src.size());
+      buf.compressed_send_buf.resize(src.size());
+      buf.residual.resize(src.size());
 
+      for (size_t i = 0; i < src.size(); ++i) {
+        buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
+                                  false, buf.merged.dtype());
+        buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(),
+                                  false, buf.merged.dtype());
+        buf.residual[i] = 0;
+        int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size());
+        buf.compressed_recv_buf[i] = NDArray(TShape{small_size}, buf.merged.ctx(),
+                                        false, buf.merged.dtype());
+        buf.compressed_send_buf[i] = NDArray(TShape{small_size}, src[i].ctx(),
+                                        false, buf.merged.dtype());
+      }
+    }
+
+    for (size_t i = 0; i < src.size(); ++i) {
+      // compress before copy
+      // this is done even if the data is on same context as copy_buf because
+      // we don't want the training to be biased towards data on this GPU
+      gc_->Quantize(src[i], &(buf.compressed_send_buf[i]), &(buf.residual[i]), priority);
+
+      if (buf.compressed_send_buf[i].ctx() != buf.compressed_recv_buf[i].ctx()) {
+        CopyFromTo(buf.compressed_send_buf[i], &(buf.compressed_recv_buf[i]), priority);
+      } else {
+        // avoid memory copy when they are on same context
+        buf.compressed_recv_buf[i] = buf.compressed_send_buf[i];
+      }
+
+      gc_->Dequantize(buf.compressed_recv_buf[i], &(buf.copy_buf[i]), priority);
+      reduce[i] = buf.copy_buf[i];
+    }
+    ElementwiseSum(reduce, &buf.merged);
     return buf.merged;
   }
 
@@ -639,6 +704,12 @@ class CommDevice : public Comm {
     NDArray merged;
     /// \brief the gpu buffer
     std::vector<NDArray> copy_buf;
+    /// \brief the residual buffer for gradient compression
+    std::vector<NDArray> residual;
+    /// \brief the small buffer for compressed data in sender
+    std::vector<NDArray> compressed_send_buf;
+    /// \brief the small buffer for compressed data in receiver
+    std::vector<NDArray> compressed_recv_buf;
   };
   std::unordered_map<int, BufferEntry> merge_buf_;
   bool inited_;
diff --git a/src/kvstore/gradient_compression-inl.h b/src/kvstore/gradient_compression-inl.h
new file mode 100644
index 0000000..9b69bd1
--- /dev/null
+++ b/src/kvstore/gradient_compression-inl.h
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file gradient_compression-inl.h
+ * \author Rahul Huilgol
+ * \brief Declares and defines functions used to quantize and dequantize data
+ */
+#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_
+#define MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_
+
+#include <vector>
+#include "../operator/mxnet_op.h"
+
+namespace mxnet {
+namespace kvstore {
+
+// these gpu functions are defined in gradient_compression.cu
+void Quantize2BitImpl(mshadow::Stream<mshadow::gpu> *s, const std::vector<mxnet::TBlob> &inputs,
+                      const float threshold);
+void Dequantize2BitImpl(mshadow::Stream<mshadow::gpu> *s, const std::vector<mxnet::TBlob> &inputs,
+                        const float threshold);
+
+struct quantize_2bit {
+  MSHADOW_XINLINE static void Map(int out_block_id,
+                                  int original_size,
+                                  float *out,
+                                  float *grad,
+                                  float *residual,
+                                  const float neg_threshold,
+                                  const float pos_threshold) {
+    // this block contains the compressed representation of
+    // upto 16 values starting from out_block_id*16
+    float *compr_block = out + out_block_id;
+    // init to 0
+    *compr_block = 0;
+    // start and end are indices in original grad array
+    const int start = out_block_id << 4;
+    const int end = (start + 16 <= original_size) ? start + 16 : original_size;
+    // cast as char* to manipulate bits of float addresses
+    char *block_ptr = reinterpret_cast < char * > (compr_block);
+    // masks to set bits when value meets pos_threshold
+    // 0xc0 is mask when value is to be represented by the first two bits in a char*
+    // 0xc0 means first two bits are set to 11
+    const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03};
+    // masks to set bits when value meets neg_threshold
+    const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02};
+    for (int i = start; i < end; i++) {
+      // adds offset to reach appropriate byte
+      char *curr_byte = block_ptr + ((i - start) >> 2);
+      // adds gradient to existing residual to get updated grad
+      residual[i] += grad[i];
+      if (residual[i] >= pos_threshold) {
+        // set data to 11
+        *curr_byte |= posbits[(i & 3)];
+        // reduce residual by pos_threshold
+        residual[i] -= pos_threshold;
+      } else if (residual[i] <= neg_threshold) {
+        // set data to 10
+        *curr_byte |= negbits[(i & 3)];
+        residual[i] -= neg_threshold;
+      }
+    }
+  }
+};
+
+template<typename xpu>
+void Quantize2BitKernelLaunch(mshadow::Stream<xpu> *s, const std::vector<mxnet::TBlob> &inputs,
+                              const float threshold) {
+  mxnet::op::mxnet_op::Kernel<quantize_2bit, xpu>
+    ::Launch(s,
+            inputs[2].Size(),         // compressed array size
+            inputs[0].Size(),         // original size
+            inputs[2].dptr<float>(),  // compressed array
+            inputs[0].dptr<float>(),  // original array
+            inputs[1].dptr<float>(),  // residual array
+            -1 *threshold,            // negative threshold
+            threshold);               // positive threshold
+}
+
+struct dequantize_2bit {
+  MSHADOW_XINLINE static void Map(int i,
+                                  float *out,
+                                  float *in,
+                                  const float neg_threshold,
+                                  const float pos_threshold) {
+    // get position of dequantized value to fill
+    float *outval = out + i;
+    // gets byte which holds quantized value for this position
+    char *ch_ptr = reinterpret_cast<char *>(in + (i >> 4));
+    ch_ptr += ((i & 15) >> 2);
+    // masks used to quantize data
+    const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03};
+    const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02};
+    // col denotes which two bits of a byte are set for this value
+    // col=0 implies first two bits, col=3 implies last two bits,...
+    const int col = i & 3;
+    const uint8_t mask = posbits[col];
+    const uint8_t negmask = negbits[col];
+    const uint8_t masked = *ch_ptr & mask;
+    if (masked == mask) {
+      *outval = pos_threshold;
+    } else if (masked == negmask) {
+      // use posbits for mask as posbits are both 1s
+      // then compare masked with negbits to see if only negbits were set
+      *outval = neg_threshold;
+    } else {
+      *outval = 0;
+    }
+  }
+};
+
+template<typename xpu>
+void Dequantize2BitKernelLaunch(mshadow::Stream<xpu> *s, const std::vector<mxnet::TBlob> &inputs,
+                                const float threshold) {
+  mxnet::op::mxnet_op::Kernel<dequantize_2bit, xpu>
+  ::Launch(s,
+          inputs[1].Size(),         // original size
+          inputs[1].dptr<float>(),  // out array
+          inputs[0].dptr<float>(),  // compressed array
+          -1 *threshold,            // negative threshold
+          threshold);               // positive threshold
+}
+
+inline void Quantize2BitImpl(mshadow::Stream<mshadow::cpu> *s,
+                             const std::vector<mxnet::TBlob> &inputs,
+                             const float threshold) {
+  Quantize2BitKernelLaunch(s, inputs, threshold);
+}
+
+inline void Dequantize2BitImpl(mshadow::Stream<mshadow::cpu> *s,
+                               const std::vector<mxnet::TBlob> &inputs,
+                               const float threshold) {
+  Dequantize2BitKernelLaunch(s, inputs, threshold);
+}
+}  // namespace kvstore
+}  // namespace mxnet
+
+#endif  // MXNET_KVSTORE_GRADIENT_COMPRESSION_INL_H_
diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc
new file mode 100644
index 0000000..b8c626c
--- /dev/null
+++ b/src/kvstore/gradient_compression.cc
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file gradient_compression.cc
+ * \brief Gradient compression for kvstore
+ * \author Rahul Huilgol
+ */
+
+#include <sstream>
+#include <vector>
+#include "gradient_compression.h"
+#include "gradient_compression-inl.h"
+
+namespace mxnet {
+namespace kvstore {
+
+/*!
+ * \brief Splits a string into smaller strings using char as delimiter
+ * Example: "a,b,c,,d" is split into ["a","b","c","","d"]
+ * \param s string to split
+ * \param delim char to split string around
+ * \param result container for tokens extracted after splitting
+ */
+template<typename Out>
+void split(const std::string &s, const char delim, Out result) {
+  std::stringstream ss;
+  ss.str(s);
+  std::string item;
+  while (std::getline(ss, item, delim)) {
+    *(result++) = item;
+  }
+}
+
+DMLC_REGISTER_PARAMETER(GradientCompressionParam);
+
+GradientCompression::GradientCompression() {
+  type_ = CompressionType::kNone;
+}
+
+void GradientCompression::SetParams(const std::vector<std::pair<std::string, std::string> >
+                                    & kwargs) {
+  GradientCompressionParam params;
+  params.InitAllowUnknown(kwargs);
+  CHECK_GT(params.threshold, 0) << "threshold must be greater than 0";
+  if (params.type == "2bit") {
+    SetTwoBitCompression(params.threshold);
+  } else {
+    LOG(FATAL) << "Unknown type for gradient compression " << params.type;
+  }
+}
+
+CompressionType GradientCompression::get_type() {
+  return type_;
+}
+
+std::string GradientCompression::get_type_str() {
+  return std::to_string(static_cast<int>(type_));
+}
+
+void GradientCompression::SetTwoBitCompression(const float threshold) {
+  type_ = CompressionType::kTwoBit;
+  threshold_ = threshold;
+}
+
+std::string GradientCompression::EncodeParams() {
+  using namespace std;  // to reduce length of next line
+  string rval = get_type_str();
+  if (type_ == CompressionType::kTwoBit) {
+    rval += "," + to_string(threshold_);
+  }
+  return rval;
+}
+
+void GradientCompression::DecodeParams(const std::string &s) {
+  std::vector<std::string> elems;
+  split(s, ',', std::back_inserter(elems));
+  type_ = static_cast<CompressionType>(stoi(elems[0]));
+  if (elems.size() > 1) {
+    if (!elems[1].empty()) {
+      threshold_ = stof(elems[1]);
+    }
+  }
+}
+
+int GradientCompression::GetCompressionFactor() {
+  if (type_ == CompressionType::kTwoBit) {
+    return 16;
+  } else {
+    LOG(FATAL) << "Unsupported compression type: " << get_type_str();
+    return 0;
+  }
+}
+
+int64_t GradientCompression::GetCompressedSize(const int64_t original_size) {
+  const int bits = GetCompressionFactor();
+  return ((original_size % bits == 0) ?
+          original_size / bits :
+          original_size / bits + 1);
+}
+
+void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
+                  mxnet::NDArray *residual, const int priority) {
+  CHECK(from.shape().ndim() != 0) << "source operand has zero dimension shape";
+  CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
+  CHECK(residual->shape().ndim() != 0) << "residual operand has zero dimension shape";
+  const int a = from.ctx().dev_mask();
+  const int b = to->ctx().dev_mask();
+  const float threshold = threshold_;
+  if (type_ == CompressionType::kTwoBit) {
+    if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
+      mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) {
+        std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
+        Quantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
+      }, from.ctx(), {from.var()}, {to->var(), residual->var()},
+      mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeCPU"));
+    } else {
+#if MXNET_USE_CUDA
+      if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
+        mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) {
+          std::vector<mxnet::TBlob> inputs = {from.data(), residual->data(), to->data()};
+          Quantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
+          // Wait GPU kernel to complete
+          ctx.get_stream<mshadow::gpu>()->Wait();
+        }, from.ctx(), {from.var()}, {to->var(), residual->var()},
+        mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeGPU"));
+      } else {
+        LOG(FATAL) << "unknown device mask";
+      }
+#else
+    LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+    }
+  } else {
+    LOG(FATAL) << "Unsupported quantization of type " << get_type_str();
+  }
+}
+
+void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to,
+                                     const int priority) {
+  CHECK(from.shape().ndim() != 0) << "source operands has zero dimension shape";
+  CHECK(to->shape().ndim() != 0) << "destination operand has zero dimension shape";
+  const int a = from.ctx().dev_mask();
+  const int b = to->ctx().dev_mask();
+  const float threshold = threshold_;
+  if (type_ == CompressionType::kTwoBit) {
+    if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
+      mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
+        std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
+        Dequantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
+      }, from.ctx(), {from.var()}, {to->var()},
+      mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeCPU"));
+    } else {
+#if MXNET_USE_CUDA
+      if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
+        mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
+          std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
+          Dequantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
+          // Wait GPU kernel to complete
+          ctx.get_stream<mshadow::gpu>()->Wait();
+        }, from.ctx(), {from.var()}, {to->var()},
+        mxnet::FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeGPU"));
+      } else {
+        LOG(FATAL) << "unknown device mask";
+      }
+#else
+      LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+    }
+  } else {
+    LOG(FATAL) << "Unsupported dequantization of type " << get_type_str();
+  }
+}
+
+}  // namespace kvstore
+}  // namespace mxnet
+
diff --git a/src/kvstore/gradient_compression.cu b/src/kvstore/gradient_compression.cu
new file mode 100644
index 0000000..b0d9662
--- /dev/null
+++ b/src/kvstore/gradient_compression.cu
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file gradient_compression.cu
+ * \author Rahul Huilgol
+ * \brief Implementation for gpu version of code
+ */
+
+#include "gradient_compression-inl.h"
+
+namespace mxnet {
+namespace kvstore {
+void Quantize2BitImpl(mshadow::Stream<gpu>* s, const std::vector<TBlob>& inputs,
+                      const float threshold) {
+  Quantize2BitKernelLaunch(s, inputs, threshold);
+}
+
+void Dequantize2BitImpl(mshadow::Stream<gpu>* s, const std::vector<TBlob>& inputs,
+                        const float threshold) {
+  Dequantize2BitKernelLaunch(s, inputs, threshold);
+}
+}  // namespace kvstore
+}  // namespace mxnet
diff --git a/src/kvstore/gradient_compression.h b/src/kvstore/gradient_compression.h
new file mode 100644
index 0000000..f40b45f
--- /dev/null
+++ b/src/kvstore/gradient_compression.h
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file gradient_compression.h
+ * \brief Gradient compression for kvstore
+ * \author Rahul Huilgol
+ */
+
+#ifndef MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
+#define MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
+#include <dmlc/parameter.h>
+#include <string>
+#include <utility>
+#include <vector>
+#include "mxnet/ndarray.h"
+
+namespace mxnet {
+namespace kvstore {
+
+enum class CompressionType {
+  kNone, kTwoBit
+};
+
+struct GradientCompressionParam : public dmlc::Parameter<GradientCompressionParam> {
+  std::string type;
+  float threshold;
+  DMLC_DECLARE_PARAMETER(GradientCompressionParam) {
+    DMLC_DECLARE_FIELD(type)
+      .describe("Type of gradient compression to use, like `2bit` for example");
+    DMLC_DECLARE_FIELD(threshold).set_default(0.5)
+      .describe("Threshold to use for 2bit gradient compression");
+  }
+};
+
+class GradientCompression {
+ public:
+  GradientCompression();
+
+  virtual ~GradientCompression() {}
+
+  /*!
+   * \brief sets parameters for gradient compression
+   * \param kwargs a vector of pair of strings. A pair represents key and value
+   * of the parameter. Will be parsed by GradientCompressionParam
+   */
+  void SetParams(const std::vector<std::pair<std::string, std::string> >& kwargs);
+
+  /*!
+   * \brief returns type of compression if any
+   */
+  CompressionType get_type();
+
+  /*!
+   * \brief returns as string the enum value of compression type
+   */
+  std::string get_type_str();
+
+  /*!
+   * \brief sets two bit gradient compression
+   * \param threshold float value used for thresholding gradients
+   */
+  void SetTwoBitCompression(const float threshold);
+
+  /*!
+   * \brief encodes parameters of gc into a string
+   */
+  std::string EncodeParams();
+
+  /*!
+   * \brief decodes parameters of gc from a string and assigns them to member variables
+   */
+  void DecodeParams(const std::string &s);
+
+  /*!
+   * \brief returns compression factor, which is the factor by which size of gradient
+   * reduces when using a particular type of compression
+   */
+  int GetCompressionFactor();
+
+  /*!
+   * \brief returns the size of compressed gradients given an original sized gradient array
+   */
+  int64_t GetCompressedSize(const int64_t original_size);
+
+  /*!
+  * \brief Issues quantize operation to be scheduled by the engine
+  * Compresses `from` into `to` and accumulates the quantization error
+  * into 'residual', using the quantization of type `type_`
+  * \param from the ndarray containing original data to be quantized
+  * \param to the target ndarray which contains quantized data
+  * \param residual the ndarray which accumulates quantization error
+  * \param priority Priority of the action.
+  */
+  void Quantize(const mxnet::NDArray &from, mxnet::NDArray *to,
+                mxnet::NDArray *residual, const int priority);
+
+  /*!
+  * \brief Issues dequantize operation to be scheduled by the engine
+  * Decompresses `from` into `to` using current parameters of `type` and `threshold`
+  * \param from the ndarray containing quantized data
+  * \param to the target ndarray which contains final dequantized data
+  * \param priority Priority of the action.
+  */
+  void Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority);
+
+ private:
+  /*!
+   * \brief denotes the type of gradient compression which has been set
+   */
+  CompressionType type_;
+
+  /*!
+   * \brief denotes threshold used for quantization and dequantization
+   * Must be a positive value. All positive gradients will be thresholded to `threshold_` and
+   * all negative gradients will be thresholded to -1*`threshold_`
+   */
+  float threshold_ = 0;
+};
+}  // namespace kvstore
+}  // namespace mxnet
+#endif  // MXNET_KVSTORE_GRADIENT_COMPRESSION_H_
diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc
index ac37d5d..ac15873 100644
--- a/src/kvstore/kvstore.cc
+++ b/src/kvstore/kvstore.cc
@@ -49,7 +49,7 @@ KVStore* KVStore::Create(const char *type_name) {
     kv = new kvstore::KVStoreDist(use_device_comm);
     if (!has("_async") && kv->IsWorkerNode() && kv->get_rank() == 0) {
       // configure the server to be the sync mode
-      kv->SendCommandToServers(kvstore::kSyncMode, "");
+      kv->SendCommandToServers(static_cast<int>(kvstore::CommandType::kSyncMode), "");
     }
 #else
     LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname;
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index 571767d..b00d0de 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -69,7 +69,7 @@ class KVStoreDist : public KVStoreLocal {
         Barrier();
         if (get_rank() == 0) {
           // stop the executor at servers
-          SendCommandToServers(kStopServer, "");
+          SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
         }
       }
       ps::Finalize(barrier_before_exit_);
@@ -86,6 +86,15 @@ class KVStoreDist : public KVStoreLocal {
     }
   }
 
+  void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
+                              & kwargs) override {
+    KVStoreLocal::SetGradientCompression(kwargs);
+    if (get_rank() == 0) {
+      SendCommandToServers(static_cast<int>(CommandType::kSetGradientCompression),
+                           gradient_compression_->EncodeParams());
+    }
+  }
+
   void Barrier() override {
     ps::Postoffice::Get()->Barrier(ps::kWorkerGroup);
   }
@@ -132,6 +141,38 @@ class KVStoreDist : public KVStoreLocal {
   }
 
  private:
+  /**
+   * \brief struct for ps keys and lens
+   */
+  struct PSKV {
+    ps::SArray<ps::Key> keys;  // n keys
+    ps::SArray<int> lens;  // the length of the i-th value
+    int size;
+  };
+
+  struct ComprPSKV {
+    PSKV push;
+    PSKV pull;
+  };
+
+  /**
+   * \brief cache all key partitions
+   *
+   * `ps_kv_` is used for pushes and pulls without gradient compression
+   * `compr_ps_kv_` is used for gradient compression. It contains different
+   * pskv for push and pull because sizes would be different in both cases.
+   * Note: `ps_kv_[k]` for some key k may not be the same as `compr_ps_kv_[k].pull`
+   * This is because sharding may cause slightly different divisions when size is
+   * not perfectly divisible.
+   */
+  std::unordered_map<int, PSKV> ps_kv_;
+  std::unordered_map<int, ComprPSKV> compr_ps_kv_;
+
+  /**
+   * \brief serialize access to ps_kv_ or push_ps_kv_/pull_ps_kv_ while encoding keys
+   */
+  std::mutex mu_;
+
   void InitImpl(const std::vector<int>& keys,
                 const std::vector<NDArray>& values) override {
     CheckUnique(keys);
@@ -143,6 +184,7 @@ class KVStoreDist : public KVStoreLocal {
       // wait until the push is finished
       for (const int key : keys) {
         comm_buf_[key].WaitToWrite();
+        compr_buf_[key].WaitToWrite();
       }
     } else {
       // do nothing
@@ -182,7 +224,10 @@ class KVStoreDist : public KVStoreLocal {
           RunContext rctx, Engine::CallbackOnComplete cb) {
         // convert to ps keys
         size_t size = recv_buf.shape().Size();
-        PSKV& pskv = EncodeKey(key, size);
+
+        PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
+                      EncodeDefaultKey(key, size, false) :
+                      EncodeCompressedKey(key, size, false);
 #if MKL_EXPERIMENTAL == 1
         mkl_set_tblob_eager_mode(recv_buf.data());
 #endif
@@ -190,8 +235,11 @@ class KVStoreDist : public KVStoreLocal {
         // false means not to delete data when SArray is deleted
         auto vals = new ps::SArray<real_t>(data, size, false);
         // issue pull
+        int cmd = (gradient_compression_->get_type() != CompressionType::kNone) ?
+                  static_cast<int>(DataHandleType::kCompressedPushPull) :
+                  static_cast<int>(DataHandleType::kDefaultPushPull);
         CHECK_NOTNULL(ps_worker_)->ZPull(
-          pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); });
+          pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
       };
 
       CHECK_NOTNULL(Engine::Get())->PushAsync(
@@ -201,7 +249,7 @@ class KVStoreDist : public KVStoreLocal {
           {recv_buf.var()},
           FnProperty::kNormal,
           priority,
-          PROFILER_MESSAGE("KVStoreDistDefaultPull"));
+          PROFILER_MESSAGE("KVStoreDistDefaultStoragePull"));
 
       comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
     }
@@ -261,103 +309,121 @@ class KVStoreDist : public KVStoreLocal {
     GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);
 
     for (size_t i = 0; i < uniq_keys.size(); ++i) {
-      // merge over devcies
+      // merge over devices
       int key = uniq_keys[i];
       const auto& vals = grouped_vals[i];
       NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0];
 
-      auto& send_buf = comm_buf_[key];
       const auto storage_type = merged.storage_type();
+      auto &comm_buf = comm_buf_[key];
       if (merged.ctx().dev_mask() == cpu::kDevMask) {
         // Start of a push doesn't guarantee that the previous pushes are completed.
         // This shouldn't affect training of networks though because training involves
         // a sequence of push, pull, then push. This imposes ordering that the
         // second push happens after the first pull, and the pull happens after first push.
-        send_buf = merged;  // avoid memory copy
+        comm_buf = merged;  // avoid memory copy
       } else {
-        if (send_buf.is_none()) {
+        if (comm_buf.is_none()) {
           if (storage_type == kDefaultStorage) {
-            send_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
+            comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
           } else {
-            send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype());
+            comm_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype());
           }
         }
-        CopyFromTo(merged, &send_buf);
+        CopyFromTo(merged, &comm_buf);
       }
 
       // push to servers
       if (storage_type == kDefaultStorage) {
-      auto push_to_servers =
-          [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
-          // convert to ps keys
-          size_t size = send_buf.shape().Size();
-          PSKV& pskv = EncodeKey(key, size);
-
-#if MKL_EXPERIMENTAL == 1
-          mkl_set_tblob_eager_mode(send_buf.data());
-#endif
-          real_t* data = send_buf.data().dptr<real_t>();
-          // do push. false means no delete
-          ps::SArray<real_t> vals(data, size, false);
-          CHECK_NOTNULL(ps_worker_)->ZPush(
-              pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); });
-        };
-        Engine::Get()->PushAsync(
-            push_to_servers,
-            pinned_ctx_,
-            {send_buf.var()},
-            {},
-            FnProperty::kNormal,
-            priority,
-            PROFILER_MESSAGE("KVStoreDistDefaultPush"));
+        if (gradient_compression_->get_type() == CompressionType::kNone) {
+          PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), true);
+          PushDefault(key, comm_buf, pskv, priority);
+        } else {
+          // Note: gradient compression uses `do_merge` as proxy to
+          // detect whether the push is initialization of a key or not.
+          // is_active is false when push is initialization of key
+          bool is_active = do_merge;
+          PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active);
+          // Returns push_pskv if active, else pull_pskv
+          // we want inactive gc to send uncompressed gradients,
+          // but sharded in the same way as later pushes would when gc becomes active
+          if (is_active) {
+            PushCompressed(key, comm_buf, pskv, priority);
+          } else {
+            PushDefault(key, comm_buf, pskv, priority);
+          }
+        }
       } else if (storage_type == kRowSparseStorage) {
-        PushRowSparse(key, send_buf, priority);
+        CHECK(gradient_compression_->get_type() == CompressionType::kNone)
+          << "Gradient compression for row sparse storage type is not supported";
+        PushRowSparse(key, comm_buf, priority);
       } else {
         LOG(FATAL) << "unknown storage type";
       }
     }
   }
 
-  // pull row sparse weight into `recv_buf` based on indices given by `indices`
-  void PullRowSparse_(const int key, const NDArray& recv_buf,
-                      const NDArray& indices, int priority) {
-    using namespace rowsparse;
-    auto pull_from_servers = [this, key, recv_buf, indices]
-                             (RunContext rctx, Engine::CallbackOnComplete cb) {
-      // allocate memory for the buffer
-      size_t num_rows = indices.shape().Size();
-      recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
+  void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) {
+    auto &small_buf = compr_buf_[key];
+    auto &res_buf = residual_[key];
+    size_t original_size = comm_buf.shape().Size();
+
+    // Init the small buffer and residual_ buffer for quantize
+    if (small_buf.is_none()) {
+      small_buf = NDArray(TShape{pskv.size}, comm_buf.ctx(), false, comm_buf.dtype());
+      res_buf = NDArray(TShape{(int64_t) original_size}, comm_buf.ctx(),
+                        false, comm_buf.dtype());
+      res_buf = 0;
+    }
+    gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority);
+    auto push_to_servers =
+      [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+        size_t size = small_buf.shape().Size();
+        real_t* data = small_buf.data().dptr<real_t>();
 #if MKL_EXPERIMENTAL == 1
-      mkl_set_tblob_eager_mode(recv_buf.data());
+        mkl_set_tblob_eager_mode(small_buf.data());
 #endif
-      real_t* data = recv_buf.data().dptr<real_t>();
-      const auto offsets = indices.data().dptr<int64_t>();
-      const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
-      const int64_t size = num_rows * unit_len;
-       // convert to ps keys in row sparse format
-      PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
-                                      unit_len, recv_buf.shape()[0]);
-      if (this->log_verbose_) {
-        LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
-                  << pskv.keys << " size: " << size;
-      }
-      auto vals = new ps::SArray<real_t>(data, size, false);
-      // copy indices to recv_buf. this needs to be done before ZPull
-      // because after pull is done, the callback function returns and locks are released.
-      // at this point, later functions may access the indices variable while copy happens
-      mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
-                    indices.data().FlatTo1D<cpu, int64_t>());
-      CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull,
-        [vals, cb]() { delete vals; cb(); });
-    };
-    CHECK_NOTNULL(Engine::Get())->PushAsync(
-        pull_from_servers,
+        // do push. false means no delete
+        ps::SArray<real_t> vals(data, size, false);
+        CHECK_NOTNULL(ps_worker_)->ZPush(
+          pskv.keys, vals, pskv.lens,
+          static_cast<int>(DataHandleType::kCompressedPushPull), [cb]() { cb(); });
+      };
+    // acquire locks on both comm_buf and small_buf so that
+    // pull (which uses comm_buf) for the same key waits till push finishes
+    Engine::Get()->PushAsync(
+      push_to_servers,
+      pinned_ctx_,
+      {small_buf.var(), comm_buf.var()},
+      {},
+      FnProperty::kNormal,
+      priority,
+      PROFILER_MESSAGE("KVStoreDistCompressedPush"));
+  }
+
+  void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
+    auto push_to_servers =
+        [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
+          // convert to ps keys
+          size_t size = send_buf.shape().Size();
+          real_t* data = send_buf.data().dptr<real_t>();
+#if MKL_EXPERIMENTAL == 1
+          mkl_set_tblob_eager_mode(send_buf.data());
+#endif
+          // do push. false means no delete
+          ps::SArray<real_t> vals(data, size, false);
+          CHECK_NOTNULL(ps_worker_)->ZPush(
+              pskv.keys, vals, pskv.lens,
+              static_cast<int>(DataHandleType::kDefaultPushPull), [cb]() { cb(); });
+        };
+    Engine::Get()->PushAsync(
+        push_to_servers,
         pinned_ctx_,
-        {indices.var()},
-        {recv_buf.var()},
+        {send_buf.var()},
+        {},
         FnProperty::kNormal,
         priority,
-        PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
+        PROFILER_MESSAGE("KVStoreDistDefaultPush"));
   }
 
   // push row sparse gradient
@@ -382,9 +448,9 @@ class KVStoreDist : public KVStoreLocal {
                   << pskv.keys << " size: " << size;
       }
       ps::SArray<real_t> vals(data, size, false);
-      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() {
-        cb();
-      });
+      CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens,
+                                       static_cast<int>(DataHandleType::kRowSparsePushPull),
+                                       [cb]() { cb(); });
     };
     Engine::Get()->PushAsync(
         push_to_servers,
@@ -396,6 +462,50 @@ class KVStoreDist : public KVStoreLocal {
         PROFILER_MESSAGE("KVStoreDistRowSparsePush"));
   }
 
+
+  // pull row sparse weight into `recv_buf` based on indices given by `indices`
+  void PullRowSparse_(const int key, const NDArray& recv_buf,
+                      const NDArray& indices, int priority) {
+    using namespace rowsparse;
+    auto pull_from_servers = [this, key, recv_buf, indices]
+      (RunContext rctx, Engine::CallbackOnComplete cb) {
+      // allocate memory for the buffer
+      size_t num_rows = indices.shape().Size();
+      recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
+#if MKL_EXPERIMENTAL == 1
+      mkl_set_tblob_eager_mode(recv_buf.data());
+#endif
+      real_t* data = recv_buf.data().dptr<real_t>();
+      const auto offsets = indices.data().dptr<int64_t>();
+      const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
+      const int64_t size = num_rows * unit_len;
+      // convert to ps keys in row sparse format
+      PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets,
+                                      unit_len, recv_buf.shape()[0]);
+      if (this->log_verbose_) {
+        LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: "
+                  << pskv.keys << " size: " << size;
+      }
+      auto vals = new ps::SArray<real_t>(data, size, false);
+      // copy indices to recv_buf. this needs to be done before ZPull
+      // because after pull is done, the callback function returns and locks are released.
+      // at this point, later functions may access the indices variable while copy happens
+      mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D<cpu, int64_t>(),
+                    indices.data().FlatTo1D<cpu, int64_t>());
+      CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens,
+                                       static_cast<int>(DataHandleType::kRowSparsePushPull),
+                                       [vals, cb]() { delete vals; cb(); });
+    };
+    CHECK_NOTNULL(Engine::Get())->PushAsync(
+      pull_from_servers,
+      pinned_ctx_,
+      {indices.var()},
+      {recv_buf.var()},
+      FnProperty::kNormal,
+      priority,
+      PROFILER_MESSAGE("KVStoreDistRowSparsePull"));
+  }
+
   /**
    * \brief check if the keys are all unique
    */
@@ -407,32 +517,12 @@ class KVStoreDist : public KVStoreLocal {
   }
 
   /**
-   * \brief struct for ps keys and lens
-   */
-  struct PSKV {
-    ps::SArray<ps::Key> keys;  // n keys
-    ps::SArray<int> lens;  // the length of the i-th value
-    int size;
-  };
-
-  /**
-   * \brief cache all key partitions
-   */
-  std::unordered_map<int, PSKV> ps_kv_;
-
-  /**
-   * \brief serizelize EncodeRowSparseKey and EncodeKey
-   */
-  std::mutex mu_;
-
-  /**
    * \brief convert to keys in ps
    */
-  inline PSKV& EncodeKey(int key, size_t size) {
+  inline PSKV& EncodeDefaultKey(int key, size_t size, bool is_push) {
     mu_.lock();
     PSKV& pskv = ps_kv_[key];
     mu_.unlock();
-
     if (!pskv.keys.empty()) {
       CHECK_EQ(static_cast<size_t>(pskv.size), size) << "The value size cannot be changed";
     } else {
@@ -454,8 +544,8 @@ class KVStoreDist : public KVStoreLocal {
         pskv.size = 0;
         for (int i = 0; i < num_servers; ++i) {
           size_t part_size =
-              static_cast<size_t>(round(static_cast<double>(size)/num_servers*(i+1))) -
-              static_cast<size_t>(round(static_cast<double>(size)/num_servers*i));
+            static_cast<size_t>(round(static_cast<double>(size)/num_servers*(i+1))) -
+            static_cast<size_t>(round(static_cast<double>(size)/num_servers*i));
           ps::Key ps_key = krs[i].begin() + key;
           CHECK_LT(ps_key, krs[i].end());
           pskv.keys.push_back(ps_key);
@@ -468,6 +558,94 @@ class KVStoreDist : public KVStoreLocal {
     return pskv;
   }
 
+  /**
+   * \brief Convert to keys in ps for compressed values
+   * Divides original array into equal parts for each server
+   * Populates both push and pull pskv on first call
+   */
+  inline PSKV& EncodeCompressedKey(int key, size_t original_size, bool is_push) {
+    auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
+    int num_servers = krs.size();
+    CHECK_GT(num_servers, 0);
+
+    // represents size of data to be sent
+    size_t compr_size = gradient_compression_->GetCompressedSize(original_size);
+
+    mu_.lock();
+    PSKV& pskv = (is_push) ? compr_ps_kv_[key].push : compr_ps_kv_[key].pull;
+    mu_.unlock();
+
+    if (!pskv.keys.empty()) {
+      size_t size = (is_push) ? compr_size : original_size;
+      CHECK_EQ(static_cast<size_t >(pskv.size), size)<< "The value size can't be changed";
+    } else {
+      // populate both pull and push pskvs
+      // push pskv has sizes corresponding to compressed data
+      // pull pskv has decompressed sizes for parts in push_pskv
+      mu_.lock();
+      PSKV& pull_pskv = compr_ps_kv_[key].pull;
+      PSKV& push_pskv = compr_ps_kv_[key].push;
+      mu_.unlock();
+
+      if (original_size < bigarray_bound_) {
+        // a simple heuristic for load balancing
+        // send it to a single random picked server
+        int server = (key * 9973) % num_servers;
+        ps::Key ps_key = krs[server].begin() + key;
+        CHECK_LT(ps_key, krs[server].end());
+        // meta info
+        push_pskv.keys.push_back(krs[server].begin() + original_size);
+        push_pskv.lens.push_back(0);
+        // data
+        push_pskv.keys.push_back(ps_key);
+        pull_pskv.keys.push_back(ps_key);
+        push_pskv.lens.push_back(compr_size);
+        pull_pskv.lens.push_back(original_size);
+        push_pskv.size = compr_size;
+        pull_pskv.size = original_size;
+      } else {
+        // partition it to all servers
+        push_pskv.size = 0;
+        pull_pskv.size = 0;
+
+        for (int i = 0; i < num_servers; ++i) {
+          size_t part_compr, part_orig;
+          if (i == num_servers-1) {
+            part_compr = compr_size - push_pskv.size;
+            part_orig = original_size - pull_pskv.size;
+          } else {
+            part_compr =
+              static_cast<size_t> (round(static_cast<double>(compr_size)/num_servers*(i+1))) -
+              static_cast<size_t> (round(static_cast<double>(compr_size)/num_servers*(i)));
+            part_orig = part_compr * gradient_compression_->GetCompressionFactor();
+          }
+
+          // meta info
+          ps::Key ps_key_dummy = krs[i].begin() + part_orig;
+          CHECK_LT(ps_key_dummy, krs[i].end());
+          push_pskv.keys.push_back(ps_key_dummy);
+          push_pskv.lens.push_back(0);
+
+          // data
+          ps::Key ps_key = krs[i].begin() + key;
+          CHECK_LT(ps_key, krs[i].end());
+          push_pskv.keys.push_back(ps_key);
+          pull_pskv.keys.push_back(ps_key);
+          // push_pskv stores lengths of compressed blocks
+          push_pskv.lens.push_back(part_compr);
+          // pull_pskv stores lengths of original data
+          pull_pskv.lens.push_back(part_orig);
+          push_pskv.size += part_compr;
+          pull_pskv.size += part_orig;
+        }
+        CHECK_EQ(static_cast<size_t>(push_pskv.size), compr_size);
+        CHECK_EQ(static_cast<size_t>(pull_pskv.size), original_size);
+        CHECK_EQ(push_pskv.lens.size(), num_servers*2);
+        }
+      }
+    return pskv;
+  }
+
   // Note: this encoding method for row sparse keys doesn't allow cross-layer batching
   inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows,
                                   const int64_t *offsets, const size_t unit_len,
@@ -528,7 +706,6 @@ class KVStoreDist : public KVStoreLocal {
     return pskv;
   }
 
-
   /**
    * \brief for worker to push and pull data
    */
@@ -541,8 +718,23 @@ class KVStoreDist : public KVStoreLocal {
    * \brief threshold for partition
    */
   size_t bigarray_bound_;
-  /// \brief send & recver buffer
+  /**
+   * \brief buffer for non-compressed data.
+   * When gradient compression is active, this is used
+   * for the data in pull and for original data in push
+   */
   std::unordered_map<int, NDArray> comm_buf_;
+  /**
+   * \brief buffer for compressed data
+   * Used when gradient compression is active and action
+   * is push
+   */
+  std::unordered_map<int, NDArray> compr_buf_;
+  /**
+   * \brief residual buffer to accumulate quantization error
+   * during gradient compression
+   */
+  std::unordered_map<int, NDArray> residual_;
   bool log_verbose_;
 };
 
diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h
index f2123e7..de94c86 100644
--- a/src/kvstore/kvstore_dist_server.h
+++ b/src/kvstore/kvstore_dist_server.h
@@ -40,10 +40,13 @@
 namespace mxnet {
 namespace kvstore {
 
-static const int kRowSparsePushPull = 1;
-static const int kDefaultPushPull = 0;
-static const int kStopServer = -1;
-static const int kSyncMode = -2;
+enum class CommandType {
+  kController, kStopServer, kSyncMode, kSetGradientCompression
+};
+
+enum class DataHandleType {
+  kDefaultPushPull, kCompressedPushPull, kRowSparsePushPull
+};
 
 /**
  * \brief executor runs a function using the thread called \ref Start
@@ -117,6 +120,7 @@ class KVStoreDistServer {
     ps_server_->set_request_handle(
         std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
     sync_mode_ = false;
+    gradient_compression_ = std::make_shared<GradientCompression>();
     log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
   }
 
@@ -148,11 +152,15 @@ class KVStoreDistServer {
   };
 
   void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
-    if (recved.head == kStopServer) {
+    CommandType recved_type = static_cast<CommandType>(recved.head);
+    if (recved_type == CommandType::kStopServer) {
       exec_.Stop();
-    } else if (recved.head == kSyncMode) {
+    } else if (recved_type == CommandType::kSyncMode) {
       sync_mode_ = true;
+    } else if (recved_type == CommandType::kSetGradientCompression) {
+      gradient_compression_->DecodeParams(recved.body);
     } else {
+      // this uses value 0 for message id from frontend
       // let the main thread to execute ctrl, which is necessary for python
       exec_.Exec([this, recved]() {
           CHECK(controller_);
@@ -165,8 +173,11 @@ class KVStoreDistServer {
   void DataHandleEx(const ps::KVMeta& req_meta,
                     const ps::KVPairs<real_t>& req_data,
                     ps::KVServer<real_t>* server) {
-    if (req_meta.cmd == kRowSparsePushPull) {
+    DataHandleType recved_type = static_cast<DataHandleType>(req_meta.cmd);
+    if (recved_type == DataHandleType::kRowSparsePushPull) {
       DataHandleRowSparse(req_meta, req_data, server);
+    } else if (recved_type == DataHandleType::kCompressedPushPull) {
+      DataHandleCompressed(req_meta, req_data, server);
     } else {
       DataHandleDefault(req_meta, req_data, server);
     }
@@ -359,10 +370,91 @@ class KVStoreDistServer {
     }
   }
 
+  void DefaultStorageResponse(int key, const NDArray& stored,
+                              const ps::KVMeta& req_meta,
+                              const ps::KVPairs<real_t> &req_data,
+                              ps::KVServer<real_t>* server) {
+    ps::KVPairs<real_t> response;
+    CHECK(!stored.is_none()) << "init " << key << " first";
+    auto len = stored.shape().Size();
+    response.keys = req_data.keys;
+    response.lens = {len};
+    // TODO(mli) try to remove this CopyFrom
+    response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
+    server->Response(req_meta, response);
+  }
+
+  void DataHandleCompressed(const ps::KVMeta& req_meta,
+                            const ps::KVPairs<real_t> &req_data,
+                            ps::KVServer<real_t>* server) {
+    if (req_meta.push) {
+      // there used several WaitToRead, this is because \a recved's memory
+      // could be deallocated when this function returns. so we need to make sure
+      // the operators with \a NDArray are actually finished
+
+      // first for dummy key which represents original size of array, whose len is 0
+      CHECK_EQ(req_data.keys.size(), (size_t)2);
+      CHECK_EQ(req_data.lens.size(), (size_t)2);
+      CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[1]);
+
+      int original_size = DecodeKey(req_data.keys[0]);
+      int key = DecodeKey(req_data.keys[1]);
+      auto& stored = store_[key];
+
+      size_t ds[] = {(size_t)req_data.lens[1]};
+      TShape dshape(ds, ds + 1);
+      TBlob recv_blob((real_t*) req_data.vals.data(), // NOLINT(*)
+                      dshape, cpu::kDevMask);
+      NDArray recved = NDArray(recv_blob, 0);
+
+      NDArray decomp_buf = decomp_buf_[key];
+      dshape = TShape{(int64_t) original_size};
+
+      if (decomp_buf.is_none()) {
+        decomp_buf = NDArray(dshape, Context());
+      }
+
+      if (stored.is_none()) {
+        stored = NDArray(dshape, Context());
+        gradient_compression_->Dequantize(recved, &stored, 0);
+        server->Response(req_meta);
+        stored.WaitToRead();
+      } else if (sync_mode_) {
+        // synced push
+        auto& merged = merge_buf_[key];
+        if (merged.array.is_none()) {
+          merged.array = NDArray(dshape, Context());
+        }
+        if (merged.request.size() == 0) {
+          gradient_compression_->Dequantize(recved, &merged.array, 0);
+        } else {
+          gradient_compression_->Dequantize(recved, &decomp_buf, 0);
+          merged.array += decomp_buf;
+        }
+        merged.request.push_back(req_meta);
+        ApplyUpdates(key, &merged, &stored, server);
+      } else {
+        // async push
+        gradient_compression_->Dequantize(recved, &decomp_buf, 0);
+        exec_.Exec([this, key, &decomp_buf, &stored]() {
+          CHECK(updater_);
+          updater_(key, decomp_buf, &stored);
+        });
+        server->Response(req_meta);
+        stored.WaitToRead();
+      }
+    } else {       // pull
+      CHECK_EQ(req_data.keys.size(), (size_t)1);
+      CHECK_EQ(req_data.lens.size(), (size_t)0);
+      int key = DecodeKey(req_data.keys[0]);
+      DefaultStorageResponse(key, store_[key], req_meta, req_data, server);
+    }
+  }
+
   void DataHandleDefault(const ps::KVMeta& req_meta,
                          const ps::KVPairs<real_t> &req_data,
                          ps::KVServer<real_t>* server) {
-    CHECK_EQ(req_meta.cmd, kDefaultPushPull);
+    CHECK_EQ(req_meta.cmd, static_cast<int>(DataHandleType::kDefaultPushPull));
     // do some check
     CHECK_EQ(req_data.keys.size(), (size_t)1);
     if (req_meta.push) {
@@ -411,15 +503,7 @@ class KVStoreDistServer {
         stored.WaitToRead();
       }
     } else {
-      // pull
-      ps::KVPairs<real_t> response;
-      CHECK(!stored.is_none()) << "init " << key << " first";
-      auto len = stored.shape().Size();
-      response.keys = req_data.keys;
-      response.lens = {len};
-      // TODO(mli) try to remove this CopyFrom
-      response.vals.CopyFrom(static_cast<const float*>(stored.data().dptr_), len);
-      server->Response(req_meta, response);
+      DefaultStorageResponse(key, stored, req_meta, req_data, server);
     }
   }
 
@@ -428,21 +512,44 @@ class KVStoreDistServer {
     return key - kr.begin();
   }
 
+
   /**
-   * \brief user defined
+   * \brief user defined mode for push
    */
   bool sync_mode_;
   KVStore::Controller controller_;
   KVStore::Updater updater_;
 
+  /**
+   * \brief store_ contains the value at kvstore for each key
+   */
   std::unordered_map<int, NDArray> store_;
+
+  /**
+   * \brief merge_buf_ is a buffer used if sync_mode is true. It represents
+   * values from different workers being merged. The store will be updated
+   * to this value when values from all workers are pushed into this buffer.
+   */
   std::unordered_map<int, MergeBuf> merge_buf_;
 
+  /**
+   * \brief decomp_buf_ is a buffer into which compressed values are
+   * decompressed before merging to the store. used when compress_!='none'
+   */
+  std::unordered_map<int, NDArray> decomp_buf_;
+
   Executor exec_;
   ps::KVServer<float>* ps_server_;
 
   // whether to LOG verbose information
   bool log_verbose_;
+
+  /**
+   * \brief gradient compression object.
+   * starts with none, used after SetGradientCompression sets the type
+   * currently there is no support for unsetting gradient compression
+   */
+  std::shared_ptr<kvstore::GradientCompression> gradient_compression_;
 };
 
 }  // namespace kvstore
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 1a4ced8..9fe161c 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -59,6 +59,7 @@ class KVStoreLocal : public KVStore {
       comm_ = new CommCPU();
     }
     pinned_ctx_ = comm_->pinned_ctx();
+    gradient_compression_ = std::make_shared<GradientCompression>();
   }
 
   virtual ~KVStoreLocal() {
@@ -136,6 +137,11 @@ class KVStoreLocal : public KVStore {
     PullRowSparseImpl(keys, val_rowids, priority);
   }
 
+  void SetGradientCompression(const std::vector<std::pair<std::string, std::string> >
+                              & kwargs) override {
+    gradient_compression_->SetParams(kwargs);
+  }
+
  private:
   virtual void InitImpl(const std::vector<int>& keys,
                         const std::vector<NDArray>& values) {
@@ -145,6 +151,7 @@ class KVStoreLocal : public KVStore {
       local_[keys[i]] = values[i].Copy(pinned_ctx_);
       comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
     }
+    comm_->SetGradientCompression(gradient_compression_);
   }
 
   virtual void PushImpl(const std::vector<int>& keys,
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index 900d6bb..df85fe5 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -23,7 +23,8 @@ sys.path.insert(0, "../../python/")
 import mxnet as mx
 import numpy as np
 import numpy.random as rnd
-import time
+from mxnet.test_utils import assert_almost_equal
+from test_kvstore import compute_expected_2bit_quantization
 
 def check_diff_to_scalar(A, x, rank=None):
     """ assert A == x"""
@@ -39,6 +40,7 @@ init_test_keys_device_big = [str(i) for i in range(500,600)]
 
 rate = 2
 shape = (2, 3)
+irregular_shape = (1211,1211)
 big_shape = (1200, 1200)        # bigger than MXNET_KVSTORE_BIGARRAY_BOUND
 
 kv = mx.kv.create('dist_sync')
@@ -57,6 +59,17 @@ def init_kv():
     kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
     return kv, my_rank, nworker
 
+def init_kv_compressed(kv):
+    threshold = 0.5
+    kv.set_gradient_compression({'type': '2bit', 'threshold':threshold})
+    # init kv compression keys
+    kv.init('11221', mx.nd.zeros(big_shape))
+    kv.init('112221', mx.nd.zeros(irregular_shape))
+    kv.init('1121', mx.nd.zeros(shape))
+    # to test inactive mode
+    kv.init('1122', mx.nd.ones(shape))
+    return kv, threshold
+
 def test_sync_push_pull():
     kv, my_rank, nworker = init_kv()
     def check_default_keys(kv, my_rank, nworker):
@@ -173,11 +186,114 @@ def test_sync_push_pull():
                 expected[row] = updated_val[row]
             check_diff_to_scalar(val, expected, rank=my_rank)
 
+    def check_compr_residual(kv, threshold, nworker):
+        for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
+            # doesn't meet threshold
+            kv.push(k, mx.nd.ones(s)*0.4)
+            val=mx.nd.zeros(s)
+            kv.pull(k,val)
+            check_diff_to_scalar(val, 0)
+
+            # just meets threshold with residual
+            kv.push(k, mx.nd.ones(s)*(threshold - 0.4))
+            val2 = mx.nd.zeros(s)
+            kv.pull(k,val2)
+            curval = threshold * rate * nworker
+            check_diff_to_scalar(val2, curval)
+
+            # doesn't meet threshold
+            kv.push(k, mx.nd.ones(s)*0.2)
+            val3= mx.nd.zeros(s)
+            kv.pull(k, val3)
+            check_diff_to_scalar(val3, curval)
+
+            # exceeds again
+            kv.push(k, mx.nd.ones(s)*(threshold-0.2))
+            val4 = mx.nd.zeros(s)
+            kv.pull(k,val4)
+            curval += threshold*rate*nworker
+            check_diff_to_scalar(val4, curval)
+            # residual is 0 now
+
+    def check_compr_ones(kv, threshold, nworker):
+        for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
+            val = mx.nd.zeros(s)
+            kv.pull(k, val)
+            curval = val[0][0].asnumpy()[0]
+            kv.push(k,mx.nd.ones(s)*threshold)
+            val2 = mx.nd.zeros(s)
+            kv.pull(k, val2)
+            newval = curval + rate*nworker*threshold
+            check_diff_to_scalar(val2, newval)
+            # residual = 0  again
+
+    def check_compr_pull_before_push(kv):
+        for k,s in [('1121', shape),('112221',irregular_shape),
+                    ('11221', big_shape), ('1122',shape)]:
+            if k=='1122':
+                # tests that GC is not used for init of a key
+                val = mx.nd.zeros(s)
+                kv.pull(k, val)
+                check_diff_to_scalar(val, 1)
+            else:
+                val = mx.nd.ones(s)
+                kv.pull(k, val)
+                check_diff_to_scalar(val, 0)
+
+    def check_compr_zero(kv):
+        for k,s in [('1121', shape),('112221',irregular_shape),('11221', big_shape)]:
+            kv.push(k, mx.nd.zeros(s))
+            # to check that all are set to 0s
+            val = mx.nd.ones(s)
+            kv.pull(k, val)
+            check_diff_to_scalar(val, 0)
+
+    def check_compr_random(kv, threshold, nworker):
+        # set a seed so all workers generate same data. knowing this helps
+        # calculate expected value after pull
+        mx.random.seed(123)
+        rnd.seed(123)
+        nrepeat = 5
+        compr_random_keys_shapes = [('2121', shape),('212221',irregular_shape),('21221', big_shape)]
+        # use new keys so residual is 0 for calculation of expected
+        for k,s in compr_random_keys_shapes:
+            kv.init(k, mx.nd.zeros(s))
+        for k,s in compr_random_keys_shapes:
+            curr_residual = np.zeros(s)
+            for l in range(nrepeat):
+                orig_val = mx.nd.zeros(s)
+                kv.pull(k, orig_val)
+
+                grad = mx.nd.array(rnd.rand(s[0], s[1]))
+                # creates a copy because push changes grad because of assignment
+                grad_cpy = mx.nd.array(grad)
+                kv.push(k, grad)
+                val = mx.nd.zeros(s)
+                kv.pull(k, val)
+
+                diff = val - orig_val
+
+                # compute expected by using simulation of operator
+                compr, curr_residual, decompr = compute_expected_2bit_quantization(grad_cpy, curr_residual, threshold)
+                decompr *= nworker * rate
+                assert_almost_equal(diff.asnumpy(), decompr)
+
+    print ('worker '+str(my_rank)+' started with non compression tests')
     check_default_keys(kv, my_rank, nworker)
     check_row_sparse_keys(kv, my_rank, nworker)
     check_row_sparse_keys_with_zeros(kv, my_rank, nworker)
     check_big_row_sparse_keys(kv, my_rank, nworker)
-    print('worker ' + str(my_rank) + ' is done')
+    print('worker ' + str(my_rank) + ' is done with non compression tests')
+
+    # don't run non compressed keys after this as kvstore now is set to compressed
+    print ('worker '+str(my_rank)+' started with compression tests')
+    kv, threshold = init_kv_compressed(kv)
+    check_compr_pull_before_push(kv)
+    check_compr_zero(kv)
+    check_compr_residual(kv, threshold, nworker)
+    check_compr_ones(kv, threshold, nworker)
+    check_compr_random(kv, threshold, nworker)
+    print('worker ' + str(my_rank) + ' is done with compression tests')
 
 def test_sync_init():
     def check_init(kv, cur_keys, cur_shape, device=False):
diff --git a/tests/nightly/test_kvstore.py b/tests/nightly/test_kvstore.py
index 081bc9c..a14feac 100644
--- a/tests/nightly/test_kvstore.py
+++ b/tests/nightly/test_kvstore.py
@@ -21,17 +21,59 @@ import sys
 sys.path.insert(0, "../../python/")
 import mxnet as mx
 import numpy as np
+import numpy.random as rnd
+import copy
 
-keys = [3, 5, 7]
-# let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
-shapes = [(4, 4), (100, 100), (2000, 2000)];
+from mxnet.test_utils import assert_almost_equal
 
-lr = .1
-nworker = 4
-nrepeat = 10
+def check_diff_to_scalar(A, x, rank=None):
+    """ assert A == x"""
+    assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x)
 
-## generate data
-data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
+def compute_expected_2bit_quantization(arr, curr_residual, threshold):
+    from struct import pack,unpack
+    def bits2int(bits):
+        bits = [int(x) for x in bits[::-1]]
+        x = 0
+        for i in range(len(bits)):
+            x += bits[i]*2**i
+        return x
+
+    def as_float32(s):
+        return unpack("f",pack("I", bits2int(s)))[0]
+
+    # str_quant stores the quantized representation as a sequence of bits
+    str_quant = ''
+    new_residual = []
+    decompr = []
+
+    arr_npy = arr.asnumpy()
+    for i, a in np.ndenumerate(arr_npy):
+        a += curr_residual[i]
+        if a >= threshold:
+            str_quant += '11'
+            new_residual.append(a - threshold)
+            decompr.append(threshold)
+        elif a <= (-1*threshold):
+            str_quant += '10'
+            new_residual.append(a + threshold)
+            decompr.append(-1*threshold)
+        else:
+            str_quant += '00'
+            new_residual.append(a)
+            decompr.append(0)
+    # append extra bits when size of array not a factor of 16
+    if len(str_quant)%16 != 0:
+        str_quant += '0'*(16 - len(str_quant)%16)
+
+    compr = []
+    # converts the string generated into integers 32chars at a time
+    i = 0
+    while i<len(str_quant):
+        cur_float = str_quant[i+24:i+32] + str_quant[i+16:i+24] + str_quant[i+8:i+16] + str_quant[i:i+8]
+        compr.append(as_float32(cur_float))
+        i+=32
+    return np.array(compr), np.array(new_residual).reshape(arr.shape), np.array(decompr).reshape(arr.shape)
 
 ## individual key interface
 def test_kvstore(kv_type):
@@ -55,9 +97,118 @@ def test_kvstore(kv_type):
             err = sum(err) / np.sum(np.abs(res[j]))
             assert(err < 1e-6), (err, shapes[j])
 
-test_kvstore('local_update_cpu')
-test_kvstore('local_allreduce_cpu')
-test_kvstore('local_allreduce_device')
+def test_compress_kvstore(kv_type, compression='2bit', threshold=0.5):
+    print(kv_type + ' with ' + compression + ' compression')
+    rate = 2
+    kv = mx.kv.create(kv_type)
+    kv.set_gradient_compression({'type':compression, 'threshold':threshold})
+    kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate))
+    for k, s in zip(keys, shapes):
+        kv.init(k, mx.nd.zeros(s))
+    # init one key with 1s so we can check if it was compressed during init
+    kv.init(gc_init_test_key, mx.nd.ones(shapes[0]))
+    # use different keys for random tests so that
+    # we can track residual from start
+    random_keys = [13, 15, 17]
+    for k, s in zip(random_keys, shapes):
+        kv.init(k, mx.nd.zeros(s))
+
+    def pull_init_test(kv):
+        # checks that compression is not applied to init of key
+        out = [mx.nd.zeros(shapes[0], mx.gpu(g)) for g in range(nworker)]
+        kv.pull(gc_init_test_key, out=out)
+        exp = np.ones_like(out[0].asnumpy())
+        for o in out:
+            assert_almost_equal(o.asnumpy(), exp)
+
+    def pull_before_push(kv):
+        for i in range(nrepeat):
+            for j in range(len(keys)):
+                out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
+                kv.pull(keys[j], out=out)
+                exp = np.zeros_like(out[0].asnumpy())
+                for o in out:
+                    assert_almost_equal(o.asnumpy(), exp)
+
+    def push_zeros(kv):
+        for i in range(nrepeat):
+            for j in range(len(keys)):
+                kv.push(keys[j], [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)])
+                out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
+                kv.pull(keys[j], out=out)
+                exp = np.zeros_like(out[0].asnumpy())
+                for o in out:
+                    assert_almost_equal(o.asnumpy(), exp)
+
+    def verify_residual(kv, threshold, rate):
+        for j in range(len(keys)):
+            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*0.4 for g in range(nworker)])
+            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
+            kv.pull(keys[j],out=out)
+            for o in out:
+                check_diff_to_scalar(o, 0)
+
+            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)])
+            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
+            kv.pull(keys[j],out=out)
+            curval = threshold * rate * nworker
+            for o in out:
+                check_diff_to_scalar(o, curval)
+
+            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(0.2) for g in range(nworker)])
+            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
+            kv.pull(keys[j],out=out)
+            for o in out:
+                check_diff_to_scalar(o, curval)
+
+            kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*(threshold-0.3) for g in range(nworker)])
+            out = [mx.nd.zeros(shapes[j], mx.gpu(g)) for g in range(nworker)]
+            kv.pull(keys[j],out=out)
+            curval += threshold*rate*nworker
+            for o in out:
+                check_diff_to_scalar(o, curval)
+            # residual would be 0 now
+        return curval
+
+    def check_neg(kv, neg, rate, curval):
+        for r in range(nrepeat):
+            curval = curval + rate*nworker*neg
+            for j in range(len(keys)):
+                kv.push(keys[j], [mx.nd.ones(shapes[j], mx.gpu(g))*neg for g in range(nworker)])
+                out = [mx.nd.ones(shapes[j], mx.gpu(g)) for g in range(nworker)]
+                kv.pull(keys[j], out=out)
+                for o in out:
+                    check_diff_to_scalar(o, curval)
+            # residual would be 0 again
+
+    def check_compr_random(kv, threshold):
+        for k, s in zip(random_keys, shapes):
+            curr_residual = [np.zeros(s) for g in range(nworker)]
+            orig_val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)]
+            kv.pull(k, out=orig_val)
+            grads = [mx.nd.random_uniform(-0.6, 0.6, shape=s, ctx=mx.gpu(g)) for g in range(nworker)]
+            grads_cpy = copy.deepcopy(grads)
+            kv.push(k, grads)
+            val = [mx.nd.zeros(s, mx.gpu(g)) for g in range(nworker)]
+            kv.pull(k, out=val)
+            diffs = [val[g] - orig_val[g] for g in range(nworker)]
+            # compute expected by using simulation of operator
+            # on cpu
+            sum_dequantized_vals = np.zeros(s)
+            for g in range(nworker):
+                compr, curr_residual[g], decompr = compute_expected_2bit_quantization(
+                                                    grads_cpy[g], curr_residual[g], threshold)
+                sum_dequantized_vals += (decompr * rate)
+
+            for g in range(nworker):
+                assert_almost_equal(diffs[g].asnumpy(), sum_dequantized_vals)
+
+    pull_init_test(kv)
+    pull_before_push(kv)
+    push_zeros(kv)
+    curval = verify_residual(kv, threshold, rate)
+    check_neg(kv, -1*threshold, rate, curval)
+    check_compr_random(kv, threshold)
 
 ## group keys interface
 def test_group_kvstore(kv_type):
@@ -79,6 +230,27 @@ def test_group_kvstore(kv_type):
             err = sum(err) / np.sum(np.abs(a))
             assert(err < 1e-6), (err, a.shape)
 
-test_group_kvstore('local_update_cpu')
-test_group_kvstore('local_allreduce_cpu')
-test_group_kvstore('local_allreduce_device')
+if __name__ == "__main__":
+    keys = [3, 5, 7]
+    # let the last shape exceed MXNET_KVSTORE_BIGARRAY_BOUND
+    shapes = [(4, 4), (100, 100), (2000, 2000)]
+
+    gc_init_test_key = 9
+
+    lr = .1
+    nworker = 4
+    nrepeat = 10
+
+    ## generate data
+    data = [[[np.random.random(s)*2-1 for i in range(nworker)] for s in shapes] for j in range(nrepeat)]
+
+    test_kvstore('local_update_cpu')
+    test_kvstore('local_allreduce_cpu')
+    test_kvstore('local_allreduce_device')
+
+    # compression for local kvstore happens only when reduce is on device
+    test_compress_kvstore('local_allreduce_device')
+
+    test_group_kvstore('local_update_cpu')
+    test_group_kvstore('local_allreduce_cpu')
+    test_group_kvstore('local_allreduce_device')
diff --git a/tools/bandwidth/measure.py b/tools/bandwidth/measure.py
index 66ef737..cd4f0fe 100644
--- a/tools/bandwidth/measure.py
+++ b/tools/bandwidth/measure.py
@@ -53,6 +53,8 @@ def parse_args():
                         help='number of classes')
     parser.add_argument('--optimizer', type=str, default='None',
                         help='the optimizer set to kvstore. None means no optimizer')
+    parser.add_argument('--gc-type', type=str, default='none',
+                        help='type of gradient compression')
     args = parser.parse_args()
     logging.info(args)
     return args
@@ -72,10 +74,12 @@ def error(gpu_res, cpu_res):
     return res
 
 def run(network, optimizer, gpus, kv_store, image_shape, disp_batches,
-        num_batches, test_results, **kwargs):
+        num_batches, test_results, gc_type, **kwargs):
     # create kvstore and optimizer
     devs = [mx.gpu(int(i)) for i in gpus.split(',')]
     kv = mx.kv.create(kv_store)
+    if gc_type != 'none':
+        kv.set_gradient_compression({'type': gc_type})
     if optimizer is None or optimizer == 'None':
         opt = None
     else:

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].