You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2017/11/23 01:43:44 UTC

[incubator-mxnet] branch v1.0.0 updated: Indexing (#187)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch v1.0.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.0.0 by this push:
     new 2cdb2da  Indexing (#187)
2cdb2da is described below

commit 2cdb2dad1fdc719f3f11dc4b92844a8f8f38857b
Author: Aaron Markham <ma...@amazon.com>
AuthorDate: Wed Nov 22 17:20:47 2017 -0800

    Indexing (#187)
    
    * changed url references from dmlc to apache/incubator-mxnet
    
    * gradient compression faq
    
    * added examples, edited content and order
    
    * indexing features tutorial
    
    * further technical notes, plus example invocation
    
    * updates needed for gradient compression example
    
    * minor patch from rahul
    
    * minor edits after reviews
    
    * added reference and minor grammar fixes
    
    * removed one word
    
    * minor updates to text
---
 docs/faq/gradient_compression.md           | 107 ++++++++
 docs/faq/index.md                          |   7 +-
 docs/faq/multi_devices.md                  |  13 +
 docs/tutorials/basic/ndarray_indexing.md   | 377 +++++++++++++++++++++++++++++
 docs/tutorials/index.md                    |   1 +
 example/gluon/word_language_model/train.py |  10 +-
 python/mxnet/kvstore.py                    |  11 +-
 7 files changed, 519 insertions(+), 7 deletions(-)

diff --git a/docs/faq/gradient_compression.md b/docs/faq/gradient_compression.md
new file mode 100644
index 0000000..4cd58f0
--- /dev/null
+++ b/docs/faq/gradient_compression.md
@@ -0,0 +1,107 @@
+# Gradient Compression
+
+Gradient Compression reduces communication bandwidth, and in some scenarios, it can make training more scalable and efficient without significant loss in convergence rate or accuracy. Example implementations with GPUs, CPUs, and distributed training are provided in this document. 
+
+
+## Benefits
+
+**Increased Speed**
+
+For architectures with fully connected layers, the gradient compression capability is observed to speedup training by about 2x, depending on the size of the model and the network bandwidth of the instance. Bigger models see larger speedup with gradient compression.
+
+**Minimal Accuracy Loss**
+
+Gradient compression uses the approach of delaying the synchronization of weight updates which are small. Although small weight updates might not be sent for that batch, this information is not discarded. Once the weight updates for this location accumulate to become a larger value, they will be propagated. Since there is no information loss, but only delayed updates, it does not lead to a significant loss in accuracy or convergence rate. In distributed training experiments[1], the accur [...]
+
+
+## When to Use Gradient Compression
+
+When training models whose architectures include large fully connected components, it can be helpful to use gradient compression. For larger models, as well as recurrent neural networks, the communication cost becomes a major factor. Such models stand to benefit greatly with gradient compression.
+
+
+### GPU versus CPU
+
+The greatest benefits from gradient compression are realized when using multi-node (single or multi-GPU) distributed training. Training on CPU would provide a lower compute density per compute node as compared to the massive compute density per compute node on a GPU. Due to this, the required communication bandwidth for CPU-based nodes during training is not as high as for GPU-based nodes. Hence, the benefits of gradient compression are lower for CPU-based nodes as compared to GPU-based nodes.
+
+
+### Network Latency
+
+Benefits of gradient compression can be found when using distributed training with network connected nodes. Depending on the network latency between nodes and the model's size, these can contribute to slow performance such that gradient compression may provide speed improvements.
+
+You may not want to use gradient compression if you have low latency network communication.
+
+
+### Model Size
+
+Distributed training involves synchronization of weights after each batch. Larger models have much higher communication costs during training, hence such models stand to benefit much more from gradient compression.
+When running distributed training with gradient compression, the quantize and dequantize operations happen on CPU parallelized with OpenMP. For smaller models, when training on GPUs, it helps to set `OMP_NUM_THREADS=1` on each node, so that the overhead of launching OMP threads doesn't cause the compression and decompression to be slow.
+
+### Model Architecture
+
+The communication bandwidth requirements during training vary across various neural network architectures and hence the benefits of gradient compression vary accordingly.
+
+In networks which have significant fully connected components, since such layers have low compute cost on GPUs, communication becomes a bottleneck limiting the speed of distributed training. Gradient compression can help reduce the communication cost, and thus speed up training in such cases. We have observed speedup of about 2x on large fully connected neural networks. Models like AlexNet and VGG have large fully connected components as part of the network, hence stand to benefit from g [...]
+
+Architectures like Convolutional Neural Networks on the other hand have a higher compute cost, in which case some communication can be parallelized with computation. Since communication is not the bottleneck in such networks, gradient compression doesn't help much.
+
+
+### Single Node Gradient Compression
+
+When the training is configured to use device to device communication on a single node with multiple GPUs, gradient compression can be used to reduce the cost of communication. This can provide about 20% speedup for large models using older generation architectures. However, speed benefits may be negligible on a machine with a newer generation architecture where GPUs can communicate at low latency.
+
+
+## Approach
+
+The idea behind gradient compression comes from two observations:
+
+First, when training large neural networks, the gradients of weights computed for a small mini-batch of training data are typically sparse. Only a small fraction of the weights have significant updates after each mini-batch. The synchronization of updates that are near zero can be safely delayed longer than the typical mini-batch size. This essentially means that the rate of weight-update can vary depending on the value of an individual weight.
+
+Secondly, gradients can be compressed significantly by considering only those gradient elements whose absolute values exceed a threshold, and then quantizing them to use lower bits per gradient value. By compressing the gradients, we can reduce communication bandwidth. The delayed gradient values, in the form of quantization error and values that don't meet the threshold, are aggregated into a gradient residual which is communicated when it reaches the threshold.
+
+## Technical Implementation
+
+### Two Bit Quantization
+
+Currently the supported type of quantization uses two bits for each gradient value. Any positive value greater than or equal to the threshold sets two bits as `11`, any negative value whose absolute value is greater or equal to the threshold sets two bits as `10`, and others are set to `00`. This enables us to store 16 quantized gradients as one float. The error in quantization, which is `original_value - quantized_value` is stored in the form of a gradient residual.
+
+### Types of Kvstore
+
+Supported types of `kvstore` are `device` and all distributed kvstores such as `dist_sync`, `dist_async`, and `dist_sync_device`. When `kvstore` is `device`, the communication between GPUs is compressed. Please note that this increases the memory usage of GPUs because of the additional residual stored. When using a distributed kvstore, worker-to-server communication is compressed. In this case, compression and decompression happen on the CPU, and gradient residuals will be stored on the  [...]
+
+## Enabling the Gradient Compression in MXNet
+
+Gradient compression is a run-time configuration parameter to be enabled during training. Here are the MXNet APIs to enable gradient compression:
+
+**Gluon API**:
+
+```
+trainer = gluon.Trainer(..., compression_params={'type’:'2bit', 'threshold':0.5})
+```
+A reference `gluon` implementation with a gradient compression option can be found in the [train.py script from a word-level language modeling RNN example](https://github.com/apache/incubator-mxnet/blob/master/example/gluon/word_language_model/train.py).
+
+**Module API**:
+
+```
+mod = mx.mod.Module(..., compression_params={'type’:'2bit', 'threshold':0.5})
+```
+
+A `module` example is provided with [this guide for setting up MXNet with distributed training](https://mxnet.incubator.apache.org/versions/master/how_to/multi_devices.html#distributed-training-with-multiple-machines). It comes with the option of turning on gradient compression as an argument to the [train_mnist.py script](https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/train_mnist.py).
+
+### Configuration Details
+
+**Threshold**
+
+A default `threshold` value of `0.5` is good for most use cases, but to get the most benefit from gradient compression for a particular scenario, it can be beneficial to experiment. If the threshold is set to a very large value, say `10.0`, then the updates become too infrequent and the training will converge slower. Setting the threshold automatically is expected in a future release.
+
+**Quantization**
+
+This release supports 2-bit quantization for encoding of gradients to reduce the communication bandwidth during training. Future releases will support 1-bit quantization and other approaches for encoding of gradients based on experimental evidence of benefits and user demand.
+
+**Sparse Format**
+
+We believe that the density of data will need to be really low (i.e. around > 90% zeros) to reap benefits of the sparse format. However, this is an area of experimentation that will be explored in a future release.
+
+
+## References
+
+1. [Nikko Storm, Amazon.com, Scalable Distributed Training using commodity GPU cloud computing.](https://s3-us-west-2.amazonaws.com/amazon.jobs-public-documents/strom_interspeech2015.pdf)
diff --git a/docs/faq/index.md b/docs/faq/index.md
index 883d8e6..68c7d41 100644
--- a/docs/faq/index.md
+++ b/docs/faq/index.md
@@ -14,12 +14,15 @@ and full working examples, visit the [tutorials section](../tutorials/index.md).
 * [How do I visualize neural networks as computation graphs?](http://mxnet.io/how_to/visualize_graph.html)
 
 
-## Speed
-
+## Scale
 * [How can I train with multiple CPU/GPUs with data parallelism?](http://mxnet.io/how_to/multi_devices.html)
 
 * [How can I train with multiple GPUs with model parallelism?](http://mxnet.io/how_to/model_parallel_lstm.html)
 
+
+## Speed
+* [How do I use gradient compression with distributed training?](http://mxnet.io/how_to/gradient_compression.html)
+
 * [Can I use nnpack to improve the CPU performance of MXNet?](http://mxnet.io/how_to/nnpack.html)
 
 * [What are the best setup and data-handling tips and tricks for improving speed?](http://mxnet.io/how_to/perf.html)
diff --git a/docs/faq/multi_devices.md b/docs/faq/multi_devices.md
index 3272062..c79d1f8 100644
--- a/docs/faq/multi_devices.md
+++ b/docs/faq/multi_devices.md
@@ -167,6 +167,19 @@ python ../../tools/launch.py -n 2 -H hosts --sync-dst-dir /tmp/mxnet \
    python train_mnist.py --network lenet --kv-store dist_sync
 ```
 
+
+### Gradient compression
+
+If your model has fully connected components or recurrent neural networks, you may achieve increased training speed using gradient compression with potentially slight loss of accuracy. Please see [Gradient Compression](https://mxnet.incubator.apache.org/versions/master/faq/gradient_compression.html) for more details on when and how to use it. For the above example, gradient compression can be enabled by running the following:
+
+```bash
+python ../../tools/launch.py -n 2 --launcher ssh -H hosts python train_mnist.py --network lenet \
+    --kv-store dist_sync --gc-type 2bit
+```
+
+In this example, `gc-type` has been set to `2bit`, to enable two bit gradient compression.
+
+
 ### Use a Particular Network Interface
 
 _MXNet_ often chooses the first available network interface.
diff --git a/docs/tutorials/basic/ndarray_indexing.md b/docs/tutorials/basic/ndarray_indexing.md
new file mode 100644
index 0000000..37168b3
--- /dev/null
+++ b/docs/tutorials/basic/ndarray_indexing.md
@@ -0,0 +1,377 @@
+
+# NDArray Indexing - Array indexing features
+
+MXNet's advanced indexing features are modeled after [NumPy's implementation and documentation](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing). You will see direct adaptations of many NumPy indexing features and examples which are close, if not identical, so we borrow much from their documentation.
+
+`NDArray`s can be indexed using the standard Python `x[obj]` syntax, where _x_ is the array and _obj_ the selection.
+
+There are two kinds of indexing available:
+
+1. basic slicing
+1. advanced indexing
+
+In MXNet, we support both basic and advanced indexing following the convention of indexing NumPy's `ndarray`.
+
+
+## Basic Slicing and Indexing
+
+Basic slicing extends Python’s basic concept of slicing to N dimensions. For a quick review:
+
+```
+a[start:end] # items start through end-1
+a[start:]    # items start through the rest of the array
+a[:end]      # items from the beginning through end-1
+a[:]         # a copy of the whole array
+```
+
+
+```python
+from mxnet import nd
+```
+
+For some working examples of basic slicing we'll start simple.
+
+
+```python
+x = nd.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int32')
+x[5:]
+```
+
+
+
+
+
+    [5 6 7 8 9]
+    <NDArray 5 @cpu(0)>
+
+
+
+
+```python
+x = nd.array([0, 1, 2, 3])
+print('1D complete array, x=', x)
+s = x[1:3]
+print('slicing the 2nd and 3rd elements, s=', s)
+```
+
+    1D complete array, x=
+    [ 0.  1.  2.  3.]
+    <NDArray 4 @cpu(0)>
+    slicing the 2nd and 3rd elements, s=
+    [ 1.  2.]
+    <NDArray 2 @cpu(0)>
+
+
+Now let's try slicing the 2nd and 3rd elements of a multi-dimensional array.
+
+
+```python
+x = nd.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
+print('multi-D complete array, x=', x)
+s = x[1:3]
+print('slicing the 2nd and 3rd elements, s=', s)
+```
+
+    multi-D complete array, x=
+    [[  1.   2.   3.   4.]
+     [  5.   6.   7.   8.]
+     [  9.  10.  11.  12.]]
+    <NDArray 3x4 @cpu(0)>
+    slicing the 2nd and 3rd elements, s=
+    [[  5.   6.   7.   8.]
+     [  9.  10.  11.  12.]]
+    <NDArray 2x4 @cpu(0)>
+
+
+Now let's try writing to a specific element. We'll write `9` to element `2` using `x[2] = 9.0`, which will update the whole row.
+
+
+```python
+print('original x, x=', x)
+x[2] = 9.0
+print('replaced entire row with x[2] = 9.0, x=', x)
+```
+
+    original x, x=
+    [[  1.   2.   3.   4.]
+     [  5.   6.   7.   8.]
+     [  9.  10.  11.  12.]]
+    <NDArray 3x4 @cpu(0)>
+    replaced entire row with x[2] = 9.0, x=
+    [[ 1.  2.  3.  4.]
+     [ 5.  6.  7.  8.]
+     [ 9.  9.  9.  9.]]
+    <NDArray 3x4 @cpu(0)>
+
+
+We can target specific elements too. Let's replace the number `3` in the first row with the number `9` using `x[0, 2] = 9.0`.
+
+
+```python
+print('original x, x=', x)
+x[0, 2] = 9.0
+print('replaced specific element with x[0, 2] = 9.0, x=', x)
+```
+
+    original x, x=
+    [[ 1.  2.  3.  4.]
+     [ 5.  6.  7.  8.]
+     [ 9.  9.  9.  9.]]
+    <NDArray 3x4 @cpu(0)>
+    replaced specific element with x[0, 2] = 9.0, x=
+    [[ 1.  2.  9.  4.]
+     [ 5.  6.  7.  8.]
+     [ 9.  9.  9.  9.]]
+    <NDArray 3x4 @cpu(0)>
+
+
+Now lets target even more by selecting a couple of targets at the same time. We'll replace the `6` and the `7` with `x[1:2, 1:3] = 5.0`.
+
+
+```python
+print('original x, x=', x)
+x[1:2, 1:3] = 5.0
+print('replaced range of elements with x[1:2, 1:3] = 5.0, x=', x)
+```
+
+    original x, x=
+    [[ 1.  2.  9.  4.]
+     [ 5.  6.  7.  8.]
+     [ 9.  9.  9.  9.]]
+    <NDArray 3x4 @cpu(0)>
+    replaced range of elements with x[1:2, 1:3] = 5.0, x=
+    [[ 1.  2.  9.  4.]
+     [ 5.  5.  5.  8.]
+     [ 9.  9.  9.  9.]]
+    <NDArray 3x4 @cpu(0)>
+
+
+## New Indexing Features in v1.0
+
+### Step
+
+The basic slice syntax is `i:j:k` where _i_ is the starting index, _j_ is the stopping index, and _k_ is the step (_k_ must be nonzero).
+
+**Note**: Previously, MXNet supported basic slicing and indexing only with `step=1`. From release 1.0, arbitrary values of `step` are supported.
+
+
+```python
+x = nd.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='int32')
+# Select elements 1 through 7, and use a step of 2
+x[1:7:2]
+```
+
+
+
+
+
+    [1 3 5]
+    <NDArray 3 @cpu(0)>
+
+
+
+## Negative Indices
+Negative _i_ and _j_ are interpreted as _n + i_ and _n + j_ where _n_ is the number of elements in the corresponding dimension. Negative _k_ makes stepping go towards smaller indices.
+
+
+```python
+x[-2:10]
+```
+
+
+
+
+
+    [8 9]
+    <NDArray 2 @cpu(0)>
+
+
+
+If the number of objects in the selection tuple is less than N , then : is assumed for any subsequent dimensions.
+
+
+```python
+x = nd.array([[[1],[2],[3]],
+              [[4],[5],[6]]], dtype='int32')
+x[1:2]
+```
+
+
+
+
+
+    [[[4]
+      [5]
+      [6]]]
+    <NDArray 1x3x1 @cpu(0)>
+
+
+
+You may use slicing to set values in the array, but (unlike lists) you can never grow the array. The size of the value to be set in `x[obj] = value` must be able to broadcast to the same shape as `x[obj]`.
+
+
+```python
+x = nd.arange(16, dtype='int32').reshape((4, 4))
+print(x)
+```
+
+
+    [[ 0  1  2  3]
+     [ 4  5  6  7]
+     [ 8  9 10 11]
+     [12 13 14 15]]
+    <NDArray 4x4 @cpu(0)>
+
+
+
+```python
+print(x[1:4:2, 3:0:-1])
+```
+
+
+    [[ 7  6  5]
+     [15 14 13]]
+    <NDArray 2x3 @cpu(0)>
+
+
+
+```python
+x[1:4:2, 3:0:-1] = [[16], [17]]
+print(x)
+```
+
+
+    [[ 0  1  2  3]
+     [ 4 16 16 16]
+     [ 8  9 10 11]
+     [12 17 17 17]]
+    <NDArray 4x4 @cpu(0)>
+
+
+## New Advanced Indexing Features in v1.0
+
+Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object (e.g. a Python list), a NumPy `ndarray` (of data type integer), an MXNet `NDArray`, or a tuple with at least one sequence object.
+
+Advanced indexing always returns a __copy__ of the data.
+
+**Note**:
+- When the selection object is a Python list, it must be a list of integers. MXNet does not support the selection object being a nested list. That is, `x[[1, 2]]` is supported, while `x[[1], [2]]` is not.
+- When the selection object is a NumPy `ndarray` or an MXNet `NDArray`, there is no dimension restrictions on the object.
+- When the selection object is a tuple containing Python list(s), both integer lists and nested lists are supported. That is, both `x[1:4, [1, 2]]` and `x[1:4, [[1], [2]]` are supported.
+
+### Purely Integer Array Indexing
+When the index consists of as many integer arrays as the array being indexed has dimensions, the indexing is straight forward, but different from slicing.
+
+Advanced indexes always are [broadcast](https://docs.scipy.org/doc/numpy-1.13.0/reference/ufuncs.html#ufuncs-broadcasting) and iterated as one:
+```python
+result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
+                           ..., ind_N[i_1, ..., i_M]]
+```
+Note that the result shape is identical to the (broadcast) indexing array shapes `ind_1, ..., ind_N`.
+
+**Example:**
+From each row, a specific element should be selected. The row index is just [0, 1, 2] and the column index specifies the element to choose for the corresponding row, here [0, 1, 0]. Using both together the task can be solved using advanced indexing:
+
+
+```python
+x = nd.array([[1, 2],
+              [3, 4],
+              [5, 6]], dtype='int32')
+x[[0, 1, 2], [0, 1, 0]]
+```
+
+
+
+
+
+    [1 4 5]
+    <NDArray 3 @cpu(0)>
+
+
+
+To achieve a behavior similar to the basic slicing above, broadcasting can be used. This is best understood with an example.
+
+Example:
+From a 4x3 array the corner elements should be selected using advanced indexing. Thus all elements for which the column is one of `[0, 2]` and the row is one of `[0, 3]` need to be selected. To use advanced indexing one needs to select all elements explicitly. Using the method explained previously one could write:
+
+
+```python
+x = nd.array([[ 0,  1,  2],
+              [ 3,  4,  5],
+              [ 6,  7,  8],
+              [ 9, 10, 11]], dtype='int32')
+x[[[0, 0], [3, 3]],
+  [[0, 2], [0, 2]]]
+```
+
+
+
+
+
+    [[ 0  2]
+     [ 9 11]]
+    <NDArray 2x2 @cpu(0)>
+
+
+
+However, since the indexing arrays above just repeat themselves, broadcasting can be used.
+
+
+```python
+x[[[0], [3]],
+  [[0, 2]]]
+```
+
+
+
+
+
+    [[ 0  2]
+     [ 9 11]]
+    <NDArray 2x2 @cpu(0)>
+
+
+
+### Combining Advanced and Basic Indexing
+There are three situations we need to consider when mix advanced and basic indices in a single selection object. Let's look at examples to understand each one's behavior.
+
+- There is only one advanced index in the selection object. For example, `x` is an `NDArray` with `shape=(10, 20, 30, 40, 50)` and `result=x[:, :, ind]` has one advanced index `ind` with `shape=(2, 3, 4)` on the third axis. The `result` will have `shape=(10, 20, 2, 3, 4, 40, 50)` because the subspace of `x` in the third dimension is replaced by the subspace of `shape=(2, 3, 4)`. If we let _i_, _j_, _k_ loop over the (2, 3, 4)-shaped subspace, it is equivalent to `result[:, :, i, j, k, :, [...]
+
+
+```python
+import numpy as np
+shape = (10, 20, 30, 40, 50)
+x = nd.arange(np.prod(shape), dtype='int32').reshape(shape)
+ind = nd.arange(24).reshape((2, 3, 4))
+print(x[:, :, ind].shape)
+```
+
+    (10, 20, 2, 3, 4, 40, 50)
+
+
+- There are at least two advanced indices in the selection object, and all the advanced indices are adjacent to each other. For example, `x` is an `NDArray` with `shape=(10, 20, 30, 40, 50)` and `result=x[:, :, ind1, ind2, :]` has two advanced indices with shapes that are broadcastable to `shape=(2, 3, 4)`. Then the `result` has `shape=(10, 20, 2, 3, 4, 50)` because `(30, 40)`-shaped subspace has been replaced with `(2, 3, 4)`-shaped subspace from the indices.
+
+
+```python
+ind1 = [0, 1, 2, 3]
+ind2 = [[[0], [1], [2]], [[3], [4], [5]]]
+print(x[:, :, ind1, ind2, :].shape)
+```
+
+    (10, 20, 2, 3, 4, 50)
+
+
+- There are at least two advanced indices in the selection object, and there is at least one advanced index separated from the others by basic indices. For example,  `x` is an `NDArray` with `shape=(10, 20, 30, 40, 50)` and `result=x[:, :, ind1, :, ind2]` has two advanced indices with shapes that are broadcastable to `shape=(2, 3, 4)`. Then the `result` has `shape=(2, 3, 4, 10, 20, 40)` because there is no unambiguous place to place the indexing subspace, hence it is prepended to the beginning.
+
+
+```python
+print(x[:, :, ind1, :, ind2].shape)
+```
+
+    (2, 3, 4, 10, 20, 40)
+
+## References
+
+[NumPy documentation](https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing)
+
+<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 6429dfb..325f385 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -39,6 +39,7 @@ please see [gluon.mxnet.io](http://gluon.mxnet.io).
    :maxdepth: 1
 
    basic/ndarray
+   basic/ndarray_indexing
    basic/symbol
    basic/module
    basic/data
diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model/train.py
index 0b50499..b419277 100644
--- a/example/gluon/word_language_model/train.py
+++ b/example/gluon/word_language_model/train.py
@@ -54,6 +54,11 @@ parser.add_argument('--log-interval', type=int, default=200, metavar='N',
                     help='report interval')
 parser.add_argument('--save', type=str, default='model.params',
                     help='path to save the final model')
+parser.add_argument('--gctype', type=str, default='none',
+                    help='type of gradient compression to use, \
+                          takes `2bit` or `none` for now.')
+parser.add_argument('--gcthreshold', type=float, default=0.5,
+                    help='threshold for 2bit gradient compression')
 args = parser.parse_args()
 
 
@@ -90,10 +95,13 @@ ntokens = len(corpus.dictionary)
 model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
                        args.nlayers, args.dropout, args.tied)
 model.collect_params().initialize(mx.init.Xavier(), ctx=context)
+
+compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold}
 trainer = gluon.Trainer(model.collect_params(), 'sgd',
                         {'learning_rate': args.lr,
                          'momentum': 0,
-                         'wd': 0})
+                         'wd': 0},
+                        compression_params=compression_params)
 loss = gluon.loss.SoftmaxCrossEntropyLoss()
 
 ###############################################################################
diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py
index d068d06..a6d3aa5 100644
--- a/python/mxnet/kvstore.py
+++ b/python/mxnet/kvstore.py
@@ -408,10 +408,13 @@ class KVStore(object):
             Other keys in this dictionary are optional and specific to the type
             of gradient compression.
         """
-        ckeys, cvals = _ctype_dict(compression_params)
-        check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
-                                                        mx_uint(len(compression_params)),
-                                                        ckeys, cvals))
+        if (self.type == 'device') or ('dist' in self.type):
+            ckeys, cvals = _ctype_dict(compression_params)
+            check_call(_LIB.MXKVStoreSetGradientCompression(self.handle,
+                                                            mx_uint(len(compression_params)),
+                                                            ckeys, cvals))
+        else:
+            raise Exception('Gradient compression is not supported for this type of kvstore')
 
     def set_optimizer(self, optimizer):
         """ Registers an optimizer with the kvstore.

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].