You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/06 18:38:05 UTC

[GitHub] piiswrong closed pull request #10311: [MXNET-107]Fused GRU implementation for CPU

piiswrong closed pull request #10311: [MXNET-107]Fused GRU implementation for CPU
URL: https://github.com/apache/incubator-mxnet/pull/10311
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/rnn/bucketing/cudnn_lstm_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py
similarity index 87%
rename from example/rnn/bucketing/cudnn_lstm_bucketing.py
rename to example/rnn/bucketing/cudnn_rnn_bucketing.py
index 84cfc9d4380..29a66a8f484 100644
--- a/example/rnn/bucketing/cudnn_lstm_bucketing.py
+++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py
@@ -65,6 +65,8 @@
                     help='stack fused RNN cells to reduce communication overhead')
 parser.add_argument('--dropout', type=float, default='0.0',
                     help='dropout probability (1.0 - keep probability)')
+parser.add_argument('--rnntype', type=str, default='lstm',
+                    help='rnn type: gru and lstm are supported')
 
 #buckets = [32]
 buckets = [10, 20, 30, 40, 50, 60]
@@ -97,13 +99,13 @@ def train(args):
         cell = mx.rnn.SequentialRNNCell()
         for i in range(args.num_layers):
             cell.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
-                                         mode='lstm', prefix='lstm_l%d'%i,
+                                         mode=args.rnntype, prefix='%s_l%d'%(args.rnntype,i),
                                          bidirectional=args.bidirectional))
-            if args.dropout > 0 and i < args.num_layers - 1:
-                cell.add(mx.rnn.DropoutCell(args.dropout, prefix='lstm_d%d'%i))
+            if args.dropout > 0 and i < args.num_layers - 1 and args.rnntype == 'lstm':
+                cell.add(mx.rnn.DropoutCell(args.dropout, prefix='%s_d%d'%(args.rnntype,i)))
     else:
         cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
-                                   mode='lstm', bidirectional=args.bidirectional)
+                                   mode=args.rnntype, bidirectional=args.bidirectional)
 
     def sym_gen(seq_len):
         data = mx.sym.Variable('data')
@@ -168,16 +170,25 @@ def test(args):
 
     if not args.stack_rnn:
         stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
-                mode='lstm', bidirectional=args.bidirectional).unfuse()
+                mode=args.rnntype, bidirectional=args.bidirectional).unfuse()
     else:
         stack = mx.rnn.SequentialRNNCell()
         for i in range(args.num_layers):
-            cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
-            if args.bidirectional:
-                cell = mx.rnn.BidirectionalCell(
-                        cell,
-                        mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
-                        output_prefix='bi_lstm_%d'%i)
+            if args.rnntype == 'lstm':
+                cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
+            elif args.rnntype == 'gru':
+                cell = mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
+
             stack.add(cell)
 
     def sym_gen(seq_len):
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 056c1d517c0..d9dc98ece48 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -190,7 +190,7 @@ def forward(self, inputs, states=None):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
                 self.i2h_weight[i]._finish_deferred_init()
         if inputs.context.device_type == 'gpu' or \
-           self._mode == 'lstm' and not self._dropout:
+           self._mode in ['lstm', 'gru'] and not self._dropout:
             out = self._forward_kernel(inputs, states)
         else:
             out = self._forward(inputs, states)
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index eded6aeed8a..99531739afa 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -101,12 +101,14 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
-             + seq_length * batch_size * hidden_size * direction;
+             + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
+      break;
+    case rnn_enum::kGru:
+      size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
       break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
@@ -125,12 +127,16 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
       break;
+    case rnn_enum::kGru:
+      size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
+          batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
+          seq_length * batch_size * 7 * hidden_size * direction;
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -221,14 +227,18 @@ void RNNForwardTraining(DType* ws,
   switch (mode) {
     case rnn_enum::kRnnTanh:
     case rnn_enum::kRnnRelu:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
                                  batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
                                  w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
       break;
+    case rnn_enum::kGru:
+      GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
+                                batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                w_ptr, y_ptr, hy_ptr);
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -256,14 +266,18 @@ void RNNForwardInference(DType* ws,
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
                                   batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
                                   w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
       break;
+    case rnn_enum::kGru:
+      GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
+                                 batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                 w_ptr, y_ptr, hy_ptr);
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
       break;
@@ -292,16 +306,26 @@ void RNNBackward(DType* ws,
                  DType* dcx_ptr,
                  DType* dw_ptr,
                  DType* db_ptr,
+                 int req_data,
+                 int req_params,
+                 int req_state,
+                 int req_statecell,
                  int mode) {
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
       break;
     case rnn_enum::kLstm:
       LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
                           input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
-                          dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr);
+                          dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
+                          req_data, req_params, req_state, req_statecell);
+      break;
+    case rnn_enum::kGru:
+      GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
+                         input_size, state_size, x_ptr, hx_ptr, w_ptr,
+                         dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
+                         req_data, req_params, req_state);
       break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
@@ -330,7 +354,8 @@ class RNNOp : public Operator{
                        const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
+    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
+        << "Only lstm and gru mode are supported at the moment.";
     CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
 
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
@@ -442,8 +467,10 @@ class RNNOp : public Operator{
                         const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
+    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
+        << "Only lstm and gru mode are supported at the moment.";
     CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
+
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
     size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     if (!param_.state_outputs) {
@@ -535,6 +562,10 @@ class RNNOp : public Operator{
                        dcx_ptr,
                        dw.dptr_,
                        db_ptr,
+                       req[rnn_enum::kData],
+                       req[rnn_enum::kParams],
+                       req[rnn_enum::kState],
+                       req[rnn_enum::kStateCell],
                        param_.mode);
   }
 
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index 2ee374bbf56..e92a18218f9 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -40,6 +40,10 @@
 #include "./mshadow_op.h"
 #include "./linalg.h"
 
+
+namespace mxnet {
+namespace op {
+
 template<typename DType>
 inline DType sigmoid(DType x) {
   return 1.0f / (1.0f + exp(-x));
@@ -297,6 +301,7 @@ void LstmForwardInference(DType* ws,
 template <typename DType>
 void LstmBackwardSingleLayer(DType* ws,
                              DType* rs,
+                             DType* tmp_buf,
                              bool bid,
                              const int T,
                              const int N,
@@ -314,7 +319,11 @@ void LstmBackwardSingleLayer(DType* ws,
                              DType* dcy_ptr,
                              DType* w_ptr,
                              DType* dw_ptr,
-                             DType* db_ptr) {
+                             DType* db_ptr,
+                             int req_data,
+                             int req_params,
+                             int req_state,
+                             int req_statecell) {
   using namespace mshadow;
   const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
   const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
@@ -336,6 +345,7 @@ void LstmBackwardSingleLayer(DType* ws,
   const DType alpha = 1.0;
   const DType beta0 = 0.0;
   const DType beta1 = 1.0;
+  const DType beta2 = 2.0;
   const int cell_size = N * H;
   if (dhy_ptr != NULL) {
     memcpy(dh.dptr_, dhy_ptr, cell_size * sizeof(DType));
@@ -367,24 +377,67 @@ void LstmBackwardSingleLayer(DType* ws,
       difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft);
       difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt);
       difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot);
-      dcnext[j][k] = dc[j][k] * ft;
+      if (req_statecell != kNullOp || i > 0) {
+        dcnext[j][k] = dc[j][k] * ft;
+      }
       if (i) {
         htmp[j][k] = y[tnext][j][k + offset];
       }
     }
     Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4));
-    linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
-    linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
+    if (req_state != kNullOp || i > 0) {
+      linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
+    }
+    if (req_params != kNullOp) {
+      if (req_params != kAddTo) {
+        linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
+      } else {
+        linalg_gemm(dyh, hnext, dwh, alpha, beta2, true, false);
+
+        //  generate dwx every time step for AddTo
+        Tensor<cpu, 2, DType> x_t(x.dptr_ + i * N * I, Shape2(N, I));
+        Tensor<cpu, 2, DType> dyx_t(difgo.dptr_ + i * N * H * 4, Shape2(N, H * 4));
+        linalg_gemm(dyx_t, x_t, dwx, alpha, beta2, true, false);
+      }
+    }
   }
   Tensor<cpu, 2, DType> dyx(difgo.dptr_, Shape2(T * N, H * 4));
-  linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
-  linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
+  if (req_data != kNullOp) {
+    linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
+  }
+  if (req_params != kNullOp && req_params != kAddTo) {
+    linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
+  }
   const int row = T * N;
   const int col = H * 4;
-  for (int i = 0; i < row; ++i) {
-    for (int j = 0; j < col; ++j) {
-      dbx[j] += dyx[i][j];
-      dbh[j] = dbx[j];
+  if (req_params != kNullOp) {
+    if (req_params != kAddTo) {
+      for (int i = 0; i < row; ++i) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < col; ++j) {
+          dbx[j] += dyx[i][j];
+          dbh[j] = dbx[j];
+        }
+      }
+    } else {
+      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf, Shape2(col, T));
+      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + col * T, Shape2(col, T));
+      memset(tmp_dbx.dptr_, 0, col * T * sizeof(DType));
+      memset(tmp_dbh.dptr_, 0, col * T * sizeof(DType));
+      for (int t = T - 1; t >= 0; --t) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < col; ++j) {
+          for (int i = 0; i < N; ++i) {
+            tmp_dbx[j][t] += dyx[t * N + i][j];
+            tmp_dbh[j][t] = tmp_dbx[j][t];
+          }
+        }
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < col; ++j) {
+          dbx[j] += tmp_dbx[j][t] + dbx[j];
+          dbh[j] += tmp_dbh[j][t] + dbh[j];
+        }
+      }
     }
   }
 }
@@ -410,7 +463,13 @@ void LstmBackward(DType* ws,
                   DType* dhx_ptr,
                   DType* dcx_ptr,
                   DType* dw_ptr,
-                  DType* db_ptr) {
+                  DType* db_ptr,
+                  int req_data,
+                  int req_params,
+                  int req_state,
+                  int req_statecell) {
+  DType* tmp_buf = ws;
+  DType* ws2 = tmp_buf + 8 * T * H;
   const int total_layers = D * L;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
@@ -422,7 +481,7 @@ void LstmBackward(DType* ws,
   const int w_size1 = (I + H) * H * 4;      // first layer
   const int w_size2 = (D * H + H) * H * 4;  // other layers
   const int cell_size = N * H;
-  DType* dy_tmp_ptr = ws + T * cell_size * 4 + cell_size * 3;
+  DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3;
   for (int i = L - 1; i >= 0; --i) {
     const int input_size = i ? H * D : I;
     const int w_size = i ? w_size2 : w_size1;
@@ -437,9 +496,10 @@ void LstmBackward(DType* ws,
     Tensor<cpu, 3, DType> dy(dy_ptr, Shape3(T, N, H * D));
     Tensor<cpu, 2, DType> x(i ? y.dptr_ - r_size : x_ptr, Shape2(T * N, input_size));
     Tensor<cpu, 2, DType> dx(i ? dy_tmp_ptr : dx_ptr, Shape2(T * N, input_size));
-    LstmBackwardSingleLayer<DType>(ws, rs_cur_ptr, false, T, N, input_size, H,
+    LstmBackwardSingleLayer<DType>(ws2, rs_cur_ptr, tmp_buf, false, T, N, input_size, H,
                                    x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx],
-                                   dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr);
+                                   dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
+                                   req_data, req_params, req_state, req_statecell);
     if (D == 2) {
       w_cur_ptr += w_size;
       dw_cur_ptr += w_size;
@@ -447,11 +507,874 @@ void LstmBackward(DType* ws,
       ++idx;
       dhy_cur_ptr = dhy_ptr ? dhy_cur_ptr + cell_size : NULL;
       dcy_cur_ptr = dcy_ptr ? dcy_cur_ptr + cell_size : NULL;
-      LstmBackwardSingleLayer<DType>(ws, rs_cur_ptr, true, T, N, input_size, H,
+      LstmBackwardSingleLayer<DType>(ws2, rs_cur_ptr, tmp_buf, true, T, N, input_size, H,
                                      x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx],
-                                     dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr);
+                                     dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
+                                     req_data, req_params, req_state, req_statecell);
     }
     dy_ptr = dx.dptr_;
   }
 }
+
+template<typename DType>
+void GruForwardInferenceSingleLayer(DType* ws,
+                                    DType* tmp_buf,
+                                    bool state_outputs,
+                                    const int D,
+                                    const int T,
+                                    const int N,
+                                    const int I,
+                                    const int H,
+                                    const Tensor<cpu, 2, DType> &x,
+                                    const Tensor<cpu, 2, DType> &hx,
+                                    DType* wx_ptr,
+                                    DType* wh_ptr,
+                                    DType* bx_ptr,
+                                    DType* bh_ptr,
+                                    DType* y_ptr,
+                                    DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H;
+  DType* back_ht = back_ht_1;
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gemmC2 + N * 3 * H;
+  DType* zt = rt + N * H;
+  DType* nt = zt + N * H;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2: NULL;
+  DType* back_gemmC1 = gemmC1 + T * N * 3 * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, true);
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int rtb = i * 3 * H;
+          int ztb = i * 3 * H + H;
+          int ntb = i * 3 * H + 2 * H;
+          rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
+              gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
+          zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
+              gemmC2[ztb + j] + back_bx[1][j]+ back_bh[1][j]);
+          nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j]
+              + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j]));
+          back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j]
+              + zt[i * H + j] * back_ht_1[i * D * H + j];
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void GruForwardInference(DType* ws,
+                         bool state_outputs,
+                         const int L,
+                         const int D,
+                         const int T,
+                         const int N,
+                         int I,
+                         const int H,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* w_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3;
+  DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* bh = bx + H * 3;
+
+  DType* y_tmp = ws;
+  DType* y_l = x_ptr;
+  DType* tmp_buf = y_tmp + D * T * N * H;
+  DType* ws2 = y_tmp + D * T * N * H + D * H * N;
+
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  for (int l = 0; l < L; l++) {
+    Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I));
+    if ((L + l) % 2) {
+      y_l = y_ptr;
+    } else {
+      y_l = y_tmp;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    GruForwardInferenceSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                        x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l, hy_l);
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + 3 * H * D * 2;
+    bh_l = bh_l + 3 * H * D * 2;
+    wx_l = wx_l + I * H * 3 * D + H * H * 3 * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * 3 * H;
+  }
+}
+
+
+template<typename DType>
+void GruForwardTrainingSingleLayer(DType* ws,
+                                   DType* tmp_buf,
+                                   bool state_outputs,
+                                   const int D,
+                                   const int T,
+                                   const int N,
+                                   const int I,
+                                   const int H,
+                                   const Tensor<cpu, 2, DType> &x,
+                                   const Tensor<cpu, 2, DType> &hx,
+                                   DType* wx_ptr,
+                                   DType* wh_ptr,
+                                   DType* bx_ptr,
+                                   DType* bh_ptr,
+                                   DType* gateR,
+                                   DType* gateZ,
+                                   DType* gateN,
+                                   DType* Mnh,
+                                   DType* y_ptr,
+                                   DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H;
+  DType* back_ht = back_ht_1;
+
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gateR;
+  DType* zt = gateZ;
+  DType* nt = gateN;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2 : NULL;
+  DType* back_gateR = gateR + T * N * H;
+  DType* back_gateZ = gateZ + T * N * H;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_Mnh = Mnh + T * N * H;
+  DType* back_gemmC1 = gemmC1 + T * N * 3 * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    DType* Mnht = Mnh + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j];
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      rt = back_gateR + (T - 1 - t) * N * H;
+      zt = back_gateZ + (T - 1 - t) * N * H;
+      nt = back_gateN + (T - 1 - t) * N * H;
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, true);
+
+      DType* back_Mnht = back_Mnh + (T - 1 - t) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int rtb = i * 3 * H;
+          int ztb = i * 3 * H + H;
+          int ntb = i * 3 * H + 2 * H;
+          back_Mnht[i * H + j] = gemmC2[ntb + j] + back_bh[2][j];
+          rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
+              gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
+          zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
+              gemmC2[ztb + j] + back_bx[1][j] + back_bh[1][j]);
+          nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j]
+              + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j]));
+          back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j]
+              + zt[i * H + j] * back_ht_1[i * D * H + j];
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void GruForwardTraining(DType* ws,
+                        DType* rs,
+                        bool state_outputs,
+                        const int L,
+                        const int D,
+                        const int T,
+                        const int N,
+                        int I,
+                        const int H,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* w_ptr,
+                        DType* y_ptr,
+                        DType* hy_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3;
+  DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* bh = bx + H * 3;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  DType* gateR_l = rs;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N;
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  DType* y_tmp = x_ptr;
+
+  for (int l = 0; l < L; l++) {
+    if (l != 0) {
+      y_tmp = y_l;
+      y_l = y_l + T * N * H * D;
+    }
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    GruForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                         x_l, hx_l, wx_l, wh_l, bx_l, bh_l,
+                                         gateR_l, gateZ_l, gateN_l, Mnh_l, y_l, hy_l);
+    gateR_l = gateR_l + T * D * N * H;
+    gateZ_l = gateZ_l + T * D * N * H;
+    gateN_l = gateN_l + T * D * N * H;
+    Mnh_l = Mnh_l +  T * D * N * H;
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + 3 * H * D * 2;
+    bh_l = bh_l + 3 * H * D * 2;
+
+    wx_l = wx_l + I * H * 3 * D + H * H * 3 * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * 3 * H;
+  }
+  memcpy(y_ptr, y_l, T * N * H * D * sizeof(DType));
+}
+
+template <typename DType>
+void GruBackwardSingleLayer(DType* ws,
+                            DType* tmp_buf,
+                            const int D,
+                            const int T,
+                            const int N,
+                            const int I,
+                            const int H,
+                            const Tensor<cpu, 2, DType> &x,
+                            const Tensor<cpu, 2, DType> &hx,
+                            DType* wx_ptr,
+                            DType* wh_ptr,
+                            DType* y_ptr,
+                            DType* dy_ptr,
+                            DType* dhy_ptr,
+                            DType* gateR,
+                            DType* gateZ,
+                            DType* gateN,
+                            DType* Mnh,
+                            DType* dx,
+                            DType* dhx,
+                            DType* dwx,
+                            DType* dwh,
+                            DType* dbx,
+                            DType* dbh,
+                            int req_data,
+                            int req_params,
+                            int req_state) {
+  DType* dyt;
+  DType* ht1;  // [N, D, H]
+  DType* rt;
+  DType* zt;
+  DType* nt;
+  DType* dat;
+  DType* dart;
+  DType* dar = ws;  // [T, N, 3 * H]
+  DType* da = dar + T * N * 3 * H;  // [T, N, 3 * H]
+  DType* dht1 = da + T * N * 3 * H;  // [D, N, H]
+  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
+  DType* Mnht = Mnh;
+  DType* back_ht1;
+  DType* back_dht1 = dht1 + N * H;  // [N, H]
+  DType* back_Mnht = Mnh + T * N * H;
+  DType* back_gateR = gateR + T * N * H;
+  DType* back_gateZ = gateZ + T * N * H;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_dwx = dwx + I * 3 * H + H * 3 * H;
+  DType* back_dwh = dwh + I * 3 * H + H * 3 * H;
+  DType* back_dbx = dbx + 3 * H * 2;
+  DType* back_dbh = dbh + 3 * H * 2;
+
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N * H; ++i) {
+    if (dhy_ptr) {
+      dht1[i] = dhy_ptr[i];
+    } else {
+      dht1[i] = 0;
+    }
+  }
+
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N; ++i) {
+    for (int j = 0; j < H; ++j) {
+      hx_[i * D * H + j] = hx[i][j];
+    }
+  }
+
+  if (D == 2) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N * H; ++i) {
+      if (dhy_ptr) {
+        back_dht1[i] = dhy_ptr[N * H + i];
+      } else {
+        back_dht1[i] = 0;
+      }
+    }
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        hx_[i * D * H + H + j] = hx[N + i][j];
+      }
+    }
+  }
+  for (int t = T - 1; t >= 0; --t) {
+    if (t) {
+      ht1 = y_ptr + (t - 1) * N * D * H;
+    } else {
+      ht1 = hx_;
+    }
+    // add dy[T, N, D, H] to dhy[D, N, H]
+    dyt = dy_ptr + t * N * D * H;
+
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        dht1[i * H + j] += dyt[i * D * H + j];
+      }
+    }
+
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    Mnht = Mnh +  t * N * H;
+    dat = da + t * N * 3 * H;
+    dart = dar + t * N * 3 * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int nid = i * 3 * H + 2 * H + j;
+        int zid = i * 3 * H + H + j;
+        int rid = i * 3 * H + j;
+        int id = i * H + j;
+        dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
+        dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) *
+            zt[id] * (1 - zt[id]);
+        dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] *
+            (1 - rt[id]);
+        dart[nid] = dat[nid] * rt[id];
+        dht1[id] = dht1[id] * zt[id];
+      }
+    }
+    if (req_params != kNullOp) {
+      alpha = 1.0;
+      beta = 1.0;
+      // dht1 = dart * wh    [N, H] = [N, 3 * H] * [3 * H, H]
+      Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
+      Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
+      linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
+
+      if (req_params == kAddTo) {
+        beta = 2.0;
+        // dwx = da.T * x    [3 * H, I] = [3 * H, N] * [N, I] for AddTo
+        Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+        Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
+        Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
+        linalg_gemm(d_dat, d_xt, d_dwx, alpha, beta, true, false);
+      }
+      // dwh = dart.T * ht1    [3 * H, H] = [3 * H, N] * [N, H]
+      Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
+      Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(3 * H, H));
+      Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
+      linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
+    }
+  }
+
+  if (req_params != kNullOp) {
+    // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
+    if (req_params != kAddTo) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < 3 * H; ++i) {
+        for (int j = 0; j < N * T; ++j) {
+          dbx[i] += da[j * 3 * H + i];
+          dbh[i] += dar[j * 3 * H + i];
+        }
+      }
+    } else {
+      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
+      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
+      memset(tmp_dbx.dptr_, 0, H * T * 3 * sizeof(DType));
+      memset(tmp_dbh.dptr_, 0, H * T * 3 * sizeof(DType));
+
+      for (int t = T - 1; t >= 0; --t) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < 3 * H; ++i) {
+          for (int j = 0; j < N; ++j) {
+            tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
+            tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
+          }
+        }
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < 3 * H; ++i) {
+          dbx[i] += tmp_dbx[i][t] + dbx[i];
+          dbh[i] += tmp_dbh[i][t] + dbh[i];
+        }
+      }
+    }
+  }
+  alpha = 1.0;
+  beta = 0.0;
+
+  // dx = da * wx    [T * N, I] = [T * N, 3 * H] * [3 * H, I]
+  Tensor<cpu, 2, DType> d_da(da, Shape2(T * N, 3 * H));
+  if (req_data != kNullOp) {
+    Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+    linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false);
+  }
+
+  // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
+  if (req_params != kNullOp && req_params != kAddTo) {
+    Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
+    linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false);
+  }
+
+  if (D == 2) {
+    for (int t = 0; t < T; ++t) {
+      if (t == T-1) {
+        back_ht1 = hx_;
+      } else {
+        back_ht1 = y_ptr + (t + 1) * N * D * H;
+      }
+
+      //  add dy[T, N, D, H] to dhy[D, N, H]
+      dyt = dy_ptr + t * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          back_dht1[i * H + j] += dyt[i * D * H + H + j];
+        }
+      }
+
+      rt = back_gateR + t * N * H;
+      zt = back_gateZ + t * N * H;
+      nt = back_gateN + t * N * H;
+      back_Mnht = Mnh + (T + t) * N * H;
+      dat = da + t * N * 3 * H;
+      dart = dar + t * N * 3 * H;
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int nid = i * 3 * H + 2 * H + j;
+          int zid = i * 3 * H + H + j;
+          int rid = i * 3 * H + j;
+          int id = i * H + j;
+          dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
+          dart[zid] = dat[zid] = back_dht1[id] * (back_ht1[i * D * H + H + j] -
+              nt[id]) * zt[id] * (1 - zt[id]);
+          dart[rid] = dat[rid] = dat[nid] * back_Mnht[id] * rt[id] *
+              (1 - rt[id]);
+          dart[nid] = dat[nid] * rt[id];
+          back_dht1[id] = back_dht1[id] * zt[id];
+        }
+      }
+
+      if (req_params != kNullOp) {
+        alpha = 1.0;
+        beta = 1.0;
+        // dht1 = da * wh    [N, H] = [N, 3 * H] * [3 * H, H]
+        Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
+        Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
+        linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);
+
+        // dwh = da.T * ht1     [3 * H, H] = [3 * H, N] * [N, H]
+        Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(3 * H, H));
+        Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
+        Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
+            (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+        d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
+        if (req_params == kAddTo) {
+          beta = 2.0;
+          // dwx = da.T * x    [3 * H, I] = [3 * H, N] * [N, I] for AddTo
+          Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+          Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
+          Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
+          linalg_gemm(d_dat, d_xt, d_back_dwx, alpha, beta, true, false);
+        }
+        linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
+      }
+    }
+
+    if (req_params != kNullOp) {
+    // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
+      if (req_params != kAddTo) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < 3 * H; ++i) {
+          for (int j = 0; j < N * T; ++j) {
+            back_dbx[i] += da[j * 3 * H + i];
+            back_dbh[i] += dar[j * 3 * H + i];
+          }
+        }
+      } else {
+        const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
+        const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
+        memset(tmp_dbx.dptr_, 0, H * T * 3 * sizeof(DType));
+        memset(tmp_dbh.dptr_, 0, H * T * 3 * sizeof(DType));
+
+        for (int t = T - 1; t >= 0; --t) {
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < 3 * H; ++i) {
+            for (int j = 0; j < N; ++j) {
+              tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
+              tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
+            }
+          }
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < 3 * H; ++i) {
+            back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
+            back_dbh[i] += tmp_dbh[i][t] + back_dbh[i];
+          }
+        }
+      }
+    }
+    alpha = 1.0;
+    beta = 1.0;
+    // dxt = da * wx    [T * N, I] = [T * N, 3 * H] * [3 * H, I]
+    Tensor<cpu, 2, DType> d_da2(da, Shape2(T * N, 3 * H));
+    if (req_data != kNullOp) {
+      Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+      linalg_gemm(d_da2, back_wx, d_dx, alpha, beta, false, false);
+    }
+    alpha = 1.0;
+    beta = 0.0;
+    // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
+    if (req_params != kNullOp && req_params != kAddTo) {
+      Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
+      linalg_gemm(d_da2, x, d_back_dwx, alpha, beta, true, false);
+    }
+  }
+  if (req_state != kNullOp) {
+    memcpy(dhx, dht1, N * H * D * sizeof(DType));
+  }
+}
+
+template <typename DType>
+void GruBackward(DType* ws,
+                 DType* rs,
+                 const int L,
+                 const int D,
+                 const int T,
+                 const int N,
+                 int I,
+                 const int H,
+                 DType* x_ptr,
+                 DType* hx_ptr,
+                 DType* w_ptr,
+                 DType* dy_ptr,
+                 DType* dhy_ptr,
+                 DType* dx_ptr,
+                 DType* dhx_ptr,
+                 DType* dw_ptr,
+                 int req_data,
+                 int req_params,
+                 int req_state) {
+  DType* wx = w_ptr;
+  DType* dwx = dw_ptr;
+  DType* dwh = dwx + I * H * 3;
+  DType* dbx = dwh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* gateR_l = rs + (L - 1) * T * D * N * H;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2;
+  DType* ws2 = dx_l + T * N * D * H;
+  DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H
+      + D * I * 3 * H + D * H * 3 * H;
+  DType* wh_l = wx_l;
+  if (L == 1) {
+    wh_l = wh_l + I * H * 3;
+  } else {
+    wh_l = wh_l + (D * H) * H * 3;
+  }
+  DType* dhy_l = NULL;
+  if (dhy_ptr)
+    dhy_l = dhy_ptr + (L - 1) * D * N * H;
+  DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D + 1) * H * 3 * H
+      + D * I * 3 * H + D * H * 3 * H;
+  DType* dwh_l = NULL;
+  if (L == 1) {
+    dwh_l = dwx_l + I * H * 3;
+  } else {
+    dwh_l = dwx_l + (D * H) * H * 3;
+  }
+  DType* dbx_l = dbx + (L - 1) * D * 3 * H * 2;
+  DType* dbh_l = dbx_l + 3 * H;
+  DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
+  DType* dy_l = dy_ptr;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H));
+  int inputsize = I;
+  DType* y_tmp = y_l - T * N * H * D;
+  for (int l = L - 1; l >= 0; --l) {
+    if (l == 0) {
+      I = inputsize;
+      y_tmp = x_ptr;
+      dx_l = dx_ptr;
+    } else {
+      I = D * H;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[l];
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    GruBackwardSingleLayer<DType>(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l,
+                                  dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l,
+                                  dwx_l, dwh_l, dbx_l, dbh_l, req_data, req_params, req_state);
+    if (l > 0) {
+      memcpy(dy_l, dx_l, T * N * H * D * sizeof(DType));
+      gateR_l = gateR_l - T * D * N * H;
+      gateZ_l = gateZ_l - T * D * N * H;
+      gateN_l = gateN_l - T * D * N * H;
+      Mnh_l = Mnh_l -  T * D * N * H;
+      dhx_l = dhx_l - D * N * H;
+      if (dhy_l)
+        dhy_l = dhy_l - D * N * H;
+      y_l = y_l - T * N * H * D;
+      y_tmp = y_l;
+      if (l == 1) {
+        wx_l = wx_l - (inputsize + H) * H * 3 * D;
+        wh_l = wx_l + inputsize * 3 * H;
+        dwx_l = dwx_l - (inputsize + H) * H * 3 * D;
+        dwh_l = dwx_l + inputsize * 3 * H;
+      } else {
+        wx_l = wx_l - (I + H) * H * 3 * D;
+        wh_l = wx_l + I * 3 * H;
+        dwx_l = dwx_l - (I + H) * H * 3 * D;
+        dwh_l = dwx_l + I * 3 * H;
+      }
+      dbx_l = dbx_l - D * 3 * H * 2;
+      dbh_l = dbx_l + 3 * H;
+    }
+  }
+}
+}  // namespace op
+}  // namespace mxnet
 #endif  // MXNET_OPERATOR_RNN_IMPL_H_
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1eb23cc9228..ab03973e8e8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -28,17 +28,17 @@
 from common import setup_module, with_seed
 import unittest
 
-def check_rnn_consistency(cell1, cell2, T, N, I, H):
+def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req):
     dshape = (N, T, I)
     data = mx.sym.Variable('data')
 
     Y1, _ = cell1.unroll(T, data, layout='NTC', merge_outputs=True)
     mod1 = mx.mod.Module(Y1, label_names=None, context=default_context())
-    mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True)
+    mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req)
 
     Y2, _ = cell2.unroll(T, data, layout='NTC', merge_outputs=True)
     mod2 = mx.mod.Module(Y2, label_names=None, context=default_context())
-    mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True)
+    mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req)
 
     mod1.init_params()
     args, auxs = mod1.get_params()
@@ -60,8 +60,14 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H):
 
     dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape)
     mod1.backward(out_grads=[dy])
-    mod2.backward(out_grads=[dy])
-    assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+    mod2.backward(out_grads=[dy])    
+    if grad_req != 'null':
+        assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+    else:
+        assert(mod1.get_input_grads()[0] == None)
+        assert(mod2.get_input_grads()[0] == None)
+        
+        
 
 @with_seed()
 def test_lstm_sym():
@@ -71,8 +77,10 @@ def test_lstm_sym():
     stack.add(mx.rnn.LSTMCell(H, prefix='l0_'))
     stack.add(mx.rnn.LSTMCell(H, prefix='l1_'))
     stack.add(mx.rnn.LSTMCell(H, prefix='l2_'))
-    check_rnn_consistency(fused, stack, T, N, I, H)
-    check_rnn_consistency(stack, fused, T, N, I, H)
+    
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
 def test_lstm_bidirectional():
@@ -90,8 +98,45 @@ def test_lstm_bidirectional():
                 mx.rnn.LSTMCell(H, prefix='r1_'),
                 output_prefix='bi_lstm_1_'))
 
-    check_rnn_consistency(stack, fused, T, N, I, H)
-    check_rnn_consistency(fused, stack, T, N, I, H)
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_gru_sym():
+    T, N, I, H = 5, 32, 800, 800
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.GRUCell(H, prefix='l0_'))
+    stack.add(mx.rnn.GRUCell(H, prefix='l1_'))
+    stack.add(mx.rnn.GRUCell(H, prefix='l2_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_gru_bidirectional():
+    T, N, I, H = 5, 20, 800, 800
+    
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru',
+                                bidirectional=True, get_next_state=True, prefix='')
+    
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.GRUCell(H, prefix='l0_'),
+                mx.rnn.GRUCell(H, prefix='r0_'),
+                output_prefix='bi_gru_0_'))    
+    
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.GRUCell(H, prefix='l1_'),
+                mx.rnn.GRUCell(H, prefix='r1_'),
+                output_prefix='bi_gru_1_'))
+    
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
 
 # Currently, fused LSTM operator doesn't support dropout.
 # Will change this test after dropout is supported


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services