You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/03/26 13:53:01 UTC

[GitHub] [incubator-mxnet] lihaofd commented on a change in pull request #14476: Change RNN OP to stateful

lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r269112331
 
 

 ##########
 File path: src/operator/rnn-inl.h
 ##########
 @@ -436,387 +566,897 @@ class RNNOp : public Operator{
     if (param_.state_outputs) {
       hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
     }
-    DType* cx_ptr = NULL;
-    DType* cy_ptr = NULL;
+    DType * cx_ptr = NULL;
+    DType * cy_ptr = NULL;
+    if (param_.mode == rnn_enum::kLstm)
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+    if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+      cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
 
-    if (param_.mode == rnn_enum::kLstm) {
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
-      }
-    }
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
 
     // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+    const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
                                                       param_.state_size, direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+    DType* work_cpu_space = NULL;
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(s, in_data, out_data);
+    }
+    // Get temp space
+    int temp_size = workspace_size_;
+    Tensor<xpu, 1, DType> temp_space =
+      ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+                              mshadow::Shape1(temp_size + work_cpu_space_size), s);
+
+    work_cpu_space = temp_space.dptr_ + temp_size;
+
+    #if USE_CUDNN_LSTM_PROJ
+    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+                                         dtype_,
+                                         CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         param_.input_size_,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    int out_size =
+      (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size;
+    out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+                                         dtype_,
+                                         CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         out_size,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    if (ctx.is_train) {
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+                                           dtype_,
+                                           CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+                                           dtype_,
+                                           CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           out_size,
+                                           seqLengthArray.data(),
+                                           nullptr));
+    }
+    #endif
+
+    #if USE_CUDNN_LSTM_PROJ
+    bool clip_state = param_.lstm_state_clip_min.has_value();
+    bool clip_nan = param_.lstm_state_clip_nan;
+    CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+                               rnn_desc_,
+                               clip_state ? CUDNN_RNN_CLIP_MINMAX : CUDNN_RNN_CLIP_NONE,
+                               clip_nan ? CUDNN_NOT_PROPAGATE_NAN : CUDNN_PROPAGATE_NAN,
+                               clip_state ? param_.lstm_state_clip_min.value() : 0.0,
+                               clip_state ? param_.lstm_state_clip_max.value() : 0.0));
+    #endif
 
     if (ctx.is_train) {
-      const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
-                                                   param_.seq_length_, param_.batch_size_,
-                                                   param_.state_size, param_.mode);
-      if (init_space_ && reserve_space_size_ < r_size) {
-        Storage::Get()->Free(reserve_space_);
-        init_space_ = false;
-      }
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+                                           rnn_desc_,
+                                           x_data_desc_,
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           cx_desc_,
+                                           cx_ptr,
+                                           w_desc_,
+                                           w.dptr_,
+                                           y_data_desc_,
+                                           y.dptr_,
+                                           hy_desc_,
+                                           hy_ptr,
+                                           cy_desc_,
+                                           cy_ptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           temp_space.dptr_,
+                                           workspace_byte_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+      #else
+      CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+                                         rnn_desc_,
+                                         param_.seq_length_,
+                                         x_desc_vec_.data(),
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         cx_desc_,
+                                         cx_ptr,
+                                         w_desc_,
+                                         w.dptr_,
+                                         y_desc_vec_.data(),
+                                         y.dptr_,
+                                         hy_desc_,
+                                         hy_ptr,
+                                         cy_desc_,
+                                         cy_ptr,
+                                         temp_space.dptr_,
+                                         workspace_byte_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+      #endif
+    } else {
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+                                            rnn_desc_,
+                                            x_data_desc_,
+                                            x.dptr_,
+                                            hx_desc_,
+                                            hx.dptr_,
+                                            cx_desc_,
+                                            cx_ptr,
+                                            w_desc_,
+                                            w.dptr_,
+                                            y_data_desc_,
+                                            y.dptr_,
+                                            hy_desc_,
+                                            hy_ptr,
+                                            cy_desc_,
+                                            cy_ptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            temp_space.dptr_,
+                                            workspace_byte_));
+      #else
+      CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.seq_length_,
+                                          x_desc_vec_.data(),
+                                          x.dptr_,
+                                          hx_desc_,
+                                          hx.dptr_,
+                                          cx_desc_,
+                                          cx_ptr,
+                                          w_desc_,
+                                          w.dptr_,
+                                          y_desc_vec_.data(),
+                                          y.dptr_,
+                                          hy_desc_,
+                                          hy_ptr,
+                                          cy_desc_,
+                                          cy_ptr,
+                                          temp_space.dptr_,
+                                          workspace_byte_));
+      #endif
+    }
+    #endif
 
-      if (!init_space_) {
-        reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU());
-        reserve_space_size_ = r_size;
-        init_space_ = true;
+    if (ctx_.dev_type == kCPU) {
+      if (!work_cpu_space) {
+        Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+          .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+        work_cpu_space = workspace.dptr_;
+      }
+      if (ctx.is_train) {
+        const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
+                                                     param_.seq_length_, param_.batch_size_,
+                                                     param_.state_size, param_.mode);
+        if (init_space_ && reserve_cpu_space_size_ < r_size) {
+          Storage::Get()->Free(reserve_cpu_space_);
+          init_space_ = false;
+        }
+        if (!init_space_) {
+          reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU());
+          reserve_cpu_space_size_ = r_size;
+          init_space_ = true;
+        }
+
+        DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
+
+        RNNForwardTraining<DType>(work_cpu_space,
+                                  reserve_space_ptr,
+                                  param_.state_outputs,
+                                  param_.num_layers,
+                                  direction,
+                                  param_.seq_length_,
+                                  param_.batch_size_,
+                                  param_.input_size_,
+                                  param_.state_size,
+                                  x.dptr_,
+                                  hx.dptr_,
+                                  cx_ptr,
+                                  w.dptr_,
+                                  b_ptr,
+                                  y.dptr_,
+                                  hy_ptr,
+                                  cy_ptr,
+                                  param_.p,
+                                  param_.mode);
+      } else {
+        RNNForwardInference<DType>(work_cpu_space,
+                                   param_.state_outputs,
+                                   param_.num_layers,
+                                   direction,
+                                   param_.seq_length_,
+                                   param_.batch_size_,
+                                   param_.input_size_,
+                                   param_.state_size,
+                                   x.dptr_,
+                                   hx.dptr_,
+                                   cx_ptr,
+                                   w.dptr_,
+                                   b_ptr,
+                                   y.dptr_,
+                                   hy_ptr,
+                                   cy_ptr,
+                                   param_.mode);
       }
-
-      DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
-      RNNForwardTraining<DType>(workspace.dptr_,
-                                reserve_space_ptr,
-                                param_.state_outputs,
-                                param_.num_layers,
-                                direction,
-                                param_.seq_length_,
-                                param_.batch_size_,
-                                param_.input_size_,
-                                param_.state_size,
-                                x.dptr_,
-                                hx.dptr_,
-                                cx_ptr,
-                                w.dptr_,
-                                b_ptr,
-                                y.dptr_,
-                                hy_ptr,
-                                cy_ptr,
-                                param_.p,
-                                param_.mode);
-    } else {
-      RNNForwardInference<DType>(workspace.dptr_,
-                                 param_.state_outputs,
-                                 param_.num_layers,
-                                 direction,
-                                 param_.seq_length_,
-                                 param_.batch_size_,
-                                 param_.input_size_,
-                                 param_.state_size,
-                                 x.dptr_,
-                                 hx.dptr_,
-                                 cx_ptr,
-                                 w.dptr_,
-                                 b_ptr,
-                                 y.dptr_,
-                                 hy_ptr,
-                                 cy_ptr,
-                                 param_.mode);
     }
   }
 
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
+  void Backward(const OpContext &ctx,
+                const std::vector<TBlob> &out_grad,
+                const std::vector<TBlob> &in_data,
+                const std::vector<TBlob> &out_data,
+                const std::vector<OpReqType> &req,
+                const std::vector<TBlob> &in_grad) {
     using namespace mshadow;
     using namespace mshadow::expr;
     CHECK(param_.p >= 0.0f && param_.p < 1.0f)
         << "unsupported dropout value, should be 0 <= dropout < 1";
 
-    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
-    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
-    if (!param_.state_outputs) {
-      out_expected = 1;
+    size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    //  kOut
+    size_t num_outputs = 1;
+    if (param_.state_outputs) {
+      // kOut, kStateOut, kStateCellOut
+      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     }
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    CHECK_EQ(in_grad.size(), in_expected);
-    CHECK_EQ(out_grad.size(), out_expected);
-    CHECK_EQ(req.size(), in_expected);
+
+    CHECK_EQ(in_data.size(), num_inputs);
+    CHECK_EQ(out_data.size(), num_outputs);
+    CHECK_EQ(in_grad.size(), num_inputs);
+    CHECK_EQ(out_grad.size(), num_outputs);
+    CHECK_EQ(req.size(), num_inputs);
     CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data";
     CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state";
-    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    Stream<xpu> *s = ctx.get_stream<xpu>();
     // get input + output tensors
-    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    CHECK(x.CheckContiguous());
-    CHECK(w.CheckContiguous());
-    CHECK(hx.CheckContiguous());
-    CHECK(y.CheckContiguous());
-    CHECK(dx.CheckContiguous());
-    CHECK(dw.CheckContiguous());
-    CHECK(dhx.CheckContiguous());
-    CHECK(dy.CheckContiguous());
+    Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+    Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, DType>(s);
+    Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(dw.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(dhx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
+    CHECK_EQ(dy.CheckContiguous(), true);
+
+    if (req[rnn_enum::kParams] != kAddTo) {
+      dw = mshadow::expr::ScalarExp<DType>(0.0f);
+    }
+
     param_.seq_length_ = x.shape_[0];
     param_.batch_size_ = x.shape_[1];
     param_.input_size_ = x.shape_[2];
 
     const int direction = param_.bidirectional ? 2 : 1;
     const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);
+
     DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
 
     DType * dhy_ptr = NULL;
     if (param_.state_outputs) {
       dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
     }
 
-    DType * cx_ptr = NULL;
-    DType * dcx_ptr = NULL;
-    DType * dcy_ptr = NULL;
+    DType* dcx_ptr = NULL;
+    DType* dcy_ptr = NULL;
+    DType* cx_ptr = NULL;
 
     if (param_.mode == rnn_enum::kLstm) {
       CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell";
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
-      }
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+      dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
     }
+    if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
+        dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
 
     // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
-                                                      param_.state_size, direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
-    size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
-                                           param_.seq_length_, param_.batch_size_,
-                                           param_.state_size, param_.mode);
-    if (!init_space_ || reserve_space_size_ != r_size) {
-      LOG(FATAL) << "Check forward init error";
+    const size_t work_cpu_space_size =
+        GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                            param_.state_size, direction, param_.mode);
+    DType* work_cpu_space = NULL;
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(s, in_data, out_data);
     }
 
-    DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
-    RNNBackward<DType>(workspace.dptr_,
-                       reserve_space_ptr,
-                       param_.num_layers,
-                       direction,
-                       param_.seq_length_,
-                       param_.batch_size_,
-                       param_.input_size_,
-                       param_.state_size,
-                       x.dptr_,
-                       hx.dptr_,
-                       cx_ptr,
-                       w.dptr_,
-                       y.dptr_,
-                       dy.dptr_,
-                       dhy_ptr,
-                       dcy_ptr,
-                       dx.dptr_,
-                       dhx.dptr_,
-                       dcx_ptr,
-                       dw.dptr_,
-                       db_ptr,
-                       req[rnn_enum::kData],
-                       req[rnn_enum::kParams],
-                       req[rnn_enum::kState],
-                       // State cell should be present for LSTMs, but is absent for other RNNs.
-                       param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp,
-                       param_.p,
-                       param_.mode);
-  }
-
- private:
-  RNNParam param_;
-  bool init_space_;
-  size_t reserve_space_size_;
-  Storage::Handle reserve_space_;
-};  // class RNNOp
+    // Get temp space
+    int temp_size = workspace_size_;
+    Tensor<xpu, 1, DType> temp_space =
+      ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+                              mshadow::Shape1(temp_size + work_cpu_space_size), s);
+    work_cpu_space = temp_space.dptr_ + temp_size;
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+                                      rnn_desc_,
+                                      y_data_desc_,
+                                      y.dptr_,
+                                      dy_data_desc_,
+                                      dy.dptr_,
+                                      nullptr,
+                                      nullptr,
+                                      dhy_desc_,
+                                      dhy_ptr,
+                                      dcy_desc_,
+                                      dcy_ptr,
+                                      w_desc_,
+                                      w.dptr_,
+                                      hx_desc_,
+                                      hx.dptr_,
+                                      cx_desc_,
+                                      cx_ptr,
+                                      dx_data_desc_,
+                                      dx.dptr_,
+                                      dhx_desc_,
+                                      dhx.dptr_,
+                                      dcx_desc_,
+                                      dcx_ptr,
+                                      nullptr,
+                                      nullptr,
+                                      temp_space.dptr_,
+                                      workspace_byte_,
+                                      reserve_space_.dptr,
+                                      reserve_space_byte_));
+    CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+                                         rnn_desc_,
+                                         x_data_desc_,
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         y_data_desc_,
+                                         y.dptr_,
+                                         temp_space.dptr_,
+                                         workspace_byte_,
+                                         dw_desc_,
+                                         dw.dptr_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+    #else
+    CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
+                                    rnn_desc_,
+                                    param_.seq_length_,
+                                    y_desc_vec_.data(),
+                                    y.dptr_,
+                                    dy_desc_vec_.data(),
+                                    dy.dptr_,
+                                    dhy_desc_,
+                                    dhy_ptr,
+                                    dcy_desc_,
+                                    dcy_ptr,
+                                    w_desc_,
+                                    w.dptr_,
+                                    hx_desc_,
+                                    hx.dptr_,
+                                    cx_desc_,
+                                    cx_ptr,
+                                    dx_desc_vec_.data(),
+                                    dx.dptr_,
+                                    dhx_desc_,
+                                    dhx.dptr_,
+                                    dcx_desc_,
+                                    dcx_ptr,
+                                    temp_space.dptr_,
+                                    workspace_byte_,
+                                    reserve_space_.dptr,
+                                    reserve_space_byte_));
+    CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
+                                       rnn_desc_,
+                                       param_.seq_length_,
+                                       x_desc_vec_.data(),
+                                       x.dptr_,
+                                       hx_desc_,
+                                       hx.dptr_,
+                                       y_desc_vec_.data(),
+                                       y.dptr_,
+                                       temp_space.dptr_,
+                                       workspace_byte_,
+                                       dw_desc_,
+                                       dw.dptr_,
+                                       reserve_space_.dptr,
+                                       reserve_space_byte_));
+    #endif
+    #endif
+
+    if (ctx_.dev_type == kCPU) {
+      if (!work_cpu_space) {
+        Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+          .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+        work_cpu_space = workspace.dptr_;
+      }
+      size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
+                                             param_.seq_length_, param_.batch_size_,
+                                             param_.state_size, param_.mode);
 
-template<typename xpu>
-Operator* CreateOp(RNNParam param, int dtype);
+      if (!init_space_ || reserve_cpu_space_size_ != r_size) {
+        LOG(FATAL) << "Check forward init error";
+      }
 
-#if DMLC_USE_CXX11
-class RNNProp : public OperatorProperty {
- public:
-  std::vector<std::string> ListArguments() const override {
-    if (param_.mode == rnn_enum::kLstm) {
-      return {"data", "parameters", "state", "state_cell"};
-    } else {
-      return {"data", "parameters", "state"};
+      DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
+      RNNBackward<DType>(work_cpu_space,
+                         reserve_space_ptr,
+                         param_.num_layers,
+                         direction,
+                         param_.seq_length_,
+                         param_.batch_size_,
+                         param_.input_size_,
+                         param_.state_size,
+                         x.dptr_,
+                         hx.dptr_,
+                         cx_ptr,
+                         w.dptr_,
+                         y.dptr_,
+                         dy.dptr_,
+                         dhy_ptr,
+                         dcy_ptr,
+                         dx.dptr_,
+                         dhx.dptr_,
+                         dcx_ptr,
+                         dw.dptr_,
+                         db_ptr,
+                         req[rnn_enum::kData],
+                         req[rnn_enum::kParams],
+                         req[rnn_enum::kState],
+                         // State cell should be present for LSTMs, but is absent for other RNNs.
+                         param_.mode == rnn_enum::kLstm ? req[rnn_enum::kStateCell] : kNullOp,
+                         param_.p,
+                         param_.mode);
     }
   }
 
-  std::vector<std::string> ListOutputs() const override {
-    std::vector<std::string> outputs = {"output"};
-    if (!param_.state_outputs)
-      return outputs;
-    else
-      outputs.emplace_back("state");
-    if (param_.mode == rnn_enum::kLstm)
-      outputs.emplace_back("state_cell");
-    return outputs;
-  }
-
-  int NumOutputs() const override {
-    int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1;
-    int num_outputs = param_.state_outputs ? (mode_num + 1) : 1;
-    return num_outputs;
-  }
-
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
-  }
 
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(mxnet::ShapeVector *in_shape,
-                  mxnet::ShapeVector *out_shape,
-                  mxnet::ShapeVector *aux_shape) const override {
+ private:
+  inline void Init(mshadow::Stream<xpu> *s,
+                   const std::vector<TBlob> &in_data,
+                   const std::vector<TBlob> &out_data) {
     using namespace mshadow;
-    if (param_.mode == rnn_enum::kLstm) {
-      CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]";
-    } else {
-      CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
-    }
-    const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData];
-    if (dshape.ndim() ==  0) return false;
-    CHECK_EQ(dshape.ndim(), 3U) \
-        << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
-    // data: [sequence len, batch, input dimension]
-    int batch_size = dshape[1];
-    int input_size = dshape[2];
-    int numDirections = param_.bidirectional ? 2 : 1;
-    int total_layers = numDirections * param_.num_layers;  // double for bidirectional
-    int layer_size = (param_.projection_size.has_value()) ?
-                     param_.projection_size.value() : param_.state_size;
-    SHAPE_ASSIGN_CHECK(*in_shape,
-                       rnn_enum::kState,
-                       Shape3(total_layers, batch_size, layer_size));
-    if (param_.mode == rnn_enum::kLstm)
-      SHAPE_ASSIGN_CHECK(*in_shape,
-                         rnn_enum::kStateCell,
-                         Shape3(total_layers, batch_size, param_.state_size));
-
-    // calculate parameter vector length
-    int param_size = GetRnnParamSize(param_.num_layers,
-                                     input_size,
-                                     param_.state_size,
-                                     numDirections,
-                                     param_.mode,
-                                     param_.projection_size);
-    SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
-
-    out_shape->clear();
-    // output: [sequence len, batch, output size]
-    mxnet::TShape oshape = dshape;
-    if (param_.projection_size.has_value()) {
-      oshape[2] = numDirections * param_.projection_size.value();
-    } else {
-      oshape[2] = numDirections * param_.state_size;
+    size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    //  kOut
+    size_t num_outputs = 1;
+    if (param_.state_outputs) {
+      // kOut, kStateOut, kStateCellOut
+      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     }
-    out_shape->push_back(oshape);
-    if (!param_.state_outputs) {
-      return true;
-    } else {
-      // outStateShape: [layer_num, batch, state size]
-      mxnet::TShape outStateShape = dshape;
-      outStateShape[0] = total_layers;
-      outStateShape[1] = batch_size;
-      if (param_.projection_size.has_value()) {
-        outStateShape[2] = param_.projection_size.value();
+
+    CHECK_EQ(in_data.size(), num_inputs);
+    CHECK_EQ(out_data.size(), num_outputs);
+
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    #if CUDNN_MAJOR >= 5
+    format_ = CUDNN_TENSOR_NCHW;
+    #endif
+
+    if (!init_cudnn_) {
+      init_cudnn_ = true;
+      // get input + output tensors
+      Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+      Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+      param_.seq_length_ = x.shape_[0];
+      param_.batch_size_ = x.shape_[1];
+      param_.input_size_ = x.shape_[2];
+
+      // Tensor Descriptors
+      std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_);
+      std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_);
+      std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_);
+      std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_);
+      int dimA[3];
+      int strideA[3];
+      for (int i = 0; i < param_.seq_length_; i++) {
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i]));
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i]));
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i]));
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i]));
+
+        dimA[0] = param_.batch_size_;
+        dimA[1] = param_.input_size_;
+        dimA[2] = 1;
+        strideA[0] = dimA[2] * dimA[1];
+        strideA[1] = dimA[2];
+        strideA[2] = 1;
+
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+        dimA[0] = param_.batch_size_;
+        dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size;
+        dimA[2] = 1;
+        strideA[0] = dimA[2] * dimA[1];
+        strideA[1] = dimA[2];
+        strideA[2] = 1;
+
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+      }
+      x_desc_vec_ = x_vec;
+      y_desc_vec_ = y_vec;
+      dx_desc_vec_ = dx_vec;
+      dy_desc_vec_ = dy_vec;
+
+      // set the state tensors
+      dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimA[1] = param_.batch_size_;
+      dimA[2] = param_.state_size;
+      strideA[0] = dimA[2] * dimA[1];
+      strideA[1] = dimA[2];
+      strideA[2] = 1;
+      #if USE_CUDNN_LSTM_PROJ
+      int dimB[3];
+      int strideB[3];
+      dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimB[1] = param_.batch_size_;
+      dimB[2] = param_.projection_size.has_value() ?
+                param_.projection_size.value() : param_.state_size;
+      strideB[0] = dimB[2] * dimB[1];
+      strideB[1] = dimB[2];
+      strideB[2] = 1;
+      #endif
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+
+      // Create Dropout descriptors
+      if (param_.p > 0) {
+        CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
+        dropout_size_ = dropout_byte_ / sizeof(DType);
+        dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id));
 
 Review comment:
   How to get storage after get_cudnn_dropout_desc?
   if using get_space_typed after get_cudnn_dropout_desc
   
   ctx.requested[rnn_enum::kCuDNNDropoutDesc].get_cudnn_dropout_desc
             (&dropout_desc_, s, 1.0f - param_.p, seed_);
   Tensor<xpu, 1, DType> dropout_space =
               ctx.requested[rnn_enum::kCuDNNDropoutDesc].get_space_typed<xpu, 1, DType>(
               mshadow::Shape1(dropout_size_), s); // it will run with error - src/storage/storage.cc:134: Cannot Free space to a device you have not allocated

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services