You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/26 17:44:05 UTC

[incubator-mxnet] branch master updated: add vRNN and dropout (#11399)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0538ad9  add vRNN and dropout (#11399)
0538ad9 is described below

commit 0538ad9115e0856c2f45fcff479a9af431b31f76
Author: Hao Li <ha...@intel.com>
AuthorDate: Wed Jun 27 01:43:57 2018 +0800

    add vRNN and dropout (#11399)
---
 example/rnn/bucketing/cudnn_rnn_bucketing.py |  16 +-
 src/operator/rnn-inl.h                       |  74 ++-
 src/operator/rnn_impl.h                      | 947 ++++++++++++++++++++++++++-
 tests/python/unittest/test_operator.py       | 116 +++-
 4 files changed, 1099 insertions(+), 54 deletions(-)

diff --git a/example/rnn/bucketing/cudnn_rnn_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py
index 29a66a8..5825290 100644
--- a/example/rnn/bucketing/cudnn_rnn_bucketing.py
+++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py
@@ -66,7 +66,7 @@ parser.add_argument('--stack-rnn', default=False,
 parser.add_argument('--dropout', type=float, default='0.0',
                     help='dropout probability (1.0 - keep probability)')
 parser.add_argument('--rnntype', type=str, default='lstm',
-                    help='rnn type: gru and lstm are supported')
+                    help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported')
 
 #buckets = [32]
 buckets = [10, 20, 30, 40, 50, 60]
@@ -188,6 +188,20 @@ def test(args):
                             cell,
                             mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
                             output_prefix='bi_%s_%d'%(args.rnntype,i))
+            elif args.rnntype == 'rnn_tanh':
+                cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
+            elif args.rnntype == 'rnn_relu':
+                cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
 
             stack.add(cell)
 
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 9953173..1f905ed 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -99,10 +99,6 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
                                   int mode) {
   size_t size = 0;
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
       size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
              + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
@@ -110,6 +106,10 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
     case rnn_enum::kGru:
       size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
       break;
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+      size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4;
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -125,18 +125,20 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
                                      int mode) {
   size_t size = 0;
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
-      size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
+      size = direction * seq_length * batch_size * hidden_size * (num_layer * 7 - 1);
       break;
     case rnn_enum::kGru:
-      size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
+      size = seq_length * batch_size * hidden_size * direction * (num_layer * 9 - 1) +
           batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
           seq_length * batch_size * 7 * hidden_size * direction;
       break;
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+      size = seq_length * batch_size * hidden_size * direction * (num_layer * 6 - 1) +
+          batch_size * hidden_size * direction * 3 + hidden_size * seq_length * 2 +
+          seq_length * batch_size * 2 * hidden_size * direction;
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -223,21 +225,24 @@ void RNNForwardTraining(DType* ws,
                         DType* y_ptr,
                         DType* hy_ptr,
                         DType* cy_ptr,
+                        const float dropout,
                         int mode) {
   switch (mode) {
-    case rnn_enum::kRnnTanh:
-    case rnn_enum::kRnnRelu:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
       LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
                                  batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
-                                 w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
+                                 w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, dropout);
       break;
     case rnn_enum::kGru:
       GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
                                 batch_size, input_size, state_size, x_ptr, hx_ptr,
-                                w_ptr, y_ptr, hy_ptr);
+                                w_ptr, y_ptr, hy_ptr, dropout);
+      break;
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      VanillaRNNForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
+                                       batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                       w_ptr, y_ptr, hy_ptr, dropout, mode);
       break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
@@ -264,10 +269,6 @@ void RNNForwardInference(DType* ws,
                          DType* cy_ptr,
                          int mode) {
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
       LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
                                   batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
@@ -278,6 +279,12 @@ void RNNForwardInference(DType* ws,
                                  batch_size, input_size, state_size, x_ptr, hx_ptr,
                                  w_ptr, y_ptr, hy_ptr);
       break;
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      VanillaRNNForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
+                                        batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                        w_ptr, y_ptr, hy_ptr, mode);
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
       break;
@@ -310,22 +317,27 @@ void RNNBackward(DType* ws,
                  int req_params,
                  int req_state,
                  int req_statecell,
+                 const float dropout,
                  int mode) {
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      break;
     case rnn_enum::kLstm:
       LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
                           input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
                           dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
-                          req_data, req_params, req_state, req_statecell);
+                          req_data, req_params, req_state, req_statecell, dropout);
       break;
     case rnn_enum::kGru:
       GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
                          input_size, state_size, x_ptr, hx_ptr, w_ptr,
                          dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
-                         req_data, req_params, req_state);
+                         req_data, req_params, req_state, dropout);
+      break;
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      VanillaRNNBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
+                                input_size, state_size, x_ptr, hx_ptr, w_ptr,
+                                dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
+                                req_data, req_params, req_state, dropout, mode);
       break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
@@ -354,9 +366,8 @@ class RNNOp : public Operator{
                        const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
-        << "Only lstm and gru mode are supported at the moment.";
-    CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
+    CHECK(param_.p >= 0.0f && param_.p < 1.0f)
+        << "unsupported dropout value, should be 0 <= dropout < 1";
 
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
     size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
@@ -436,6 +447,7 @@ class RNNOp : public Operator{
                                 y.dptr_,
                                 hy_ptr,
                                 cy_ptr,
+                                param_.p,
                                 param_.mode);
     } else {
       RNNForwardInference<DType>(workspace.dptr_,
@@ -467,9 +479,8 @@ class RNNOp : public Operator{
                         const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
-        << "Only lstm and gru mode are supported at the moment.";
-    CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
+    CHECK(param_.p >= 0.0f && param_.p < 1.0f)
+        << "unsupported dropout value, should be 0 <= dropout < 1";
 
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
     size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
@@ -566,6 +577,7 @@ class RNNOp : public Operator{
                        req[rnn_enum::kParams],
                        req[rnn_enum::kState],
                        req[rnn_enum::kStateCell],
+                       param_.p,
                        param_.mode);
   }
 
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index fa8d671..e1b4a2b 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -50,6 +50,11 @@ inline DType sigmoid(DType x) {
 }
 
 template<typename DType>
+inline DType relu(DType x) {
+  return x > 0.0f ? static_cast<float>(x) : 0.0f;
+}
+
+template<typename DType>
 void LstmForwardTrainingSingleLayer(DType* ws,
                                     DType* rs,
                                     bool state_outputs,
@@ -133,7 +138,10 @@ void LstmForwardTraining(DType* ws,
                          DType* b_ptr,
                          DType* y_ptr,
                          DType* hy_ptr,
-                         DType* cy_ptr) {
+                         DType* cy_ptr,
+                         const float dropout) {
+  DType* dropout_random = rs;
+  DType* rs2 = dropout_random + (L - 1) * D * T * N * H;
   const int total_layers = D * L;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
@@ -141,14 +149,15 @@ void LstmForwardTraining(DType* ws,
   const int r_size = D * T * N * H * 6;
   const int y_offset = T * N * H * 5;
   const int cell_size = N * H;
+  unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
   int idx = 0;  // state & cell state's idx;
   const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   for (int i = 0; i < L; ++i) {
     const int input_size = i ? H * D : I;
     const int w_size = (input_size + H) * H * 4;
     Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
-    Tensor<cpu, 3, DType> y(rs + y_offset, Shape3(T, N, H * D));
-    LstmForwardTrainingSingleLayer<DType>(ws, rs, state_outputs, false, T, N, input_size, H, x,
+    Tensor<cpu, 3, DType> y(rs2 + y_offset, Shape3(T, N, H * D));
+    LstmForwardTrainingSingleLayer<DType>(ws, rs2, state_outputs, false, T, N, input_size, H, x,
                                           hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
     if (D == 2) {
       w_ptr += w_size;
@@ -158,14 +167,27 @@ void LstmForwardTraining(DType* ws,
         hy_ptr += cell_size;
         cy_ptr += cell_size;
       }
-      LstmForwardTrainingSingleLayer<DType>(ws, rs, state_outputs, true, T, N, input_size, H, x,
+      LstmForwardTrainingSingleLayer<DType>(ws, rs2, state_outputs, true, T, N, input_size, H, x,
                                             hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
     }
     if (i != L - 1) {
       w_ptr += w_size;
       b_ptr += b_size;
+      if (dropout > 0.0f) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < T * N * H * D; j++) {
+          int rand_data = rand_r(&seed_);
+          if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * dropout)) {
+            dropout_random[i * T * N * H * D + j] = 0;
+            y.dptr_[j] = 0;
+          } else {
+            dropout_random[i * T * N * H * D + j] = 1.0f - dropout;
+            y.dptr_[j] =  y.dptr_[j] / (1.0f - dropout);
+          }
+        }
+      }
       x_ptr = y.dptr_;
-      rs += r_size;
+      rs2 += r_size;
       ++idx;
       if (state_outputs) {
         hy_ptr += cell_size;
@@ -175,7 +197,7 @@ void LstmForwardTraining(DType* ws,
   }
   #pragma omp parallel for num_threads(omp_threads)
   for (int i = 0; i < T * N * H * D; ++i) {
-    y_ptr[i] = (rs + y_offset)[i];
+    y_ptr[i] = (rs2 + y_offset)[i];
   }
 }
 
@@ -498,7 +520,10 @@ void LstmBackward(DType* ws,
                   int req_data,
                   int req_params,
                   int req_state,
-                  int req_statecell) {
+                  int req_statecell,
+                  const float dropout) {
+  DType* dropout_random = rs + (L - 1) * D * T * N * H;
+  DType* rs2 = rs + (L - 1) * D * T * N * H;
   DType* tmp_buf = ws;
   DType* ws2 = tmp_buf + 8 * T * H;
   const int total_layers = D * L;
@@ -520,7 +545,7 @@ void LstmBackward(DType* ws,
     DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr;
     DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : dw_ptr;
     DType* db_cur_ptr = db_ptr + i * b_size * D;
-    DType* rs_cur_ptr = rs + i * r_size;
+    DType* rs_cur_ptr = rs2 + i * r_size;
     DType* dhy_cur_ptr = dhy_ptr ? dhy_ptr + i * cell_size * D : NULL;
     DType* dcy_cur_ptr = dcy_ptr ? dcy_ptr + i * cell_size * D : NULL;
     Tensor<cpu, 3, DType> y(rs_cur_ptr + y_offset, Shape3(T, N, H * D));
@@ -543,6 +568,18 @@ void LstmBackward(DType* ws,
                                      dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
                                      req_data, req_params, req_state, req_statecell);
     }
+    if (dropout > 0.0f && i > 0 && req_data != kNullOp) {
+      dropout_random = dropout_random - T * N * D * H;
+      const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int j = 0; j < T * N * D * H; j++) {
+        if (dropout_random[j] == 0) {
+          dx.dptr_[j] = 0;
+        } else {
+          dx.dptr_[j] = dx.dptr_[j] / (1.0f - dropout);
+        }
+      }
+    }
     dy_ptr = dx.dptr_;
   }
 }
@@ -935,7 +972,8 @@ void GruForwardTraining(DType* ws,
                         DType* hx_ptr,
                         DType* w_ptr,
                         DType* y_ptr,
-                        DType* hy_ptr) {
+                        DType* hy_ptr,
+                        const float dropout) {
   DType* wx = w_ptr;
   DType* wh = wx + I * H * 3;
   DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
@@ -948,19 +986,34 @@ void GruForwardTraining(DType* ws,
   DType* gateN_l = gateZ_l + L * T * D * N * H;
   DType* y_l = gateN_l + L * T * D * N * H;
   DType* Mnh_l = y_l + L * T * N * H * D;
-  DType* tmp_buf = Mnh_l + L * D * T * N * H;
-  DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N;
+  DType* dropout_random = Mnh_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
+  DType* ws2 = tmp_buf + D * N * H;
   DType* wx_l = wx;
   DType* wh_l = wh;
   DType* bx_l = bx;
   DType* bh_l = bh;
   DType* y_tmp = x_ptr;
-
+  unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
   for (int l = 0; l < L; l++) {
     if (l != 0) {
       y_tmp = y_l;
       y_l = y_l + T * N * H * D;
     }
+    if (dropout > 0.0f && l > 0) {
+      const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        int rand_data = rand_r(&seed_);
+        if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * dropout)) {
+          dropout_random[(l - 1) * T * N * I + i] = 0;
+          y_tmp[i] = 0;
+        } else {
+          dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout;
+          y_tmp[i] =  y_tmp[i] / (1.0f - dropout);
+        }
+      }
+    }
     Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
     Tensor<cpu, 2, DType> hx_l = hx[D * l];
     GruForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
@@ -1349,7 +1402,8 @@ void GruBackward(DType* ws,
                  DType* dw_ptr,
                  int req_data,
                  int req_params,
-                 int req_state) {
+                 int req_state,
+                 const float dropout) {
   DType* wx = w_ptr;
   DType* dwx = dw_ptr;
   DType* dwh = dwx + I * H * 3;
@@ -1360,7 +1414,8 @@ void GruBackward(DType* ws,
   DType* gateN_l = gateZ_l + L * T * D * N * H;
   DType* y_l = gateN_l + L * T * D * N * H;
   DType* Mnh_l = y_l + L * T * N * H * D;
-  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* dropout_random = Mnh_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
   DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2;
   DType* ws2 = dx_l + T * N * D * H;
   DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H
@@ -1403,6 +1458,17 @@ void GruBackward(DType* ws,
     GruBackwardSingleLayer<DType>(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l,
                                   dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l,
                                   dwx_l, dwh_l, dbx_l, dbh_l, req_data, req_params, req_state);
+    if (dropout > 0.0f && l > 0 && req_data != kNullOp) {
+      dropout_random = dropout_random - T * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        if (dropout_random[i] == 0) {
+          dx_l[i] = 0;
+        } else {
+          dx_l[i] = dx_l[i] / (1.0f - dropout);
+        }
+      }
+    }
     if (l > 0) {
       #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < T * N * H * D; ++i) {
@@ -1433,6 +1499,859 @@ void GruBackward(DType* ws,
     }
   }
 }
+
+template<typename DType>
+void VanillaRNNForwardInferenceSingleLayer(DType* ws,
+                                           DType* tmp_buf,
+                                           bool state_outputs,
+                                           const int D,
+                                           const int T,
+                                           const int N,
+                                           const int I,
+                                           const int H,
+                                           const Tensor<cpu, 2, DType> &x,
+                                           const Tensor<cpu, 2, DType> &hx,
+                                           DType* wx_ptr,
+                                           DType* wh_ptr,
+                                           DType* bx_ptr,
+                                           DType* bh_ptr,
+                                           DType* y_ptr,
+                                           DType* hy_ptr,
+                                           int mode) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H;
+  DType* back_ht = back_ht_1;
+  DType* gemmC1  = ws;              // [D, T, N, H]
+  DType* gemmC2  = gemmC1 + D * T * N * H;  // N * H
+  DType* back_wx_ptr = wx_ptr + I * H + H * H;
+  DType* back_wh_ptr = wh_ptr + I * H + H * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + H * 2: NULL;
+  DType* back_gemmC1 = gemmC1 + T * N * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(1, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, H));
+
+  // x * wx.T : [T * N, I] * [I, H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    gemmC1_t = gemmC1 + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int tb = i * H;
+        if (mode == 1) {
+          ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] +
+              gemmC2[tb + j] + bh[0][j]);
+        } else {
+          ht[i * D * H + j] = relu(gemmC1_t[tb + j] + bx[0][j] +
+              gemmC2[tb + j] + bh[0][j]);
+        }
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true);
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int tb = i * H;
+          if (mode == 1) {
+            back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j]
+                + gemmC2[tb + j] + back_bh[0][j]);
+          } else {
+            back_ht[i * D * H + j] = relu(gemmC1_t[tb + j] + back_bx[0][j]
+              + gemmC2[tb + j] + back_bh[0][j]);
+          }
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void VanillaRNNForwardInference(DType* ws,
+                                bool state_outputs,
+                                const int L,
+                                const int D,
+                                const int T,
+                                const int N,
+                                int I,
+                                const int H,
+                                DType* x_ptr,
+                                DType* hx_ptr,
+                                DType* w_ptr,
+                                DType* y_ptr,
+                                DType* hy_ptr,
+                                int mode) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H;
+  DType* bx = wh + H * H + (D - 1) * (H * H + I * H)
+      + (L - 1) * ((D + 1) * H) * H * D;
+  DType* bh = bx + H;
+
+  DType* y_tmp = ws;
+  DType* y_l = x_ptr;
+  DType* tmp_buf = y_tmp + D * T * N * H;
+  DType* ws2 = y_tmp + D * T * N * H + D * H * N;
+
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  for (int l = 0; l < L; l++) {
+    Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I));
+    if ((L + l) % 2) {
+      y_l = y_ptr;
+    } else {
+      y_l = y_tmp;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    VanillaRNNForwardInferenceSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                                 x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l,
+                                                 hy_l, mode);
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + H * D * 2;
+    bh_l = bh_l + H * D * 2;
+    wx_l = wx_l + I * H * D + H * H * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * H;
+  }
+}
+
+
+template<typename DType>
+void VanillaRNNForwardTrainingSingleLayer(DType* ws,
+                                       DType* tmp_buf,
+                                       bool state_outputs,
+                                       const int D,
+                                       const int T,
+                                       const int N,
+                                       const int I,
+                                       const int H,
+                                       const Tensor<cpu, 2, DType> &x,
+                                       const Tensor<cpu, 2, DType> &hx,
+                                       DType* wx_ptr,
+                                       DType* wh_ptr,
+                                       DType* bx_ptr,
+                                       DType* bh_ptr,
+                                       DType* gateN,
+                                       DType* y_ptr,
+                                       DType* hy_ptr,
+                                       int mode) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H;
+  DType* back_ht = back_ht_1;
+
+  DType* gemmC1  = ws;              // [D, T, N, H]
+  DType* gemmC2  = gemmC1 + D * T * N * H;  // N * H
+  DType* nt = gateN;
+  DType* back_wx_ptr = wx_ptr + I * H + H * H;
+  DType* back_wh_ptr = wh_ptr + I * H + H * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + H * 2 : NULL;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_gemmC1 = gemmC1 + T * N * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 1, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 1, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(1, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, H));
+
+  // x * wx.T : [T * N, I] * [I, H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    nt = gateN + t * N * H;
+    gemmC1_t = gemmC1 + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int tb = i * H;
+        if (mode == 1) {
+          nt[tb + j] = ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] +
+              gemmC2[tb + j] + bh[0][j]);
+        } else {
+          nt[tb + j] = gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j];
+          ht[i * D * H + j] = relu(nt[tb + j]);
+        }
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      nt = back_gateN + (T - 1 - t) * N * H;
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true);
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int tb = i * H;
+          if (mode == 1) {
+            nt[tb + j] = back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j]
+                + gemmC2[tb + j] + back_bh[0][j]);
+          } else {
+            nt[tb + j] = gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j];
+            back_ht[i * D * H + j] = relu(nt[tb + j]);
+          }
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void VanillaRNNForwardTraining(DType* ws,
+                               DType* rs,
+                               bool state_outputs,
+                               const int L,
+                               const int D,
+                               const int T,
+                               const int N,
+                               int I,
+                               const int H,
+                               DType* x_ptr,
+                               DType* hx_ptr,
+                               DType* w_ptr,
+                               DType* y_ptr,
+                               DType* hy_ptr,
+                               const float dropout,
+                               int mode) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H;
+  DType* bx = wh + H * H + (D - 1) * (H * H + I * H)
+      + (L - 1) * ((D + 1) * H) * H * D;
+  DType* bh = bx + H;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  DType* gateN_l = rs;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* dropout_random = y_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
+  DType* ws2 = tmp_buf + D * N * H;
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  DType* y_tmp = x_ptr;
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
+  for (int l = 0; l < L; l++) {
+    if (l != 0) {
+      y_tmp = y_l;
+      y_l = y_l + T * N * H * D;
+    }
+    if (dropout > 0.0f && l > 0) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        int rand_data = rand_r(&seed_);
+        if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * dropout)) {
+          dropout_random[(l - 1) * T * N * I + i] = 0;
+          y_tmp[i] = 0;
+        } else {
+          dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout;
+          y_tmp[i] =  y_tmp[i] / (1.0f - dropout);
+        }
+      }
+    }
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    VanillaRNNForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                             x_l, hx_l, wx_l, wh_l, bx_l, bh_l,
+                                             gateN_l, y_l, hy_l, mode);
+    gateN_l = gateN_l +  T * D * N * H;
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + H * D * 2;
+    bh_l = bh_l + H * D * 2;
+
+    wx_l = wx_l + I * H * D + H * H * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * H;
+  }
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < T * N * H * D; ++i) {
+    y_ptr[i] = y_l[i];
+  }
+}
+
+template <typename DType>
+void VanillaRNNBackwardSingleLayer(DType* ws,
+                                   DType* tmp_buf,
+                                   const int D,
+                                   const int T,
+                                   const int N,
+                                   const int I,
+                                   const int H,
+                                   const Tensor<cpu, 2, DType> &x,
+                                   const Tensor<cpu, 2, DType> &hx,
+                                   DType* wx_ptr,
+                                   DType* wh_ptr,
+                                   DType* y_ptr,
+                                   DType* dy_ptr,
+                                   DType* dhy_ptr,
+                                   DType* gateN,
+                                   DType* dx,
+                                   DType* dhx,
+                                   DType* dwx,
+                                   DType* dwh,
+                                   DType* dbx,
+                                   DType* dbh,
+                                   int req_data,
+                                   int req_params,
+                                   int req_state,
+                                   int mode) {
+  DType* dyt;
+  DType* ht1;  // [N, D, H]
+  DType* dart;
+  DType* nt;
+  DType* dar = ws;  // [T, N, H]
+  DType* dht1 = dar + T * N * H;  // [D, N, H]
+  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
+
+  DType* back_ht1;
+  DType* back_dht1 = dht1 + N * H;  // [N, H]
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_wx_ptr = wx_ptr + I * H + H * H;
+  DType* back_wh_ptr = wh_ptr + I * H + H * H;
+  DType* back_dwx = dwx + I * H + H * H;
+  DType* back_dwh = dwh + I * H + H * H;
+  DType* back_dbx = dbx + H * 2;
+  DType* back_dbh = dbh + H * 2;
+
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (req_params != kNullOp && req_params != kAddTo) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < D * H * H; ++i) {
+      dwh[i] = 0;
+    }
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < D * H; ++i) {
+      dbx[i] = 0;
+      dbh[i] = 0;
+    }
+  }
+
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N * H; ++i) {
+    if (dhy_ptr) {
+      dht1[i] = dhy_ptr[i];
+    } else {
+      dht1[i] = 0;
+    }
+  }
+
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N; ++i) {
+    for (int j = 0; j < H; ++j) {
+      hx_[i * D * H + j] = hx[i][j];
+    }
+  }
+
+  if (D == 2) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N * H; ++i) {
+      if (dhy_ptr) {
+        back_dht1[i] = dhy_ptr[N * H + i];
+      } else {
+        back_dht1[i] = 0;
+      }
+    }
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        hx_[i * D * H + H + j] = hx[N + i][j];
+      }
+    }
+  }
+  for (int t = T - 1; t >= 0; --t) {
+    if (t) {
+      ht1 = y_ptr + (t - 1) * N * D * H;
+    } else {
+      ht1 = hx_;
+    }
+    // add dy[T, N, D, H] to dhy[D, N, H]
+    dyt = dy_ptr + t * N * D * H;
+
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        dht1[i * H + j] += dyt[i * D * H + j];
+      }
+    }
+
+    nt = gateN + t * N * H;
+    dart = dar + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int id = i * H + j;
+        if (mode == 1) {
+          dart[id] = dht1[id] * (1 - nt[id] * nt[id]);
+        } else {
+          dart[id] = nt[id] > 0.0f ? static_cast<float>(dht1[id]) : 0.0f;
+        }
+        dht1[id] = 0;
+      }
+    }
+    if (req_params != kNullOp) {
+      alpha = 1.0;
+      beta = 1.0;
+      // dht1 = dart * wh    [N, H] = [N, H] * [H, H]
+      Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
+      Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H));
+      linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
+
+      if (req_params == kAddTo) {
+        beta = 2.0;
+        // dwx = da.T * x    [H, I] = [H, N] * [N, I] for AddTo
+        Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+        Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I));
+        linalg_gemm(d_dart, d_xt, d_dwx, alpha, beta, true, false);
+      }
+      // dwh = dart.T * ht1    [H, H] = [H, N] * [N, H]
+      Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
+      Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(H, H));
+      Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
+      linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
+    }
+  }
+
+  if (req_params != kNullOp) {
+    // dbx = e * da       [1, H] = [1, N] * [N, H]
+    if (req_params != kAddTo) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < H; ++i) {
+        for (int j = 0; j < N * T; ++j) {
+          dbx[i] += dar[j * H + i];
+          dbh[i] = dbx[i];
+        }
+      }
+    } else {
+      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T));
+      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T));
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < H * T; ++i) {
+        tmp_dbx.dptr_[i] = 0;
+        tmp_dbh.dptr_[i] = 0;
+      }
+
+      for (int t = T - 1; t >= 0; --t) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H; ++i) {
+          for (int j = 0; j < N; ++j) {
+            tmp_dbx[i][t] += dar[t * N * H + j * H + i];
+            tmp_dbh[i][t] = tmp_dbx[i][t];
+          }
+        }
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H; ++i) {
+          dbx[i] += tmp_dbx[i][t] + dbx[i];
+          dbh[i] = dbx[i];
+        }
+      }
+    }
+  }
+  alpha = 1.0;
+  beta = 0.0;
+
+  // dx = da * wx    [T * N, I] = [T * N, H] * [H, I]
+  Tensor<cpu, 2, DType> d_dar(dar, Shape2(T * N, H));
+  if (req_data != kNullOp) {
+    Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+    linalg_gemm(d_dar, wx, d_dx, alpha, beta, false, false);
+  }
+
+  // dwx = da.T * x    [H, I] = [H, T * N] * [T * N, I]
+  if (req_params != kNullOp && req_params != kAddTo) {
+    Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I));
+    linalg_gemm(d_dar, x, d_dwx, alpha, beta, true, false);
+  }
+
+  if (D == 2) {
+    for (int t = 0; t < T; ++t) {
+      if (t == T-1) {
+        back_ht1 = hx_;
+      } else {
+        back_ht1 = y_ptr + (t + 1) * N * D * H;
+      }
+
+      //  add dy[T, N, D, H] to dhy[D, N, H]
+      dyt = dy_ptr + t * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          back_dht1[i * H + j] += dyt[i * D * H + H + j];
+        }
+      }
+
+      nt = back_gateN + t * N * H;
+      dart = dar + t * N * H;
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int id = i * H + j;
+          if (mode == 1) {
+            dart[id] = back_dht1[id] * (1 - nt[id] * nt[id]);
+          } else {
+            dart[id] = nt[id] > 0.0f ? static_cast<float>(back_dht1[id]) : 0.0f;
+          }
+          back_dht1[id] = 0;
+        }
+      }
+
+      if (req_params != kNullOp) {
+        alpha = 1.0;
+        beta = 1.0;
+        // dht1 = da * wh    [N, H] = [N, H] * [H, H]
+        Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H));
+        Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
+        linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);
+
+        // dwh = da.T * ht1     [H, H] = [H, N] * [N, H]
+        Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(H, H));
+        Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
+        Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
+            (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+        d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
+        if (req_params == kAddTo) {
+          beta = 2.0;
+          // dwx = da.T * x    [ H, I] = [H, N] * [N, I] for AddTo
+          Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+          Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I));
+          linalg_gemm(d_dart, d_xt, d_back_dwx, alpha, beta, true, false);
+        }
+        linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
+      }
+    }
+
+    if (req_params != kNullOp) {
+    // dbx = e * da       [1, H] = [1, N] * [N, H]
+      if (req_params != kAddTo) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H; ++i) {
+          for (int j = 0; j < N * T; ++j) {
+            back_dbx[i] += dar[j * H + i];
+            back_dbh[i] = back_dbx[i];
+          }
+        }
+      } else {
+        const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T));
+        const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T));
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H * T; ++i) {
+          tmp_dbx.dptr_[i] = 0;
+          tmp_dbh.dptr_[i] = 0;
+        }
+
+        for (int t = T - 1; t >= 0; --t) {
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < H; ++i) {
+            for (int j = 0; j < N; ++j) {
+              tmp_dbx[i][t] += dar[t * N * H + j * H + i];
+              tmp_dbh[i][t] = tmp_dbx[i][t];
+            }
+          }
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < H; ++i) {
+            back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
+            back_dbh[i] = back_dbx[i];
+          }
+        }
+      }
+    }
+    alpha = 1.0;
+    beta = 1.0;
+    // dxt = da * wx    [T * N, I] = [T * N, H] * [H, I]
+     Tensor<cpu, 2, DType> d_dar2(dar, Shape2(T * N, H));
+    if (req_data != kNullOp) {
+      Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+      linalg_gemm(d_dar2, back_wx, d_dx, alpha, beta, false, false);
+    }
+    alpha = 1.0;
+    beta = 0.0;
+    // dwx = da.T * x    [H, I] = [H, T * N] * [T * N, I]
+    if (req_params != kNullOp && req_params != kAddTo) {
+      Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I));
+      linalg_gemm(d_dar2, x, d_back_dwx, alpha, beta, true, false);
+    }
+  }
+  if (req_state != kNullOp) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N * H * D; ++i) {
+      dhx[i] = dht1[i];
+    }
+  }
+}
+
+template <typename DType>
+void VanillaRNNBackward(DType* ws,
+                        DType* rs,
+                        const int L,
+                        const int D,
+                        const int T,
+                        const int N,
+                        int I,
+                        const int H,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* w_ptr,
+                        DType* dy_ptr,
+                        DType* dhy_ptr,
+                        DType* dx_ptr,
+                        DType* dhx_ptr,
+                        DType* dw_ptr,
+                        int req_data,
+                        int req_params,
+                        int req_state,
+                        const float dropout,
+                        int mode) {
+  DType* wx = w_ptr;
+  DType* dwx = dw_ptr;
+  DType* dwh = dwx + I * H;
+  DType* dbx = dwh + H * H + (D - 1) * (H * H + I * H)
+      + (L - 1) * ((D + 1) * H) * H * D;
+  DType* gateN_l = rs + (L - 1) * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* dropout_random = y_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
+  DType* dx_l = tmp_buf + T * N * D * H + H * T * 2;
+  DType* ws2 = dx_l + T * N * D * H;
+  DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * H
+      + D * I * H + D * H * H;
+  DType* wh_l = wx_l;
+  if (L == 1) {
+    wh_l = wh_l + I * H;
+  } else {
+    wh_l = wh_l + (D * H) * H;
+  }
+  DType* dhy_l = NULL;
+  if (dhy_ptr)
+    dhy_l = dhy_ptr + (L - 1) * D * N * H;
+  DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D + 1) * H * H
+      + D * I * H + D * H * H;
+  DType* dwh_l = NULL;
+  if (L == 1) {
+    dwh_l = dwx_l + I * H;
+  } else {
+    dwh_l = dwx_l + (D * H) * H;
+  }
+  DType* dbx_l = dbx + (L - 1) * D * H * 2;
+  DType* dbh_l = dbx_l + H;
+  DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
+  DType* dy_l = dy_ptr;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H));
+  int inputsize = I;
+  DType* y_tmp = y_l - T * N * H * D;
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  for (int l = L - 1; l >= 0; --l) {
+    if (l == 0) {
+      I = inputsize;
+      y_tmp = x_ptr;
+      dx_l = dx_ptr;
+    } else {
+      I = D * H;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[l];
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    VanillaRNNBackwardSingleLayer<DType>(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l,
+                                         y_l, dy_l, dhy_l, gateN_l, dx_l, dhx_l, dwx_l, dwh_l,
+                                         dbx_l, dbh_l, req_data, req_params, req_state, mode);
+    if (dropout > 0.0f && l > 0 && req_data != kNullOp) {
+      dropout_random = dropout_random - T * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        if (dropout_random[i] == 0) {
+          dx_l[i] = 0;
+        } else {
+          dx_l[i] = dx_l[i] / (1.0f - dropout);
+        }
+      }
+    }
+    if (l > 0) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * H * D; ++i) {
+        dy_l[i] = dx_l[i];
+      }
+      gateN_l = gateN_l -  T * D * N * H;
+      dhx_l = dhx_l - D * N * H;
+      if (dhy_l)
+        dhy_l = dhy_l - D * N * H;
+      y_l = y_l - T * N * H * D;
+      y_tmp = y_l;
+      if (l == 1) {
+        wx_l = wx_l - (inputsize + H) * H * D;
+        wh_l = wx_l + inputsize * H;
+        dwx_l = dwx_l - (inputsize + H) * H * D;
+        dwh_l = dwx_l + inputsize * H;
+      } else {
+        wx_l = wx_l - (I + H) * H * D;
+        wh_l = wx_l + I * H;
+        dwx_l = dwx_l - (I + H) * H * D;
+        dwh_l = dwx_l + I * H;
+      }
+      dbx_l = dbx_l - D * H * 2;
+      dbh_l = dbx_l + H;
+    }
+  }
+}
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_RNN_IMPL_H_
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 3de30f2..e07a602 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -137,8 +137,76 @@ def test_gru_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'add')
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
-# Currently, fused LSTM operator doesn't support dropout.
-# Will change this test after dropout is supported
+@with_seed()
+def test_rnntanh_sym():
+    T, N, I, H = 5, 32, 800, 800
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'))
+    stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'))
+    stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_rnntanh_bidirectional():
+    T, N, I, H = 5, 20, 800, 800
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh',
+                                bidirectional=True, get_next_state=True, prefix='')
+    
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'),
+                mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'),
+                output_prefix='bi_rnntanh_0_'))    
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'),
+                mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'),
+                output_prefix='bi_rnntanh_1_'))
+    
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_rnnrelu_sym():
+    T, N, I, H = 5, 32, 200, 200
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_'))
+    stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_'))
+    stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_rnnrelu_bidirectional():
+    T, N, I, H = 5, 20, 200, 200
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu',
+                                bidirectional=True, get_next_state=True, prefix='')
+    
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='relu', prefix='l0_'),
+                mx.rnn.RNNCell(H, activation='relu', prefix='r0_'),
+                output_prefix='bi_rnnrelu_0_'))    
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='relu', prefix='l1_'),
+                mx.rnn.RNNCell(H, activation='relu', prefix='r1_'),
+                output_prefix='bi_rnnrelu_1_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
 @with_seed()
 def test_lstm_dropout():
     X = mx.sym.Variable('x')
@@ -149,12 +217,44 @@ def test_lstm_dropout():
     rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX,
                      state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM')
     exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
-    try:
-        out = exe.forward(is_train=False)
-        out[0].wait_to_read()
-        assert False  # should not reach here
-    except mx.base.MXNetError as err:
-        assert str(err).find('Dropout is not supported at the moment') != -1
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
+
+@with_seed()
+def test_gru_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
+                     state_size=H, num_layers=5, mode='gru', p=0.5, state_outputs=True, name='GRU')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
+
+@with_seed()
+def test_rnntanh_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
+                     state_size=H, num_layers=5, mode='rnn_tanh', p=0.5, state_outputs=True, name='RNN_TANH')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
+
+@with_seed()
+def test_rnnrelu_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
+                     state_size=H, num_layers=5, mode='rnn_relu', p=0.5, state_outputs=True, name='RNN_RELU')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
 
 def np_softmax(x, axis=-1):
     # fix for old numpy on Travis not supporting keepdims