You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/06 18:38:17 UTC

[incubator-mxnet] branch master updated: [MXNET-107]Fused GRU implementation for CPU (#10311)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 069026a  [MXNET-107]Fused GRU implementation for CPU (#10311)
069026a is described below

commit 069026ab1a9924fd870a625558e000b19b9b9507
Author: Hao Li <ha...@intel.com>
AuthorDate: Thu Jun 7 02:38:03 2018 +0800

    [MXNET-107]Fused GRU implementation for CPU (#10311)
    
    * Add GRU Support and Test Case
    
    * skip the gpu test case that has nothing to do with RNN GRU
    
    * fix robust bug for gru backward
    
    * fix bug for unifying weight parameter
    
    * add GRU multiple layer and bidirection support with test case
    
    * fix test case bug
    
    * fix test case bug
    
    * fix bug for memory issue
    
    * fix bug for bidirection
    
    * rebase code and fix bug for memory corruption issue
    
    * fix gpu compile issue
    
    * fix bug and enable some test cases
    
    * fix robust bug
    
    * trigger the build to check if quantize-gpu case is covered
    
    * trigger the build to check if MKLDNN+GPU case is covered
    
    * disable failed gpu test case of MKLDNN_UTIL_FUNC-MemFormat because it has nothing to do with this PR and will recover it once the issue is passed
    
    * skip failed test_reduce test case temporarily as it has nothing to do with RNN
    
    * enable several test cases
    
    * retrigger the build
    
    * rebase code from lstm
    
    * rebase code for resolve conflict
    
    * add gru code after resolve conflict
    
    * fix bug for resolve conflict
    
    * add Fused GRU code with test case
    
    * retrigger the build
    
    * add GetRecommendedOMPThreadCount for omp
    
    * fix conflict issue
    
    * add gru relate code
    
    * fix bug for code
    
    * update code for gru
    
    * retrigger the build
    
    * fix code about gru condition
    
    * enhance test case to test gradient weights and bias
    
    * fix bug for test case
    
    * fix bug for test case
    
    * fix bug about dropout condition and test case
    
    * fix bug for test case
    
    * fix bug for test case
    
    * retrigger the build
    
    * rebase code
    
    * add gru code
    
    * fix issues about namespace, removing define and memcpy
    
    * retrigger the build
    
    * fix issues and add cudnn_gru_bucketing.py test case
    
    * retrigger the build
    
    * update cudnn_rnn_bucketing.py test case
    
    * update cudnn_rnn_bucketing.py test case
    
    * update cudnn_rnn_bucketing.py test case
    
    * add check for req[kParams] and kAddTo from cudnn_rnn-inl.h
    
    * retrigger the build
    
    * retrigger the build
    
    * retrigger the build
    
    * add kNullOp check
    
    * retrigger the build
    
    * update kNullOp support and test case for both GRU and LSTM
    
    * update kAddToOp support for both GRU and LSTM
---
 ...nn_lstm_bucketing.py => cudnn_rnn_bucketing.py} |  33 +-
 python/mxnet/gluon/rnn/rnn_layer.py                |   2 +-
 src/operator/rnn-inl.h                             |  57 +-
 src/operator/rnn_impl.h                            | 955 ++++++++++++++++++++-
 tests/python/unittest/test_operator.py             |  63 +-
 5 files changed, 1060 insertions(+), 50 deletions(-)

diff --git a/example/rnn/bucketing/cudnn_lstm_bucketing.py b/example/rnn/bucketing/cudnn_rnn_bucketing.py
similarity index 87%
rename from example/rnn/bucketing/cudnn_lstm_bucketing.py
rename to example/rnn/bucketing/cudnn_rnn_bucketing.py
index 84cfc9d..29a66a8 100644
--- a/example/rnn/bucketing/cudnn_lstm_bucketing.py
+++ b/example/rnn/bucketing/cudnn_rnn_bucketing.py
@@ -65,6 +65,8 @@ parser.add_argument('--stack-rnn', default=False,
                     help='stack fused RNN cells to reduce communication overhead')
 parser.add_argument('--dropout', type=float, default='0.0',
                     help='dropout probability (1.0 - keep probability)')
+parser.add_argument('--rnntype', type=str, default='lstm',
+                    help='rnn type: gru and lstm are supported')
 
 #buckets = [32]
 buckets = [10, 20, 30, 40, 50, 60]
@@ -97,13 +99,13 @@ def train(args):
         cell = mx.rnn.SequentialRNNCell()
         for i in range(args.num_layers):
             cell.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
-                                         mode='lstm', prefix='lstm_l%d'%i,
+                                         mode=args.rnntype, prefix='%s_l%d'%(args.rnntype,i),
                                          bidirectional=args.bidirectional))
-            if args.dropout > 0 and i < args.num_layers - 1:
-                cell.add(mx.rnn.DropoutCell(args.dropout, prefix='lstm_d%d'%i))
+            if args.dropout > 0 and i < args.num_layers - 1 and args.rnntype == 'lstm':
+                cell.add(mx.rnn.DropoutCell(args.dropout, prefix='%s_d%d'%(args.rnntype,i)))
     else:
         cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
-                                   mode='lstm', bidirectional=args.bidirectional)
+                                   mode=args.rnntype, bidirectional=args.bidirectional)
 
     def sym_gen(seq_len):
         data = mx.sym.Variable('data')
@@ -168,16 +170,25 @@ def test(args):
 
     if not args.stack_rnn:
         stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
-                mode='lstm', bidirectional=args.bidirectional).unfuse()
+                mode=args.rnntype, bidirectional=args.bidirectional).unfuse()
     else:
         stack = mx.rnn.SequentialRNNCell()
         for i in range(args.num_layers):
-            cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
-            if args.bidirectional:
-                cell = mx.rnn.BidirectionalCell(
-                        cell,
-                        mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
-                        output_prefix='bi_lstm_%d'%i)
+            if args.rnntype == 'lstm':
+                cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
+            elif args.rnntype == 'gru':
+                cell = mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
+                if args.bidirectional:
+                    cell = mx.rnn.BidirectionalCell(
+                            cell,
+                            mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
+                            output_prefix='bi_%s_%d'%(args.rnntype,i))
+
             stack.add(cell)
 
     def sym_gen(seq_len):
diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py
index 056c1d5..d9dc98e 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -190,7 +190,7 @@ class _RNNLayer(Block):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
                 self.i2h_weight[i]._finish_deferred_init()
         if inputs.context.device_type == 'gpu' or \
-           self._mode == 'lstm' and not self._dropout:
+           self._mode in ['lstm', 'gru'] and not self._dropout:
             out = self._forward_kernel(inputs, states)
         else:
             out = self._forward(inputs, states)
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index eded6ae..9953173 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -101,12 +101,14 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
-             + seq_length * batch_size * hidden_size * direction;
+             + seq_length * batch_size * hidden_size * direction + hidden_size * seq_length * 8;
+      break;
+    case rnn_enum::kGru:
+      size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
       break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
@@ -125,12 +127,16 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
       break;
+    case rnn_enum::kGru:
+      size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
+          batch_size * hidden_size * direction * 9 + hidden_size * seq_length * 6 +
+          seq_length * batch_size * 7 * hidden_size * direction;
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -221,14 +227,18 @@ void RNNForwardTraining(DType* ws,
   switch (mode) {
     case rnn_enum::kRnnTanh:
     case rnn_enum::kRnnRelu:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
                                  batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
                                  w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
       break;
+    case rnn_enum::kGru:
+      GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
+                                batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                w_ptr, y_ptr, hy_ptr);
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode " << mode;
       break;
@@ -256,14 +266,18 @@ void RNNForwardInference(DType* ws,
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
-      LOG(FATAL) << "Only LSTM is supported at the moment";
+      LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
       break;
     case rnn_enum::kLstm:
       LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
                                   batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
                                   w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
       break;
+    case rnn_enum::kGru:
+      GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
+                                 batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                 w_ptr, y_ptr, hy_ptr);
+      break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
       break;
@@ -292,16 +306,26 @@ void RNNBackward(DType* ws,
                  DType* dcx_ptr,
                  DType* dw_ptr,
                  DType* db_ptr,
+                 int req_data,
+                 int req_params,
+                 int req_state,
+                 int req_statecell,
                  int mode) {
   switch (mode) {
     case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-    case rnn_enum::kGru:
       break;
     case rnn_enum::kLstm:
       LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
                           input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
-                          dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr);
+                          dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr,
+                          req_data, req_params, req_state, req_statecell);
+      break;
+    case rnn_enum::kGru:
+      GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
+                         input_size, state_size, x_ptr, hx_ptr, w_ptr,
+                         dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr,
+                         req_data, req_params, req_state);
       break;
     default:
       LOG(FATAL) << "unknown RNN mode" << mode;
@@ -330,7 +354,8 @@ class RNNOp : public Operator{
                        const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
+    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
+        << "Only lstm and gru mode are supported at the moment.";
     CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
 
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
@@ -442,8 +467,10 @@ class RNNOp : public Operator{
                         const std::vector<TBlob> &aux_args) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
+    CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
+        << "Only lstm and gru mode are supported at the moment.";
     CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";
+
     size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
     size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     if (!param_.state_outputs) {
@@ -535,6 +562,10 @@ class RNNOp : public Operator{
                        dcx_ptr,
                        dw.dptr_,
                        db_ptr,
+                       req[rnn_enum::kData],
+                       req[rnn_enum::kParams],
+                       req[rnn_enum::kState],
+                       req[rnn_enum::kStateCell],
                        param_.mode);
   }
 
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index 2ee374b..e92a182 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -40,6 +40,10 @@
 #include "./mshadow_op.h"
 #include "./linalg.h"
 
+
+namespace mxnet {
+namespace op {
+
 template<typename DType>
 inline DType sigmoid(DType x) {
   return 1.0f / (1.0f + exp(-x));
@@ -297,6 +301,7 @@ void LstmForwardInference(DType* ws,
 template <typename DType>
 void LstmBackwardSingleLayer(DType* ws,
                              DType* rs,
+                             DType* tmp_buf,
                              bool bid,
                              const int T,
                              const int N,
@@ -314,7 +319,11 @@ void LstmBackwardSingleLayer(DType* ws,
                              DType* dcy_ptr,
                              DType* w_ptr,
                              DType* dw_ptr,
-                             DType* db_ptr) {
+                             DType* db_ptr,
+                             int req_data,
+                             int req_params,
+                             int req_state,
+                             int req_statecell) {
   using namespace mshadow;
   const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
   const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
@@ -336,6 +345,7 @@ void LstmBackwardSingleLayer(DType* ws,
   const DType alpha = 1.0;
   const DType beta0 = 0.0;
   const DType beta1 = 1.0;
+  const DType beta2 = 2.0;
   const int cell_size = N * H;
   if (dhy_ptr != NULL) {
     memcpy(dh.dptr_, dhy_ptr, cell_size * sizeof(DType));
@@ -367,24 +377,67 @@ void LstmBackwardSingleLayer(DType* ws,
       difgo[t][j][1][k] = dc[j][k] * cnext[j][k] * ft * (1 - ft);
       difgo[t][j][2][k] = dc[j][k] * it * (1 - gt * gt);
       difgo[t][j][3][k] = dh[j][k] * tc * ot * (1 - ot);
-      dcnext[j][k] = dc[j][k] * ft;
+      if (req_statecell != kNullOp || i > 0) {
+        dcnext[j][k] = dc[j][k] * ft;
+      }
       if (i) {
         htmp[j][k] = y[tnext][j][k + offset];
       }
     }
     Tensor<cpu, 2, DType> dyh(difgo[t].dptr_, Shape2(N, H * 4));
-    linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
-    linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
+    if (req_state != kNullOp || i > 0) {
+      linalg_gemm(dyh, wh, dhnext, alpha, beta0, false, false);
+    }
+    if (req_params != kNullOp) {
+      if (req_params != kAddTo) {
+        linalg_gemm(dyh, hnext, dwh, alpha, beta1, true, false);
+      } else {
+        linalg_gemm(dyh, hnext, dwh, alpha, beta2, true, false);
+
+        //  generate dwx every time step for AddTo
+        Tensor<cpu, 2, DType> x_t(x.dptr_ + i * N * I, Shape2(N, I));
+        Tensor<cpu, 2, DType> dyx_t(difgo.dptr_ + i * N * H * 4, Shape2(N, H * 4));
+        linalg_gemm(dyx_t, x_t, dwx, alpha, beta2, true, false);
+      }
+    }
   }
   Tensor<cpu, 2, DType> dyx(difgo.dptr_, Shape2(T * N, H * 4));
-  linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
-  linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
+  if (req_data != kNullOp) {
+    linalg_gemm(dyx, wx, dx, alpha, bid ? beta1 : beta0, false, false);
+  }
+  if (req_params != kNullOp && req_params != kAddTo) {
+    linalg_gemm(dyx, x, dwx, alpha, beta0, true, false);
+  }
   const int row = T * N;
   const int col = H * 4;
-  for (int i = 0; i < row; ++i) {
-    for (int j = 0; j < col; ++j) {
-      dbx[j] += dyx[i][j];
-      dbh[j] = dbx[j];
+  if (req_params != kNullOp) {
+    if (req_params != kAddTo) {
+      for (int i = 0; i < row; ++i) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < col; ++j) {
+          dbx[j] += dyx[i][j];
+          dbh[j] = dbx[j];
+        }
+      }
+    } else {
+      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf, Shape2(col, T));
+      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + col * T, Shape2(col, T));
+      memset(tmp_dbx.dptr_, 0, col * T * sizeof(DType));
+      memset(tmp_dbh.dptr_, 0, col * T * sizeof(DType));
+      for (int t = T - 1; t >= 0; --t) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < col; ++j) {
+          for (int i = 0; i < N; ++i) {
+            tmp_dbx[j][t] += dyx[t * N + i][j];
+            tmp_dbh[j][t] = tmp_dbx[j][t];
+          }
+        }
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int j = 0; j < col; ++j) {
+          dbx[j] += tmp_dbx[j][t] + dbx[j];
+          dbh[j] += tmp_dbh[j][t] + dbh[j];
+        }
+      }
     }
   }
 }
@@ -410,7 +463,13 @@ void LstmBackward(DType* ws,
                   DType* dhx_ptr,
                   DType* dcx_ptr,
                   DType* dw_ptr,
-                  DType* db_ptr) {
+                  DType* db_ptr,
+                  int req_data,
+                  int req_params,
+                  int req_state,
+                  int req_statecell) {
+  DType* tmp_buf = ws;
+  DType* ws2 = tmp_buf + 8 * T * H;
   const int total_layers = D * L;
   Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
   Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
@@ -422,7 +481,7 @@ void LstmBackward(DType* ws,
   const int w_size1 = (I + H) * H * 4;      // first layer
   const int w_size2 = (D * H + H) * H * 4;  // other layers
   const int cell_size = N * H;
-  DType* dy_tmp_ptr = ws + T * cell_size * 4 + cell_size * 3;
+  DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3;
   for (int i = L - 1; i >= 0; --i) {
     const int input_size = i ? H * D : I;
     const int w_size = i ? w_size2 : w_size1;
@@ -437,9 +496,10 @@ void LstmBackward(DType* ws,
     Tensor<cpu, 3, DType> dy(dy_ptr, Shape3(T, N, H * D));
     Tensor<cpu, 2, DType> x(i ? y.dptr_ - r_size : x_ptr, Shape2(T * N, input_size));
     Tensor<cpu, 2, DType> dx(i ? dy_tmp_ptr : dx_ptr, Shape2(T * N, input_size));
-    LstmBackwardSingleLayer<DType>(ws, rs_cur_ptr, false, T, N, input_size, H,
+    LstmBackwardSingleLayer<DType>(ws2, rs_cur_ptr, tmp_buf, false, T, N, input_size, H,
                                    x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx],
-                                   dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr);
+                                   dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
+                                   req_data, req_params, req_state, req_statecell);
     if (D == 2) {
       w_cur_ptr += w_size;
       dw_cur_ptr += w_size;
@@ -447,11 +507,874 @@ void LstmBackward(DType* ws,
       ++idx;
       dhy_cur_ptr = dhy_ptr ? dhy_cur_ptr + cell_size : NULL;
       dcy_cur_ptr = dcy_ptr ? dcy_cur_ptr + cell_size : NULL;
-      LstmBackwardSingleLayer<DType>(ws, rs_cur_ptr, true, T, N, input_size, H,
+      LstmBackwardSingleLayer<DType>(ws2, rs_cur_ptr, tmp_buf, true, T, N, input_size, H,
                                      x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx],
-                                     dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr);
+                                     dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
+                                     req_data, req_params, req_state, req_statecell);
     }
     dy_ptr = dx.dptr_;
   }
 }
+
+template<typename DType>
+void GruForwardInferenceSingleLayer(DType* ws,
+                                    DType* tmp_buf,
+                                    bool state_outputs,
+                                    const int D,
+                                    const int T,
+                                    const int N,
+                                    const int I,
+                                    const int H,
+                                    const Tensor<cpu, 2, DType> &x,
+                                    const Tensor<cpu, 2, DType> &hx,
+                                    DType* wx_ptr,
+                                    DType* wh_ptr,
+                                    DType* bx_ptr,
+                                    DType* bh_ptr,
+                                    DType* y_ptr,
+                                    DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H;
+  DType* back_ht = back_ht_1;
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gemmC2 + N * 3 * H;
+  DType* zt = rt + N * H;
+  DType* nt = zt + N * H;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2: NULL;
+  DType* back_gemmC1 = gemmC1 + T * N * 3 * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, true);
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int rtb = i * 3 * H;
+          int ztb = i * 3 * H + H;
+          int ntb = i * 3 * H + 2 * H;
+          rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
+              gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
+          zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
+              gemmC2[ztb + j] + back_bx[1][j]+ back_bh[1][j]);
+          nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j]
+              + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j]));
+          back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j]
+              + zt[i * H + j] * back_ht_1[i * D * H + j];
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void GruForwardInference(DType* ws,
+                         bool state_outputs,
+                         const int L,
+                         const int D,
+                         const int T,
+                         const int N,
+                         int I,
+                         const int H,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* w_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3;
+  DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* bh = bx + H * 3;
+
+  DType* y_tmp = ws;
+  DType* y_l = x_ptr;
+  DType* tmp_buf = y_tmp + D * T * N * H;
+  DType* ws2 = y_tmp + D * T * N * H + D * H * N;
+
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  for (int l = 0; l < L; l++) {
+    Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I));
+    if ((L + l) % 2) {
+      y_l = y_ptr;
+    } else {
+      y_l = y_tmp;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    GruForwardInferenceSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                        x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l, hy_l);
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + 3 * H * D * 2;
+    bh_l = bh_l + 3 * H * D * 2;
+    wx_l = wx_l + I * H * 3 * D + H * H * 3 * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * 3 * H;
+  }
+}
+
+
+template<typename DType>
+void GruForwardTrainingSingleLayer(DType* ws,
+                                   DType* tmp_buf,
+                                   bool state_outputs,
+                                   const int D,
+                                   const int T,
+                                   const int N,
+                                   const int I,
+                                   const int H,
+                                   const Tensor<cpu, 2, DType> &x,
+                                   const Tensor<cpu, 2, DType> &hx,
+                                   DType* wx_ptr,
+                                   DType* wh_ptr,
+                                   DType* bx_ptr,
+                                   DType* bh_ptr,
+                                   DType* gateR,
+                                   DType* gateZ,
+                                   DType* gateN,
+                                   DType* Mnh,
+                                   DType* y_ptr,
+                                   DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H;
+  DType* back_ht = back_ht_1;
+
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gateR;
+  DType* zt = gateZ;
+  DType* nt = gateN;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2 : NULL;
+  DType* back_gateR = gateR + T * N * H;
+  DType* back_gateZ = gateZ + T * N * H;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_Mnh = Mnh + T * N * H;
+  DType* back_gemmC1 = gemmC1 + T * N * 3 * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    DType* Mnht = Mnh + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j];
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      rt = back_gateR + (T - 1 - t) * N * H;
+      zt = back_gateZ + (T - 1 - t) * N * H;
+      nt = back_gateN + (T - 1 - t) * N * H;
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, true);
+
+      DType* back_Mnht = back_Mnh + (T - 1 - t) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int rtb = i * 3 * H;
+          int ztb = i * 3 * H + H;
+          int ntb = i * 3 * H + 2 * H;
+          back_Mnht[i * H + j] = gemmC2[ntb + j] + back_bh[2][j];
+          rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
+              gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
+          zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
+              gemmC2[ztb + j] + back_bx[1][j] + back_bh[1][j]);
+          nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j]
+              + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j]));
+          back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j]
+              + zt[i * H + j] * back_ht_1[i * D * H + j];
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void GruForwardTraining(DType* ws,
+                        DType* rs,
+                        bool state_outputs,
+                        const int L,
+                        const int D,
+                        const int T,
+                        const int N,
+                        int I,
+                        const int H,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* w_ptr,
+                        DType* y_ptr,
+                        DType* hy_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3;
+  DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* bh = bx + H * 3;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  DType* gateR_l = rs;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N;
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  DType* y_tmp = x_ptr;
+
+  for (int l = 0; l < L; l++) {
+    if (l != 0) {
+      y_tmp = y_l;
+      y_l = y_l + T * N * H * D;
+    }
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    GruForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, I, H,
+                                         x_l, hx_l, wx_l, wh_l, bx_l, bh_l,
+                                         gateR_l, gateZ_l, gateN_l, Mnh_l, y_l, hy_l);
+    gateR_l = gateR_l + T * D * N * H;
+    gateZ_l = gateZ_l + T * D * N * H;
+    gateN_l = gateN_l + T * D * N * H;
+    Mnh_l = Mnh_l +  T * D * N * H;
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + 3 * H * D * 2;
+    bh_l = bh_l + 3 * H * D * 2;
+
+    wx_l = wx_l + I * H * 3 * D + H * H * 3 * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * 3 * H;
+  }
+  memcpy(y_ptr, y_l, T * N * H * D * sizeof(DType));
+}
+
+template <typename DType>
+void GruBackwardSingleLayer(DType* ws,
+                            DType* tmp_buf,
+                            const int D,
+                            const int T,
+                            const int N,
+                            const int I,
+                            const int H,
+                            const Tensor<cpu, 2, DType> &x,
+                            const Tensor<cpu, 2, DType> &hx,
+                            DType* wx_ptr,
+                            DType* wh_ptr,
+                            DType* y_ptr,
+                            DType* dy_ptr,
+                            DType* dhy_ptr,
+                            DType* gateR,
+                            DType* gateZ,
+                            DType* gateN,
+                            DType* Mnh,
+                            DType* dx,
+                            DType* dhx,
+                            DType* dwx,
+                            DType* dwh,
+                            DType* dbx,
+                            DType* dbh,
+                            int req_data,
+                            int req_params,
+                            int req_state) {
+  DType* dyt;
+  DType* ht1;  // [N, D, H]
+  DType* rt;
+  DType* zt;
+  DType* nt;
+  DType* dat;
+  DType* dart;
+  DType* dar = ws;  // [T, N, 3 * H]
+  DType* da = dar + T * N * 3 * H;  // [T, N, 3 * H]
+  DType* dht1 = da + T * N * 3 * H;  // [D, N, H]
+  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
+  DType* Mnht = Mnh;
+  DType* back_ht1;
+  DType* back_dht1 = dht1 + N * H;  // [N, H]
+  DType* back_Mnht = Mnh + T * N * H;
+  DType* back_gateR = gateR + T * N * H;
+  DType* back_gateZ = gateZ + T * N * H;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_dwx = dwx + I * 3 * H + H * 3 * H;
+  DType* back_dwh = dwh + I * 3 * H + H * 3 * H;
+  DType* back_dbx = dbx + 3 * H * 2;
+  DType* back_dbh = dbh + 3 * H * 2;
+
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N * H; ++i) {
+    if (dhy_ptr) {
+      dht1[i] = dhy_ptr[i];
+    } else {
+      dht1[i] = 0;
+    }
+  }
+
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < N; ++i) {
+    for (int j = 0; j < H; ++j) {
+      hx_[i * D * H + j] = hx[i][j];
+    }
+  }
+
+  if (D == 2) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N * H; ++i) {
+      if (dhy_ptr) {
+        back_dht1[i] = dhy_ptr[N * H + i];
+      } else {
+        back_dht1[i] = 0;
+      }
+    }
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        hx_[i * D * H + H + j] = hx[N + i][j];
+      }
+    }
+  }
+  for (int t = T - 1; t >= 0; --t) {
+    if (t) {
+      ht1 = y_ptr + (t - 1) * N * D * H;
+    } else {
+      ht1 = hx_;
+    }
+    // add dy[T, N, D, H] to dhy[D, N, H]
+    dyt = dy_ptr + t * N * D * H;
+
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        dht1[i * H + j] += dyt[i * D * H + j];
+      }
+    }
+
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    Mnht = Mnh +  t * N * H;
+    dat = da + t * N * 3 * H;
+    dart = dar + t * N * 3 * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int nid = i * 3 * H + 2 * H + j;
+        int zid = i * 3 * H + H + j;
+        int rid = i * 3 * H + j;
+        int id = i * H + j;
+        dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
+        dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) *
+            zt[id] * (1 - zt[id]);
+        dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] *
+            (1 - rt[id]);
+        dart[nid] = dat[nid] * rt[id];
+        dht1[id] = dht1[id] * zt[id];
+      }
+    }
+    if (req_params != kNullOp) {
+      alpha = 1.0;
+      beta = 1.0;
+      // dht1 = dart * wh    [N, H] = [N, 3 * H] * [3 * H, H]
+      Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
+      Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
+      linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
+
+      if (req_params == kAddTo) {
+        beta = 2.0;
+        // dwx = da.T * x    [3 * H, I] = [3 * H, N] * [N, I] for AddTo
+        Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+        Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
+        Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
+        linalg_gemm(d_dat, d_xt, d_dwx, alpha, beta, true, false);
+      }
+      // dwh = dart.T * ht1    [3 * H, H] = [3 * H, N] * [N, H]
+      Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, D * H));
+      Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(3 * H, H));
+      Tensor<cpu, 3, DType> d_ht1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      d_ht1_tmp = reshape(d_ht1.T(), Shape3(D, H, N));
+      linalg_gemm(d_dart, d_ht1_tmp[0], d_dwh, alpha, beta, true, true);
+    }
+  }
+
+  if (req_params != kNullOp) {
+    // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
+    if (req_params != kAddTo) {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < 3 * H; ++i) {
+        for (int j = 0; j < N * T; ++j) {
+          dbx[i] += da[j * 3 * H + i];
+          dbh[i] += dar[j * 3 * H + i];
+        }
+      }
+    } else {
+      const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
+      const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
+      memset(tmp_dbx.dptr_, 0, H * T * 3 * sizeof(DType));
+      memset(tmp_dbh.dptr_, 0, H * T * 3 * sizeof(DType));
+
+      for (int t = T - 1; t >= 0; --t) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < 3 * H; ++i) {
+          for (int j = 0; j < N; ++j) {
+            tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
+            tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
+          }
+        }
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < 3 * H; ++i) {
+          dbx[i] += tmp_dbx[i][t] + dbx[i];
+          dbh[i] += tmp_dbh[i][t] + dbh[i];
+        }
+      }
+    }
+  }
+  alpha = 1.0;
+  beta = 0.0;
+
+  // dx = da * wx    [T * N, I] = [T * N, 3 * H] * [3 * H, I]
+  Tensor<cpu, 2, DType> d_da(da, Shape2(T * N, 3 * H));
+  if (req_data != kNullOp) {
+    Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+    linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false);
+  }
+
+  // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
+  if (req_params != kNullOp && req_params != kAddTo) {
+    Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
+    linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false);
+  }
+
+  if (D == 2) {
+    for (int t = 0; t < T; ++t) {
+      if (t == T-1) {
+        back_ht1 = hx_;
+      } else {
+        back_ht1 = y_ptr + (t + 1) * N * D * H;
+      }
+
+      //  add dy[T, N, D, H] to dhy[D, N, H]
+      dyt = dy_ptr + t * N * D * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          back_dht1[i * H + j] += dyt[i * D * H + H + j];
+        }
+      }
+
+      rt = back_gateR + t * N * H;
+      zt = back_gateZ + t * N * H;
+      nt = back_gateN + t * N * H;
+      back_Mnht = Mnh + (T + t) * N * H;
+      dat = da + t * N * 3 * H;
+      dart = dar + t * N * 3 * H;
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int nid = i * 3 * H + 2 * H + j;
+          int zid = i * 3 * H + H + j;
+          int rid = i * 3 * H + j;
+          int id = i * H + j;
+          dat[nid] = back_dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
+          dart[zid] = dat[zid] = back_dht1[id] * (back_ht1[i * D * H + H + j] -
+              nt[id]) * zt[id] * (1 - zt[id]);
+          dart[rid] = dat[rid] = dat[nid] * back_Mnht[id] * rt[id] *
+              (1 - rt[id]);
+          dart[nid] = dat[nid] * rt[id];
+          back_dht1[id] = back_dht1[id] * zt[id];
+        }
+      }
+
+      if (req_params != kNullOp) {
+        alpha = 1.0;
+        beta = 1.0;
+        // dht1 = da * wh    [N, H] = [N, 3 * H] * [3 * H, H]
+        Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
+        Tensor<cpu, 2, DType> d_back_dht1(back_dht1, Shape2(N, H));
+        linalg_gemm(d_dart, back_wh, d_back_dht1, alpha, beta, false, false);
+
+        // dwh = da.T * ht1     [3 * H, H] = [3 * H, N] * [N, H]
+        Tensor<cpu, 2, DType> d_back_dwh(back_dwh, Shape2(3 * H, H));
+        Tensor<cpu, 2, DType> d_back_ht1(back_ht1 + H, Shape2(N, D * H));
+        Tensor<cpu, 3, DType> d_back_ht1_tmp = Tensor<cpu, 3, DType>
+            (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+        d_back_ht1_tmp = reshape(d_back_ht1.T(), Shape3(D, H, N));
+        if (req_params == kAddTo) {
+          beta = 2.0;
+          // dwx = da.T * x    [3 * H, I] = [3 * H, N] * [N, I] for AddTo
+          Tensor<cpu, 2, DType> d_xt(x.dptr_ + t * N * I, Shape2(N, I));
+          Tensor<cpu, 2, DType> d_dat(dat, Shape2(N, 3 * H));
+          Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
+          linalg_gemm(d_dat, d_xt, d_back_dwx, alpha, beta, true, false);
+        }
+        linalg_gemm(d_dart, d_back_ht1_tmp[0], d_back_dwh, alpha, beta, true, true);
+      }
+    }
+
+    if (req_params != kNullOp) {
+    // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
+      if (req_params != kAddTo) {
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int i = 0; i < 3 * H; ++i) {
+          for (int j = 0; j < N * T; ++j) {
+            back_dbx[i] += da[j * 3 * H + i];
+            back_dbh[i] += dar[j * 3 * H + i];
+          }
+        }
+      } else {
+        const Tensor<cpu, 2, DType> tmp_dbx(tmp_buf + T * N * D * H, Shape2(H * 3, T));
+        const Tensor<cpu, 2, DType> tmp_dbh(tmp_buf + T * N * D * H + 3 * H * T, Shape2(H * 3, T));
+        memset(tmp_dbx.dptr_, 0, H * T * 3 * sizeof(DType));
+        memset(tmp_dbh.dptr_, 0, H * T * 3 * sizeof(DType));
+
+        for (int t = T - 1; t >= 0; --t) {
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < 3 * H; ++i) {
+            for (int j = 0; j < N; ++j) {
+              tmp_dbx[i][t] += da[t * N * 3 * H + j * 3 * H + i];
+              tmp_dbh[i][t] += dar[t * N * 3 * H + j * 3 * H + i];
+            }
+          }
+          #pragma omp parallel for num_threads(omp_threads)
+          for (int i = 0; i < 3 * H; ++i) {
+            back_dbx[i] += tmp_dbx[i][t] + back_dbx[i];
+            back_dbh[i] += tmp_dbh[i][t] + back_dbh[i];
+          }
+        }
+      }
+    }
+    alpha = 1.0;
+    beta = 1.0;
+    // dxt = da * wx    [T * N, I] = [T * N, 3 * H] * [3 * H, I]
+    Tensor<cpu, 2, DType> d_da2(da, Shape2(T * N, 3 * H));
+    if (req_data != kNullOp) {
+      Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+      linalg_gemm(d_da2, back_wx, d_dx, alpha, beta, false, false);
+    }
+    alpha = 1.0;
+    beta = 0.0;
+    // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
+    if (req_params != kNullOp && req_params != kAddTo) {
+      Tensor<cpu, 2, DType> d_back_dwx(back_dwx, Shape2(3 * H, I));
+      linalg_gemm(d_da2, x, d_back_dwx, alpha, beta, true, false);
+    }
+  }
+  if (req_state != kNullOp) {
+    memcpy(dhx, dht1, N * H * D * sizeof(DType));
+  }
+}
+
+template <typename DType>
+void GruBackward(DType* ws,
+                 DType* rs,
+                 const int L,
+                 const int D,
+                 const int T,
+                 const int N,
+                 int I,
+                 const int H,
+                 DType* x_ptr,
+                 DType* hx_ptr,
+                 DType* w_ptr,
+                 DType* dy_ptr,
+                 DType* dhy_ptr,
+                 DType* dx_ptr,
+                 DType* dhx_ptr,
+                 DType* dw_ptr,
+                 int req_data,
+                 int req_params,
+                 int req_state) {
+  DType* wx = w_ptr;
+  DType* dwx = dw_ptr;
+  DType* dwh = dwx + I * H * 3;
+  DType* dbx = dwh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* gateR_l = rs + (L - 1) * T * D * N * H;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* dx_l = tmp_buf + T * N * D * H + 3 * H * T * 2;
+  DType* ws2 = dx_l + T * N * D * H;
+  DType* wx_l = (L == 1)? wx : wx + (L - 2) * D * (D + 1) * H * 3 * H
+      + D * I * 3 * H + D * H * 3 * H;
+  DType* wh_l = wx_l;
+  if (L == 1) {
+    wh_l = wh_l + I * H * 3;
+  } else {
+    wh_l = wh_l + (D * H) * H * 3;
+  }
+  DType* dhy_l = NULL;
+  if (dhy_ptr)
+    dhy_l = dhy_ptr + (L - 1) * D * N * H;
+  DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D + 1) * H * 3 * H
+      + D * I * 3 * H + D * H * 3 * H;
+  DType* dwh_l = NULL;
+  if (L == 1) {
+    dwh_l = dwx_l + I * H * 3;
+  } else {
+    dwh_l = dwx_l + (D * H) * H * 3;
+  }
+  DType* dbx_l = dbx + (L - 1) * D * 3 * H * 2;
+  DType* dbh_l = dbx_l + 3 * H;
+  DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
+  DType* dy_l = dy_ptr;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, D * N, H));
+  int inputsize = I;
+  DType* y_tmp = y_l - T * N * H * D;
+  for (int l = L - 1; l >= 0; --l) {
+    if (l == 0) {
+      I = inputsize;
+      y_tmp = x_ptr;
+      dx_l = dx_ptr;
+    } else {
+      I = D * H;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[l];
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    GruBackwardSingleLayer<DType>(ws2, tmp_buf, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l,
+                                  dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l,
+                                  dwx_l, dwh_l, dbx_l, dbh_l, req_data, req_params, req_state);
+    if (l > 0) {
+      memcpy(dy_l, dx_l, T * N * H * D * sizeof(DType));
+      gateR_l = gateR_l - T * D * N * H;
+      gateZ_l = gateZ_l - T * D * N * H;
+      gateN_l = gateN_l - T * D * N * H;
+      Mnh_l = Mnh_l -  T * D * N * H;
+      dhx_l = dhx_l - D * N * H;
+      if (dhy_l)
+        dhy_l = dhy_l - D * N * H;
+      y_l = y_l - T * N * H * D;
+      y_tmp = y_l;
+      if (l == 1) {
+        wx_l = wx_l - (inputsize + H) * H * 3 * D;
+        wh_l = wx_l + inputsize * 3 * H;
+        dwx_l = dwx_l - (inputsize + H) * H * 3 * D;
+        dwh_l = dwx_l + inputsize * 3 * H;
+      } else {
+        wx_l = wx_l - (I + H) * H * 3 * D;
+        wh_l = wx_l + I * 3 * H;
+        dwx_l = dwx_l - (I + H) * H * 3 * D;
+        dwh_l = dwx_l + I * 3 * H;
+      }
+      dbx_l = dbx_l - D * 3 * H * 2;
+      dbh_l = dbx_l + 3 * H;
+    }
+  }
+}
+}  // namespace op
+}  // namespace mxnet
 #endif  // MXNET_OPERATOR_RNN_IMPL_H_
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 1eb23cc..ab03973 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -28,17 +28,17 @@ from mxnet.base import py_str, MXNetError
 from common import setup_module, with_seed
 import unittest
 
-def check_rnn_consistency(cell1, cell2, T, N, I, H):
+def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req):
     dshape = (N, T, I)
     data = mx.sym.Variable('data')
 
     Y1, _ = cell1.unroll(T, data, layout='NTC', merge_outputs=True)
     mod1 = mx.mod.Module(Y1, label_names=None, context=default_context())
-    mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True)
+    mod1.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req)
 
     Y2, _ = cell2.unroll(T, data, layout='NTC', merge_outputs=True)
     mod2 = mx.mod.Module(Y2, label_names=None, context=default_context())
-    mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True)
+    mod2.bind(data_shapes=[('data', dshape)], label_shapes=None, inputs_need_grad=True, grad_req=grad_req)
 
     mod1.init_params()
     args, auxs = mod1.get_params()
@@ -60,8 +60,14 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H):
 
     dy = mx.random.uniform(shape=mod1.get_outputs()[0].shape)
     mod1.backward(out_grads=[dy])
-    mod2.backward(out_grads=[dy])
-    assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+    mod2.backward(out_grads=[dy])    
+    if grad_req != 'null':
+        assert_allclose(mod1.get_input_grads()[0].asnumpy(), mod2.get_input_grads()[0].asnumpy(), rtol=1e-2, atol=1e-4)
+    else:
+        assert(mod1.get_input_grads()[0] == None)
+        assert(mod2.get_input_grads()[0] == None)
+        
+        
 
 @with_seed()
 def test_lstm_sym():
@@ -71,8 +77,10 @@ def test_lstm_sym():
     stack.add(mx.rnn.LSTMCell(H, prefix='l0_'))
     stack.add(mx.rnn.LSTMCell(H, prefix='l1_'))
     stack.add(mx.rnn.LSTMCell(H, prefix='l2_'))
-    check_rnn_consistency(fused, stack, T, N, I, H)
-    check_rnn_consistency(stack, fused, T, N, I, H)
+    
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
 
 @with_seed()
 def test_lstm_bidirectional():
@@ -90,8 +98,45 @@ def test_lstm_bidirectional():
                 mx.rnn.LSTMCell(H, prefix='r1_'),
                 output_prefix='bi_lstm_1_'))
 
-    check_rnn_consistency(stack, fused, T, N, I, H)
-    check_rnn_consistency(fused, stack, T, N, I, H)
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_gru_sym():
+    T, N, I, H = 5, 32, 800, 800
+    fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='')
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.GRUCell(H, prefix='l0_'))
+    stack.add(mx.rnn.GRUCell(H, prefix='l1_'))
+    stack.add(mx.rnn.GRUCell(H, prefix='l2_'))
+
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
+@with_seed()
+def test_gru_bidirectional():
+    T, N, I, H = 5, 20, 800, 800
+    
+    fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru',
+                                bidirectional=True, get_next_state=True, prefix='')
+    
+    stack = mx.rnn.SequentialRNNCell()
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.GRUCell(H, prefix='l0_'),
+                mx.rnn.GRUCell(H, prefix='r0_'),
+                output_prefix='bi_gru_0_'))    
+    
+    stack.add(mx.rnn.BidirectionalCell(
+                mx.rnn.GRUCell(H, prefix='l1_'),
+                mx.rnn.GRUCell(H, prefix='r1_'),
+                output_prefix='bi_gru_1_'))
+    
+    check_rnn_consistency(fused, stack, T, N, I, H, 'write')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'add')
+    check_rnn_consistency(fused, stack, T, N, I, H, 'null')
+
 
 # Currently, fused LSTM operator doesn't support dropout.
 # Will change this test after dropout is supported

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.