You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/28 04:58:08 UTC

[GitHub] lihaofd closed pull request #10289: [WIP][MXNET-107]Fused GRU implementation for CPU

lihaofd closed pull request #10289: [WIP][MXNET-107]Fused GRU implementation for CPU
URL: https://github.com/apache/incubator-mxnet/pull/10289
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 13c077dd9e3..163588ad1d3 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -21,7 +21,7 @@
  * Copyright (c) 2015 by Contributors
  * \file rnn-inl.h
  * \brief
- * \author Sebastian Bodenstein
+ * \author Sebastian Bodenstein, Shu Zhang(shu.zhang@intel.com)
 */
 #ifndef MXNET_OPERATOR_RNN_INL_H_
 #define MXNET_OPERATOR_RNN_INL_H_
@@ -29,6 +29,7 @@
 #include <dmlc/logging.h>
 #include <dmlc/parameter.h>
 #include <mxnet/operator.h>
+#include <mxnet/storage.h>
 #include <algorithm>
 #include <map>
 #include <vector>
@@ -37,8 +38,7 @@
 #include "./math.h"
 #include "./math_functions-inl.h"
 #include "./operator_common.h"
-#include "./mshadow_op.h"
-#include "./linalg.h"
+#include "./rnn_impl.hpp"
 
 namespace mxnet {
 namespace op {
@@ -50,18 +50,37 @@ namespace rnn_enum {
   enum RNNOpResource {kTempSpace};
 }
 
-// A utility function to calculate input size
-inline int rnn_single_param_size(int inputSize,
-                                int hiddenSize,
-                                int mode) {
-  int size = hiddenSize * (hiddenSize + inputSize + 2);
-  // Different RNN's have different num weights
+inline int GetRnnParamSize(int num_layer,
+                           int input_size,
+                           int state_size,
+                           int direction,
+                           int mode) {
+  int size = state_size * direction;
   switch (mode) {
     case rnn_enum::kRnnRelu:
-      size *= 1;
+    case rnn_enum::kRnnTanh:
+      break;
+    case rnn_enum::kLstm:
+      size *= 4;
       break;
+    case rnn_enum::kGru:
+      size *= 3;
+      break;
+  }
+  int size1 = (input_size + state_size + 2) * size;  // first layer size
+  int size2 = (state_size * direction + state_size + 2) * size;  // other layers size
+  int param_size = size1 + (num_layer - 1) * size2;
+  return param_size;
+}
+
+inline int GetRnnBiasSize(int num_layer,
+                           int state_size,
+                           int direction,
+                           int mode) {
+  int size = 2 * state_size * direction * num_layer;
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
     case rnn_enum::kRnnTanh:
-      size *= 1;
       break;
     case rnn_enum::kLstm:
       size *= 4;
@@ -73,19 +92,46 @@ inline int rnn_single_param_size(int inputSize,
   return size;
 }
 
-inline int rnn_param_size(int layerNum,
-                          int inputSize,
-                          int hiddenSize,
-                          bool bidirectional,
-                          int mode) {
-  // get size of first layer
-  int size = rnn_single_param_size(inputSize, hiddenSize, mode);
-  // get size of remaining layers
-  if (bidirectional) {
-    size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, mode);
-    size *= 2;
-  } else {
-    size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode);
+inline size_t GetRNNWorkspaceSize(int seq_length,
+                                  int batch_size,
+                                  int hidden_size,
+                                  int direction,
+                                  int mode) {
+  size_t size = 0;
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      size = seq_length * batch_size * hidden_size * 4 + batch_size * hidden_size * 6;
+      break;
+    case rnn_enum::kLstm:
+      LOG(FATAL) << "Only GRU is supported at the moment";
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
+  }
+  return size;
+}
+
+inline size_t GetRNNReserveSpaceSize(int seq_length,
+                                     int batch_size,
+                                     int hidden_size,
+                                     int mode) {
+  size_t size = 0;
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      size = seq_length * batch_size * hidden_size * 5 + batch_size * hidden_size * 7 +
+          2 * seq_length * batch_size * 3 * hidden_size;
+      break;
+    case rnn_enum::kLstm:
+      LOG(FATAL) << "Only GRU is supported at the moment";
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
   }
   return size;
 }
@@ -123,420 +169,457 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
     DMLC_DECLARE_FIELD(state_outputs).set_default(false)
     .describe("Whether to have the states as symbol outputs.");
   }
-};
 
-template<typename xpu, typename DType>
-class RNNOp : public Operator {
- public:
-  explicit RNNOp(RNNParam p) {
+  bool operator==(const RNNParam& other) const {
+    return this->state_size == other.state_size &&
+           this->num_layers == other.num_layers &&
+           this->bidirectional == other.bidirectional &&
+           this->state_outputs == other.state_outputs &&
+           this->mode == other.mode &&
+           this->seq_length_ == other.seq_length_ &&
+           this->batch_size_ == other.batch_size_ &&
+           this->input_size_ == other.input_size_ &&
+           this->lstm_q_ == other.lstm_q_;
   }
+};
 
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    // TODO(sbodenstein): add MShadow implementation
+typedef ParamOpSign<RNNParam> RNNSignature;
+
+/**
+ * @params: ws: Temp workspace for gemm's output storage.
+ *          rs: Reserve space of forward intermediate data used for training.
+ *          num_layers: The number of recurrent layers.
+ *          direction: direction is 2 if use bidirectional recurrent layers, else is 1;
+ *          seq_length: The number of iterations to unroll over.
+ *          batch_size: size of batch.
+ *          input_size: The number of expected input features.
+ *          state_size: The number of hidden state features.
+ *          x_ptr: Pointer of tensor x containing the features of the input sequence.
+ *                 x's shape is [seq_length, batch_size, input_size]
+ *          hx_ptr: Pointer of tensor hx containing the initial hidden state.
+ *                  hx's shape is [num_layers, batch_size, state_size]
+ *          cx_ptr: Only used in lstm mode. pointer of tensor cx containing the initial cell state.
+ *                  cx's shape is [num_layers, batch_size, state_size]
+ *          w_ptr: Pointer of tensor w containing weights.
+ *          b_ptr: Pointer of tensor w containing bias.
+ *          y_ptr: Pointer of tensor y containing the features of the output features from the
+ *                 last layers of the RNN. y's shape is [seq_length, batch_size, state_size]
+ *          hy_ptr: Pointer of tensor hy containing the hidden state for t=seq_length.
+ *                  hy's shape is [num_layers, batch_size, state_size]
+ *          cy_ptr: Only used in lstm mode. pointer of tensor cy  containing the cell state
+ *                  for t=seq_length. cy' shape is [num_layers, batch_size, state_size]
+ *          mode: Specifies the type of RNN to compute.
+ */
+template <typename DType>
+void RNNForwardTraining(DType* ws,
+                        DType* rs,
+                        bool state_outputs,
+                        const int num_layers,
+                        const int direction,
+                        const int seq_length,
+                        const int batch_size,
+                        const int input_size,
+                        const int state_size,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* cx_ptr,
+                        DType* w_ptr,
+                        DType* y_ptr,
+                        DType* hy_ptr,
+                        DType* cy_ptr,
+                        int mode) {
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      GruForwardTraining<DType>(rs, state_outputs, num_layers, direction, seq_length,
+                                batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                w_ptr, y_ptr, hy_ptr);
+      break;
+    case rnn_enum::kLstm:
+      LOG(FATAL) << "Only GRU is supported at the moment";
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
   }
+}
 
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    // TODO(sbodenstein): add MShadow implementation
+template <typename DType>
+void RNNForwardInference(DType* ws,
+                         bool state_outputs,
+                         const int num_layers,
+                         const int direction,
+                         const int seq_length,
+                         const int batch_size,
+                         const int input_size,
+                         const int state_size,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* cx_ptr,
+                         DType* w_ptr,
+                         DType* b_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr,
+                         DType* cy_ptr,
+                         int mode) {
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kGru:
+      GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
+                                 batch_size, input_size, state_size, x_ptr, hx_ptr,
+                                 w_ptr, y_ptr, hy_ptr);
+      break;
+    case rnn_enum::kLstm:
+      LOG(FATAL) << "Only GRU is supported at the moment";
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
   }
+}
 
- private:
-  RNNParam param_;
-};  // class RNNOp
+template <typename DType>
+void RNNBackward(DType* ws,
+                 DType* rs,
+                 const int num_layers,
+                 const int direction,
+                 const int seq_length,
+                 const int batch_size,
+                 const int input_size,
+                 const int state_size,
+                 DType* x_ptr,
+                 DType* hx_ptr,
+                 DType* cx_ptr,
+                 DType* w_ptr,
+                 DType* y_ptr,
+                 DType* dy_ptr,
+                 DType* dhy_ptr,
+                 DType* dcy_ptr,
+                 DType* dx_ptr,
+                 DType* dhx_ptr,
+                 DType* dcx_ptr,
+                 DType* dw_ptr,
+                 int mode) {
+  switch (mode) {
+    case rnn_enum::kRnnRelu:
+      break;
+    case rnn_enum::kRnnTanh:
+      break;
+    case rnn_enum::kLstm:
+      LOG(FATAL) << "Only GRU is supported at the moment";
+      break;
+    case rnn_enum::kGru:
+      GruBackward<DType>(rs, num_layers, direction, seq_length, batch_size,
+                         input_size, state_size, x_ptr, hx_ptr, w_ptr,
+                         dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr);
+      break;
+  }
+}
 
 template<typename DType>
-class RNNOp<cpu, DType> : public Operator {
+class RNNOp {
  public:
-  explicit RNNOp(RNNParam param) {
-    this->param_ = param;
-    // RNN Mode
-    param_.lstm_q_ = false;
-    switch (param_.mode) {
-      case rnn_enum::kLstm:
-        param_.lstm_q_ = true;
-        break;
-      default:
-        LOG(FATAL) << "only LSTM is implmented on CPU";
+  explicit RNNOp(RNNParam p) {
+    param_ = p;
+    init_space_ = false;
+    reserve_space_size_ = 0;
+  }
+
+  ~RNNOp() {
+    if (init_space_) {
+      Storage::Get()->Free(reserve_space_);
     }
   }
 
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    // Layout TNC
-    CHECK(!ctx.is_train) << "only inference mode is available"
-      "for cpu at the moment.";
-    size_t in_expected = param_.lstm_q_ ? 4 : 3;
-    size_t out_expected = param_.lstm_q_ ? 3 : 2;
-
-    if (!param_.state_outputs)
-      LOG(FATAL) << "no state outputs is currently not supported for cpu.";
-
-    CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
+  void Forward(const OpContext &ctx,
+               const std::vector<TBlob> &in_data,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &out_data) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(param_.mode, rnn_enum::kGru) << "Only gru mode is supported at the moment while param_.mode is:" << param_.mode;
+
+    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
+    if (!param_.state_outputs) {
+      out_expected = 1;
+    }
     CHECK_EQ(in_data.size(), in_expected);
     CHECK_EQ(out_data.size(), out_expected);
-
-    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
-    // get input + output tensors
-    // w layout i2h_w, h2h_w, i2h_b, h2h_b
-    Tensor<cpu, 3, DType> x =
-        in_data[rnn_enum::kData].get<cpu, 3, DType>(s);  // TNC
+    Stream<cpu> *s = ctx.get_stream<cpu>();
+    // get input + output tensor
+    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
     Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> hx =
-        in_data[rnn_enum::kState].get<cpu, 3, DType>(s);  // LNC
-    Tensor<cpu, 3, DType> y =
-        out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);  // TNC
-    int64_t seq_len = x.shape_[0];
-    int64_t num_layers = hx.shape_[0];
-    int64_t batch_size = x.shape_[1];
-    int64_t h_channel = hx.shape_[2];
-    int64_t in_channel = x.shape_[2];
-    Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
-      .get_with_shape<cpu, 2, DType>(
-          mshadow::Shape2(seq_len * batch_size, in_channel), s);  // (T*N)C
-    Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
-      .get_with_shape<cpu, 2, DType>(
-          mshadow::Shape2(
-              y.shape_[0] * y.shape_[1], y.shape_[2]), s);  // (T*N)C
-
+    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
     CHECK(x.CheckContiguous());
     CHECK(w.CheckContiguous());
     CHECK(hx.CheckContiguous());
     CHECK(y.CheckContiguous());
+    param_.seq_length_ = x.shape_[0];
+    param_.batch_size_ = x.shape_[1];
+    param_.input_size_ = x.shape_[2];
 
-    if (param_.lstm_q_) {
-      const size_t kNumMat = 4;
-      int64_t fused_h_ch = kNumMat * h_channel;
-      int64_t h_size = batch_size * fused_h_ch;
-      int64_t num_dir = 1 + param_.bidirectional;
-      int64_t h2h_w_size = h_channel * fused_h_ch;
-
-      Tensor<cpu, 3, DType> cx =
-          in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s);
-      CHECK(cx.CheckContiguous());
-
-      Tensor<cpu, 3, DType> cy =
-          out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s);
-      Tensor<cpu, 3, DType> hy =
-          out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s);
-      CHECK(cy.CheckContiguous());
-      CHECK(hy.CheckContiguous());
-
-      DType* workspace_addr =
-      static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace]
-          .get_host_space_internal(sizeof(DType) *
-                                  (seq_len * h_size + h_size
-                                  + y.shape_[0] * y.shape_[1] * y.shape_[2])));
-      Tensor<cpu, 3, DType> i2h_y(
-          workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch));
-      Tensor<cpu, 2, DType> i2h_y_flatten(
-          workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch));
-      Tensor<cpu, 2, DType> h2h_y(workspace_addr
-          + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch));
-      Tensor<cpu, 3, DType> y_tmp(workspace_addr
-          + (seq_len + 1) * h_size, y.shape_);
-      Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr
-          + (seq_len + 1) * h_size, y_flatten.shape_);
-      CHECK(i2h_y.CheckContiguous());
-      CHECK(h2h_y.CheckContiguous());
-      CHECK(y_tmp.CheckContiguous());
-
-      for (int64_t layer = 0; layer < num_layers; layer++) {
-        int reverse_dir = 0;
-        int out_tmp = 0;
-        if (param_.bidirectional && layer % 2)
-          reverse_dir = 1;
-        if (layer / num_dir % 2 == 0)
-          out_tmp = 1;
-        mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch,
-            (layer < num_dir) ? in_channel : num_dir * h_channel);
-        mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel);
-        int64_t start = layer < num_dir ?
-            (layer * (in_channel * fused_h_ch + h2h_w_size)) :  // input layer
-              (num_dir * (in_channel * fused_h_ch + h2h_w_size)
-              + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
-        Tensor<cpu, 2, DType> i2h_w(w.dptr_ + start, i2h_w_shape);
-        start += layer < num_dir ?
-            in_channel * fused_h_ch : h2h_w_size * num_dir;
-        Tensor<cpu, 2, DType> h2h_w(w.dptr_ + start, h2h_w_shape);
-        start = num_dir * (in_channel * fused_h_ch + h2h_w_size)
-            + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1))
-              + layer * fused_h_ch * 2;
-        Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch);
-        start += fused_h_ch;
-        Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch);
-        if (out_tmp) {
-          linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w,
-              i2h_y_flatten, false, true, s);
-        } else {
-          linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w,
-              i2h_y_flatten, false, true, s);
-        }
-        i2h_y_flatten += repmat(i2h_b, seq_len * batch_size);
-        for (int64_t t = 0; t < seq_len; t++) {
-          int64_t timestep = t;
-          if (reverse_dir)
-            timestep = seq_len - 1 - t;
-          linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y,
-              false, true, s);
-          h2h_y += repmat(h2h_b, batch_size);
-          // fused element-wise ops
-          LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y,
-              y[timestep], out_tmp ? y_tmp[timestep]: y[timestep],
-                hy[layer], cy[layer], batch_size, h_channel, t,
-                reverse_dir, out_tmp && (layer == num_layers - 1));
-        }
-      }
-    } else {
-      LOG(FATAL) << "only LSTM is available for cpu at the moment.";
-    }
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-      const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    LOG(FATAL) << "LSTM backward is not available for cpu at the moment.";
-  }
-
- private:
-  RNNParam param_;
+    const int direction = param_.bidirectional ? 2 : 1;
+    const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, direction, param_.mode);
+    DType* b_ptr = w.dptr_ + w.shape_[0] - bsize;
 
-  void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
-                                  const Tensor<cpu, 2, DType> &cx,
-                                  const Tensor<cpu, 2, DType> &h2h_y,
-                                  const Tensor<cpu, 2, DType> &y,
-                                  // holding intermediate layer output
-                                  const Tensor<cpu, 2, DType> &tmp,
-                                  const Tensor<cpu, 2, DType> &hy,
-                                  const Tensor<cpu, 2, DType> &cy,
-                                  const int64_t batch_size,
-                                  const int64_t h_channel,
-                                  const int64_t t,
-                                  const int reverse_dir,
-                                  const int copy_tmp2y) {
-    int64_t length = batch_size * h_channel;
-    #pragma omp parallel for
-    for (int64_t ji = 0; ji < length; ++ji) {
-      int64_t j = ji / h_channel;  // batch dim
-      int64_t i = ji % h_channel;
-      int64_t f = i + h_channel;
-      int64_t c = i + h_channel * 2;
-      int64_t o = i + h_channel * 3;
-      int64_t j_pos = j * h_channel * 4;
-      h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i];
-      h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f];
-      h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o];
-      h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c];
-      h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i]));
-      h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f]));
-      h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o]));
-      h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]);
-      cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i])
-          + h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c];
-      hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]);
-      tmp[j][i + h_channel * reverse_dir] = hy[j][i];
-      if (copy_tmp2y) {
-        y[j][i] = tmp[j][i];
-        if (reverse_dir)
-          y[j][i + h_channel] = tmp[j][i + h_channel];
-      }
+    DType* hy_ptr = NULL;
+    if (param_.state_outputs) {
+      hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
     }
-  }
-};  // class RNNOp
+    DType* cx_ptr = NULL;
+    DType* cy_ptr = NULL;
 
-template<typename xpu>
-Operator* CreateOp(RNNParam param, int dtype);
-
-#if DMLC_USE_CXX11
-class RNNProp : public OperatorProperty {
- public:
-  std::vector<std::string> ListArguments() const override {
     if (param_.mode == rnn_enum::kLstm) {
-      return {"data", "parameters", "state", "state_cell"};
-    } else {
-      return {"data", "parameters", "state"};
+      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
+      if (param_.state_outputs) {
+        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
+      }
     }
-  }
-
-  std::vector<std::string> ListOutputs() const override {
-    std::vector<std::string> outputs = {"output"};
-    if (!param_.state_outputs)
-      return outputs;
-    else
-      outputs.push_back("state");
-    if (param_.mode == rnn_enum::kLstm)
-      outputs.push_back("state_cell");
-    return outputs;
-  }
-
-  int NumOutputs() const override {
-    int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1;
-    int num_outputs = param_.state_outputs ? (mode_num + 1) : 1;
-    return num_outputs;
-  }
 
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
-  }
+    // allocate temp space
+    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                                                      param_.state_size, direction, param_.mode);
+    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+
+    if (ctx.is_train) {
+      const size_t r_size = GetRNNReserveSpaceSize(param_.seq_length_, param_.batch_size_,
+                                                   param_.state_size, param_.mode);
+      if (init_space_ && reserve_space_size_ < r_size) {
+        Storage::Get()->Free(reserve_space_);
+        init_space_ = false;
+      }
 
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
+      if (!init_space_) {
+        reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU());
+        reserve_space_size_ = r_size;
+        init_space_ = true;
+      }
 
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    if (param_.mode == rnn_enum::kLstm) {
-      CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]";
-    } else {
-      CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
-    }
-    const TShape &dshape = (*in_shape)[rnn_enum::kData];
-    if (dshape.ndim() ==  0) return false;
-    CHECK_EQ(dshape.ndim(), 3U) \
-        << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
-    // data: [sequence len, batch, input dimension]
-    int batch_size = dshape[1];
-    int input_size = dshape[2];
-    int numDirections = param_.bidirectional ? 2 : 1;
-    int total_layers = numDirections * param_.num_layers;  // double for bidirectional
-    SHAPE_ASSIGN_CHECK(*in_shape,
-                       rnn_enum::kState,
-                       Shape3(total_layers, batch_size, param_.state_size));
-    if (param_.mode == rnn_enum::kLstm)
-      SHAPE_ASSIGN_CHECK(*in_shape,
-                        rnn_enum::kStateCell,
-                        Shape3(total_layers, batch_size, param_.state_size));
-
-    // calculate parameter vector length
-    int param_size = rnn_param_size(param_.num_layers,
-                                    input_size,
-                                    param_.state_size,
-                                    param_.bidirectional,
-                                    param_.mode);
-    SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
-
-    out_shape->clear();
-    // output: [sequence len, batch, output size]
-    TShape oshape = dshape;
-    oshape[2] = numDirections * param_.state_size;
-    out_shape->push_back(oshape);
-    if (!param_.state_outputs) {
-      return true;
+      DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
+      RNNForwardTraining<DType>(workspace.dptr_,
+                                reserve_space_ptr,
+                                param_.state_outputs,
+                                param_.num_layers,
+                                direction,
+                                param_.seq_length_,
+                                param_.batch_size_,
+                                param_.input_size_,
+                                param_.state_size,
+                                x.dptr_,
+                                hx.dptr_,
+                                cx_ptr,
+                                w.dptr_,
+                                y.dptr_,
+                                hy_ptr,
+                                cy_ptr,
+                                param_.mode);
     } else {
-      // outStateShape: [layer_num, batch, state size]
-      TShape outStateShape = dshape;
-      outStateShape[0] = total_layers;
-      outStateShape[1] = batch_size;
-      outStateShape[2] = param_.state_size;
-      out_shape->push_back(outStateShape);
-      // Deal with lstm cell state
-      if (param_.mode == rnn_enum::kLstm)
-        out_shape->push_back(outStateShape);
-      return true;
+      RNNForwardInference<DType>(workspace.dptr_,
+                                 param_.state_outputs,
+                                 param_.num_layers,
+                                 direction,
+                                 param_.seq_length_,
+                                 param_.batch_size_,
+                                 param_.input_size_,
+                                 param_.state_size,
+                                 x.dptr_,
+                                 hx.dptr_,
+                                 cx_ptr,
+                                 w.dptr_,
+                                 b_ptr,
+                                 y.dptr_,
+                                 hy_ptr,
+                                 cy_ptr,
+                                 param_.mode);
     }
   }
 
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    CHECK_GE(in_type->size(), 1U);
-    int dtype = (*in_type)[0];
-    CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
-      if ((*in_type)[i] == -1) {
-        (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
-      }
+  void Backward(const OpContext &ctx,
+                const std::vector<TBlob> &out_grad,
+                const std::vector<TBlob> &in_data,
+                const std::vector<TBlob> &out_data,
+                const std::vector<OpReqType> &req,
+                const std::vector<TBlob> &in_grad) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(param_.mode, rnn_enum::kGru) << "Only gru mode is supported at the moment while param_.mode is:" << param_.mode;
+    if (param_.bidirectional || param_.num_layers != 1) {
+      LOG(FATAL) << "Only single layer and unidirectional is supported at the moment";
     }
-    out_type->clear();
-    out_type->push_back(dtype);
+    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     if (!param_.state_outputs) {
-      return true;
-    } else {
-      out_type->push_back(dtype);
-      // Deal with lstm cell state
-      if (param_.mode == rnn_enum::kLstm)
-        out_type->push_back(dtype);
-      return true;
+      out_expected = 1;
     }
-  }
-
-  OperatorProperty* Copy() const override {
-    auto ptr = new RNNProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override {
-    return "RNN";
-  }
+    CHECK_EQ(in_data.size(), in_expected);
+    CHECK_EQ(out_data.size(), out_expected);
+    CHECK_EQ(in_grad.size(), in_expected);
+    CHECK_EQ(out_grad.size(), out_expected);
+    CHECK_EQ(req.size(), in_expected);
+    CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for data";
+    CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for state";
+    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    // get input + output tensors
+    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
+    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
+    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
+    Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, DType>(s);
+    Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, DType>(s);
+    Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
+    CHECK(x.CheckContiguous());
+    CHECK(w.CheckContiguous());
+    CHECK(hx.CheckContiguous());
+    CHECK(y.CheckContiguous());
+    CHECK(dx.CheckContiguous());
+    CHECK(dw.CheckContiguous());
+    CHECK(dhx.CheckContiguous());
+    CHECK(dy.CheckContiguous());
+    param_.seq_length_ = x.shape_[0];
+    param_.batch_size_ = x.shape_[1];
+    param_.input_size_ = x.shape_[2];
 
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    std::vector<int> dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams],
-        in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]};
+    const int direction = param_.bidirectional ? 2 : 1;
 
+    DType * dhy_ptr = NULL;
     if (param_.state_outputs) {
-      dep.push_back(out_data[rnn_enum::kStateOut]);
-      dep.push_back(out_grad[rnn_enum::kStateOut]);
+      dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
     }
 
+    DType * cx_ptr = NULL;
+    DType * dcx_ptr = NULL;
+    DType * dcy_ptr = NULL;
+
     if (param_.mode == rnn_enum::kLstm) {
-      dep.push_back(in_data[rnn_enum::kStateCell]);
+      CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported for state cell";
+      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
+      dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
       if (param_.state_outputs) {
-        dep.push_back(out_data[rnn_enum::kStateCellOut]);
-        dep.push_back(out_grad[rnn_enum::kStateCellOut]);
+        dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
       }
     }
-    return dep;
-  }
 
-  std::vector<ResourceRequest> ForwardResource(
-      const std::vector<TShape> &in_shape) const override {
-    return {ResourceRequest::kTempSpace};
-  }
+    // allocate temp space
+    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                                                      param_.state_size, direction, param_.mode);
+    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
 
-  std::vector<ResourceRequest> BackwardResource(
-      const std::vector<TShape> &in_shape) const override {
-    return {ResourceRequest::kTempSpace};
-  }
+    size_t r_size = GetRNNReserveSpaceSize(param_.seq_length_, param_.batch_size_,
+                                           param_.state_size, param_.mode);
+    if (!init_space_ || reserve_space_size_ != r_size) {
+      LOG(FATAL) << " Check forward init error" << reserve_space_size_;
+    }
 
-  Operator* CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not Implemented";
-    return NULL;
+    DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
+    RNNBackward<DType>(workspace.dptr_,
+                       reserve_space_ptr,
+                       param_.num_layers,
+                       direction,
+                       param_.seq_length_,
+                       param_.batch_size_,
+                       param_.input_size_,
+                       param_.state_size,
+                       x.dptr_,
+                       hx.dptr_,
+                       cx_ptr,
+                       w.dptr_,
+                       y.dptr_,
+                       dy.dptr_,
+                       dhy_ptr,
+                       dcy_ptr,
+                       dx.dptr_,
+                       dhx.dptr_,
+                       dcx_ptr,
+                       dw.dptr_,
+                       param_.mode);
   }
 
-  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                             std::vector<int> *in_type) const override;
-
  private:
   RNNParam param_;
-};  // class RNNProp
-#endif  // DMLC_USE_CXX11
+  bool init_space_;
+  size_t reserve_space_size_;
+  Storage::Handle reserve_space_;
+};  // class RNNOp
+
+template<typename DType>
+static RNNOp<DType> &GetRNNOp(const RNNParam &param) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local RNNOp<DType> op(param);
+#else
+  static MX_THREAD_LOCAL RNNOp<DType> op(param);
+#endif
+  return op;
+}
+
+template<typename xpu>
+void RNNCompute(const nnvm::NodeAttrs& attrs,
+                const OpContext& ctx,
+                const std::vector<TBlob>& inputs,
+                const std::vector<OpReqType>& req,
+                const std::vector<TBlob>& outputs) {
+  const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+  MSHADOW_REAL_TYPE_SWITCH(inputs[rnn_enum::kData].type_flag_, DType, {
+    GetRNNOp<DType>(param).Forward(ctx, inputs, req, outputs);
+  });
+}
+
+template<typename xpu>
+void RNNGradCompute(const nnvm::NodeAttrs& attrs,
+                    const OpContext& ctx,
+                    const std::vector<TBlob>& inputs,
+                    const std::vector<OpReqType>& req,
+                    const std::vector<TBlob>& outputs) {
+  const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+  std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
+  std::vector<TBlob> out_data{inputs[3]};
+  std::vector<TBlob> out_grad{inputs[4]};
+
+  int index = 5;
+  if (param.state_outputs) {
+    out_data.push_back(inputs[index++]);
+    out_grad.push_back(inputs[index++]);
+  }
+
+  if (param.mode == rnn_enum::kLstm) {
+    in_data.push_back(inputs[index++]);
+    if (param.state_outputs) {
+      out_data.push_back(inputs[index++]);
+      out_grad.push_back(inputs[index]);
+    }
+  }
+  const std::vector<TBlob> &in_grad = outputs;
+  MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    GetRNNOp<DType>(param).Backward(ctx, out_grad, in_data, out_data, req, in_grad);
+  });
+}
+
 }  // namespace op
 }  // namespace mxnet
+
+namespace std {
+template<>
+struct hash<mxnet::op::RNNParam> {
+  size_t operator()(const mxnet::op::RNNParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.state_size);
+    ret = dmlc::HashCombine(ret, val.num_layers);
+    ret = dmlc::HashCombine(ret, val.bidirectional);
+    ret = dmlc::HashCombine(ret, val.state_outputs);
+    ret = dmlc::HashCombine(ret, val.mode);
+    ret = dmlc::HashCombine(ret, val.seq_length_);
+    ret = dmlc::HashCombine(ret, val.batch_size_);
+    ret = dmlc::HashCombine(ret, val.input_size_);
+    ret = dmlc::HashCombine(ret, val.lstm_q_);
+    return ret;
+  }
+};
+}  // namespace std
+
 #endif  // MXNET_OPERATOR_RNN_INL_H_
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index a60adbcd2fb..7e75d628ab6 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -21,32 +21,172 @@
  * Copyright (c) 2015 by Contributors
  * \file rnn.cc
  * \brief
- * \author Sebastian Bodenstein
+ * \author Sebastian Bodenstein, Shu Zhang(shu.zhang@intel.com)
 */
-
 #include "./rnn-inl.h"
 
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(RNNParam param, int dtype) {
-  Operator *op = NULL;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new RNNOp<cpu, DType>(param);
-  });
-  return op;
+
+DMLC_REGISTER_PARAMETER(RNNParam);
+static inline std::vector<std::string> ListArguments(const RNNParam& param_) {
+  if (param_.mode == rnn_enum::kLstm) {
+    return {"data", "parameters", "state", "state_cell"};
+  } else {
+    return {"data", "parameters", "state"};
+  }
 }
 
-Operator *RNNProp::CreateOperatorEx(Context ctx,
-                                  std::vector<TShape> *in_shape,
-                                  std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+static bool RNNShape(const nnvm::NodeAttrs& attrs,
+                     std::vector<TShape> *in_shape,
+                     std::vector<TShape> *out_shape) {
+  const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+  using namespace mshadow;
+  if (param_.mode == rnn_enum::kLstm) {
+    CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]";
+  } else {
+    CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
+  }
+  const TShape &dshape = (*in_shape)[rnn_enum::kData];
+  if (dshape.ndim() ==  0) return false;
+  CHECK_EQ(dshape.ndim(), 3U) \
+      << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
+  // data: [sequence len, batch, input dimension]
+  int batch_size = dshape[1];
+  int input_size = dshape[2];
+  int numDirections = param_.bidirectional ? 2 : 1;
+  int total_layers = numDirections * param_.num_layers;  // double for bidirectional
+  SHAPE_ASSIGN_CHECK(*in_shape,
+                     rnn_enum::kState,
+                     Shape3(total_layers, batch_size, param_.state_size));
+  if (param_.mode == rnn_enum::kLstm)
+    SHAPE_ASSIGN_CHECK(*in_shape,
+                      rnn_enum::kStateCell,
+                      Shape3(total_layers, batch_size, param_.state_size));
+
+  // calculate parameter vector length
+  int param_size = GetRnnParamSize(param_.num_layers,
+                                   input_size,
+                                   param_.state_size,
+                                   numDirections,
+                                   param_.mode);
+  SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
+
+  out_shape->clear();
+  // output: [sequence len, batch, output size]
+  TShape oshape = dshape;
+  oshape[2] = numDirections * param_.state_size;
+  out_shape->push_back(oshape);
+  if (param_.state_outputs) {
+    // outStateShape: [layer_num, batch, state size]
+    TShape outStateShape = dshape;
+    outStateShape[0] = total_layers;
+    outStateShape[1] = batch_size;
+    outStateShape[2] = param_.state_size;
+    out_shape->push_back(outStateShape);
+    // Deal with lstm cell state
+    if (param_.mode == rnn_enum::kLstm)
+      out_shape->push_back(outStateShape);
+  }
+  return true;
 }
 
-DMLC_REGISTER_PARAMETER(RNNParam);
+static bool RNNType(const nnvm::NodeAttrs& attrs,
+                    std::vector<int> *in_type,
+                    std::vector<int> *out_type) {
+  const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+  CHECK_GE(in_type->size(), 1U);
+  int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  for (index_t i = 0; i < in_type->size(); ++i) {
+    if ((*in_type)[i] == -1) {
+      (*in_type)[i] = dtype;
+    } else {
+      UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
+    }
+  }
+  out_type->clear();
+  out_type->push_back(dtype);
+  if (param_.state_outputs) {
+    out_type->push_back(dtype);
+    // Deal with lstm cell state
+    if (param_.mode == rnn_enum::kLstm)
+      out_type->push_back(dtype);
+  }
+  return true;
+}
 
-MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
-.describe("Applies a recurrent layer to input.")
+inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
+                                  const int dev_mask,
+                                  DispatchMode* dispatch_mode,
+                                  std::vector<int> *in_attrs,
+                                  std::vector<int> *out_attrs) {
+  DispatchMode wanted_mode = DispatchMode::kFCompute;
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
+inline static bool BackwardRNNStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
+                                          DispatchMode* dispatch_mode,
+                                          std::vector<int> *in_attrs,
+                                          std::vector<int> *out_attrs) {
+  DispatchMode wanted_mode = DispatchMode::kFCompute;
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
+struct RNNGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
+          const std::vector<nnvm::NodeEntry> &ograd) const {
+    const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed);
+    std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData],
+      n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] };
+    heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0});
+    heads.push_back(ograd[rnn_enum::kOut]);
+    if (params.state_outputs) {
+      heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0});
+      heads.push_back(ograd[rnn_enum::kStateOut]);
+    }
+    if (params.mode == rnn_enum::kLstm) {
+      heads.push_back(n->inputs[rnn_enum::kStateCell]);
+      if (params.state_outputs) {
+        heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0});
+        heads.push_back(ograd[rnn_enum::kStateCellOut]);
+      }
+    }
+    return MakeGradNode(op_name, n, heads, n->attrs.dict);
+  }
+};
+
+NNVM_REGISTER_OP(RNN)
+.describe(R"code(Applies a recurrent layer to input
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<RNNParam>)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  int mode_num = (params.mode == rnn_enum::kLstm) ? 2 : 1;
+  int num_outputs = params.state_outputs ? (mode_num + 1) : 1;
+  return num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return ListArguments(params);
+})
+.set_attr<nnvm::FInferShape>("FInferShape", RNNShape)
+.set_attr<nnvm::FInferType>("FInferType", RNNType)
+.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
+.set_attr<FCompute>("FCompute<cpu>", RNNCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
 .add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
 .add_argument("parameters", "NDArray-or-Symbol",
               "Vector of all RNN trainable parameters concatenated")
@@ -54,5 +194,19 @@ MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
 .add_argument("state_cell", "NDArray-or-Symbol",
               "initial cell state for LSTM networks (only for LSTM)")
 .add_arguments(RNNParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_RNN)
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_attr_parser(ParamParser<RNNParam>)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FInferStorageType>("FInferStorageType", BackwardRNNStorageType)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<FCompute>("FCompute<cpu>", RNNGradCompute<cpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu
index 59517932b78..d4a00ffe1e1 100644
--- a/src/operator/rnn.cu
+++ b/src/operator/rnn.cu
@@ -23,7 +23,7 @@
  * \brief
  * \author Sebastian Bodenstein
 */
-
+/*
 #include "./rnn-inl.h"
 #include <algorithm>
 #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
@@ -47,3 +47,4 @@ Operator* CreateOp<gpu>(RNNParam param, int dtype) {
 
 }  // namespace op
 }  // namespace mxnet
+*/
diff --git a/src/operator/rnn_impl.hpp b/src/operator/rnn_impl.hpp
new file mode 100644
index 00000000000..c81f1e1d5fe
--- /dev/null
+++ b/src/operator/rnn_impl.hpp
@@ -0,0 +1,491 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file    rnn_impl.hpp
+ * \brief
+ * \author  Shu Zhang(shu.zhang@intel.com)
+*/
+#ifndef MXNET_OPERATOR_RNN_IMPL_HPP_
+#define MXNET_OPERATOR_RNN_IMPL_HPP_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <algorithm>
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+#include "./math.h"
+#include "./math_functions-inl.h"
+#include "./operator_common.h"
+#include "./mshadow_op.h"
+#include "./linalg.h"
+
+template<typename DType>
+inline DType sigmoid(DType x) {
+  return 1.0f / (1.0f + exp(-x));
+}
+
+template<typename DType>
+void GruForwardInferenceSingleLayer(DType* ws,
+                                    bool state_outputs,
+                                    const int D,
+                                    const int T,
+                                    const int N,
+                                    const int I,
+                                    const int H,
+                                    const Tensor<cpu, 2, DType> &x,
+                                    const Tensor<cpu, 2, DType> &hx,
+                                    const Tensor<cpu, 2, DType> &wx,
+                                    const Tensor<cpu, 2, DType> &wh,
+                                    const Tensor<cpu, 2, DType> &bx,
+                                    const Tensor<cpu, 2, DType> &bh,
+                                    DType* y_ptr,
+                                    DType* hy_ptr) {
+  #pragma omp parallel for
+  for (int i = 0; i < N; i++)
+    for (int j = 0; j < H; j++) {
+      y_ptr[i * H + j] = hx[i][j];
+    }
+
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gemmC2 + N * 3 * H;
+  DType* zt = rt + N * H;
+  DType* nt = zt + N * H;
+  DType* gemmC1_t = gemmC1;
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(D * T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(D * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    #pragma omp parallel for
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    DType* y_start = y_ptr + (T - 1) * N * H;
+    #pragma omp parallel for
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        hy_ptr[i * H + j] = y_start[i * H + j];
+      }
+  }
+}
+
+template <typename DType>
+void GruForwardInference(DType* ws,
+                         bool state_outputs,
+                         const int L,
+                         const int D,
+                         const int T,
+                         const int N,
+                         const int I,
+                         const int H,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* w_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr) {
+  const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 3, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(wh.dptr_ + H * H * 3, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bx.dptr_ + H * 3, Shape2(3, H));
+
+  DType* y_tmp = ws;
+  DType* y_l = x_ptr;
+  DType* ws2 = y_tmp + D * T * N * H;
+
+  const Tensor<cpu, 2, DType> wx_l = wx;
+  const Tensor<cpu, 2, DType> wh_l = wh;
+  const Tensor<cpu, 2, DType> bx_l = bx;
+  const Tensor<cpu, 2, DType> bh_l = bh;
+  Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, I));
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, N, H));
+  Tensor<cpu, 3, DType> hy(hy_ptr, Shape3(L, N, H));
+  Tensor<cpu, 2, DType> x_l = x;
+  Tensor<cpu, 2, DType> hx_l = hx[0];
+  DType* hy_l = hy_ptr;
+
+  for (int i = 0; i < T * N; i++)
+    for (int j = 0; j < I; j++) {
+      x_l[i][j] = y_l[i * I + j];
+    }
+
+  y_l = y_ptr;
+
+  GruForwardInferenceSingleLayer<DType>(ws2, state_outputs, D, T, N, I, H,
+                                        x_l, hx_l, wx_l, wh_l, bx_l, bh_l, y_l, hy_l);
+}
+
+
+template<typename DType>
+void GruForwardTrainingSingleLayer(DType* ws,
+                                   bool state_outputs,
+                                   const int D,
+                                   const int T,
+                                   const int N,
+                                   const int I,
+                                   const int H,
+                                   const Tensor<cpu, 2, DType> &x,
+                                   const Tensor<cpu, 2, DType> &hx,
+                                   const Tensor<cpu, 2, DType> &wx,
+                                   const Tensor<cpu, 2, DType> &wh,
+                                   const Tensor<cpu, 2, DType> &bx,
+                                   const Tensor<cpu, 2, DType> &bh,
+                                   DType* gateR,
+                                   DType* gateZ,
+                                   DType* gateN,
+                                   DType* Mnh,
+                                   DType* y_ptr,
+                                   DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gateR;
+  DType* zt = gateZ;
+  DType* nt = gateN;
+  DType* gemmC1_t = gemmC1;
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(D * T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(D * N, 3 * H));
+
+  #pragma omp parallel for
+  for (int i = 0; i < N; i++)
+    for (int j = 0; j < H; j++) {
+      y_ptr[i * H + j] = hx[i][j];
+    }
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    DType* Mnht = Mnh + t * N * H;
+    #pragma omp parallel for
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j];
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    DType* y_start = y_ptr + (T - 1) * N * H;
+    #pragma omp parallel for
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        hy_ptr[i * H + j] = y_start[i * H + j];
+      }
+  }
+}
+
+template <typename DType>
+void GruForwardTraining(DType* ws,
+                        bool state_outputs,
+                        const int L,
+                        const int D,
+                        const int T,
+                        const int N,
+                        const int I,
+                        const int H,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* w_ptr,
+                        DType* y_ptr,
+                        DType* hy_ptr) {
+  const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 3, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(wh.dptr_ + H * H * 3, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bx.dptr_ + H * 3, Shape2(3, H));
+  Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, I));
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(L, N, H));
+  Tensor<cpu, 3, DType> hy(hy_ptr, Shape3(L, N, H));
+  Tensor<cpu, 2, DType> x_l = x;
+  Tensor<cpu, 2, DType> hx_l = hx[0];
+  DType* hy_l = hy_ptr;
+  DType* gateR_l = ws;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* ws2 = Mnh_l + L * D * T * N * H;
+  const Tensor<cpu, 2, DType> wx_l = wx;
+  const Tensor<cpu, 2, DType> wh_l = wh;
+  const Tensor<cpu, 2, DType> bx_l = bx;
+  const Tensor<cpu, 2, DType> bh_l = bh;
+
+  GruForwardTrainingSingleLayer<DType>(ws2, state_outputs, D, T, N, I, H,
+                                       x_l, hx_l, wx_l, wh_l, bx_l, bh_l,
+                                       gateR_l, gateZ_l, gateN_l, Mnh_l, y_l, hy_l);
+
+  #pragma omp parallel for
+  for (int i = 0; i < T * N * H * D; i++) {
+    y_ptr[i] = y_l[i];
+  }
+}
+
+template <typename DType>
+void GruBackwardSingleLayer(DType* ws,
+                            const int D,
+                            const int T,
+                            const int N,
+                            const int I,
+                            const int H,
+                            const Tensor<cpu, 2, DType> &x,
+                            const Tensor<cpu, 2, DType> &hx,
+                            const Tensor<cpu, 2, DType> &wx,
+                            const Tensor<cpu, 2, DType> &wh,
+                            DType* y_ptr,
+                            DType* dy_ptr,
+                            DType* dhy_ptr,
+                            DType* gateR,
+                            DType* gateZ,
+                            DType* gateN,
+                            DType* Mnh,
+                            DType* dx,
+                            DType* dhx,
+                            DType* dwx,
+                            DType* dwh,
+                            DType* dbx,
+                            DType* dbh) {
+  DType* dyt;
+  DType* ht1;  // [N, D, H]
+  DType* rt;
+  DType* zt;
+  DType* nt;
+  DType* dat;
+  DType* dart;
+  DType* dar = ws;  // [T, N, 3 * H]
+  DType* da = dar + T * N * 3 * H;  // [T, N, 3 * H]
+  DType* dht1 = da + T * N * 3 * H;  // [D, N, H]
+  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
+  DType* Mnht = Mnh;
+  int i, j, t;
+  DType alpha = 1.0;
+  DType beta = 0.0;
+
+  #pragma omp parallel for
+  for (i = 0; i < D * H * 3 * H; ++i) {
+    dwh[i] = 0;
+  }
+
+  #pragma omp parallel for
+  for (i = 0; i < D * 3 * H; ++i) {
+    dbx[i] = 0;
+    dbh[i] = 0;
+  }
+
+  #pragma omp parallel for
+  for (i = 0; i < N * H; ++i) {
+    dht1[i] = dhy_ptr[i];
+  }
+
+  #pragma omp parallel for
+  for (i = 0; i < N; ++i) {
+    for (j = 0; j < H; ++j) {
+      hx_[i * D * H + j] = hx[i][j];
+    }
+  }
+
+  for (t = T - 1; t >= 0; --t) {
+    if (t) {
+      ht1 = y_ptr + (t - 1) * N * D * H;
+    } else {
+      ht1 = hx_;
+    }
+
+    // add dy[T, N, D, H] to dhy[D, N, H]
+    dyt = dy_ptr + t * N * D * H;
+    #pragma omp parallel for
+    for (i = 0; i < N; ++i) {
+      for (j = 0; j < H; ++j) {
+        dht1[i * H + j] += dyt[i * D * H + j];
+      }
+    }
+
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    Mnht = Mnh +  t * N * H;
+    dat = da + t * N * 3 * H;
+    dart = dar + t * N * 3 * H;
+
+    #pragma omp parallel for
+    for (i = 0; i < N; ++i) {
+      for (j = 0; j < H; ++j) {
+        int nid = i * 3 * H + 2 * H + j;
+        int zid = i * 3 * H + H + j;
+        int rid = i * 3 * H + j;
+        int id = i * H + j;
+        dat[nid] = dht1[id] * (1 - zt[id]) * (1 - nt[id] * nt[id]);
+        dart[zid] = dat[zid] = dht1[id] * (ht1[i * D * H + j] - nt[id]) *
+            zt[id] * (1 - zt[id]);
+        dart[rid] = dat[rid] = dat[nid] * Mnht[id] * rt[id] *
+            (1 - rt[id]);
+        dart[nid] = dat[nid] * rt[id];
+            dht1[id] = dht1[id] * zt[id];
+      }
+    }
+
+    alpha = 1.0;
+    beta = 1.0;
+
+    // dht1 = dart * wh    [N, H] = [N, 3 * H] * [3 * H, H]
+    Tensor<cpu, 2, DType> d_dht1(dht1, Shape2(N, H));
+    Tensor<cpu, 2, DType> d_dart(dart, Shape2(N, 3 * H));
+    linalg_gemm(d_dart, wh, d_dht1, alpha, beta, false, false);
+
+    // dwh = dart.T * ht1    [3 * H, H] = [3 * H, N] * [N, H]
+    Tensor<cpu, 2, DType> d_ht1(ht1, Shape2(N, H));
+    Tensor<cpu, 2, DType> d_dwh(dwh, Shape2(3 * H, H));
+    linalg_gemm(d_dart, d_ht1, d_dwh, alpha, beta, true, false);
+  }
+
+  // dbx = e * da       [1, 3 * H] = [1, N] * [N, 3 * H]
+  #pragma omp parallel for
+  for (i = 0; i < 3 * H; ++i) {
+    for (j = 0; j < N * T; ++j) {
+      dbx[i] += da[j * 3 * H + i];
+      dbh[i] += dar[j * 3 * H + i];
+    }
+  }
+  alpha = 1.0;
+  beta = 0.0;
+  // dx = da * wx    [T * N, I] = [T * N,3 * H] * [3 * H, I]
+  Tensor<cpu, 2, DType> d_da(da, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> d_dx(dx, Shape2(T * N, I));
+  linalg_gemm(d_da, wx, d_dx, alpha, beta, false, false);
+
+  // dwx = da.T * x    [3 * H, I] = [3 * H, T * N] * [T * N, I]
+  Tensor<cpu, 2, DType> d_dwx(dwx, Shape2(3 * H, I));
+  linalg_gemm(d_da, x, d_dwx, alpha, beta, true, false);
+
+  #pragma omp parallel for
+  for (i = 0; i < D * N * H; ++i) {
+    dhx[i] = dht1[i];
+  }
+}
+
+template <typename DType>
+void GruBackward(DType* ws,
+                 const int L,
+                 const int D,
+                 const int T,
+                 const int N,
+                 const int I,
+                 const int H,
+                 DType* x_ptr,
+                 DType* hx_ptr,
+                 DType* w_ptr,
+                 DType* dy_ptr,
+                 DType* dhy_ptr,
+                 DType* dx_ptr,
+                 DType* dhx_ptr,
+                 DType* dw_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3 * D;
+  DType* dwx = dw_ptr;
+  DType* dwh = dwx + I * H * 3 * D;
+  DType* dbx = dwh + H * H * 3 * D;
+  DType* dbh = dbx + H * 3 * D;
+  DType* gateR_l = ws + (L - 1) * T * D * N * H;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* ws2 = Mnh_l + T * N * H * D;
+  DType* wx_l_ptr = (L == 1)? wx : wx + (L - 2) * D * (D * H) * 3 * H + D * I * 3 * H;
+  DType* wh_l_ptr = wh + (L - 1) * D * H * 3 * H;
+  DType* x_l_ptr = x_ptr;
+  DType* hx_l_ptr = hx_ptr + (L - 1) * D * N * H;
+  DType* dhy_l = dhy_ptr + (L - 1) * D * N * H;
+  DType* dwx_l = (L == 1)? dwx : dwx + (L - 2) * D * (D * H) * 3 * H + D * I * 3 * H;
+  DType* dwh_l = dwh + (L - 1) * D * H * 3 * H;
+  DType* dbx_l = dbx + (L - 1) * D * 3 * H;
+  DType* dbh_l = dbh + (L - 1) * D * 3 * H;
+  DType* dx_l = dx_ptr;
+  DType* dhx_l = dhx_ptr + (L - 1) * D * N * H;
+  DType* dy_l = dy_ptr;
+  const Tensor<cpu, 2, DType> wx_l(wx_l_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh_l(wh_l_ptr, Shape2(H * 3, H));
+  Tensor<cpu, 2, DType> x_l(x_l_ptr, Shape2(T * N, I));
+  Tensor<cpu, 3, DType> hx(hx_l_ptr, Shape3(L, N, H));
+  Tensor<cpu, 2, DType> hx_l = hx[0];
+
+  GruBackwardSingleLayer<DType>(ws2, D, T, N, I, H, x_l, hx_l, wx_l, wh_l, y_l, dy_l,
+                                dhy_l, gateR_l, gateZ_l, gateN_l, Mnh_l, dx_l, dhx_l,
+                                dwx_l, dwh_l, dbx_l, dbh_l);
+}
+#endif  // MXNET_OPERATOR_RNN_IMPL_HPP_
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index f22b13d6575..c1615632184 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -67,6 +67,7 @@ def test_lstm_forget_bias():
                                forget_bias * np.ones(100, ), np.zeros((2 * 100,))])
     assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias)
 
+@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. Tracked at https://github.com/apache/incubator-mxnet/issues/10104")
 def test_lstm_cpu_inference():
     # should behave the same as lstm cell
     EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 2486be04a52..8037de9efb0 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -27,6 +27,86 @@
 from common import setup_module, with_seed
 import unittest
 
+def check_gru_with_type(xpu, type1, type2, atol):
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    T, N, I, H, nd, nl = 5, 32, 100, 100, 1, 1
+    x1 = mx.random.uniform(-1, 1, (T, N, I), ctx=xpu, dtype=type1)
+    dy = mx.random.uniform(-1, 1, (T, N, H), ctx=xpu, dtype=type1)
+    dhy = mx.random.uniform(-1, 1, (nl, N, H), ctx=xpu, dtype=type1)    
+    wx = mx.random.uniform(-1, 1, (3 * H, I), ctx=xpu,dtype=type1)
+    wh = mx.random.uniform(-1, 1, (3 * H, H), ctx=xpu,dtype=type1)
+    bx = mx.random.uniform(-1, 1, (3 * H, ), ctx=xpu,dtype=type1)
+    bh = mx.random.uniform(-1, 1, (3 * H, ), ctx=xpu,dtype=type1)
+    hx = mx.nd.zeros((nl, N, H), ctx=xpu, dtype=type1)
+    x1.attach_grad()
+    wx.attach_grad()
+    wh.attach_grad()
+    bx.attach_grad()
+    bh.attach_grad()
+    
+    #GRUCell case
+    cell = mx.rnn.GRUCell(H, params=None) 
+    Y, [HY] = cell.unroll(T, X, layout='TNC', merge_outputs=True)
+    G = mx.symbol.Group([Y, HY])
+
+    exe = G.bind(
+        xpu, 
+        args={
+            'x':x1, 
+            'gru_i2h_weight':wx, 
+            'gru_h2h_weight':wh, 
+            'gru_i2h_bias':bx, 
+            'gru_h2h_bias':bh,
+        }
+        ,
+        args_grad={
+            'x':x1.grad, 
+            'gru_i2h_weight':wx.grad, 
+            'gru_h2h_weight':wh.grad,
+            'gru_i2h_bias':bx.grad, 
+            'gru_h2h_bias':bh.grad
+        }
+        ,
+        grad_req='write'
+    )
+    fwd1 = exe.forward(is_train=True)
+    exe.backward([dy, dhy.reshape([N, H])])
+    bwd_dx1 = x1.grad
+    bwd_dw1 = mx.ndarray.concat(wx.grad.reshape((3*H*I,)), wh.grad.reshape((3*H*H,)),
+                                bx.grad, bh.grad, dim=0)
+
+
+    # sym.RNN
+    x2 = x1.astype(type2)    
+    params = mx.ndarray.concat(wx.reshape((3*H*I,)), wh.reshape((3*H*H,)),
+                               bx, bh, dim=0).astype(type2)
+    hx = mx.nd.zeros((nl, N, H), ctx=xpu, dtype=type2)
+    x2.attach_grad()
+    params.attach_grad()
+    Y = mx.sym.RNN(data=X, parameters=Params, state=HX, 
+               state_size=H, num_layers=nl, mode='gru', state_outputs = True, name='GRU')
+    yexe = Y.bind(xpu, 
+            args={'x':x2, 'params':params, 'state':hx},
+            args_grad={'x':x2.grad, 'params':params.grad})
+    
+    fwd2 = yexe.forward(is_train=True)
+    yexe.backward([dy.astype(type2), dhy.astype(type2)])
+    bwd_dx2 = x2.grad
+    bwd_dw2 = params.grad
+
+    # check forward:y, hy
+    assert_allclose(fwd1[0].asnumpy(), fwd2[0].asnumpy(), rtol=1e-2, atol=atol)
+    assert_allclose(fwd1[1].asnumpy(), fwd2[1][0].asnumpy(), rtol=1e-2, atol=atol)
+
+    # check backward: dx, dparams
+    assert_allclose(bwd_dx1[0].asnumpy(), bwd_dx2[0].asnumpy(), rtol=1e-2, atol=atol)
+    assert_allclose(bwd_dw1[0].asnumpy(), bwd_dw2[0].asnumpy(), rtol=1e-2, atol=atol)
+
+def test_gru():
+    check_gru_with_type(mx.cpu(), np.float32, np.float32, 1e-4)
+
 
 def np_softmax(x, axis=-1):
     # fix for old numpy on Travis not supporting keepdims


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services