You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/26 17:43:59 UTC

[GitHub] piiswrong closed pull request #11399: [MXNET-107] Add Fused Vanilla RNN and dropout for CPU

piiswrong closed pull request #11399: [MXNET-107] Add Fused Vanilla RNN and dropout for CPU
URL: https://github.com/apache/incubator-mxnet/pull/11399
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/rnn/bucketing/cudnn_rnn_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py
index 29a66a8f484..5825290e73e 100644
--- a/example/rnn/bucketing/cudnn_rnn_bucketing.py
+++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py
@@ -66,7 +66,7 @@
 parser.add_argument('--dropout', type=float, default='0.0',
                     help='dropout probability (1.0 - keep probability)')
 parser.add_argument('--rnntype', type=str, default='lstm',
-                    help='rnn type: gru and lstm are supported')
+                    help='rnn type: gru, lstm, rnn_tanh and rnn_relu are supported')
 
 #buckets = [32]
 buckets = [10, 20, 30, 40, 50, 60]
@@ -188,6 +188,20 @@ def test(args):
                             cell,
                             mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
                             output_prefix='bi_%s_%d'%(args.rnntype,i))
+            elif args.rnntype == 'rnn_tanh':
+                cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='tanh', prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
+            elif args.rnntype == 'rnn_relu':
+                cell = mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.RNNCell(num_hidden=args.num_hidden, activation='relu', prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
 
             stack.add(cell)
 
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 99531739afa..1f905eda4a9 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -99,10 +99,6 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
                                   int mode) {
   size_t size = 0;
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
       size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
              + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
@@ -110,6 +106,10 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
     case rnn_enum::kGru:
       size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
       break;
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+      size = seq_length * batch_size * hidden_size * direction * 2 + batch_size * hidden_size * 4;
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -125,18 +125,20 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
                                      int mode) {
   size_t size = 0;
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
-      size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
+      size = direction * seq_length * batch_size * hidden_size * (num_layer * 7 - 1);
       break;
     case rnn_enum::kGru:
-      size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
+      size = seq_length * batch_size * hidden_size * direction * (num_layer * 9 - 1) +
           batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
           seq_length * batch_size * 7 * hidden_size * direction;
       break;
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+      size = seq_length * batch_size * hidden_size * direction * (num_layer * 6 - 1) +
+          batch_size * hidden_size * direction * 3 + hidden_size * seq_length * 2 +
+          seq_length * batch_size * 2 * hidden_size * direction;
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -223,21 +225,24 @@ void RNNForwardTraining(DType* ws,
                         DType* y_ptr,
                         DType* hy_ptr,
                         DType* cy_ptr,
+                        const float dropout,
                         int mode) {
   switch (mode) {
-    case rnn_enum::kRnnTanh:
-    case rnn_enum::kRnnRelu:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
       LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
                                  batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
-                                 w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
+                                 w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, dropout);
       break;
     case rnn_enum::kGru:
       GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
                                 batch_size, input_size, state_size, x_ptr, hx_ptr,
-                                w_ptr, y_ptr, hy_ptr);
+                                w_ptr, y_ptr, hy_ptr, dropout);
+      break;
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      VanillaRNNForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
+                                       batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                       w_ptr, y_ptr, hy_ptr, dropout, mode);
       break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
@@ -264,10 +269,6 @@ void RNNForwardInference(DType* ws,
                          DType* cy_ptr,
                          int mode) {
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
-      break;
     case rnn_enum::kLstm:
       LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
                                   batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
@@ -278,6 +279,12 @@ void RNNForwardInference(DType* ws,
                                  batch_size, input_size, state_size, x_ptr, hx_ptr,
                                  w_ptr, y_ptr, hy_ptr);
       break;
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      VanillaRNNForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
+                                        batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                        w_ptr, y_ptr, hy_ptr, mode);
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
       break;
@@ -310,22 +317,27 @@ void RNNBackward(DType* ws,
                  int req_params,
                  int req_state,
                  int req_statecell,
+                 const float dropout,
                  int mode) {
   switch (mode) {
-    case rnn_enum::kRnnRelu:
-    case rnn_enum::kRnnTanh:
-      break;
     case rnn_enum::kLstm:
       LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
                           input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
                           dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
-                          req_data, req_params, req_state, req_statecell);
+                          req_data, req_params, req_state, req_statecell, dropout);
       break;
     case rnn_enum::kGru:
       GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
                          input_size, state_size, x_ptr, hx_ptr, w_ptr,
                          dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
-                         req_data, req_params, req_state);
+                         req_data, req_params, req_state, dropout);
+      break;
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      VanillaRNNBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
+                                input_size, state_size, x_ptr, hx_ptr, w_ptr,
+                                dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
+                                req_data, req_params, req_state, dropout, mode);
       break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
@@ -354,9 +366,8 @@ class RNNOp : public Operator{
                        const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
-        << "Only lstm and gru mode are supported at the moment.";
-    CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
+    CHECK(param_.p >= 0.0f && param_.p < 1.0f)
+        << "unsupported dropout value, should be 0 <= dropout < 1";
 
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
     size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
@@ -436,6 +447,7 @@ class RNNOp : public Operator{
                                 y.dptr_,
                                 hy_ptr,
                                 cy_ptr,
+                                param_.p,
                                 param_.mode);
     } else {
       RNNForwardInference<DType>(workspace.dptr_,
@@ -467,9 +479,8 @@ class RNNOp : public Operator{
                         const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
-        << "Only lstm and gru mode are supported at the moment.";
-    CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
+    CHECK(param_.p >= 0.0f && param_.p < 1.0f)
+        << "unsupported dropout value, should be 0 <= dropout < 1";
 
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
     size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
@@ -566,6 +577,7 @@ class RNNOp : public Operator{
                        req[rnn_enum::kParams],
                        req[rnn_enum::kState],
                        req[rnn_enum::kStateCell],
+                       param_.p,
                        param_.mode);
   }
 
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index fa8d671a200..e1b4a2b79c0 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -49,6 +49,11 @@ inline DType sigmoid(DType x) {
   return 1.0f / (1.0f + exp(-x));
 }
 
+template<typename DType>
+inline DType relu(DType x) {
+  return x > 0.0f ? static_cast<float>(x) : 0.0f;
+}
+
 template<typename DType>
 void LstmForwardTrainingSingleLayer(DType* ws,
                                     DType* rs,
@@ -133,7 +138,10 @@ void LstmForwardTraining(DType* ws,
                          DType* b_ptr,
                          DType* y_ptr,
                          DType* hy_ptr,
-                         DType* cy_ptr) {
+                         DType* cy_ptr,
+                         const float dropout) {
+  DType* dropout_random = rs;
+  DType* rs2 = dropout_random + (L - 1) * D * T * N * H;
   const int total_layers = D * L;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
@@ -141,14 +149,15 @@ void LstmForwardTraining(DType* ws,
   const int r_size = D * T * N * H * 6;
   const int y_offset = T * N * H * 5;
   const int cell_size = N * H;
+  unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
   int idx = 0;  // state & cell state's idx;
   const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   for (int i = 0; i < L; ++i) {
     const int input_size = i ? H * D : I;
     const int w_size = (input_size + H) * H * 4;
     Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
-    Tensor<cpu, 3, DType> y(rs + y_offset, Shape3(T, N, H * D));
-    LstmForwardTrainingSingleLayer<DType>(ws, rs, state_outputs, false, T, N, input_size, H, x,
+    Tensor<cpu, 3, DType> y(rs2 + y_offset, Shape3(T, N, H * D));
+    LstmForwardTrainingSingleLayer<DType>(ws, rs2, state_outputs, false, T, N, input_size, H, x,
                                           hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
     if (D == 2) {
       w_ptr += w_size;
@@ -158,14 +167,27 @@ void LstmForwardTraining(DType* ws,
         hy_ptr += cell_size;
         cy_ptr += cell_size;
       }
-      LstmForwardTrainingSingleLayer<DType>(ws, rs, state_outputs, true, T, N, input_size, H, x,
+      LstmForwardTrainingSingleLayer<DType>(ws, rs2, state_outputs, true, T, N, input_size, H, x,
                                             hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
     }
     if (i != L - 1) {
       w_ptr += w_size;
       b_ptr += b_size;
+      if (dropout > 0.0f) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < T * N * H * D; j++) {
+          int rand_data = rand_r(&seed_);
+          if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * dropout)) {
+            dropout_random[i * T * N * H * D + j] = 0;
+            y.dptr_[j] = 0;
+          } else {
+            dropout_random[i * T * N * H * D + j] = 1.0f - dropout;
+            y.dptr_[j] =  y.dptr_[j] / (1.0f - dropout);
+          }
+        }
+      }
       x_ptr = y.dptr_;
-      rs += r_size;
+      rs2 += r_size;
       ++idx;
       if (state_outputs) {
         hy_ptr += cell_size;
@@ -175,7 +197,7 @@ void LstmForwardTraining(DType* ws,
   }
   #pragma omp parallel for num_threads(omp_threads)
   for (int i = 0; i < T * N * H * D; ++i) {
-    y_ptr[i] = (rs + y_offset)[i];
+    y_ptr[i] = (rs2 + y_offset)[i];
   }
 }
 
@@ -498,7 +520,10 @@ void LstmBackward(DType* ws,
                   int req_data,
                   int req_params,
                   int req_state,
-                  int req_statecell) {
+                  int req_statecell,
+                  const float dropout) {
+  DType* dropout_random = rs + (L - 1) * D * T * N * H;
+  DType* rs2 = rs + (L - 1) * D * T * N * H;
   DType* tmp_buf = ws;
   DType* ws2 = tmp_buf + 8 * T * H;
   const int total_layers = D * L;
@@ -520,7 +545,7 @@ void LstmBackward(DType* ws,
     DType* w_cur_ptr = i ? w_ptr + (w_size1 + (i - 1) * w_size2) * D : w_ptr;
     DType* dw_cur_ptr = i ? dw_ptr + (w_size1 + (i - 1) * w_size2) * D : dw_ptr;
     DType* db_cur_ptr = db_ptr + i * b_size * D;
-    DType* rs_cur_ptr = rs + i * r_size;
+    DType* rs_cur_ptr = rs2 + i * r_size;
     DType* dhy_cur_ptr = dhy_ptr ? dhy_ptr + i * cell_size * D : NULL;
     DType* dcy_cur_ptr = dcy_ptr ? dcy_ptr + i * cell_size * D : NULL;
     Tensor<cpu, 3, DType> y(rs_cur_ptr + y_offset, Shape3(T, N, H * D));
@@ -543,6 +568,18 @@ void LstmBackward(DType* ws,
                                      dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
                                      req_data, req_params, req_state, req_statecell);
     }
+    if (dropout > 0.0f && i > 0 && req_data != kNullOp) {
+      dropout_random = dropout_random - T * N * D * H;
+      const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int j = 0; j < T * N * D * H; j++) {
+        if (dropout_random[j] == 0) {
+          dx.dptr_[j] = 0;
+        } else {
+          dx.dptr_[j] = dx.dptr_[j] / (1.0f - dropout);
+        }
+      }
+    }
     dy_ptr = dx.dptr_;
   }
 }
@@ -935,7 +972,8 @@ void GruForwardTraining(DType* ws,
                         DType* hx_ptr,
                         DType* w_ptr,
                         DType* y_ptr,
-                        DType* hy_ptr) {
+                        DType* hy_ptr,
+                        const float dropout) {
   DType* wx = w_ptr;
   DType* wh = wx + I * H * 3;
   DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
@@ -948,19 +986,34 @@ void GruForwardTraining(DType* ws,
   DType* gateN_l = gateZ_l + L * T * D * N * H;
   DType* y_l = gateN_l + L * T * D * N * H;
   DType* Mnh_l = y_l + L * T * N * H * D;
-  DType* tmp_buf = Mnh_l + L * D * T * N * H;
-  DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N;
+  DType* dropout_random = Mnh_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
+  DType* ws2 = tmp_buf + D * N * H;
   DType* wx_l = wx;
   DType* wh_l = wh;
   DType* bx_l = bx;
   DType* bh_l = bh;
   DType* y_tmp = x_ptr;
-
+  unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
   for (int l = 0; l < L; l++) {
     if (l != 0) {
       y_tmp = y_l;
       y_l = y_l + T * N * H * D;
     }
+    if (dropout > 0.0f && l > 0) {
+      const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        int rand_data = rand_r(&seed_);
+        if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * dropout)) {
+          dropout_random[(l - 1) * T * N * I + i] = 0;
+          y_tmp[i] = 0;
+        } else {
+          dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout;
+          y_tmp[i] =  y_tmp[i] / (1.0f - dropout);
+        }
+      }
+    }
     Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
     Tensor<cpu, 2, DType> hx_l = hx[D * l];
     GruForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
@@ -1349,7 +1402,8 @@ void GruBackward(DType* ws,
                  DType* dw_ptr,
                  int req_data,
                  int req_params,
-                 int req_state) {
+                 int req_state,
+                 const float dropout) {
   DType* wx = w_ptr;
   DType* dwx = dw_ptr;
   DType* dwh = dwx + I * H * 3;
@@ -1360,7 +1414,8 @@ void GruBackward(DType* ws,
   DType* gateN_l = gateZ_l + L * T * D * N * H;
   DType* y_l = gateN_l + L * T * D * N * H;
   DType* Mnh_l = y_l + L * T * N * H * D;
-  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* dropout_random = Mnh_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
   DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2;
   DType* ws2 = dx_l + T * N * D * H;
   DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H
@@ -1403,6 +1458,17 @@ void GruBackward(DType* ws,
     GruBackwardSingleLayer<DType>(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l,
                                   dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l,
                                   dwx_l, dwh_l, dbx_l, dbh_l, req_data, req_params, req_state);
+    if (dropout > 0.0f && l > 0 && req_data != kNullOp) {
+      dropout_random = dropout_random - T * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        if (dropout_random[i] == 0) {
+          dx_l[i] = 0;
+        } else {
+          dx_l[i] = dx_l[i] / (1.0f - dropout);
+        }
+      }
+    }
     if (l > 0) {
       #pragma omp parallel for num_threads(omp_threads)
       for (int i = 0; i < T * N * H * D; ++i) {
@@ -1433,6 +1499,859 @@ void GruBackward(DType* ws,
     }
   }
 }
+
+template<typename DType>
+void VanillaRNNForwardInferenceSingleLayer(DType* ws,
+                                           DType* tmp_buf,
+                                           bool state_outputs,
+                                           const int D,
+                                           const int T,
+                                           const int N,
+                                           const int I,
+                                           const int H,
+                                           const Tensor<cpu, 2, DType> &x,
+                                           const Tensor<cpu, 2, DType> &hx,
+                                           DType* wx_ptr,
+                                           DType* wh_ptr,
+                                           DType* bx_ptr,
+                                           DType* bh_ptr,
+                                           DType* y_ptr,
+                                           DType* hy_ptr,
+                                           int mode) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H;
+  DType* back_ht = back_ht_1;
+  DType* gemmC1  = ws;              // [D, T, N, H]
+  DType* gemmC2  = gemmC1 + D * T * N * H;  // N * H
+  DType* back_wx_ptr = wx_ptr + I * H + H * H;
+  DType* back_wh_ptr = wh_ptr + I * H + H * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + H * 2: NULL;
+  DType* back_gemmC1 = gemmC1 + T * N * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(1, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, H));
+
+  // x * wx.T : [T * N, I] * [I, H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    gemmC1_t = gemmC1 + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int tb = i * H;
+        if (mode == 1) {
+          ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] +
+              gemmC2[tb + j] + bh[0][j]);
+        } else {
+          ht[i * D * H + j] = relu(gemmC1_t[tb + j] + bx[0][j] +
+              gemmC2[tb + j] + bh[0][j]);
+        }
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true);
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int tb = i * H;
+          if (mode == 1) {
+            back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j]
+                + gemmC2[tb + j] + back_bh[0][j]);
+          } else {
+            back_ht[i * D * H + j] = relu(gemmC1_t[tb + j] + back_bx[0][j]
+              + gemmC2[tb + j] + back_bh[0][j]);
+          }
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void VanillaRNNForwardInference(DType* ws,
+                                bool state_outputs,
+                                const int L,
+                                const int D,
+                                const int T,
+                                const int N,
+                                int I,
+                                const int H,
+                                DType* x_ptr,
+                                DType* hx_ptr,
+                                DType* w_ptr,
+                                DType* y_ptr,
+                                DType* hy_ptr,
+                                int mode) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H;
+  DType* bx = wh + H * H + (D - 1) * (H * H + I * H)
+      + (L - 1) * ((D + 1) * H) * H * D;
+  DType* bh = bx + H;
+
+  DType* y_tmp = ws;
+  DType* y_l = x_ptr;
+  DType* tmp_buf = y_tmp + D * T * N * H;
+  DType* ws2 = y_tmp + D * T * N * H + D * H * N;
+
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  for (int l = 0; l < L; l++) {
+    Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I));
+    if ((L + l) % 2) {
+      y_l = y_ptr;
+    } else {
+      y_l = y_tmp;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    VanillaRNNForwardInferenceSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                                 x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l,
+                                                 hy_l, mode);
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + H * D * 2;
+    bh_l = bh_l + H * D * 2;
+    wx_l = wx_l + I * H * D + H * H * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * H;
+  }
+}
+
+
+template<typename DType>
+void VanillaRNNForwardTrainingSingleLayer(DType* ws,
+                                       DType* tmp_buf,
+                                       bool state_outputs,
+                                       const int D,
+                                       const int T,
+                                       const int N,
+                                       const int I,
+                                       const int H,
+                                       const Tensor<cpu, 2, DType> &x,
+                                       const Tensor<cpu, 2, DType> &hx,
+                                       DType* wx_ptr,
+                                       DType* wh_ptr,
+                                       DType* bx_ptr,
+                                       DType* bh_ptr,
+                                       DType* gateN,
+                                       DType* y_ptr,
+                                       DType* hy_ptr,
+                                       int mode) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H;
+  DType* back_ht = back_ht_1;
+
+  DType* gemmC1  = ws;              // [D, T, N, H]
+  DType* gemmC2  = gemmC1 + D * T * N * H;  // N * H
+  DType* nt = gateN;
+  DType* back_wx_ptr = wx_ptr + I * H + H * H;
+  DType* back_wh_ptr = wh_ptr + I * H + H * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + H * 2 : NULL;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_gemmC1 = gemmC1 + T * N * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 1, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 1, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(1, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(1, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, H));
+
+  // x * wx.T : [T * N, I] * [I, H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    nt = gateN + t * N * H;
+    gemmC1_t = gemmC1 + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int tb = i * H;
+        if (mode == 1) {
+          nt[tb + j] = ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + bx[0][j] +
+              gemmC2[tb + j] + bh[0][j]);
+        } else {
+          nt[tb + j] = gemmC1_t[tb + j] + bx[0][j] + gemmC2[tb + j] + bh[0][j];
+          ht[i * D * H + j] = relu(nt[tb + j]);
+        }
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      nt = back_gateN + (T - 1 - t) * N * H;
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1 - H, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[1], back_wh, dgemmC2, alpha, beta, true, true);
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int tb = i * H;
+          if (mode == 1) {
+            nt[tb + j] = back_ht[i * D * H + j] = tanh(gemmC1_t[tb + j] + back_bx[0][j]
+                + gemmC2[tb + j] + back_bh[0][j]);
+          } else {
+            nt[tb + j] = gemmC1_t[tb + j] + back_bx[0][j] + gemmC2[tb + j] + back_bh[0][j];
+            back_ht[i * D * H + j] = relu(nt[tb + j]);
+          }
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void VanillaRNNForwardTraining(DType* ws,
+                               DType* rs,
+                               bool state_outputs,
+                               const int L,
+                               const int D,
+                               const int T,
+                               const int N,
+                               int I,
+                               const int H,
+                               DType* x_ptr,
+                               DType* hx_ptr,
+                               DType* w_ptr,
+                               DType* y_ptr,
+                               DType* hy_ptr,
+                               const float dropout,
+                               int mode) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H;
+  DType* bx = wh + H * H + (D - 1) * (H * H + I * H)
+      + (L - 1) * ((D + 1) * H) * H * D;
+  DType* bh = bx + H;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  DType* gateN_l = rs;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* dropout_random = y_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
+  DType* ws2 = tmp_buf + D * N * H;
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  DType* y_tmp = x_ptr;
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  unsigned int seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
+  for (int l = 0; l < L; l++) {
+    if (l != 0) {
+      y_tmp = y_l;
+      y_l = y_l + T * N * H * D;
+    }
+    if (dropout > 0.0f && l > 0) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        int rand_data = rand_r(&seed_);
+        if (static_cast<float>(rand_data % 1000) < static_cast<float>(1000 * dropout)) {
+          dropout_random[(l - 1) * T * N * I + i] = 0;
+          y_tmp[i] = 0;
+        } else {
+          dropout_random[(l - 1) * T * N * I + i] = 1.0f - dropout;
+          y_tmp[i] =  y_tmp[i] / (1.0f - dropout);
+        }
+      }
+    }
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    VanillaRNNForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                             x_l, hx_l, wx_l, wh_l, bx_l, bh_l,
+                                             gateN_l, y_l, hy_l, mode);
+    gateN_l = gateN_l +  T * D * N * H;
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + H * D * 2;
+    bh_l = bh_l + H * D * 2;
+
+    wx_l = wx_l + I * H * D + H * H * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * H;
+  }
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < T * N * H * D; ++i) {
+    y_ptr[i] = y_l[i];
+  }
+}
+
+template <typename DType>
+void VanillaRNNBackwardSingleLayer(DType* ws,
+                                   DType* tmp_buf,
+                                   const int D,
+                                   const int T,
+                                   const int N,
+                                   const int I,
+                                   const int H,
+                                   const Tensor<cpu, 2, DType> &x,
+                                   const Tensor<cpu, 2, DType> &hx,
+                                   DType* wx_ptr,
+                                   DType* wh_ptr,
+                                   DType* y_ptr,
+                                   DType* dy_ptr,
+                                   DType* dhy_ptr,
+                                   DType* gateN,
+                                   DType* dx,
+                                   DType* dhx,
+                                   DType* dwx,
+                                   DType* dwh,
+                                   DType* dbx,
+                                   DType* dbh,
+                                   int req_data,
+                                   int req_params,
+                                   int req_state,
+                                   int mode) {
+  DType* dyt;
+  DType* ht1;  // [N, D, H]
+  DType* dart;
+  DType* nt;
+  DType* dar = ws;  // [T, N, H]
+  DType* dht1 = dar + T * N * H;  // [D, N, H]
+  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
+
+  DType* back_ht1;
+  DType* back_dht1 = dht1 + N * H;  // [N, H]
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_wx_ptr = wx_ptr + I * H + H * H;
+  DType* back_wh_ptr = wh_ptr + I * H + H * H;
+  DType* back_dwx = dwx + I * H + H * H;
+  DType* back_dwh = dwh + I * H + H * H;
+  DType* back_dbx = dbx + H * 2;
+  DType* back_dbh = dbh + H * 2;
+
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (req_params != kNullOp && req_params != kAddTo) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < D * H * H; ++i) {
+      dwh[i] = 0;
+    }
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < D * H; ++i) {
+      dbx[i] = 0;
+      dbh[i] = 0;
+    }
+  }
+
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N * H; ++i) {
+    if (dhy_ptr) {
+      dht1[i] = dhy_ptr[i];
+    } else {
+      dht1[i] = 0;
+    }
+  }
+
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N; ++i) {
+    for (int j = 0; j < H; ++j) {
+      hx_[i * D * H + j] = hx[i][j];
+    }
+  }
+
+  if (D == 2) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N * H; ++i) {
+      if (dhy_ptr) {
+        back_dht1[i] = dhy_ptr[N * H + i];
+      } else {
+        back_dht1[i] = 0;
+      }
+    }
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        hx_[i * D * H + H + j] = hx[N + i][j];
+      }
+    }
+  }
+  for (int t = T - 1; t >= 0; --t) {
+    if (t) {
+      ht1 = y_ptr + (t - 1) * N * D * H;
+    } else {
+      ht1 = hx_;
+    }
+    // add dy[T, N, D, H] to dhy[D, N, H]
+    dyt = dy_ptr + t * N * D * H;
+
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        dht1[i * H + j] += dyt[i * D * H + j];
+      }
+    }
+
+    nt = gateN + t * N * H;
+    dart = dar + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int id = i * H + j;
+        if (mode == 1) {
+          dart[id] = dht1[id] * (1 - nt[id] * nt[id]);
+        } else {
+          dart[id] = nt[id] > 0.0f ? static_cast<float>(dht1[id]) : 0.0f;
+        }
+        dht1[id] = 0;
+      }
+    }
+    if (req_params != kNullOp) {
+      alpha = 1.0;
+      beta = 1.0;
+      // dht1 = dart * wh    [N, H] = [N, H] * [H, H]
+      Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
+      Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H));
+      linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
+
+      if (req_params == kAddTo) {
+        beta = 2.0;
+        // dwx = da.T * x    [H, I] = [H, N] * [N, I] for AddTo
+        Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+        Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I));
+        linalg_gemm(d_dart, d_xt, d_dwx, alpha, beta, true, false);
+      }
+      // dwh = dart.T * ht1    [H, H] = [H, N] * [N, H]
+      Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
+      Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(H, H));
+      Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
+      linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
+    }
+  }
+
+  if (req_params != kNullOp) {
+    // dbx = e * da       [1, H] = [1, N] * [N, H]
+    if (req_params != kAddTo) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < H; ++i) {
+        for (int j = 0; j < N * T; ++j) {
+          dbx[i] += dar[j * H + i];
+          dbh[i] = dbx[i];
+        }
+      }
+    } else {
+      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T));
+      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T));
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < H * T; ++i) {
+        tmp_dbx.dptr_[i] = 0;
+        tmp_dbh.dptr_[i] = 0;
+      }
+
+      for (int t = T - 1; t >= 0; --t) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H; ++i) {
+          for (int j = 0; j < N; ++j) {
+            tmp_dbx[i][t] += dar[t * N * H + j * H + i];
+            tmp_dbh[i][t] = tmp_dbx[i][t];
+          }
+        }
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H; ++i) {
+          dbx[i] += tmp_dbx[i][t] + dbx[i];
+          dbh[i] = dbx[i];
+        }
+      }
+    }
+  }
+  alpha = 1.0;
+  beta = 0.0;
+
+  // dx = da * wx    [T * N, I] = [T * N, H] * [H, I]
+  Tensor<cpu, 2, DType> d_dar(dar, Shape2(T * N, H));
+  if (req_data != kNullOp) {
+    Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+    linalg_gemm(d_dar, wx, d_dx, alpha, beta, false, false);
+  }
+
+  // dwx = da.T * x    [H, I] = [H, T * N] * [T * N, I]
+  if (req_params != kNullOp && req_params != kAddTo) {
+    Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(H, I));
+    linalg_gemm(d_dar, x, d_dwx, alpha, beta, true, false);
+  }
+
+  if (D == 2) {
+    for (int t = 0; t < T; ++t) {
+      if (t == T-1) {
+        back_ht1 = hx_;
+      } else {
+        back_ht1 = y_ptr + (t + 1) * N * D * H;
+      }
+
+      //  add dy[T, N, D, H] to dhy[D, N, H]
+      dyt = dy_ptr + t * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          back_dht1[i * H + j] += dyt[i * D * H + H + j];
+        }
+      }
+
+      nt = back_gateN + t * N * H;
+      dart = dar + t * N * H;
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int id = i * H + j;
+          if (mode == 1) {
+            dart[id] = back_dht1[id] * (1 - nt[id] * nt[id]);
+          } else {
+            dart[id] = nt[id] > 0.0f ? static_cast<float>(back_dht1[id]) : 0.0f;
+          }
+          back_dht1[id] = 0;
+        }
+      }
+
+      if (req_params != kNullOp) {
+        alpha = 1.0;
+        beta = 1.0;
+        // dht1 = da * wh    [N, H] = [N, H] * [H, H]
+        Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, H));
+        Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
+        linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);
+
+        // dwh = da.T * ht1     [H, H] = [H, N] * [N, H]
+        Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(H, H));
+        Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
+        Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
+            (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+        d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
+        if (req_params == kAddTo) {
+          beta = 2.0;
+          // dwx = da.T * x    [ H, I] = [H, N] * [N, I] for AddTo
+          Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+          Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I));
+          linalg_gemm(d_dart, d_xt, d_back_dwx, alpha, beta, true, false);
+        }
+        linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
+      }
+    }
+
+    if (req_params != kNullOp) {
+    // dbx = e * da       [1, H] = [1, N] * [N, H]
+      if (req_params != kAddTo) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H; ++i) {
+          for (int j = 0; j < N * T; ++j) {
+            back_dbx[i] += dar[j * H + i];
+            back_dbh[i] = back_dbx[i];
+          }
+        }
+      } else {
+        const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H, T));
+        const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + H * T, Shape2(H, T));
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < H * T; ++i) {
+          tmp_dbx.dptr_[i] = 0;
+          tmp_dbh.dptr_[i] = 0;
+        }
+
+        for (int t = T - 1; t >= 0; --t) {
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < H; ++i) {
+            for (int j = 0; j < N; ++j) {
+              tmp_dbx[i][t] += dar[t * N * H + j * H + i];
+              tmp_dbh[i][t] = tmp_dbx[i][t];
+            }
+          }
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < H; ++i) {
+            back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
+            back_dbh[i] = back_dbx[i];
+          }
+        }
+      }
+    }
+    alpha = 1.0;
+    beta = 1.0;
+    // dxt = da * wx    [T * N, I] = [T * N, H] * [H, I]
+     Tensor<cpu, 2, DType> d_dar2(dar, Shape2(T * N, H));
+    if (req_data != kNullOp) {
+      Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+      linalg_gemm(d_dar2, back_wx, d_dx, alpha, beta, false, false);
+    }
+    alpha = 1.0;
+    beta = 0.0;
+    // dwx = da.T * x    [H, I] = [H, T * N] * [T * N, I]
+    if (req_params != kNullOp && req_params != kAddTo) {
+      Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(H, I));
+      linalg_gemm(d_dar2, x, d_back_dwx, alpha, beta, true, false);
+    }
+  }
+  if (req_state != kNullOp) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N * H * D; ++i) {
+      dhx[i] = dht1[i];
+    }
+  }
+}
+
+template <typename DType>
+void VanillaRNNBackward(DType* ws,
+                        DType* rs,
+                        const int L,
+                        const int D,
+                        const int T,
+                        const int N,
+                        int I,
+                        const int H,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* w_ptr,
+                        DType* dy_ptr,
+                        DType* dhy_ptr,
+                        DType* dx_ptr,
+                        DType* dhx_ptr,
+                        DType* dw_ptr,
+                        int req_data,
+                        int req_params,
+                        int req_state,
+                        const float dropout,
+                        int mode) {
+  DType* wx = w_ptr;
+  DType* dwx = dw_ptr;
+  DType* dwh = dwx + I * H;
+  DType* dbx = dwh + H * H + (D - 1) * (H * H + I * H)
+      + (L - 1) * ((D + 1) * H) * H * D;
+  DType* gateN_l = rs + (L - 1) * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* dropout_random = y_l + L * D * T * N * H;
+  DType* tmp_buf = dropout_random + (L - 1) * D * T * N * H;
+  DType* dx_l = tmp_buf + T * N * D * H + H * T * 2;
+  DType* ws2 = dx_l + T * N * D * H;
+  DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * H
+      + D * I * H + D * H * H;
+  DType* wh_l = wx_l;
+  if (L == 1) {
+    wh_l = wh_l + I * H;
+  } else {
+    wh_l = wh_l + (D * H) * H;
+  }
+  DType* dhy_l = NULL;
+  if (dhy_ptr)
+    dhy_l = dhy_ptr + (L - 1) * D * N * H;
+  DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D + 1) * H * H
+      + D * I * H + D * H * H;
+  DType* dwh_l = NULL;
+  if (L == 1) {
+    dwh_l = dwx_l + I * H;
+  } else {
+    dwh_l = dwx_l + (D * H) * H;
+  }
+  DType* dbx_l = dbx + (L - 1) * D * H * 2;
+  DType* dbh_l = dbx_l + H;
+  DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
+  DType* dy_l = dy_ptr;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H));
+  int inputsize = I;
+  DType* y_tmp = y_l - T * N * H * D;
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  for (int l = L - 1; l >= 0; --l) {
+    if (l == 0) {
+      I = inputsize;
+      y_tmp = x_ptr;
+      dx_l = dx_ptr;
+    } else {
+      I = D * H;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[l];
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    VanillaRNNBackwardSingleLayer<DType>(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l,
+                                         y_l, dy_l, dhy_l, gateN_l, dx_l, dhx_l, dwx_l, dwh_l,
+                                         dbx_l, dbh_l, req_data, req_params, req_state, mode);
+    if (dropout > 0.0f && l > 0 && req_data != kNullOp) {
+      dropout_random = dropout_random - T * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * I; i++) {
+        if (dropout_random[i] == 0) {
+          dx_l[i] = 0;
+        } else {
+          dx_l[i] = dx_l[i] / (1.0f - dropout);
+        }
+      }
+    }
+    if (l > 0) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < T * N * H * D; ++i) {
+        dy_l[i] = dx_l[i];
+      }
+      gateN_l = gateN_l -  T * D * N * H;
+      dhx_l = dhx_l - D * N * H;
+      if (dhy_l)
+        dhy_l = dhy_l - D * N * H;
+      y_l = y_l - T * N * H * D;
+      y_tmp = y_l;
+      if (l == 1) {
+        wx_l = wx_l - (inputsize + H) * H * D;
+        wh_l = wx_l + inputsize * H;
+        dwx_l = dwx_l - (inputsize + H) * H * D;
+        dwh_l = dwx_l + inputsize * H;
+      } else {
+        wx_l = wx_l - (I + H) * H * D;
+        wh_l = wx_l + I * H;
+        dwx_l = dwx_l - (I + H) * H * D;
+        dwh_l = dwx_l + I * H;
+      }
+      dbx_l = dbx_l - D * H * 2;
+      dbh_l = dbx_l + H;
+    }
+  }
+}
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_RNN_IMPL_H_
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 82d97871559..fe9ece73389 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -137,8 +137,76 @@ def test_gru_bidirectional():
     check_rnn_consistency(fused, stack, T, N, I, H, 'add')
     check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
-# Currently, fused LSTM operator doesn't support dropout.
-# Will change this test after dropout is supported
+@with_seed()
+def test_rnntanh_sym():
+    T, N, I, H = 5, 32, 800, 800
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'))
+    stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'))
+    stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_rnntanh_bidirectional():
+    T, N, I, H = 5, 20, 800, 800
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh',
+                                bidirectional=True, get_next_state=True, prefix='')
+    
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'),
+                mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'),
+                output_prefix='bi_rnntanh_0_'))    
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'),
+                mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'),
+                output_prefix='bi_rnntanh_1_'))
+    
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_rnnrelu_sym():
+    T, N, I, H = 5, 32, 200, 200
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_'))
+    stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_'))
+    stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_rnnrelu_bidirectional():
+    T, N, I, H = 5, 20, 200, 200
+
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu',
+                                bidirectional=True, get_next_state=True, prefix='')
+    
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='relu', prefix='l0_'),
+                mx.rnn.RNNCell(H, activation='relu', prefix='r0_'),
+                output_prefix='bi_rnnrelu_0_'))    
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.RNNCell(H, activation='relu', prefix='l1_'),
+                mx.rnn.RNNCell(H, activation='relu', prefix='r1_'),
+                output_prefix='bi_rnnrelu_1_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
 @with_seed()
 def test_lstm_dropout():
     X = mx.sym.Variable('x')
@@ -149,12 +217,44 @@ def test_lstm_dropout():
     rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX,
                      state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM')
     exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
-    try:
-        out = exe.forward(is_train=False)
-        out[0].wait_to_read()
-        assert False  # should not reach here
-    except mx.base.MXNetError as err:
-        assert str(err).find('Dropout is not supported at the moment') != -1
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
+
+@with_seed()
+def test_gru_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
+                     state_size=H, num_layers=5, mode='gru', p=0.5, state_outputs=True, name='GRU')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
+
+@with_seed()
+def test_rnntanh_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
+                     state_size=H, num_layers=5, mode='rnn_tanh', p=0.5, state_outputs=True, name='RNN_TANH')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
+
+@with_seed()
+def test_rnnrelu_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX,
+                     state_size=H, num_layers=5, mode='rnn_relu', p=0.5, state_outputs=True, name='RNN_RELU')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    out = exe.forward(is_train=True)
+    out[0].wait_to_read()
 
 def np_softmax(x, axis=-1):
     # fix for old numpy on Travis not supporting keepdims


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services