You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/03 22:07:32 UTC

[GitHub] haojin2 closed pull request #12541: [MXNET-936] [WIP] [DO NOT REVIEW] Support projection and clip in CuDNN LSTM

haojin2 closed pull request #12541: [MXNET-936] [WIP] [DO NOT REVIEW] Support projection and clip in CuDNN LSTM
URL: https://github.com/apache/incubator-mxnet/pull/12541
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index daf8ecbf563..f53c327f3cf 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -35,11 +35,13 @@ def __init__(self, hidden_size, num_layers, layout,
                  dropout, bidirectional, input_size,
                  i2h_weight_initializer, h2h_weight_initializer,
                  i2h_bias_initializer, h2h_bias_initializer,
-                 mode, **kwargs):
+                 mode, projection_size=None, h2p_weight_initializer=None,
+                 **kwargs):
         super(_RNNLayer, self).__init__(**kwargs)
         assert layout in ('TNC', 'NTC'), \
             "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
         self._hidden_size = hidden_size
+        self._projection_size = projection_size
         self._num_layers = num_layers
         self._mode = mode
         self._layout = layout
@@ -54,21 +56,42 @@ def __init__(self, hidden_size, num_layers, layout,
         self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
 
         ng, ni, nh = self._gates, input_size, hidden_size
-        for i in range(num_layers):
-            for j in ['l', 'r'][:self._dir]:
-                self._register_param('{}{}_i2h_weight'.format(j, i),
-                                     shape=(ng*nh, ni),
-                                     init=i2h_weight_initializer)
-                self._register_param('{}{}_h2h_weight'.format(j, i),
-                                     shape=(ng*nh, nh),
-                                     init=h2h_weight_initializer)
-                self._register_param('{}{}_i2h_bias'.format(j, i),
-                                     shape=(ng*nh,),
-                                     init=i2h_bias_initializer)
-                self._register_param('{}{}_h2h_bias'.format(j, i),
-                                     shape=(ng*nh,),
-                                     init=h2h_bias_initializer)
-            ni = nh * self._dir
+        if projection_size is None:
+            for i in range(num_layers):
+                for j in ['l', 'r'][:self._dir]:
+                    self._register_param('{}{}_i2h_weight'.format(j, i),
+                                         shape=(ng*nh, ni),
+                                         init=i2h_weight_initializer)
+                    self._register_param('{}{}_h2h_weight'.format(j, i),
+                                         shape=(ng*nh, nh),
+                                         init=h2h_weight_initializer)
+                    self._register_param('{}{}_i2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=i2h_bias_initializer)
+                    self._register_param('{}{}_h2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=h2h_bias_initializer)
+                ni = nh * self._dir
+        else:
+            np = self._projection_size
+            for i in range(num_layers):
+                for j in ['l', 'r'][:self._dir]:
+                    self._register_param('{}{}_i2h_weight'.format(j, i),
+                                         shape=(ng*nh, ni),
+                                         init=i2h_weight_initializer)
+                    self._register_param('{}{}_h2h_weight'.format(j, i),
+                                         shape=(ng*nh, np),
+                                         init=h2h_weight_initializer)
+                    self._register_param('{}{}_i2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=i2h_bias_initializer)
+                    self._register_param('{}{}_h2h_bias'.format(j, i),
+                                         shape=(ng*nh,),
+                                         init=h2h_bias_initializer)
+                    self._register_param('{}{}_h2p_weight'.format(j, i),
+                                         shape=(np, nh),
+                                         init=h2p_weight_initializer)
+                ni = np * self._dir
 
     def _register_param(self, name, shape, init):
         p = self.params.get(name, shape=shape, init=init,
@@ -125,6 +148,7 @@ def _unfuse(self):
                     'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
                                                              **kwargs)}[self._mode]
 
+        assert self._projection_size is not None, "_unfuse does not support projection layer yet!"
         stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params)
         with stack.name_scope():
             ni = self._input_size
@@ -209,14 +233,24 @@ def _forward_kernel(self, F, inputs, states, **kwargs):
         """ forward using CUDNN or CPU kenrel"""
         if self._layout == 'NTC':
             inputs = F.swapaxes(inputs, dim1=0, dim2=1)
-        params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
-                  for t in ['weight', 'bias']
-                  for l in range(self._num_layers)
-                  for d in ['l', 'r'][:self._dir]
-                  for g in ['i2h', 'h2h'])
+        if self._projection_size is None:
+            params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                      for t in ['weight', 'bias']
+                      for l in range(self._num_layers)
+                      for d in ['l', 'r'][:self._dir]
+                      for g in ['i2h', 'h2h'])
+        else:
+            params = [kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
+                      for t in ['weight', 'bias']
+                      for l in range(self._num_layers)
+                      for d in ['l', 'r'][:self._dir]
+                      for g in ['i2h', 'h2h', 'h2p']
+                      if g != 'h2p' or t != 'bias']
+
         params = F._internal._rnn_param_concat(*params, dim=0)
 
         rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
+                    projection_size=self._projection_size,
                     num_layers=self._num_layers, bidirectional=self._dir == 2,
                     p=self._dropout, state_outputs=True, mode=self._mode)
 
@@ -373,6 +407,8 @@ class LSTM(_RNNLayer):
         to zero.
     h2h_bias_initializer : str or Initializer
         Initializer for the bias vector.
+    projection_size: int, default None
+        The number of features after projection.
     input_size: int, default 0
         The number of expected features in the input x.
         If not specified, it will be inferred from input.
@@ -416,18 +452,24 @@ def __init__(self, hidden_size, num_layers=1, layout='TNC',
                  dropout=0, bidirectional=False, input_size=0,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
-                 **kwargs):
+                 projection_size=None, h2p_weight_initializer=None, **kwargs):
         super(LSTM, self).__init__(hidden_size, num_layers, layout,
                                    dropout, bidirectional, input_size,
                                    i2h_weight_initializer, h2h_weight_initializer,
                                    i2h_bias_initializer, h2h_bias_initializer,
-                                   'lstm', **kwargs)
+                                   'lstm', projection_size, h2p_weight_initializer, **kwargs)
 
     def state_info(self, batch_size=0):
-        return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
-                 '__layout__': 'LNC'},
-                {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
-                 '__layout__': 'LNC'}]
+        if self._projection_size is None:
+            return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
+                     '__layout__': 'LNC'},
+                    {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
+                     '__layout__': 'LNC'}]
+        else:
+            return [{'shape': (self._num_layers * self._dir, batch_size, self._projection_size),
+                     '__layout__': 'LNC'},
+                    {'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
+                     '__layout__': 'LNC'}]
 
 
 class GRU(_RNNLayer):
diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
index b33a717d15b..fc61bce5f1c 100644
--- a/src/operator/cudnn_rnn-inl.h
+++ b/src/operator/cudnn_rnn-inl.h
@@ -38,7 +38,7 @@ namespace mxnet {
 namespace op {
 #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
 template<typename DType>
-class CuDNNRNNOp : public Operator{
+class CuDNNRNNOp : public Operator {
  public:
   explicit CuDNNRNNOp(RNNParam param) {
     this->param_ = param;
@@ -69,6 +69,12 @@ class CuDNNRNNOp : public Operator{
       default:
         LOG(FATAL) << "Not implmented";
     }
+#if MXNET_USE_CUDNN == 1 && ((CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7)
+    CHECK((param_.mode == rnn_enum::kLstm) || !param_.projection_size.has_value())
+#else
+    CHECK(!param_.projection_size.has_value())
+#endif
+      << "Projection is only supported for LSTM with CuDNN version later than 7.1.1";
     // RNN Direction
     direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
     // Other
@@ -123,6 +129,12 @@ class CuDNNRNNOp : public Operator{
         Storage::Get()->Free(dropout_states_);
       }
     }
+    #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+    // CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
+    // CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
+    // CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
+    // CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
+    #endif
   }
 
   virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
@@ -169,49 +181,165 @@ class CuDNNRNNOp : public Operator{
     Tensor<gpu, 1, DType> temp_space =
       ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
                               mshadow::Shape1(temp_size), s);
+    #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+    if (param_.projection_size.has_value()) {
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+                                           dtype_,
+                                           CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      int out_size =
+        (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size;
+      out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+                                           dtype_,
+                                           CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           out_size,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      if (ctx.is_train) {
+        CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+                                             dtype_,
+                                             CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
+                                             param_.seq_length_,
+                                             param_.batch_size_,
+                                             param_.input_size_,
+                                             seqLengthArray.data(),
+                                             nullptr));
+        CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+                                             dtype_,
+                                             CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED,
+                                             param_.seq_length_,
+                                             param_.batch_size_,
+                                             out_size,
+                                             seqLengthArray.data(),
+                                             nullptr));
+      }
+    }
+    #endif
     if (ctx.is_train) {
-      CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
-                                         rnn_desc_,
-                                         param_.seq_length_,
-                                         x_desc_vec_.data(),
-                                         x.dptr_,
-                                         hx_desc_,
-                                         hx.dptr_,
-                                         cx_desc_,
-                                         cx_ptr,
-                                         w_desc_,
-                                         w.dptr_,
-                                         y_desc_vec_.data(),
-                                         y.dptr_,
-                                         hy_desc_,
-                                         hy_ptr,
-                                         cy_desc_,
-                                         cy_ptr,
-                                         temp_space.dptr_,
-                                         workspace_byte_,
-                                         reserve_space_.dptr,
-                                         reserve_space_byte_));
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      if (!param_.projection_size.has_value())
+      #endif
+      {
+        CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+                                           rnn_desc_,
+                                           param_.seq_length_,
+                                           x_desc_vec_.data(),
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           cx_desc_,
+                                           cx_ptr,
+                                           w_desc_,
+                                           w.dptr_,
+                                           y_desc_vec_.data(),
+                                           y.dptr_,
+                                           hy_desc_,
+                                           hy_ptr,
+                                           cy_desc_,
+                                           cy_ptr,
+                                           temp_space.dptr_,
+                                           workspace_byte_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      } else {
+        CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+                                             rnn_desc_,
+                                             x_data_desc_,
+                                             x.dptr_,
+                                             hx_desc_,
+                                             hx.dptr_,
+                                             cx_desc_,
+                                             cx_ptr,
+                                             w_desc_,
+                                             w.dptr_,
+                                             y_data_desc_,
+                                             y.dptr_,
+                                             hy_desc_,
+                                             hy_ptr,
+                                             cy_desc_,
+                                             cy_ptr,
+                                             nullptr,
+                                             nullptr,
+                                             nullptr,
+                                             nullptr,
+                                             nullptr,
+                                             nullptr,
+                                             nullptr,
+                                             nullptr,
+                                             temp_space.dptr_,
+                                             workspace_byte_,
+                                             reserve_space_.dptr,
+                                             reserve_space_byte_));
+      }
+      #else
+      }
+      #endif
     } else {
-      // inference mode
-      CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
-                                          rnn_desc_,
-                                          param_.seq_length_,
-                                          x_desc_vec_.data(),
-                                          x.dptr_,
-                                          hx_desc_,
-                                          hx.dptr_,
-                                          cx_desc_,
-                                          cx_ptr,
-                                          w_desc_,
-                                          w.dptr_,
-                                          y_desc_vec_.data(),
-                                          y.dptr_,
-                                          hy_desc_,
-                                          hy_ptr,
-                                          cy_desc_,
-                                          cy_ptr,
-                                          temp_space.dptr_,
-                                          workspace_byte_));
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      if (!param_.projection_size.has_value())
+      #endif
+      {
+        // inference mode
+        CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+                                            rnn_desc_,
+                                            param_.seq_length_,
+                                            x_desc_vec_.data(),
+                                            x.dptr_,
+                                            hx_desc_,
+                                            hx.dptr_,
+                                            cx_desc_,
+                                            cx_ptr,
+                                            w_desc_,
+                                            w.dptr_,
+                                            y_desc_vec_.data(),
+                                            y.dptr_,
+                                            hy_desc_,
+                                            hy_ptr,
+                                            cy_desc_,
+                                            cy_ptr,
+                                            temp_space.dptr_,
+                                            workspace_byte_));
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      } else {
+        CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+                                              rnn_desc_,
+                                              x_data_desc_,
+                                              x.dptr_,
+                                              hx_desc_,
+                                              hx.dptr_,
+                                              cx_desc_,
+                                              cx_ptr,
+                                              w_desc_,
+                                              w.dptr_,
+                                              y_data_desc_,
+                                              y.dptr_,
+                                              hy_desc_,
+                                              hy_ptr,
+                                              cy_desc_,
+                                              cy_ptr,
+                                              nullptr,
+                                              nullptr,
+                                              nullptr,
+                                              nullptr,
+                                              nullptr,
+                                              nullptr,
+                                              nullptr,
+                                              nullptr,
+                                              temp_space.dptr_,
+                                              workspace_byte_));
+      }
+      #else
+      }
+      #endif
     }
   }
 
@@ -283,48 +411,103 @@ class CuDNNRNNOp : public Operator{
     Tensor<gpu, 1, DType> temp_space =
       ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
                               mshadow::Shape1(temp_size), s);
-    CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
-                                    rnn_desc_,
-                                    param_.seq_length_,
-                                    y_desc_vec_.data(),
-                                    y.dptr_,
-                                    dy_desc_vec_.data(),
-                                    dy.dptr_,
-                                    dhy_desc_,
-                                    dhy_ptr,
-                                    dcy_desc_,
-                                    dcy_ptr,
-                                    w_desc_,
-                                    w.dptr_,
-                                    hx_desc_,
-                                    hx.dptr_,
-                                    cx_desc_,
-                                    cx_ptr,
-                                    dx_desc_vec_.data(),
-                                    dx.dptr_,
-                                    dhx_desc_,
-                                    dhx.dptr_,
-                                    dcx_desc_,
-                                    dcx_ptr,
-                                    temp_space.dptr_,
-                                    workspace_byte_,
-                                    reserve_space_.dptr,
-                                    reserve_space_byte_));
-    CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
-                                       rnn_desc_,
-                                       param_.seq_length_,
-                                       x_desc_vec_.data(),
-                                       x.dptr_,
-                                       hx_desc_,
-                                       hx.dptr_,
-                                       y_desc_vec_.data(),
-                                       y.dptr_,
-                                       temp_space.dptr_,
-                                       workspace_byte_,
-                                       dw_desc_,
-                                       dw.dptr_,
-                                       reserve_space_.dptr,
-                                       reserve_space_byte_));
+    #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+    if (!param_.projection_size.has_value()) {
+    #else
+    {
+    #endif
+      CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
+                                      rnn_desc_,
+                                      param_.seq_length_,
+                                      y_desc_vec_.data(),
+                                      y.dptr_,
+                                      dy_desc_vec_.data(),
+                                      dy.dptr_,
+                                      dhy_desc_,
+                                      dhy_ptr,
+                                      dcy_desc_,
+                                      dcy_ptr,
+                                      w_desc_,
+                                      w.dptr_,
+                                      hx_desc_,
+                                      hx.dptr_,
+                                      cx_desc_,
+                                      cx_ptr,
+                                      dx_desc_vec_.data(),
+                                      dx.dptr_,
+                                      dhx_desc_,
+                                      dhx.dptr_,
+                                      dcx_desc_,
+                                      dcx_ptr,
+                                      temp_space.dptr_,
+                                      workspace_byte_,
+                                      reserve_space_.dptr,
+                                      reserve_space_byte_));
+      CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
+                                         rnn_desc_,
+                                         param_.seq_length_,
+                                         x_desc_vec_.data(),
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         y_desc_vec_.data(),
+                                         y.dptr_,
+                                         temp_space.dptr_,
+                                         workspace_byte_,
+                                         dw_desc_,
+                                         dw.dptr_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+    #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+    } else {
+      CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+                                        rnn_desc_,
+                                        y_data_desc_,
+                                        y.dptr_,
+                                        dy_data_desc_,
+                                        dy.dptr_,
+                                        nullptr,
+                                        nullptr,
+                                        dhy_desc_,
+                                        dhy_ptr,
+                                        dcy_desc_,
+                                        dcy_ptr,
+                                        w_desc_,
+                                        w.dptr_,
+                                        hx_desc_,
+                                        hx.dptr_,
+                                        cx_desc_,
+                                        cx_ptr,
+                                        dx_data_desc_,
+                                        dx.dptr_,
+                                        dhx_desc_,
+                                        dhx.dptr_,
+                                        dcx_desc_,
+                                        dcx_ptr,
+                                        nullptr,
+                                        nullptr,
+                                        temp_space.dptr_,
+                                        workspace_byte_,
+                                        reserve_space_.dptr,
+                                        reserve_space_byte_));
+      CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+                                           rnn_desc_,
+                                           x_data_desc_,
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           y_data_desc_,
+                                           y.dptr_,
+                                           temp_space.dptr_,
+                                           workspace_byte_,
+                                           dw_desc_,
+                                           dw.dptr_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+    }
+    #else
+    }
+    #endif
   }
 
  private:
@@ -405,6 +588,12 @@ class CuDNNRNNOp : public Operator{
       y_desc_vec_ = y_vec;
       dx_desc_vec_ = dx_vec;
       dy_desc_vec_ = dy_vec;
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
+      CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
+      CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
+      CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
+      #endif
 
       // set the state tensors
       dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
@@ -413,12 +602,28 @@ class CuDNNRNNOp : public Operator{
       strideA[0] = dimA[2] * dimA[1];
       strideA[1] = dimA[2];
       strideA[2] = 1;
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      int dimB[3];
+      int strideB[3];
+      dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimB[1] = param_.batch_size_;
+      dimB[2] = param_.projection_size.has_value() ?
+                param_.projection_size.value() : param_.state_size;
+      strideB[0] = dimB[2] * dimB[1];
+      strideB[1] = dimB[2];
+      strideB[2] = 1;
+      #endif
 
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
                                             dtype_,
                                             3,
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+                                            dimB,
+                                            strideB));
+      #else
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
                                             dtype_,
                                             3,
@@ -427,8 +632,13 @@ class CuDNNRNNOp : public Operator{
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
                                             dtype_,
                                             3,
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+                                            dimB,
+                                            strideB));
+      #else
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
                                             dtype_,
                                             3,
@@ -437,8 +647,13 @@ class CuDNNRNNOp : public Operator{
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
                                             dtype_,
                                             3,
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+                                            dimB,
+                                            strideB));
+      #else
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
                                             dtype_,
                                             3,
@@ -447,8 +662,13 @@ class CuDNNRNNOp : public Operator{
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
                                             dtype_,
                                             3,
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+                                            dimB,
+                                            strideB));
+      #else
                                             dimA,
                                             strideA));
+      #endif
       CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
                                             dtype_,
                                             3,
@@ -498,6 +718,14 @@ class CuDNNRNNOp : public Operator{
         }
         CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
       #endif
+      #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+      if (param_.projection_size.has_value()) {
+        CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
+                                               rnn_desc_,
+                                               param_.projection_size.value(),
+                                               0));
+      }
+      #endif
       // Get temp space sizes
       CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
                                           rnn_desc_,
@@ -586,6 +814,9 @@ class CuDNNRNNOp : public Operator{
   size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
   int workspace_size_, dropout_size_;
   std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_;
+  #if (CUDNN_MAJOR == 7 && CUDNN_MINOR >= 2) || CUDNN_MAJOR > 7
+  cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_;
+  #endif
   cudnnTensorDescriptor_t hx_desc_, cx_desc_;
   cudnnTensorDescriptor_t hy_desc_, cy_desc_;
   cudnnTensorDescriptor_t dhx_desc_, dcx_desc_;
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 9211f6a456f..247ff68a2ee 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -54,7 +54,8 @@ inline int GetRnnParamSize(int num_layer,
                            int input_size,
                            int state_size,
                            int direction,
-                           int mode) {
+                           int mode,
+                           const dmlc::optional<int>& projection_size) {
   int size = state_size * direction;
   switch (mode) {
     case rnn_enum::kRnnRelu:
@@ -69,7 +70,15 @@ inline int GetRnnParamSize(int num_layer,
   }
   int size1 = (input_size + state_size + 2) * size;  // first layer size
   int size2 = (state_size * direction + state_size + 2) * size;  // other layers size
+  if (projection_size.has_value()) {
+    int proj_size = projection_size.value();
+    size1 = (input_size + proj_size + 2) * size;
+    size2 = (proj_size * direction + proj_size + 2) * size;
+  }
   int param_size = size1 + (num_layer - 1) * size2;
+  if (projection_size.has_value()) {
+    param_size += projection_size.value() * state_size * num_layer * direction;
+  }
   return param_size;
 }
 
@@ -154,6 +163,7 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   float p, pkeep_;
   int seq_length_, batch_size_, input_size_;
   bool lstm_q_;  // whether type is lstm
+  dmlc::optional<int> projection_size;
 
   DMLC_DECLARE_PARAMETER(RNNParam) {
     DMLC_DECLARE_FIELD(state_size)
@@ -178,6 +188,10 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
 
     DMLC_DECLARE_FIELD(state_outputs).set_default(false)
     .describe("Whether to have the states as symbol outputs.");
+
+    DMLC_DECLARE_FIELD(projection_size)
+    .set_default(dmlc::optional<int>())
+    .describe("size of project size");
   }
 };
 
@@ -349,8 +363,11 @@ template<typename DType>
 class RNNOp : public Operator{
  public:
   explicit RNNOp(RNNParam p)
-    :param_(p), init_space_(false), reserve_space_size_(0)
-  {}
+    :param_(p), init_space_(false), reserve_space_size_(0) {
+    if (param_.projection_size.has_value()) {
+      LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
+    }
+  }
 
   ~RNNOp() {
     if (init_space_) {
@@ -646,9 +663,11 @@ class RNNProp : public OperatorProperty {
     int input_size = dshape[2];
     int numDirections = param_.bidirectional ? 2 : 1;
     int total_layers = numDirections * param_.num_layers;  // double for bidirectional
+    int layer_size = (param_.projection_size.has_value()) ?
+                     param_.projection_size.value() : param_.state_size;
     SHAPE_ASSIGN_CHECK(*in_shape,
                        rnn_enum::kState,
-                       Shape3(total_layers, batch_size, param_.state_size));
+                       Shape3(total_layers, batch_size, layer_size));
     if (param_.mode == rnn_enum::kLstm)
       SHAPE_ASSIGN_CHECK(*in_shape,
                         rnn_enum::kStateCell,
@@ -659,13 +678,18 @@ class RNNProp : public OperatorProperty {
                                     input_size,
                                     param_.state_size,
                                     numDirections,
-                                    param_.mode);
+                                    param_.mode,
+                                    param_.projection_size);
     SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
 
     out_shape->clear();
     // output: [sequence len, batch, output size]
     TShape oshape = dshape;
-    oshape[2] = numDirections * param_.state_size;
+    if (param_.projection_size.has_value()) {
+      oshape[2] = numDirections * param_.projection_size.value();
+    } else {
+      oshape[2] = numDirections * param_.state_size;
+    }
     out_shape->push_back(oshape);
     if (!param_.state_outputs) {
       return true;
@@ -674,7 +698,11 @@ class RNNProp : public OperatorProperty {
       TShape outStateShape = dshape;
       outStateShape[0] = total_layers;
       outStateShape[1] = batch_size;
-      outStateShape[2] = param_.state_size;
+      if (param_.projection_size.has_value()) {
+        outStateShape[2] = param_.projection_size.value();
+      } else {
+        outStateShape[2] = param_.state_size;
+      }
       out_shape->push_back(outStateShape);
       // Deal with lstm cell state
       if (param_.mode == rnn_enum::kLstm)
diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu
index 59517932b78..402a8cf5f50 100644
--- a/src/operator/rnn.cu
+++ b/src/operator/rnn.cu
@@ -40,7 +40,7 @@ Operator* CreateOp<gpu>(RNNParam param, int dtype) {
     op = new CuDNNRNNOp<DType>(param);
   })
 #else
-  LOG(FATAL) << "RNN is only available for cuDNN at the moment.";
+  LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
 #endif  // MXNET_USE_CUDNN && CUDNN_MAJOR
   return op;
 }
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 8394276c8ef..9431d07788b 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -22,6 +22,7 @@
 import time
 import multiprocessing as mp
 import unittest
+import random
 import mxnet as mx
 import numpy as np
 import unittest
@@ -30,6 +31,7 @@
 from mxnet.base import MXNetError
 from mxnet import autograd
 from numpy.testing import assert_allclose
+from mxnet.test_utils import rand_ndarray
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
@@ -76,16 +78,63 @@ def check_rnn_layer_w_rand_inputs(layer):
         assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
 
 
+def check_lstmp(hidden_size, projection_size, rtol=1e-3, atol=1e-6):
+    batch_size = 1
+    seq_len = 1
+    input_size = random.randint(100, 300)
+    in_data = {'rnn_t%d_data' % i: rand_ndarray((batch_size, input_size), stype='default') for i in range(seq_len)}
+    i2h_weight = rand_ndarray((hidden_size*4, input_size), stype='default').as_in_context(mx.gpu(0))
+    i2h_bias = rand_ndarray((hidden_size*4, ), stype='default').as_in_context(mx.gpu(0))
+    h2h_weight = rand_ndarray((hidden_size*4, projection_size), stype='default').as_in_context(mx.gpu(0))
+    h2h_bias = rand_ndarray((hidden_size*4, ), stype='default').as_in_context(mx.gpu(0))
+    h2p_weight = rand_ndarray((projection_size, hidden_size), stype='default').as_in_context(mx.gpu(0))
+    lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, input_size=input_size, prefix='lstm0_')
+    lstm_layer.collect_params().initialize(ctx=mx.gpu(0))
+    cudnn_input = mx.nd.stack(*(in_data.values()))
+    _ = lstm_layer(cudnn_input)
+    for name, param in lstm_layer.collect_params().items():
+        if "i2h_weight" in name:
+            param.set_data(i2h_weight)
+        if "i2h_bias" in name:
+            param.set_data(i2h_bias)
+        if "h2h_weight" in name:
+            param.set_data(h2h_weight)
+        if "h2h_bias" in name:
+            param.set_data(h2h_bias)
+        if "h2p_weight" in name:
+            param.set_data(h2p_weight)
+    cudnn_output = lstm_layer(cudnn_input).asnumpy()
+    syms = [mx.sym.Variable("rnn_t%d_data" % i) for i  in range(seq_len)]
+    cell = gluon.contrib.rnn.LSTMPCell(hidden_size=hidden_size, projection_size=projection_size, prefix='self_')
+    syms, _ = cell.unroll(seq_len, syms)
+    sym = mx.sym.Group(syms)
+    args = in_data.copy()
+    args['self_i2h_weight'] = i2h_weight
+    args['self_h2h_weight'] = h2h_weight
+    args['self_i2h_bias'] = i2h_bias
+    args['self_h2h_bias'] = h2h_bias
+    args['self_h2r_weight'] = h2p_weight
+    exe = sym.bind(mx.gpu(0), args)
+    exe.forward()
+    outputs = list(map(lambda x: x.asnumpy(), exe.outputs))
+    self_outputs = np.stack(outputs)
+    assert_almost_equal(self_outputs, cudnn_output, rtol=rtol, atol=atol)
+
+
 @with_seed()
 @assert_raises_cudnn_disabled()
 def test_rnn_layer():
-    check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
-    check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3))
-    check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3))
-    check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))
-
-    check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
-    check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
+    # check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
+    # check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3))
+    # check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3))
+    # check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))
+    #
+    # check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
+    # check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
+
+    check_lstmp(500, 100, rtol=1e-3, atol=1e-6)
+    # check_rnn_layer_forward(gluon.rnn.LSTM(7, projection_size=5), mx.nd.ones((2, 3, 7)), run_only=False)
+    # check_rnn_layer_forward(gluon.rnn.LSTM(7, 2, projection_size=5), mx.nd.ones((2, 3, 7)), run_only=False)
 
 
 @with_seed()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services