You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/07 18:13:09 UTC

[GitHub] szha closed pull request #13129: Gluon LSTM Projection and Clipping Support (#13055) v1.3.x

szha closed pull request #13129: Gluon LSTM Projection and Clipping Support (#13055) v1.3.x
URL: https://github.com/apache/incubator-mxnet/pull/13129
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 1e38ec48e6c..ba88b3886d7 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -612,6 +612,7 @@ unittest_ubuntu_python2_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1  # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-2.7 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
@@ -644,6 +645,7 @@ unittest_ubuntu_python3_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
@@ -660,6 +662,7 @@ unittest_ubuntu_tensorrt_gpu() {
     export PYTHONPATH=./python/
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
     export LD_LIBRARY_PATH=/work/mxnet/lib:$LD_LIBRARY_PATH
+    export CUDNN_VERSION=7.0.3
     python tests/python/tensorrt/lenet5_train.py
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_trt_gpu.xml --verbose tests/python/tensorrt/
 }
@@ -671,6 +674,7 @@ unittest_ubuntu_python2_quantization_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1  # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-2.7 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_quantization_gpu.xml --verbose tests/python/quantization_gpu
 }
 
@@ -681,6 +685,7 @@ unittest_ubuntu_python3_quantization_gpu() {
     export PYTHONPATH=./python/
     export MXNET_MKLDNN_DEBUG=1 # Ignored if not present
     export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
+    export CUDNN_VERSION=7.0.3
     nosetests-3.4 $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_quantization_gpu.xml --verbose tests/python/quantization_gpu
 }
 
@@ -735,6 +740,7 @@ unittest_centos7_cpu() {
 unittest_centos7_gpu() {
     set -ex
     cd /work/mxnet
+    export CUDNN_VERSION=7.0.3
     python3.6 -m "nose" $NOSE_COVERAGE_ARGUMENTS --with-xunit --xunit-file nosetests_gpu.xml --verbose tests/python/gpu
 }
 
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index daf8ecbf563..b9e82f1ce08 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -35,11 +35,14 @@ def __init__(self, hidden_size, num_layers, layout,
                  dropout, bidirectional, input_size,
                  i2h_weight_initializer, h2h_weight_initializer,
                  i2h_bias_initializer, h2h_bias_initializer,
-                 mode, **kwargs):
+                 mode, projection_size, h2r_weight_initializer,
+                 lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
+                 **kwargs):
         super(_RNNLayer, self).__init__(**kwargs)
         assert layout in ('TNC', 'NTC'), \
             "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
         self._hidden_size = hidden_size
+        self._projection_size = projection_size if projection_size else None
         self._num_layers = num_layers
         self._mode = mode
         self._layout = layout
@@ -50,25 +53,50 @@ def __init__(self, hidden_size, num_layers, layout,
         self._h2h_weight_initializer = h2h_weight_initializer
         self._i2h_bias_initializer = i2h_bias_initializer
         self._h2h_bias_initializer = h2h_bias_initializer
+        self._h2r_weight_initializer = h2r_weight_initializer
+        self._lstm_state_clip_min = lstm_state_clip_min
+        self._lstm_state_clip_max = lstm_state_clip_max
+        self._lstm_state_clip_nan = lstm_state_clip_nan
 
         self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
 
         ng, ni, nh = self._gates, input_size, hidden_size
-        for i in range(num_layers):
-            for j in ['l', 'r'][:self._dir]:
-                self._register_param('{}{}_i2h_weight'.format(j, i),
-                                     shape=(ng*nh, ni),
-                                     init=i2h_weight_initializer)
-                self._register_param('{}{}_h2h_weight'.format(j, i),
-                                     shape=(ng*nh, nh),
-                                     init=h2h_weight_initializer)
-                self._register_param('{}{}_i2h_bias'.format(j, i),
-                                     shape=(ng*nh,),
-                                     init=i2h_bias_initializer)
-                self._register_param('{}{}_h2h_bias'.format(j, i),
-                                     shape=(ng*nh,),
-                                     init=h2h_bias_initializer)
-            ni = nh * self._dir
+        if not projection_size:
+            for i in range(num_layers):
+                for j in ['l', 'r'][:self._dir]:
+                    self._register_param('{}{}_i2h_weight'.format(j, i),
+                                         shape=(ng*nh, ni),
+                                         init=i2h_weight_initializer)
+                    self._register_param('{}{}_h2h_weight'.format(j, i),
+                                         shape=(ng*nh, nh),
+                                         init=h2h_weight_initializer)
+                    self._register_param('{}{}_i2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=i2h_bias_initializer)
+                    self._register_param('{}{}_h2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=h2h_bias_initializer)
+                ni = nh * self._dir
+        else:
+            np = self._projection_size
+            for i in range(num_layers):
+                for j in ['l', 'r'][:self._dir]:
+                    self._register_param('{}{}_i2h_weight'.format(j, i),
+                                         shape=(ng*nh, ni),
+                                         init=i2h_weight_initializer)
+                    self._register_param('{}{}_h2h_weight'.format(j, i),
+                                         shape=(ng*nh, np),
+                                         init=h2h_weight_initializer)
+                    self._register_param('{}{}_i2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=i2h_bias_initializer)
+                    self._register_param('{}{}_h2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=h2h_bias_initializer)
+                    self._register_param('{}{}_h2r_weight'.format(j, i),
+                                         shape=(np, nh),
+                                         init=h2r_weight_initializer)
+                ni = np * self._dir
 
     def _register_param(self, name, shape, init):
         p = self.params.get(name, shape=shape, init=init,
@@ -114,6 +142,9 @@ def state_info(self, batch_size=0):
 
     def _unfuse(self):
         """Unfuses the fused RNN in to a stack of rnn cells."""
+        assert not self._projection_size, "_unfuse does not support projection layer yet!"
+        assert not self._lstm_state_clip_min and not self._lstm_state_clip_max, \
+                "_unfuse does not support state clipping yet!"
         get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
                                                                   activation='relu',
                                                                   **kwargs),
@@ -189,7 +220,7 @@ def hybrid_forward(self, F, inputs, states=None, **kwargs):
         skip_states = states is None
         if skip_states:
             if F is ndarray:
-                states = self.begin_state(batch_size, ctx=inputs.context)
+                states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype)
             else:
                 states = self.begin_state(0, func=symbol.zeros)
         if isinstance(states, tensor_types):
@@ -209,16 +240,29 @@ def _forward_kernel(self, F, inputs, states, **kwargs):
         """ forward using CUDNN or CPU kenrel"""
         if self._layout == 'NTC':
             inputs = F.swapaxes(inputs, dim1=0, dim2=1)
-        params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
-                  for t in ['weight', 'bias']
-                  for l in range(self._num_layers)
-                  for d in ['l', 'r'][:self._dir]
-                  for g in ['i2h', 'h2h'])
+        if self._projection_size is None:
+            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                      for t in ['weight', 'bias']
+                      for l in range(self._num_layers)
+                      for d in ['l', 'r'][:self._dir]
+                      for g in ['i2h', 'h2h'])
+        else:
+            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                      for t in ['weight', 'bias']
+                      for l in range(self._num_layers)
+                      for d in ['l', 'r'][:self._dir]
+                      for g in ['i2h', 'h2h', 'h2r']
+                      if g != 'h2r' or t != 'bias')
+
         params = F._internal._rnn_param_concat(*params, dim=0)
 
         rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
+                    projection_size=self._projection_size,
                     num_layers=self._num_layers, bidirectional=self._dir == 2,
-                    p=self._dropout, state_outputs=True, mode=self._mode)
+                    p=self._dropout, state_outputs=True, mode=self._mode,
+                    lstm_state_clip_min=self._lstm_state_clip_min,
+                    lstm_state_clip_max=self._lstm_state_clip_max,
+                    lstm_state_clip_nan=self._lstm_state_clip_nan)
 
         if self._mode == 'lstm':
             outputs, states = rnn[0], [rnn[1], rnn[2]]
@@ -318,7 +362,8 @@ def __init__(self, hidden_size, num_layers=1, activation='relu',
                                   dropout, bidirectional, input_size,
                                   i2h_weight_initializer, h2h_weight_initializer,
                                   i2h_bias_initializer, h2h_bias_initializer,
-                                  'rnn_'+activation, **kwargs)
+                                  'rnn_'+activation, None, None, None, None, False,
+                                  **kwargs)
 
     def state_info(self, batch_size=0):
         return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
@@ -373,6 +418,20 @@ class LSTM(_RNNLayer):
         to zero.
     h2h_bias_initializer : str or Initializer
         Initializer for the bias vector.
+    projection_size: int, default None
+        The number of features after projection.
+    h2r_weight_initializer : str or Initializer, default None
+        Initializer for the projected recurrent weights matrix, used for the linear
+        transformation of the recurrent state to the projected space.
+    state_clip_min : float or None, default None
+        Minimum clip value of LSTM states. This option must be used together with
+        state_clip_max. If None, clipping is not applied.
+    state_clip_max : float or None, default None
+        Maximum clip value of LSTM states. This option must be used together with
+        state_clip_min. If None, clipping is not applied.
+    state_clip_nan : boolean, default False
+        Whether to stop NaN from propagating in state by clipping it to min/max.
+        If the clipping range is not specified, this option is ignored.
     input_size: int, default 0
         The number of expected features in the input x.
         If not specified, it will be inferred from input.
@@ -416,18 +475,28 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
                  dropout=0, bidirectional=False, input_size=0,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
+                 projection_size=None, h2r_weight_initializer=None,
+                 state_clip_min=None, state_clip_max=None, state_clip_nan=False,
                  **kwargs):
         super(LSTM, self).__init__(hidden_size, num_layers, layout,
                                    dropout, bidirectional, input_size,
                                    i2h_weight_initializer, h2h_weight_initializer,
                                    i2h_bias_initializer, h2h_bias_initializer,
-                                   'lstm', **kwargs)
+                                   'lstm', projection_size, h2r_weight_initializer,
+                                   state_clip_min, state_clip_max, state_clip_nan,
+                                   **kwargs)
 
     def state_info(self, batch_size=0):
-        return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
-                 '__layout__': 'LNC'},
-                {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
-                 '__layout__': 'LNC'}]
+        if self._projection_size is None:
+            return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
+                     '__layout__': 'LNC'},
+                    {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
+                     '__layout__': 'LNC'}]
+        else:
+            return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
+                     '__layout__': 'LNC'},
+                    {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
+                     '__layout__': 'LNC'}]
 
 
 class GRU(_RNNLayer):
@@ -519,7 +588,8 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
                                   dropout, bidirectional, input_size,
                                   i2h_weight_initializer, h2h_weight_initializer,
                                   i2h_bias_initializer, h2h_bias_initializer,
-                                  'gru', **kwargs)
+                                  'gru', None, None, None, None, False,
+                                  **kwargs)
 
     def state_info(self, batch_size=0):
         return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 63b75cf2a23..0875e562e3c 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -336,7 +336,7 @@ def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=Non
         assert(False), "unknown storage type"
         return False
 
-def rand_ndarray(shape, stype, density=None, dtype=None,
+def rand_ndarray(shape, stype='default', density=None, dtype=None,
                  modifier_func=None, shuffle_csr_indices=False, distribution=None):
     if stype == 'default':
         arr = mx.nd.array(random_arrays(shape), dtype=dtype)
diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
index b33a717d15b..f7f5e5115a1 100644
--- a/src/operator/cudnn_rnn-inl.h
+++ b/src/operator/cudnn_rnn-inl.h
@@ -26,6 +26,8 @@
 #ifndef MXNET_OPERATOR_CUDNN_RNN_INL_H_
 #define MXNET_OPERATOR_CUDNN_RNN_INL_H_
 
+#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
+
 #include <mxnet/storage.h>
 #include <vector>
 #include <map>
@@ -38,7 +40,7 @@ namespace mxnet {
 namespace op {
 #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
 template<typename DType>
-class CuDNNRNNOp : public Operator{
+class CuDNNRNNOp : public Operator {
  public:
   explicit CuDNNRNNOp(RNNParam param) {
     this->param_ = param;
@@ -69,6 +71,32 @@ class CuDNNRNNOp : public Operator{
       default:
         LOG(FATAL) << "Not implmented";
     }
+#if USE_CUDNN_LSTM_PROJ
+    if (param_.projection_size.has_value()) {
+      CHECK_EQ(param_.mode, rnn_enum::kLstm)
+        << "Projection is only supported for LSTM.";
+      CHECK_GE(param_.state_size, param_.projection_size.value())
+        << "State size must be larger than projection size.";
+    }
+#else
+    CHECK(!param_.projection_size.has_value())
+      << "Projection is only supported for LSTM with CuDNN version later than 7.1.1.";
+#endif
+#if USE_CUDNN_LSTM_PROJ
+    if (param_.lstm_state_clip_min.has_value()
+        || param_.lstm_state_clip_max.has_value()) {
+      CHECK_EQ(param_.mode, rnn_enum::kLstm)
+        << "State clipping is only supported for LSTM.";
+      CHECK(param_.lstm_state_clip_min.has_value() && param_.lstm_state_clip_max.has_value())
+        << "lstm_state_clip_min and lstm_state_clip_max must be specified together.";
+      CHECK_GE(param_.lstm_state_clip_max.value(), param_.lstm_state_clip_min.value())
+        << "lstm_state_clip_max must be greater or equal to lstm_state_clip_min";
+    }
+#else
+    CHECK(!param_.lstm_state_clip_min.has_value()
+          && !param_.lstm_state_clip_max.has_value())
+      << "State clipping is only supported for LSTM with CuDNN version later than 7.2.1.";
+#endif
     // RNN Direction
     direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
     // Other
@@ -92,6 +120,13 @@ class CuDNNRNNOp : public Operator{
 
     CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
     CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
+
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
+    #endif
   }
 
   ~CuDNNRNNOp() {
@@ -123,6 +158,12 @@ class CuDNNRNNOp : public Operator{
         Storage::Get()->Free(dropout_states_);
       }
     }
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
+    #endif
   }
 
   virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
@@ -169,7 +210,89 @@ class CuDNNRNNOp : public Operator{
     Tensor<gpu, 1, DType> temp_space =
       ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
                               mshadow::Shape1(temp_size), s);
+    #if USE_CUDNN_LSTM_PROJ
+    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+                                         dtype_,
+                                         CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         param_.input_size_,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    int out_size =
+      (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size;
+    out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+                                         dtype_,
+                                         CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         out_size,
+                                         seqLengthArray.data(),
+                                         nullptr));
     if (ctx.is_train) {
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+                                           dtype_,
+                                           CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+                                           dtype_,
+                                           CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           out_size,
+                                           seqLengthArray.data(),
+                                           nullptr));
+    }
+    #endif
+
+    #if USE_CUDNN_LSTM_PROJ
+    bool clip_state = param_.lstm_state_clip_min.has_value();
+    bool clip_nan = param_.lstm_state_clip_nan;
+    CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+                               rnn_desc_,
+                               clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE,
+                               clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN,
+                               clip_state ? param_.lstm_state_clip_min.value() : 0.0,
+                               clip_state ? param_.lstm_state_clip_max.value() : 0.0));
+    #endif
+
+    if (ctx.is_train) {
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+                                           rnn_desc_,
+                                           x_data_desc_,
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           cx_desc_,
+                                           cx_ptr,
+                                           w_desc_,
+                                           w.dptr_,
+                                           y_data_desc_,
+                                           y.dptr_,
+                                           hy_desc_,
+                                           hy_ptr,
+                                           cy_desc_,
+                                           cy_ptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           temp_space.dptr_,
+                                           workspace_byte_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+      #else
       CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
                                          rnn_desc_,
                                          param_.seq_length_,
@@ -191,8 +314,36 @@ class CuDNNRNNOp : public Operator{
                                          workspace_byte_,
                                          reserve_space_.dptr,
                                          reserve_space_byte_));
+      #endif
     } else {
-      // inference mode
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+                                            rnn_desc_,
+                                            x_data_desc_,
+                                            x.dptr_,
+                                            hx_desc_,
+                                            hx.dptr_,
+                                            cx_desc_,
+                                            cx_ptr,
+                                            w_desc_,
+                                            w.dptr_,
+                                            y_data_desc_,
+                                            y.dptr_,
+                                            hy_desc_,
+                                            hy_ptr,
+                                            cy_desc_,
+                                            cy_ptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            temp_space.dptr_,
+                                            workspace_byte_));
+      #else
       CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
                                           rnn_desc_,
                                           param_.seq_length_,
@@ -212,6 +363,7 @@ class CuDNNRNNOp : public Operator{
                                           cy_ptr,
                                           temp_space.dptr_,
                                           workspace_byte_));
+      #endif
     }
   }
 
@@ -283,6 +435,52 @@ class CuDNNRNNOp : public Operator{
     Tensor<gpu, 1, DType> temp_space =
       ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
                               mshadow::Shape1(temp_size), s);
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+                                      rnn_desc_,
+                                      y_data_desc_,
+                                      y.dptr_,
+                                      dy_data_desc_,
+                                      dy.dptr_,
+                                      nullptr,
+                                      nullptr,
+                                      dhy_desc_,
+                                      dhy_ptr,
+                                      dcy_desc_,
+                                      dcy_ptr,
+                                      w_desc_,
+                                      w.dptr_,
+                                      hx_desc_,
+                                      hx.dptr_,
+                                      cx_desc_,
+                                      cx_ptr,
+                                      dx_data_desc_,
+                                      dx.dptr_,
+                                      dhx_desc_,
+                                      dhx.dptr_,
+                                      dcx_desc_,
+                                      dcx_ptr,
+                                      nullptr,
+                                      nullptr,
+                                      temp_space.dptr_,
+                                      workspace_byte_,
+                                      reserve_space_.dptr,
+                                      reserve_space_byte_));
+    CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+                                         rnn_desc_,
+                                         x_data_desc_,
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         y_data_desc_,
+                                         y.dptr_,
+                                         temp_space.dptr_,
+                                         workspace_byte_,
+                                         dw_desc_,
+                                         dw.dptr_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+    #else
     CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
                                     rnn_desc_,
                                     param_.seq_length_,
@@ -325,6 +523,7 @@ class CuDNNRNNOp : public Operator{
                                        dw.dptr_,
                                        reserve_space_.dptr,
                                        reserve_space_byte_));
+    #endif
   }
 
  private:
@@ -367,8 +566,6 @@ class CuDNNRNNOp : public Operator{
         dimA[0] = param_.batch_size_;
         dimA[1] = param_.input_size_;
         dimA[2] = 1;
-        dimA[0] = param_.batch_size_;
-        dimA[1] = param_.input_size_;
         strideA[0] = dimA[2] * dimA[1];
         strideA[1] = dimA[2];
         strideA[2] = 1;
@@ -391,10 +588,10 @@ class CuDNNRNNOp : public Operator{
         strideA[2] = 1;
 
         CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
-                                             dtype_,
-                                             3,
-                                             dimA,
-                                             strideA));
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
         CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
                                               dtype_,
                                               3,
@@ -413,42 +610,85 @@ class CuDNNRNNOp : public Operator{
       strideA[0] = dimA[2] * dimA[1];
       strideA[1] = dimA[2];
       strideA[2] = 1;
+      #if USE_CUDNN_LSTM_PROJ
+      int dimB[3];
+      int strideB[3];
+      dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimB[1] = param_.batch_size_;
+      dimB[2] = param_.projection_size.has_value() ?
+                param_.projection_size.value() : param_.state_size;
+      strideB[0] = dimB[2] * dimB[1];
+      strideB[1] = dimB[2];
+      strideB[2] = 1;
+      #endif
 
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
                                             dtype_,
                                             3,
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
                                             dtype_,
                                             3,
@@ -470,26 +710,26 @@ class CuDNNRNNOp : public Operator{
                                            seed_));
       // RNN descriptors
       #if CUDNN_MAJOR >= 6
-        cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
-        CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
-                                            rnn_desc_,
-                                            param_.state_size,
-                                            param_.num_layers,
-                                            dropout_desc_,
-                                            input_mode_,
-                                            direction_,
-                                            mode_,
-                                            rnn_algo,
-                                            dtype_));
+      cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
+      CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.state_size,
+                                          param_.num_layers,
+                                          dropout_desc_,
+                                          input_mode_,
+                                          direction_,
+                                          mode_,
+                                          rnn_algo,
+                                          dtype_));
       #else
-        CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
-                                         param_.state_size,
-                                         param_.num_layers,
-                                         dropout_desc_,
-                                         input_mode_,
-                                         direction_,
-                                         mode_,
-                                         dtype_));
+      CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
+                                       param_.state_size,
+                                       param_.num_layers,
+                                       dropout_desc_,
+                                       input_mode_,
+                                       direction_,
+                                       mode_,
+                                       dtype_));
       #endif
       #if CUDNN_MAJOR >= 7
         cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
@@ -498,6 +738,14 @@ class CuDNNRNNOp : public Operator{
         }
         CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
       #endif
+      #if USE_CUDNN_LSTM_PROJ
+      if (param_.projection_size.has_value()) {
+        CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
+                                               rnn_desc_,
+                                               param_.projection_size.value(),
+                                               0));
+      }
+      #endif
       // Get temp space sizes
       CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
                                           rnn_desc_,
@@ -586,6 +834,9 @@ class CuDNNRNNOp : public Operator{
   size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
   int workspace_size_, dropout_size_;
   std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
+  #if USE_CUDNN_LSTM_PROJ
+  cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_;
+  #endif
   cudnnTensorDescriptor_t hx_desc_, cx_desc_;
   cudnnTensorDescriptor_t hy_desc_, cy_desc_;
   cudnnTensorDescriptor_t dhx_desc_, dcx_desc_;
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 9df459e9224..99d27e5bd38 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -86,14 +86,17 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
   TShape dshape;
   index_t size = 0;
-  int num_zero = 0;
+  std::vector<int> zero_indices;
   int axis = -1;
   for (int i = 0; i < param_.num_args; ++i) {
     TShape tmp = (*in_shape)[i];
     if (tmp.ndim()) {
       axis = CheckAxis(param_.dim, tmp.ndim());
-      num_zero += tmp[axis] == 0;
-      size += tmp[axis];
+      if (tmp[axis] == 0) {
+        zero_indices.emplace_back(i);
+      } else {
+        size += tmp[axis];
+      }
       tmp[axis] = 0;
       shape_assign(&dshape, tmp);
     }
@@ -113,18 +116,18 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
         << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
   }
 
-  if (!num_zero) dshape[axis] = size;
+  if (zero_indices.empty()) dshape[axis] = size;
   CHECK(shape_assign(&(*out_shape)[0], dshape))
       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
-  if ((*out_shape)[0][axis] != 0 && num_zero) {
+  if ((*out_shape)[0][axis] != 0 && !zero_indices.empty()) {
     int residual = (*out_shape)[0][axis] - size;
     CHECK_GE(residual, 0)
         << "Input size already exceeds output size. Residual: " << residual;
-    CHECK(num_zero <= 2 && num_zero >= 0)
-        << "Expecting 1 or 2 inputs that need shape inference. Got: " << num_zero;
+    CHECK(zero_indices.size() <= 2 && zero_indices.size() >= 0)
+        << "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size();
     bool need_infer = !(*out_shape)[0].Size();
-    for (int i = 0; i < num_zero; i++) {
-      (*in_shape)[i*2][axis] = residual / num_zero;
+    for (int i : zero_indices) {
+      (*in_shape)[i][axis] = residual / zero_indices.size();
       need_infer = need_infer || !(*in_shape)[i].Size();
     }
     return !need_infer;
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 9211f6a456f..63c20711aae 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -54,7 +54,8 @@ inline int GetRnnParamSize(int num_layer,
                            int input_size,
                            int state_size,
                            int direction,
-                           int mode) {
+                           int mode,
+                           const dmlc::optional<int>& projection_size) {
   int size = state_size * direction;
   switch (mode) {
     case rnn_enum::kRnnRelu:
@@ -69,7 +70,15 @@ inline int GetRnnParamSize(int num_layer,
   }
   int size1 = (input_size + state_size + 2) * size;  // first layer size
   int size2 = (state_size * direction + state_size + 2) * size;  // other layers size
+  if (projection_size.has_value()) {
+    int proj_size = projection_size.value();
+    size1 = (input_size + proj_size + 2) * size;
+    size2 = (proj_size * direction + proj_size + 2) * size;
+  }
   int param_size = size1 + (num_layer - 1) * size2;
+  if (projection_size.has_value()) {
+    param_size += projection_size.value() * state_size * num_layer * direction;
+  }
   return param_size;
 }
 
@@ -154,6 +163,9 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   float p, pkeep_;
   int seq_length_, batch_size_, input_size_;
   bool lstm_q_;  // whether type is lstm
+  dmlc::optional<int> projection_size;
+  dmlc::optional<double> lstm_state_clip_min, lstm_state_clip_max;
+  bool lstm_state_clip_nan;
 
   DMLC_DECLARE_PARAMETER(RNNParam) {
     DMLC_DECLARE_FIELD(state_size)
@@ -174,10 +186,29 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
 
     DMLC_DECLARE_FIELD(p).set_default(0.)
     .set_range(0, 1)
-    .describe("Dropout probability, fraction of the input that gets dropped out at training time");
+    .describe("drop rate of the dropout on the outputs of each RNN layer, except the last layer.");
 
     DMLC_DECLARE_FIELD(state_outputs).set_default(false)
     .describe("Whether to have the states as symbol outputs.");
+
+    DMLC_DECLARE_FIELD(projection_size)
+    .set_default(dmlc::optional<int>())
+    .describe("size of project size");
+
+    DMLC_DECLARE_FIELD(lstm_state_clip_min)
+    .set_default(dmlc::optional<double>())
+    .describe("Minimum clip value of LSTM states. This option must be used together with "
+              "lstm_state_clip_max.");
+
+    DMLC_DECLARE_FIELD(lstm_state_clip_max)
+    .set_default(dmlc::optional<double>())
+    .describe("Maximum clip value of LSTM states. This option must be used together with "
+              "lstm_state_clip_min.");
+
+    DMLC_DECLARE_FIELD(lstm_state_clip_nan)
+    .set_default(false)
+    .describe("Whether to stop NaN from propagating in state by clipping it to min/max. "
+              "If clipping range is not specified, this option is ignored.");
   }
 };
 
@@ -349,8 +380,15 @@ template<typename DType>
 class RNNOp : public Operator{
  public:
   explicit RNNOp(RNNParam p)
-    :param_(p), init_space_(false), reserve_space_size_(0)
-  {}
+    :param_(p), init_space_(false), reserve_space_size_(0) {
+    if (param_.projection_size.has_value()) {
+      LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
+    }
+    if (param_.lstm_state_clip_min.has_value()
+        || param_.lstm_state_clip_max.has_value()) {
+      LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
+    }
+  }
 
   ~RNNOp() {
     if (init_space_) {
@@ -646,26 +684,33 @@ class RNNProp : public OperatorProperty {
     int input_size = dshape[2];
     int numDirections = param_.bidirectional ? 2 : 1;
     int total_layers = numDirections * param_.num_layers;  // double for bidirectional
+    int layer_size = (param_.projection_size.has_value()) ?
+                     param_.projection_size.value() : param_.state_size;
     SHAPE_ASSIGN_CHECK(*in_shape,
                        rnn_enum::kState,
-                       Shape3(total_layers, batch_size, param_.state_size));
+                       Shape3(total_layers, batch_size, layer_size));
     if (param_.mode == rnn_enum::kLstm)
       SHAPE_ASSIGN_CHECK(*in_shape,
-                        rnn_enum::kStateCell,
-                        Shape3(total_layers, batch_size, param_.state_size));
+                         rnn_enum::kStateCell,
+                         Shape3(total_layers, batch_size, param_.state_size));
 
     // calculate parameter vector length
     int param_size = GetRnnParamSize(param_.num_layers,
-                                    input_size,
-                                    param_.state_size,
-                                    numDirections,
-                                    param_.mode);
+                                     input_size,
+                                     param_.state_size,
+                                     numDirections,
+                                     param_.mode,
+                                     param_.projection_size);
     SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
 
     out_shape->clear();
     // output: [sequence len, batch, output size]
     TShape oshape = dshape;
-    oshape[2] = numDirections * param_.state_size;
+    if (param_.projection_size.has_value()) {
+      oshape[2] = numDirections * param_.projection_size.value();
+    } else {
+      oshape[2] = numDirections * param_.state_size;
+    }
     out_shape->push_back(oshape);
     if (!param_.state_outputs) {
       return true;
@@ -674,11 +719,20 @@ class RNNProp : public OperatorProperty {
       TShape outStateShape = dshape;
       outStateShape[0] = total_layers;
       outStateShape[1] = batch_size;
-      outStateShape[2] = param_.state_size;
+      if (param_.projection_size.has_value()) {
+        outStateShape[2] = param_.projection_size.value();
+      } else {
+        outStateShape[2] = param_.state_size;
+      }
       out_shape->push_back(outStateShape);
       // Deal with lstm cell state
-      if (param_.mode == rnn_enum::kLstm)
-        out_shape->push_back(outStateShape);
+      if (param_.mode == rnn_enum::kLstm) {
+        TShape cellStateShape = dshape;
+        cellStateShape[0] = total_layers;
+        cellStateShape[1] = batch_size;
+        cellStateShape[2] = param_.state_size;
+        out_shape->push_back(cellStateShape);
+      }
       return true;
     }
   }
diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu
index 59517932b78..402a8cf5f50 100644
--- a/src/operator/rnn.cu
+++ b/src/operator/rnn.cu
@@ -40,7 +40,7 @@ Operator* CreateOp<gpu>(RNNParam param, int dtype) {
     op = new CuDNNRNNOp<DType>(param);
   })
 #else
-  LOG(FATAL) << "RNN is only available for cuDNN at the moment.";
+  LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
 #endif  // MXNET_USE_CUDNN && CUDNN_MAJOR
   return op;
 }
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 80c28d9b472..e5f3231dde6 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -22,6 +22,7 @@
 import time
 import multiprocessing as mp
 import unittest
+import random
 import mxnet as mx
 import numpy as np
 import unittest
@@ -31,11 +32,12 @@
 from mxnet.base import MXNetError
 from mxnet import autograd
 from numpy.testing import assert_allclose
+from mxnet.test_utils import rand_ndarray
 
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied
 from test_gluon import *
 from test_loss import *
 from test_gluon_rnn import *
@@ -79,7 +81,81 @@ def check_rnn_layer_w_rand_inputs(layer):
 
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_lstmp():
+    hidden_size, projection_size = 3, 2
+    rtol, atol = 1e-2, 1e-2
+    batch_size, seq_len = 7, 11
+    input_size = 5
+    lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0))
+    shapes = {'i2h_weight': (hidden_size*4, input_size),
+              'h2h_weight': (hidden_size*4, projection_size),
+              'i2h_bias': (hidden_size*4,),
+              'h2h_bias': (hidden_size*4,),
+              'h2r_weight': (projection_size, hidden_size)}
+    weights = {k: rand_ndarray(v) for k, v in shapes.items()}
+    lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
+                                input_size=input_size, prefix='lstm0_')
+    lstm_cell = gluon.contrib.rnn.LSTMPCell(hidden_size=hidden_size,
+                                            projection_size=projection_size,
+                                            input_size=input_size,
+                                            prefix='lstm0_l0_')
+    lstm_layer.initialize(ctx=mx.gpu(0))
+    lstm_cell.initialize(ctx=mx.gpu(0))
+    layer_params = lstm_layer.collect_params()
+    cell_params = lstm_cell.collect_params()
+    for k, v in weights.items():
+        layer_params['lstm0_l0_'+k].set_data(v.copy())
+        cell_params['lstm0_l0_'+k].set_data(v.copy())
+    with autograd.record():
+        layer_output = lstm_layer(lstm_input.copy())
+        cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC',
+                                       merge_outputs=True)[0]
+    assert_almost_equal(layer_output.asnumpy(), cell_output.asnumpy(), rtol=rtol, atol=atol)
+    layer_output.backward()
+    cell_output.backward()
+    for k, v in weights.items():
+        layer_grad = layer_params['lstm0_l0_'+k].grad()
+        cell_grad = cell_params['lstm0_l0_'+k].grad()
+        print('checking gradient for {}'.format('lstm0_l0_'+k))
+        assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(),
+                            rtol=rtol, atol=atol)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)))
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))])
+
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)),
+                            run_only=True)
+    check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5),
+                            mx.nd.ones((8, 3, 20)),
+                            [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True)
+
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_lstm_clip():
+    hidden_size, projection_size = 4096, 2048
+    batch_size, seq_len = 32, 80
+    input_size = 50
+    clip_min, clip_max, clip_nan = -5, 5, True
+    lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0))
+    lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), ctx=mx.gpu(0)),
+                   mx.nd.uniform(shape=(2, batch_size, hidden_size), ctx=mx.gpu(0))]
+    lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size,
+                                input_size=input_size, prefix='lstm0_',
+                                bidirectional=True,
+                                state_clip_min=clip_min,
+                                state_clip_max=clip_max,
+                                state_clip_nan=clip_nan)
+    lstm_layer.initialize(ctx=mx.gpu(0))
+    with autograd.record():
+        _, layer_output_states = lstm_layer(lstm_input, lstm_states)
+    cell_states = layer_output_states[0].asnumpy()
+    assert (cell_states >= clip_min).all() and (cell_states <= clip_max).all()
+    assert not np.isnan(cell_states).any()
+
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnn_layer():
     check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
     check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3))
@@ -90,7 +166,65 @@ def test_rnn_layer():
     check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
 
 
+def check_layer_bidirectional(size, in_size, proj_size):
+    class RefBiLSTM(gluon.Block):
+        def __init__(self, size, proj_size, **kwargs):
+            super(RefBiLSTM, self).__init__(**kwargs)
+            with self.name_scope():
+                self._lstm_fwd = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=False, prefix='l0')
+                self._lstm_bwd = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=False, prefix='r0')
+
+        def forward(self, inpt):
+            fwd = self._lstm_fwd(inpt)
+            bwd_inpt = nd.flip(inpt, 0)
+            bwd = self._lstm_bwd(bwd_inpt)
+            bwd = nd.flip(bwd, 0)
+            return nd.concat(fwd, bwd, dim=2)
+    weights = {}
+    for d in ['l', 'r']:
+        weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
+        if proj_size:
+            weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, proj_size))
+            weights['lstm_{}0_h2r_weight'.format(d)] = mx.random.uniform(shape=(proj_size, size))
+        else:
+            weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size))
+        weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
+        weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
+
+    net = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=True, prefix='lstm_')
+    ref_net = RefBiLSTM(size, proj_size, prefix='lstm_')
+    net.initialize()
+    ref_net.initialize()
+    net_params = net.collect_params()
+    ref_net_params = ref_net.collect_params()
+    for k in weights:
+        net_params[k].set_data(weights[k])
+        ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])
+
+    data = mx.random.uniform(shape=(11, 10, in_size))
+    assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())
+
 @with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_layer_bidirectional():
+    check_layer_bidirectional(7, 5, 0)
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='7.2.1')
+def test_layer_bidirectional_proj():
+    check_layer_bidirectional(7, 5, 3)
+
+
+@with_seed()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
+def test_rnn_layer_begin_state_type():
+    fake_data = nd.random.uniform(shape=(3, 5, 7), dtype='float16')
+    modeling_layer = gluon.rnn.LSTM(hidden_size=11, num_layers=2, dropout=0.2, bidirectional=True)
+    modeling_layer.cast('float16')
+    modeling_layer.initialize()
+    modeling_layer(fake_data)
+
+
 def test_gluon_ctc_consistency():
     loss = mx.gluon.loss.CTCLoss()
     data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0)
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 5612b0a647e..3d6654d017b 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -32,7 +32,7 @@
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled
+from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied
 from test_operator import *
 from test_optimizer import *
 from test_random import *
@@ -410,7 +410,7 @@ def test_3d_batchnorm(fix_gamma, use_global_stats):
 
 
 @with_seed(1234)
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_convolution_with_type():
     sym1 = mx.sym.Convolution(num_filter=3, kernel=(3,3), name='conv')
 
@@ -1360,7 +1360,7 @@ def check_rnn_consistency(cell1, cell2):
     assert_allclose(mod1.get_outputs()[0].asnumpy(), mod2.get_outputs()[0].asnumpy(), rtol=1e-2, atol=1e-4)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnn():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='rnn_relu', prefix='')
 
@@ -1372,7 +1372,7 @@ def test_rnn():
     check_rnn_consistency(stack, fused)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_forget_bias():
     forget_bias = 2.0
     fused = mx.rnn.FusedRNNCell(10, forget_bias=forget_bias, num_layers=2, mode='lstm', prefix='')
@@ -1394,7 +1394,7 @@ def test_lstm_forget_bias():
     assert_allclose(args[bias_name].asnumpy(), expected_bias)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_gru():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='')
 
@@ -1406,7 +1406,7 @@ def test_gru():
     check_rnn_consistency(stack, fused)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_bidirectional():
     fused = mx.rnn.FusedRNNCell(100, num_layers=2, mode='gru', prefix='',
             bidirectional=True)
@@ -1425,7 +1425,7 @@ def test_bidirectional():
     check_rnn_consistency(stack, fused)
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_unfuse():
     for mode in ['rnn_tanh', 'rnn_relu', 'lstm', 'gru']:
         fused = mx.rnn.FusedRNNCell(
@@ -1601,7 +1601,7 @@ def test_deformable_convolution_options():
                                                name='deformable_conv')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_residual_fused():
     cell = mx.rnn.ResidualCell(
             mx.rnn.FusedRNNCell(50, num_layers=3, mode='lstm',
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index f98bb79dfab..abfba73ab72 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -95,16 +95,17 @@ def random_seed(seed=None):
         random.seed(next_seed)
 
 
-def assert_raises_cudnn_disabled():
+def assert_raises_cudnn_not_satisfied(min_version):
     def test_helper(orig_test):
         @make_decorator(orig_test)
         def test_new(*args, **kwargs):
-            cudnn_disabled = (os.getenv('CUDNN_OFF_TEST_ONLY') == "true")
-            if not cudnn_disabled or mx.context.current_context().device_type == 'cpu':
+            cudnn_off = os.getenv('CUDNN_OFF_TEST_ONLY') == 'true'
+            cudnn_env_version = os.getenv('CUDNN_VERSION', None if cudnn_off else '7.3.1')
+            cudnn_test_disabled = cudnn_off or cudnn_env_version < min_version
+            if not cudnn_test_disabled or mx.context.current_context().device_type == 'cpu':
                 orig_test(*args, **kwargs)
             else:
-                errors = (MXNetError, RuntimeError)
-                assert_raises(errors, orig_test, *args, **kwargs)
+                assert_raises((MXNetError, RuntimeError), orig_test, *args, **kwargs)
         return test_new
     return test_helper
 
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 4e13fc38e87..d92e95dfc19 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -23,7 +23,8 @@
 from mxnet.gluon import nn
 from mxnet.test_utils import assert_almost_equal
 from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
-from common import setup_module, with_seed, assertRaises, teardown, assert_raises_cudnn_disabled
+from common import (setup_module, with_seed, assertRaises, teardown,
+                    assert_raises_cudnn_not_satisfied)
 import numpy as np
 from numpy.testing import assert_array_equal
 from nose.tools import raises, assert_raises
@@ -339,7 +340,7 @@ def hybrid_forward(self, F, x):
     net.hybridize()
     assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
 
-    # Test case to verify if initializing the SymbolBlock from a model with params 
+    # Test case to verify if initializing the SymbolBlock from a model with params
     # other than fp32 param dtype.
 
     # 1. Load a resnet model, cast it to fp64 and export
@@ -1321,7 +1322,7 @@ def record_name(block):
 
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_summary():
     net = gluon.model_zoo.vision.resnet50_v1()
     net.initialize()
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index 4e8241ffc1e..f7c5edd59e1 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -22,7 +22,7 @@
 from numpy.testing import assert_allclose
 import unittest
 from mxnet.test_utils import almost_equal, assert_almost_equal
-from common import assert_raises_cudnn_disabled
+from common import assert_raises_cudnn_not_satisfied
 
 
 def test_rnn():
@@ -71,7 +71,7 @@ def test_lstm_forget_bias():
     assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias)
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_cpu_inference():
     # should behave the same as lstm cell
     EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
@@ -243,7 +243,7 @@ def test_bidirectional():
     assert outs == [(10, 200), (10, 200), (10, 200)]
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_layer_bidirectional():
     class RefBiLSTM(gluon.Block):
         def __init__(self, size, **kwargs):
@@ -278,7 +278,7 @@ def forward(self, inpt):
         net_params[k].set_data(weights[k])
         ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])
 
-    data = mx.random.uniform(shape=(3, 10, in_size))
+    data = mx.random.uniform(shape=(11, 10, in_size))
     assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())
 
 
@@ -419,7 +419,7 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
         mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
     check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
@@ -442,12 +442,11 @@ def test_rnn_layers():
     check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5),
                             mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)
 
-    net = gluon.nn.HybridSequential()
+    net = gluon.nn.Sequential()
     net.add(gluon.rnn.LSTM(10, bidirectional=True))
     net.add(gluon.nn.BatchNorm(axis=2))
     net.add(gluon.nn.Flatten())
     net.add(gluon.nn.Dense(3, activation='relu'))
-    net.hybridize()
     net.collect_params().initialize()
     with mx.autograd.record():
         net(mx.nd.ones((2, 3, 10))).backward()
@@ -544,7 +543,7 @@ def test_cell_fill_shape():
     assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]
 
 
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_layer_fill_shape():
     layer = gluon.rnn.LSTM(10)
     check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 5e5e956691f..d2fba23b1c4 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -27,7 +27,7 @@
 from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
 from mxnet.base import py_str, MXNetError, _as_list
-from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled, assertRaises
+from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied, assertRaises
 import unittest
 
 def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e-4):
@@ -72,7 +72,7 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e
 
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_sym():
     T, N, I, H = 5, 32, 800, 800
     fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='')
@@ -86,7 +86,7 @@ def test_lstm_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_lstm_bidirectional():
     T, N, I, H = 5, 20, 800, 800
     fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm',
@@ -107,7 +107,7 @@ def test_lstm_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_gru_sym():
     T, N, I, H = 5, 32, 800, 800
     fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='')
@@ -121,7 +121,7 @@ def test_gru_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_gru_bidirectional():
     T, N, I, H = 5, 20, 800, 800
 
@@ -144,7 +144,7 @@ def test_gru_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnntanh_sym():
     T, N, I, H = 5, 32, 800, 800
 
@@ -159,7 +159,7 @@ def test_rnntanh_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnntanh_bidirectional():
     T, N, I, H = 5, 20, 800, 800
 
@@ -181,7 +181,7 @@ def test_rnntanh_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnnrelu_sym():
     T, N, I, H = 5, 32, 200, 200
 
@@ -196,7 +196,7 @@ def test_rnnrelu_sym():
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
-@assert_raises_cudnn_disabled()
+@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
 def test_rnnrelu_bidirectional():
     T, N, I, H = 5, 20, 200, 200
 
@@ -4753,6 +4753,29 @@ def test_quantization_op():
     assert same(qa.asnumpy(), qa_real.asnumpy())
     assert same(a_.asnumpy(),  a_real.asnumpy())
 
+@with_seed()
+def test_index_copy():
+    x = mx.nd.zeros((5,3))
+    t = mx.nd.array([[1,2,3],[4,5,6],[7,8,9]])
+    index = mx.nd.array([0,4,2], dtype=np.int64)
+
+    x.attach_grad()
+    t.attach_grad()
+    index.attach_grad()
+
+    with mx.autograd.record():
+        out = mx.nd.contrib.index_copy(x, index, t)
+    out.backward()
+
+    tensor = mx.nd.array([[1,2,3],[0,0,0],[7,8,9],[0,0,0],[4,5,6]])
+    x_grad = mx.nd.array([[0,0,0],[1,1,1],[0,0,0],[1,1,1],[0,0,0]])
+    t_grad = mx.nd.array([[1,1,1],[1,1,1],[1,1,1]])
+    index_grad = mx.nd.array([0,0,0])
+
+    assert same(out.asnumpy(), tensor.asnumpy())
+    assert same(x.grad.asnumpy(), x_grad.asnumpy())
+    assert same(t.grad.asnumpy(), t_grad.asnumpy())
+    assert same(index.grad.asnumpy(), index_grad.asnumpy())
 
 @with_seed()
 def test_div_sqrt_dim():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services