You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/05/24 22:50:41 UTC

[incubator-mxnet] branch master updated: MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) (#14713)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 653cbb4  MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) (#14713)
653cbb4 is described below

commit 653cbb4f1485a450a81d53f6e76421c980be2689
Author: Hao Li <ha...@intel.com>
AuthorDate: Sat May 25 06:50:20 2019 +0800

    MKLDNN RNN Inference Integration(fp32 LSTM and vRNN with tanh and relu) (#14713)
    
    * trigger the ci
    
    * integrate mkldnn rnn fp32 inference(LSTM and vRNN with tanh and relu)
    
    * fix bug about comparison between signed and unsigned integer expressions
    
    * fix unix-gpu issue
    
    * fix unix gpu bug
    
    * fix unix-gpu issues
    
    * fix some comments
    
    * fix issue
    
    * fix comment
    
    * rename `cached` to `initialized`
    
    * support IType
    
    * TODO for MKLDNN GRU
    
    * fix bugs in memory adjustment
    
    * Reformat TODO for MKLDNN GRU
    
    * Reserve original RNN path
    
    * Remove MKLDNN GRU
    
    * Fix bug for rnn forward
    
    * Remove `__CUDAACC__`
    
    * Move `RNNStatefulComputeCPU` to rnn.cc
    
    * Remove redundent macro of `__CUDACC__`
    
    * Remove the last macro `__CUDACC__` from rnn*
---
 src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 740 +++++++++++++++++++++++++++++++
 src/operator/rnn-inl.h                   | 161 +++++--
 src/operator/rnn.cc                      | 431 ++++++++++++++++++
 src/operator/rnn_impl.h                  |   7 +
 4 files changed, 1297 insertions(+), 42 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h
new file mode 100644
index 0000000..ea8e07e
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h
@@ -0,0 +1,740 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_
+#if MXNET_USE_MKLDNN == 1
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mxnet/storage.h>
+#include <algorithm>
+#include <map>
+#include <vector>
+#include <utility>
+#include <string>
+#include "../../math_functions-inl.h"
+#include "../../operator_common.h"
+#include "../../rnn_impl.h"
+#include "../../rnn-inl.h"
+#include "mkldnn.hpp"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+static algorithm GetMKLDNNRNNAlgo(int mode,
+                                  int* ngates,
+                                  int* nstates) {
+  algorithm algo = algorithm::vanilla_rnn;
+  switch (mode) {
+    case rnn_enum::kLstm:
+      *ngates = 4;
+      *nstates = 2;
+      algo = algorithm::vanilla_lstm;
+      break;
+    case rnn_enum::kGru:
+      *ngates = 3;
+      *nstates = 1;
+      algo = algorithm::vanilla_gru;
+      break;
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+      *ngates = 1;
+      *nstates = 1;
+      algo = algorithm::vanilla_rnn;
+      break;
+    default:
+      LOG(FATAL) << "unsupported RNN mode:" << mode;
+      break;
+  }
+  return algo;
+}
+
+static void ConcatData(mkldnn::memory::format src_format,
+                       mkldnn::memory::format dst_format,
+                       std::vector<mkldnn::memory::dims> srcs_cds,
+                       mkldnn::memory::dims dst_cds,
+                       mkldnn::memory::data_type mkldnn_dtype,
+                       int concat_dimension,
+                       std::vector<void*> srcs_data,
+                       const mkldnn::memory &dst) {
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  std::vector<mkldnn::memory::primitive_desc> srcs_pd;
+  std::vector<mkldnn::memory> srcs;
+  for (size_t i = 0; i < srcs_cds.size(); i++) {
+    auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format);
+    auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine);
+    auto src_memory = mkldnn::memory(mpd, srcs_data[i]);
+    srcs_pd.push_back(mpd);
+    srcs.push_back(src_memory);
+  }
+  std::vector<primitive::at> inputs;
+  for (size_t i = 0; i < srcs_cds.size(); i++) {
+    inputs.push_back(srcs[i]);
+  }
+  auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format);
+  auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd);
+  MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst));
+  MKLDNNStream::Get()->Submit();
+}
+
+//  cached mkldnn memory
+//  first layer wx, wh with next L - 1 layers wx and wh
+//  with L layers hx and cx, src and dst data/iter etc.
+//  it will prepare memory on before and after reorder and concat.
+//  for unidirectional, it will fused as dim like 1  + (L - 1) when I != H.
+//  for bidirectional, it will fused as data + back_data (weight, bias, iter etc),
+//  also need to identify first layer and next layers
+static size_t GetMKLDNNRNNCacheMemorySize(int L,
+                                          int D,
+                                          int T,
+                                          int N,
+                                          int I,
+                                          int H,
+                                          int mode) {
+  size_t size = 0;
+  switch (mode) {
+    case rnn_enum::kLstm:
+      size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H +
+             L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H +
+             6 * D * (I + H + 2) * 4 * H + T * N * I * 2;
+      break;
+    case rnn_enum::kGru:
+      size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H +
+             L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H +
+             6 * D * (I + H + 2) * 3 * H + T * N * I * 2;
+      break;
+    case rnn_enum::kRnnRelu:
+    case rnn_enum::kRnnTanh:
+      size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H +
+             L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H +
+             6 * D * (I + H + 2) * 1 * H + T * N * I * 2;
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode " << mode;
+      break;
+  }
+  return size;
+}
+
+template <typename DType>
+static void AdjustGruWeightGateOrder(DType* weight,
+                                     const int I,
+                                     const int H) {
+  // mxnet gru gate order is reset, update and new gates
+  // mkldnn gru gate order is update, reset and new gates
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  DType* weight_reset = weight;
+  DType* weight_update = weight + I * H;
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < I * H; i++) {
+    DType tmp = weight_update[i];
+    weight_update[i] = weight_reset[i];
+    weight_reset[i] = tmp;
+  }
+}
+
+template <typename DType>
+static void AdjustGruBiasGateOrder(DType* bias,
+                                   const int H) {
+  // mxnet gru gate order is reset, update and new gates
+  // mkldnn gru gate order is update, reset and new gates
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  DType* bias_reset = bias;
+  DType* bias_update = bias + H;
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < H; i++) {
+    DType tmp = bias_update[i];
+    bias_update[i] = bias_reset[i];
+    bias_reset[i] = tmp;
+  }
+}
+// since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN,
+// bidirectional will be fused layer by layer,
+// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H)
+
+template <typename DType>
+static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs,
+                                          const int T,
+                                          const int N,
+                                          const int I,
+                                          const int H,
+                                          DType* x_ptr,
+                                          mkldnn::memory *user_src_layer_memory,
+                                          DType* hx_ptr,
+                                          DType* cx_ptr,
+                                          DType* w_ptr,
+                                          DType* b_ptr,
+                                          DType* y_ptr,
+                                          DType* hy_ptr,
+                                          DType* cy_ptr,
+                                          std::vector<mkldnn::memory> *concat_weight_memory,
+                                          std::vector<mkldnn::memory> *concat_iter_memory,
+                                          std::vector<mkldnn::memory> *x_memory,
+                                          std::vector<mkldnn::memory> *hcx_memory,
+                                          std::vector<mkldnn::memory> *wx_memory,
+                                          std::vector<mkldnn::memory> *wh_memory,
+                                          std::vector<mkldnn::memory> *bias_memory,
+                                          std::vector<mkldnn::memory> *y_memory,
+                                          std::vector<mkldnn::memory> *hcy_memory,
+                                          std::vector<primitive> *rnn_forward_prim,
+                                          int layer_index,
+                                          bool *has_cache,
+                                          int lvalue,
+                                          int dtype,
+                                          bool is_train,
+                                          int mode) {
+  int ngates = 0, nstates = 0;
+  algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates);
+  mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
+  const int single_cell_size = N * H;
+  const int single_b_size = ngates * H;
+  DType* wx = w_ptr;  //  ngates * H, I
+  DType* wh = w_ptr + I * H * ngates;  //  ngates * H, H
+  DType* back_wx = w_ptr + ngates * H * (I + H);
+  DType* back_wh = back_wx + I * H * ngates;
+  DType* bx = b_ptr;
+  DType* bh = b_ptr + H * ngates;
+  DType* back_bx = b_ptr + single_b_size * 2;
+  DType* back_bh = back_bx + H * ngates;
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  auto null_memory_ = null_memory(cpu_engine);
+  int offset1 = 0, offset2 = 0;
+  bool initialized = *has_cache;
+  mkldnn::memory::dims src_layer_tz = {T, N, I};
+  mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H};
+  mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H};  //  ldigo
+  mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H};  //  ldigo for reorder
+  mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H};  //  ldigo
+  mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H};  //  ldigo for reorder
+  mkldnn::memory::dims bias_tz = {1, 2, ngates, H};
+  mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H};  //  ldsnc
+  mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H};  //  ldsnc
+
+  if (!initialized) {
+    if (mode == rnn_enum::kGru) {
+      AdjustGruWeightGateOrder(wx, I, H);
+      AdjustGruWeightGateOrder(back_wx, I, H);
+      AdjustGruWeightGateOrder(wh, H, H);
+      AdjustGruWeightGateOrder(back_wh, H, H);
+      AdjustGruBiasGateOrder(bx, H);
+      AdjustGruBiasGateOrder(back_bx, H);
+      AdjustGruBiasGateOrder(bh, H);
+      AdjustGruBiasGateOrder(back_bh, H);
+    }
+    auto src_wx = (*concat_weight_memory)[2 * layer_index];
+    auto src_wh = (*concat_weight_memory)[2 * layer_index + 1];
+    std::vector<void*> srcs_data1;
+    srcs_data1.push_back(wx);
+    srcs_data1.push_back(back_wx);
+    ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+        {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz,
+        mkldnn_dtype, 1, srcs_data1, src_wx);
+    srcs_data1.clear();
+    srcs_data1.push_back(wh);
+    srcs_data1.push_back(back_wh);
+    ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+        {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz,
+         mkldnn_dtype, 1, srcs_data1, src_wh);
+    int tmpvalue = 0;
+    if (lvalue > 0) {
+      tmpvalue = lvalue + 1;
+    }
+    MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue]));
+    MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue]));
+
+    DType* user_bias = reinterpret_cast<DType *>
+        ((*bias_memory)[tmpvalue].get_data_handle());
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int j = 0; j < single_b_size; j++) {
+      user_bias[j] = bx[j] + bh[j];
+      user_bias[single_b_size + j] = back_bx[j] + back_bh[j];
+    }
+  }
+  if (lvalue > 0) {
+    (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle());
+    (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle());
+    (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle());
+  }
+
+  auto src_layer_md = mkldnn::memory::desc(
+      { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+  auto weight_layer_md = mkldnn::memory::desc(
+      { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+  auto weight_iter_md = mkldnn::memory::desc(
+      { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+  auto dst_layer_md = mkldnn::memory::desc(
+      { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+  auto dst_iter_md = mkldnn::memory::desc(
+      { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+  auto src_iter_md = mkldnn::memory::desc(
+      {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+  auto bias_md = mkldnn::memory::desc({bias_tz},
+      mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+  auto user_src_iter_memory = (*concat_iter_memory)[2];
+  if (mode == rnn_enum::kLstm) {
+    std::vector<void*> srcs_data1;
+    srcs_data1.push_back(hx_ptr);
+    srcs_data1.push_back(cx_ptr);
+    auto tmp1_src_iter_memory = (*concat_iter_memory)[0];
+    ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+        {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2,
+        srcs_data1, tmp1_src_iter_memory);
+    std::vector<void*> srcs_data2;
+    srcs_data2.push_back(hx_ptr + single_cell_size);
+    srcs_data2.push_back(cx_ptr + single_cell_size);
+    auto tmp2_src_iter_memory = (*concat_iter_memory)[1];
+    ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+        {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2,
+        srcs_data2, tmp2_src_iter_memory);
+    std::vector<void*> srcs_data3;
+    srcs_data3.push_back(reinterpret_cast<DType *>(tmp1_src_iter_memory.get_data_handle()));
+    srcs_data3.push_back(reinterpret_cast<DType *>(tmp2_src_iter_memory.get_data_handle()));
+    ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+        {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H},
+        mkldnn_dtype, 1, srcs_data3, user_src_iter_memory);
+  } else {
+    user_src_iter_memory.set_data_handle(hx_ptr);
+  }
+  (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle());
+
+  rnn_cell::desc rnn_cell(nalgorithm,
+      mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh);
+
+  rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell,
+      rnn_direction::bidirectional_concat, src_layer_md,
+      src_iter_md, weight_layer_md, weight_iter_md,
+      bias_md, dst_layer_md, dst_iter_md);
+
+  auto prim_desc
+       = rnn_forward::primitive_desc(layer_desc, cpu_engine);
+
+  if (x_ptr && layer_index == 0) {
+    (*x_memory)[layer_index].set_data_handle(x_ptr);
+  } else {
+    (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle());
+  }
+  (*y_memory)[layer_index].set_data_handle(y_ptr);
+
+  if (rnn_forward_prim->size() <= (size_t)layer_index) {
+    primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index],
+          (*hcx_memory)[layer_index], (*wx_memory)[layer_index],
+          (*wh_memory)[layer_index], (*bias_memory)[layer_index],
+          (*y_memory)[layer_index],
+         (*hcy_memory)[layer_index], null_memory_);
+    rnn_forward_prim->push_back(rnn_prim);
+  }
+  MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]);
+  MKLDNNStream::Get()->Submit();
+
+  if (state_outputs) {
+    DType* dst_hcy = reinterpret_cast<DType *> ((*hcy_memory)[layer_index].get_data_handle());
+    if (mode == rnn_enum::kLstm) {
+      offset1 = nstates * single_cell_size;
+      offset2 = (nstates + 1) * single_cell_size;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int n = 0; n < single_cell_size; n++) {
+        hy_ptr[n] = dst_hcy[n];
+        hy_ptr[n + single_cell_size] = dst_hcy[n + offset1];
+        cy_ptr[n] = dst_hcy[n + single_cell_size];
+        cy_ptr[n + single_cell_size] = dst_hcy[n + offset2];
+      }
+    } else {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int n = 0; n < 2 * single_cell_size; n++) {
+        hy_ptr[n] = dst_hcy[n];
+      }
+    }
+  }
+}
+
+
+template <typename DType>
+static void MKLDNNRNNForwardUnidi(bool state_outputs,
+                                  const int L,
+                                  const int T,
+                                  const int N,
+                                  const int I,
+                                  const int H,
+                                  DType* x_ptr,
+                                  mkldnn::memory *user_src_layer_memory,
+                                  DType* hx_ptr,
+                                  DType* cx_ptr,
+                                  DType* w_ptr,
+                                  DType* b_ptr,
+                                  DType* y_ptr,
+                                  DType* hy_ptr,
+                                  DType* cy_ptr,
+                                  std::vector<mkldnn::memory> *concat_weight_memory,
+                                  std::vector<mkldnn::memory> *concat_iter_memory,
+                                  std::vector<mkldnn::memory> *x_memory,
+                                  std::vector<mkldnn::memory> *hcx_memory,
+                                  std::vector<mkldnn::memory> *wx_memory,
+                                  std::vector<mkldnn::memory> *wh_memory,
+                                  std::vector<mkldnn::memory> *bias_memory,
+                                  std::vector<mkldnn::memory> *y_memory,
+                                  std::vector<mkldnn::memory> *hcy_memory,
+                                  std::vector<primitive> *rnn_forward_prim,
+                                  int layer_index,
+                                  bool *has_cache,
+                                  int dtype,
+                                  bool is_train,
+                                  int mode) {
+  int ngates = 0, nstates = 0;
+  algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates);
+  mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
+  const int cell_size = N * H;
+  const int single_cell_size = N * H;
+  const int single_b_size = ngates * H;
+  int w_size = (I + H) * H * ngates;
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  auto null_memory_ = null_memory(cpu_engine);
+  int offset1 = 0, offset2 = 0;
+  bool initialized = *has_cache;
+
+  mkldnn::memory::dims src_layer_tz = {T, N, I};
+  mkldnn::memory::dims dst_layer_tz = {T, N, H};
+  mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H};  //  ldigo
+  mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H};  //  ldigo
+  mkldnn::memory::dims bias_tz = {L, 1, ngates, H};
+  mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H};  //  ldsnc
+  mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H};  //  ldsnc
+  mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H};  //  ldigo for reorder
+  mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H};  //  ldigo for reorder
+
+  auto weight_layer_md = mkldnn::memory::desc(
+      { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+  auto weight_iter_md = mkldnn::memory::desc(
+      { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+  auto src_layer_md = mkldnn::memory::desc(
+      { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+  auto dst_layer_md = mkldnn::memory::desc(
+      {dst_layer_tz}, mkldnn_dtype, mkldnn::memory::format::tnc);
+  auto src_iter_md = mkldnn::memory::desc(
+      {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+  auto bias_md = mkldnn::memory::desc({bias_tz},
+      mkldnn_dtype, mkldnn::memory::format::ldgo);
+  auto dst_iter_md = mkldnn::memory::desc(
+      {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+
+  for (int l = 0; l < L; l++) {
+    if (mode == rnn_enum::kLstm) {
+      std::vector<void*> srcs_data;
+      srcs_data.push_back(hx_ptr);
+      srcs_data.push_back(cx_ptr);
+      auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index];
+      ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc,
+          {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype,
+          2, srcs_data, tmp_src_iter_memory);
+    } else {
+      (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr);
+    }
+    hx_ptr += cell_size;
+    if (mode == rnn_enum::kLstm) {
+      cx_ptr += cell_size;
+    }
+  }
+
+  auto user_src_iter_memory = null_memory_;
+  if (L == 1) {
+    user_src_iter_memory = (*concat_iter_memory)[layer_index];
+  } else {
+    user_src_iter_memory = (*concat_iter_memory)[L + layer_index];
+    std::vector<void*> src_l_data;
+    std::vector<mkldnn::memory::dims> src_l_dim;
+    for (int l = 0; l < L; l++) {
+      src_l_data.push_back(reinterpret_cast<DType *>
+          ((*concat_iter_memory)[l + layer_index].get_data_handle()));
+      src_l_dim.push_back({1, 1, nstates, N, H});
+    }
+    ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim,
+        {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory);
+  }
+  (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle());
+
+  auto src_wx_f = (*concat_weight_memory)[2 * layer_index];
+  auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1];
+
+  std::vector<void*> srcs_data_x;
+  std::vector<void*> srcs_data_h;
+  std::vector<mkldnn::memory::dims> src_l_dim_x;
+  std::vector<mkldnn::memory::dims> src_l_dim_h;
+  if (!initialized) {
+    if (L == 1) {
+      DType* wx = w_ptr;
+      DType* wh = w_ptr + I * H * ngates;
+      if (mode == rnn_enum::kGru) {
+        AdjustGruWeightGateOrder(wx, I, H);
+        AdjustGruWeightGateOrder(wh, H, H);
+        AdjustGruBiasGateOrder(b_ptr, H);
+        AdjustGruBiasGateOrder(b_ptr + H * ngates, H);
+      }
+      src_wx_f.set_data_handle(wx);
+      src_wh_f.set_data_handle(wh);
+    } else {
+      for (int l = 0; l < L; l++) {
+        DType* wx = w_ptr;
+        DType* wh = w_ptr + I * H * ngates;
+        DType* bx = b_ptr + l * ngates * H * 2;
+        DType* bh = b_ptr + l * ngates * H * 2 + H * ngates;
+        if (mode == rnn_enum::kGru) {
+          AdjustGruWeightGateOrder(wx, I, H);
+          AdjustGruWeightGateOrder(wh, H, H);
+          AdjustGruBiasGateOrder(bx, H);
+          AdjustGruBiasGateOrder(bh, H);
+        }
+        srcs_data_x.push_back(wx);
+        srcs_data_h.push_back(wh);
+        src_l_dim_x.push_back(weights_layer_r_tz);
+        src_l_dim_h.push_back(weights_iter_r_tz);
+        w_ptr = w_ptr + w_size;
+      }
+      ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+          src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f);
+      ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi,
+          src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f);
+    }
+    MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index]));
+    MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index]));
+
+    DType* user_bias_f = reinterpret_cast<DType *> ((*bias_memory)[layer_index].get_data_handle());
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int j = 0; j < L * single_b_size; j++) {
+      int k = j / single_b_size;
+      user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size];
+    }
+  }
+
+  rnn_cell::desc rnn_cell(nalgorithm,
+      mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh);
+
+  rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell,
+      rnn_direction::unidirectional, src_layer_md,
+      src_iter_md, weight_layer_md, weight_iter_md,
+      bias_md, dst_layer_md, dst_iter_md);
+
+  auto prim_desc
+       = rnn_forward::primitive_desc(layer_desc, cpu_engine);
+
+  if (x_ptr && layer_index == 0) {
+    (*x_memory)[layer_index].set_data_handle(x_ptr);
+  } else {
+    (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle());
+  }
+  (*y_memory)[layer_index].set_data_handle(y_ptr);
+
+  if (rnn_forward_prim->size() <= (size_t)layer_index) {
+    primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index],
+          (*hcx_memory)[layer_index], (*wx_memory)[layer_index],
+          (*wh_memory)[layer_index], (*bias_memory)[layer_index],
+          (*y_memory)[layer_index],
+         (*hcy_memory)[layer_index], null_memory_);
+    rnn_forward_prim->push_back(rnn_prim);
+  }
+  MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]);
+  MKLDNNStream::Get()->Submit();
+
+  if (state_outputs) {
+    DType* dst_hcy = reinterpret_cast<DType *> ((*hcy_memory)[layer_index].get_data_handle());
+    if (mode == rnn_enum::kLstm) {
+      for (int l = 0; l < L; l++) {
+        offset1 = l * single_cell_size;
+        offset2 = l * nstates * single_cell_size;
+        #pragma omp parallel for num_threads(omp_threads)
+        for (int n = 0; n < single_cell_size; n++) {
+          hy_ptr[offset1 + n] = dst_hcy[offset2 + n];
+          cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size];
+        }
+      }
+    } else {
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int n = 0; n < L * single_cell_size; n++) {
+        hy_ptr[n] = dst_hcy[n];
+      }
+    }
+  }
+}
+
+template <typename DType>
+static void MKLDNNRNNForward(bool state_outputs,
+                             const int L,
+                             const int D,
+                             const int T,
+                             const int N,
+                             const int I,
+                             const int H,
+                             DType* x_ptr,
+                             DType* hx_ptr,
+                             DType* cx_ptr,
+                             DType* w_ptr,
+                             DType* b_ptr,
+                             DType* y_ptr,
+                             DType* hy_ptr,
+                             DType* cy_ptr,
+                             std::vector<mkldnn::memory> *concat_weight_memory,
+                             std::vector<mkldnn::memory> *concat_iter_memory,
+                             std::vector<mkldnn::memory> *x_memory,
+                             std::vector<mkldnn::memory> *hcx_memory,
+                             std::vector<mkldnn::memory> *wx_memory,
+                             std::vector<mkldnn::memory> *wh_memory,
+                             std::vector<mkldnn::memory> *bias_memory,
+                             std::vector<mkldnn::memory> *y_memory,
+                             std::vector<mkldnn::memory> *hcy_memory,
+                             std::vector<primitive> *rnn_forward_prim,
+                             bool *has_cache,
+                             int dtype,
+                             bool is_train,
+                             int mode) {
+  int ngates = 0, nstates = 0;
+  GetMKLDNNRNNAlgo(mode, &ngates, &nstates);
+  const int b_size = 2 * H * ngates * D;
+  const int cell_size = N * H * D;
+  //  First layer
+  int w_size = (I + H) * H * ngates * D;
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  auto null_memory_ = null_memory(cpu_engine);
+  DType* tmpNull = NULL;
+  // when D = 1 and I == H, L layers can be fused together
+  if (D == 1 && I == H) {
+    MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_,
+        hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+        concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory,
+        bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+        0, has_cache, dtype, is_train, mode);
+  } else {
+    auto user_src_layer_memory_l = null_memory_;
+    if (D == 2) {
+      MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l,
+          hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+          concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory,
+          bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+          0, has_cache, 0, dtype, is_train, mode);
+    } else {
+      MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l,
+          hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+          concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory,
+          bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+          0, has_cache, dtype, is_train, mode);
+    }
+    if (L > 1) {
+      user_src_layer_memory_l = (*y_memory)[0];
+      //  go to next L - 1 layers.
+      //  If D = 2, do it layer by layer. If D = 1, fused L - 1 layers
+      w_ptr += w_size;
+      b_ptr += b_size;
+      if (D == 2) {
+        w_size = (H * D + H) * H * ngates * D;
+        for (int l = 0; l < L - 1; l++) {
+          if (state_outputs) {
+            hy_ptr += cell_size;
+            if (mode == rnn_enum::kLstm) {
+              cy_ptr += cell_size;
+            }
+          }
+          hx_ptr += cell_size;
+          if (mode == rnn_enum::kLstm) {
+            cx_ptr += cell_size;
+          }
+          MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull,
+              &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr,
+              cy_ptr, concat_weight_memory, concat_iter_memory, x_memory,
+              hcx_memory, wx_memory, wh_memory, bias_memory,
+              y_memory, hcy_memory, rnn_forward_prim,
+              1, has_cache, l + 1, dtype, is_train, mode);
+          user_src_layer_memory_l = (*y_memory)[1];
+          w_ptr += w_size;
+          b_ptr += b_size;
+        }
+      }
+      if (D == 1) {
+        if (state_outputs) {
+          hy_ptr += cell_size;
+          if (mode == rnn_enum::kLstm) {
+            cy_ptr += cell_size;
+          }
+        }
+        w_size = (H + H) * H * ngates;
+        MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l,
+            hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory,
+            concat_iter_memory, x_memory, hcx_memory, wx_memory,
+            wh_memory, bias_memory, y_memory, hcy_memory,
+            rnn_forward_prim, 1, has_cache, dtype, is_train, mode);
+      }
+    }
+  }
+  *has_cache = true;
+}
+
+template <typename DType>
+static void MKLDNNRNNForwardInference(bool state_outputs,
+                                      const int num_layers,
+                                      const int direction,
+                                      const int seq_length,
+                                      const int batch_size,
+                                      const int input_size,
+                                      const int state_size,
+                                      DType* x_ptr,
+                                      DType* hx_ptr,
+                                      DType* cx_ptr,
+                                      DType* w_ptr,
+                                      DType* b_ptr,
+                                      DType* y_ptr,
+                                      DType* hy_ptr,
+                                      DType* cy_ptr,
+                                      std::vector<mkldnn::memory>* concat_weight_memory,
+                                      std::vector<mkldnn::memory>* concat_iter_memory,
+                                      std::vector<mkldnn::memory> *x_memory,
+                                      std::vector<mkldnn::memory> *hcx_memory,
+                                      std::vector<mkldnn::memory> *wx_memory,
+                                      std::vector<mkldnn::memory> *wh_memory,
+                                      std::vector<mkldnn::memory> *bias_memory,
+                                      std::vector<mkldnn::memory> *y_memory,
+                                      std::vector<mkldnn::memory> *hcy_memory,
+                                      std::vector<primitive> *rnn_forward_prim,
+                                      bool *has_cache,
+                                      int dtype,
+                                      bool is_train,
+                                      int mode) {
+  switch (mode) {
+    case rnn_enum::kLstm:
+    case rnn_enum::kGru:
+    case rnn_enum::kRnnTanh:
+    case rnn_enum::kRnnRelu:
+      MKLDNNRNNForward<DType>(state_outputs, num_layers, direction, seq_length,
+                              batch_size, input_size, state_size, x_ptr, hx_ptr,
+                              cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr,
+                              concat_weight_memory, concat_iter_memory, x_memory,
+                              hcx_memory, wx_memory, wh_memory,
+                              bias_memory, y_memory, hcy_memory, rnn_forward_prim,
+                              has_cache, dtype, is_train, mode);
+      break;
+    default:
+      LOG(FATAL) << "unknown RNN mode" << mode;
+      break;
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index e43b3c9..9785be2 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -44,17 +44,13 @@
 #include "./math_functions-inl.h"
 #include "./operator_common.h"
 #include "./rnn_impl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./nn/mkldnn/mkldnn_rnn_impl.h"
+#endif
 
 namespace mxnet {
 namespace op {
 
-namespace rnn_enum {
-  enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength};
-  enum RNNOpOutputs {kOut, kStateOut, kStateCellOut};
-  enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru};
-  enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace};
-}
-
 inline int GetRnnParamSize(int num_layer,
                            int input_size,
                            int state_size,
@@ -400,9 +396,29 @@ class RNNOp {
  public:
   RNNParam param_;
   Context ctx_;
+  #if MXNET_USE_MKLDNN == 1
+  std::vector<mkldnn::memory> concat_weight_memory;
+  std::vector<mkldnn::memory> concat_iter_memory;
+  std::vector<primitive> rnn_forward_prim;
+  std::vector<mkldnn::memory> x_memory;
+  std::vector<mkldnn::memory> hcx_memory;
+  std::vector<mkldnn::memory> wx_memory;
+  std::vector<mkldnn::memory> wh_memory;
+  std::vector<mkldnn::memory> bias_memory;
+  std::vector<mkldnn::memory> y_memory;
+  std::vector<mkldnn::memory> hcy_memory;
+  bool has_cache;
+  bool init_mem_;
+  size_t reserve_mem_size_;
+  Storage::Handle mem_space_;
+  #endif
   explicit RNNOp(RNNParam param, Context ctx) {
     this->param_ = param;
     this->ctx_ = ctx;
+    #if MXNET_USE_MKLDNN == 1
+    init_mem_ = false;
+    reserve_mem_size_ = 0;
+    #endif
     #if MXNET_USE_CUDNN_RNN
     init_cudnn_ = false;
     dtype_ = mshadow::DataType<DType>::kCudnnFlag;
@@ -410,8 +426,8 @@ class RNNOp {
     // No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
     cudnn_tensor_core_ = false;
     // When fp16 RNN tests are introduced, we can enable TensorCore as follows:
-//    cudnn_tensor_core =
-//        mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
+    // cudnn_tensor_core =
+    //     mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
     // Defaults
     input_mode_ = CUDNN_LINEAR_INPUT;  // Don't support this yet
     // RNN Mode
@@ -492,7 +508,6 @@ class RNNOp {
       this->temp_init_space_ = false;
       this->reserve_cpu_space_size_ = 0;
       this->temp_cpu_space_size_ = 0;
-
       if (param_.projection_size.has_value()) {
         LOG(FATAL) <<
             "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
@@ -505,6 +520,12 @@ class RNNOp {
   }
 
   ~RNNOp() {
+    #if MXNET_USE_MKLDNN == 1
+    if (init_mem_) {
+      Storage::Get()->Free(mem_space_);
+      init_mem_ = false;
+    }
+    #endif
     #if MXNET_USE_CUDNN_RNN
     CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
     CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
@@ -829,22 +850,23 @@ class RNNOp {
 #endif
 
     if (ctx_.dev_type == kCPU) {
-      // allocate temp space
-      const size_t work_cpu_space_size =
-        GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
-                            param_.state_size, direction, param_.mode);
-      if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
-        Storage::Get()->Free(temp_cpu_space_);
-        temp_init_space_ = false;
-      }
-      if (!temp_init_space_) {
-        temp_cpu_space_ = Storage::Get()->Alloc
-          (work_cpu_space_size * sizeof(DType), Context::CPU());
-        temp_cpu_space_size_ = work_cpu_space_size;
-        temp_init_space_ = true;
-      }
-      DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
       if (ctx.is_train) {
+        // allocate temp space
+        const size_t work_cpu_space_size =
+            GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                              param_.state_size, direction, param_.mode);
+        if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
+            Storage::Get()->Free(temp_cpu_space_);
+            temp_init_space_ = false;
+        }
+        if (!temp_init_space_) {
+          temp_cpu_space_ = Storage::Get()->Alloc
+              (work_cpu_space_size * sizeof(DType), Context::CPU());
+          temp_cpu_space_size_ = work_cpu_space_size;
+          temp_init_space_ = true;
+        }
+        DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
+
         const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
                                                      param_.seq_length_, param_.batch_size_,
                                                      param_.state_size, param_.mode);
@@ -880,23 +902,78 @@ class RNNOp {
                                   param_.p,
                                   param_.mode);
       } else {
-        RNNForwardInference<DType>(work_cpu_space,
-                                   param_.state_outputs,
-                                   param_.num_layers,
-                                   direction,
-                                   param_.seq_length_,
-                                   param_.batch_size_,
-                                   param_.input_size_,
-                                   param_.state_size,
-                                   x.dptr_,
-                                   hx.dptr_,
-                                   cx_ptr,
-                                   w.dptr_,
-                                   b_ptr,
-                                   y.dptr_,
-                                   hy_ptr,
-                                   cy_ptr,
-                                   param_.mode);
+        #if MXNET_USE_MKLDNN == 1
+        if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) {
+          // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one
+          //   will be added to MXNet when we figure out the issue.
+          int dtype = in_data[rnn_enum::kData].type_flag_;
+          MKLDNNRNNForwardInference<DType>(param_.state_outputs,
+                                           param_.num_layers,
+                                           direction,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           param_.state_size,
+                                           x.dptr_,
+                                           hx.dptr_,
+                                           cx_ptr,
+                                           w.dptr_,
+                                           b_ptr,
+                                           y.dptr_,
+                                           hy_ptr,
+                                           cy_ptr,
+                                           &concat_weight_memory,
+                                           &concat_iter_memory,
+                                           &x_memory,
+                                           &hcx_memory,
+                                           &wx_memory,
+                                           &wh_memory,
+                                           &bias_memory,
+                                           &y_memory,
+                                           &hcy_memory,
+                                           &rnn_forward_prim,
+                                           &has_cache,
+                                           dtype,
+                                           ctx.is_train,
+                                           param_.mode);
+        } else {
+        #endif
+          //  Before integrating MKLDNN GRU fp32 inference
+          //  using below code for keep func being OK
+          const size_t work_cpu_space_size =
+              GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                                  param_.state_size, direction, param_.mode);
+          if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
+            Storage::Get()->Free(temp_cpu_space_);
+            temp_init_space_ = false;
+          }
+          if (!temp_init_space_) {
+            temp_cpu_space_ = Storage::Get()->Alloc
+                (work_cpu_space_size * sizeof(DType), Context::CPU());
+            temp_cpu_space_size_ = work_cpu_space_size;
+            temp_init_space_ = true;
+          }
+          DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
+          RNNForwardInference<DType>(work_cpu_space,
+                                     param_.state_outputs,
+                                     param_.num_layers,
+                                     direction,
+                                     param_.seq_length_,
+                                     param_.batch_size_,
+                                     param_.input_size_,
+                                     param_.state_size,
+                                     x.dptr_,
+                                     hx.dptr_,
+                                     cx_ptr,
+                                     w.dptr_,
+                                     b_ptr,
+                                     y.dptr_,
+                                     hy_ptr,
+                                     cy_ptr,
+                                     param_.mode);
+        #if MXNET_USE_MKLDNN == 1
+        }
+        #endif
       }
     }
   }
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 9b412a2..3218494 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -167,6 +167,21 @@ static bool RNNType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
+                                  const int dev_mask,
+                                  DispatchMode* dispatch_mode,
+                                  std::vector<int> *in_attrs,
+                                  std::vector<int> *out_attrs) {
+  DispatchMode wanted_mode = DispatchMode::kFCompute;
+
+  #if MXNET_USE_MKLDNN == 1
+    wanted_mode = DispatchMode::kFComputeEx;
+  #endif
+
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
 struct RNNGrad {
   const char *op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
@@ -191,6 +206,417 @@ struct RNNGrad {
   }
 };
 
+#if MXNET_USE_MKLDNN == 1
+static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
+                                  const OpContext& ctx,
+                                  const std::vector<NDArray>& inputs,
+                                  const std::vector<OpReqType>& req,
+                                  const std::vector<NDArray>& outputs) {
+  std::vector<TBlob> in_blobs;
+  std::vector<TBlob> out_blobs;
+  std::vector<NDArray> temp_ndarrays_i;
+  std::vector<NDArray> temp_ndarrays_o;
+  for (const NDArray& in : inputs) {
+    if (in.storage_type() == kDefaultStorage) {
+      temp_ndarrays_i.push_back(in.Reorder2Default());
+      in_blobs.emplace_back(temp_ndarrays_i.back().data());
+    } else {
+      in_blobs.emplace_back(in.data());
+    }
+  }
+
+  for (const NDArray& out : outputs) {
+    if (out.storage_type() == kDefaultStorage) {
+      temp_ndarrays_o.push_back(out.Reorder2Default());
+      out_blobs.emplace_back(temp_ndarrays_o.back().data());
+    } else {
+      out_blobs.emplace_back(out.data());
+    }
+  }
+  int dtype = in_blobs[rnn_enum::kData].type_flag_;
+  int itype = in_blobs[inputs.size()-1].type_flag_;
+  mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype);
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    MSHADOW_TYPE_SWITCH(itype, IType, {
+      RNNOp<cpu, DType, IType>& op = state_ptr.get_state<RNNOp<cpu, DType, IType>>();
+      const RNNParam& param = op.param_;
+      int ngates = 0, nstates = 0;
+      GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates);
+      int D = param.bidirectional ? 2 : 1;
+      Tensor<cpu, 3, DType> x = in_blobs[rnn_enum::kData].get<cpu, 3, DType>(s);
+      int T = x.shape_[0];
+      int N = x.shape_[1];
+      int I = x.shape_[2];
+      int H = param.state_size;
+      int L = param.num_layers;
+
+      const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode);
+      if (op.init_mem_ && op.reserve_mem_size_ < r_size) {
+        Storage::Get()->Free(op.mem_space_);
+        op.init_mem_ = false;
+      }
+      if (!op.init_mem_) {
+        op.mem_space_ = Storage::Get()->Alloc(
+            r_size * sizeof(DType),
+            Context::CPU());
+        op.reserve_mem_size_ = r_size;
+        op.init_mem_ = true;
+        op.has_cache = false;
+      }
+      if (op.has_cache && op.x_memory.size() == 0) {
+        op.has_cache = false;
+      }
+
+      DType* workptr = static_cast<DType*>(op.mem_space_.dptr);
+      mkldnn::memory::dims src_layer_tz_0 = {T, N, I};
+      mkldnn::memory::dims src_layer_tz = {T, N, D * H};
+      mkldnn::memory::dims dst_layer_tz = {T, N, D * H};
+      auto dst_layer_md = mkldnn::memory::desc(
+        { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+      if (op.x_memory.size() == 0) {
+        if (D == 1 && I == H) {
+          auto user_src_layer_md = mkldnn::memory::desc(
+              { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+          auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine });
+          op.x_memory.push_back(user_src_layer_memory_n);
+
+          mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H};  //  ldigo
+          mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H};  //  ldigo
+          mkldnn::memory::dims bias_tz = {L, 1, ngates, H};
+          auto user_weight_layer_md = mkldnn::memory::desc(
+              { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+          auto user_weight_iter_md = mkldnn::memory::desc(
+              { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+          auto user_bias_md = mkldnn::memory::desc({ bias_tz },
+              mkldnn_dtype, mkldnn::memory::format::ldgo);
+          DType* weight_layer_n = workptr;  //  L * I * ngates * H
+          auto user_weight_layer_memory_n
+              = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
+          op.wx_memory.push_back(user_weight_layer_memory_n);
+
+          DType* weight_iter_n = weight_layer_n + L * I * ngates * H;  //  L * H * ngates * H
+          auto user_weight_iter_memory_n
+              = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
+          op.wh_memory.push_back(user_weight_iter_memory_n);
+
+          DType* bias_n = weight_iter_n + L * H * ngates * H;  //  L * ngates * H
+          auto user_bias_memory_n =
+              mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
+          op.bias_memory.push_back(user_bias_memory_n);
+
+          auto wx_md_n = mkldnn::memory::desc(
+              { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+          DType* wx_n = bias_n + L * ngates * H;  //   L * ngates * I * H
+          auto wx_memory_n =
+              mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
+          DType* wh_n = wx_n + L * ngates * I * H;  //  L * ngates * H * H
+          auto wh_md_n = mkldnn::memory::desc(
+              { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+          auto wh_memory_n =
+              mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
+
+          op.concat_weight_memory.push_back(wx_memory_n);
+          op.concat_weight_memory.push_back(wh_memory_n);
+          workptr = wh_n + L * ngates * H * H;
+
+          mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H};  //  ldsnc
+          auto src_iter_md_n1 = mkldnn::memory::desc(
+              { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+          for (int l = 0; l < L; l++) {
+            DType* src_iter_n1 = workptr;  //  nstates * N * H
+            auto src_iter_memory_n1 =
+                mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1);
+            op.concat_iter_memory.push_back(src_iter_memory_n1);
+            workptr = src_iter_n1 + nstates * N * H;
+          }
+          mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H};  //  ldsnc
+          auto src_iter_md_n = mkldnn::memory::desc(
+              { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+          DType* src_iter_n = workptr;  //  L * nstates * N * H
+          auto src_iter_memory_n =
+              mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n);
+          op.concat_iter_memory.push_back(src_iter_memory_n);
+          op.hcx_memory.push_back(src_iter_memory_n);
+          DType* dst_layer_n = src_iter_n + L * nstates * N * H;  //  T * N * D * H
+          auto dst_layer_memory_n
+              = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
+          op.y_memory.push_back(dst_layer_memory_n);
+
+          mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H};  //  ldsnc
+          auto dst_iter_md_n = mkldnn::memory::desc(
+              { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+          DType* dst_iter_n = dst_layer_n + T * N * D * H;  //  L * nstates * N * H
+          auto dst_iter_memory_n =
+              mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
+          op.hcy_memory.push_back(dst_iter_memory_n);
+          workptr = dst_iter_n + L * nstates * N * H;
+
+        } else {
+          auto user_src_layer_md_0 = mkldnn::memory::desc(
+              { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc);
+          auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine });
+          op.x_memory.push_back(user_src_layer_memory_0);
+
+          mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H};  //  ldigo
+          mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H};  //  ldigo
+          mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H};
+          auto user_weight_layer_md_0 = mkldnn::memory::desc(
+              { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+          auto user_weight_iter_md_0 = mkldnn::memory::desc(
+              { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+          auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 },
+              mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+          DType* weight_layer_0 = workptr;  //  D * I * ngates * H
+          auto user_weight_layer_memory_0
+              = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0);
+          op.wx_memory.push_back(user_weight_layer_memory_0);
+
+          DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H;  //  D * H * ngates * H
+          auto user_weight_iter_memory_0
+              = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0);
+          op.wh_memory.push_back(user_weight_iter_memory_0);
+
+          DType* bias_0 = weight_iter_0 + D * H * ngates * H;  //  D * ngates * H
+          auto user_bias_memory_0 =
+              mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0);
+          op.bias_memory.push_back(user_bias_memory_0);
+          workptr = bias_0 + D * ngates * H;
+
+          auto wx_md_0 = mkldnn::memory::desc(
+              { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+          auto wx_memory_0 =
+              mkldnn::memory({ wx_md_0, cpu_engine });
+          auto wh_md_0 = mkldnn::memory::desc(
+              { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+          auto wh_memory_0 =
+              mkldnn::memory({ wh_md_0, cpu_engine });
+          if (D == 2) {
+            DType* wx_0 = workptr;  //  D * ngates * I * H
+            wx_memory_0.set_data_handle(wx_0);
+            DType* wh_0 = wx_0 + D * ngates * I * H;  //  D * ngates * H * H
+            wh_memory_0.set_data_handle(wh_0);
+            workptr = wh_0 + D * ngates * H * H;
+          }
+          op.concat_weight_memory.push_back(wx_memory_0);
+          op.concat_weight_memory.push_back(wh_memory_0);
+
+          mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H};  //  ldsnc
+          auto src_iter_undi_md_0 = mkldnn::memory::desc(
+              { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+          DType* src_iter_undi_0 = workptr;  //  nstates * N * H
+          auto src_iter_undi_memory_0 =
+              mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0);
+          op.concat_iter_memory.push_back(src_iter_undi_memory_0);
+          workptr = src_iter_undi_0 + nstates * N * H;
+          if (D == 1) {
+            op.hcx_memory.push_back(src_iter_undi_memory_0);
+          } else {
+            DType* src_iter_undi2_0 = workptr;  //  nstates * N * H
+            auto src_iter_undi2_memory_0 =
+                mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0);
+            op.concat_iter_memory.push_back(src_iter_undi2_memory_0);
+
+            mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H};  //  ldsnc
+            auto src_iter_md_0 = mkldnn::memory::desc(
+                { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H;  //  D * nstates * N * H
+            auto src_iter_memory_0 =
+                mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0);
+            op.concat_iter_memory.push_back(src_iter_memory_0);
+            op.hcx_memory.push_back(src_iter_memory_0);
+            workptr = src_iter_0 + D * nstates * N * H;
+          }
+
+          DType* dst_layer_0 = workptr;  //  T * N * D * H
+          auto dst_layer_memory_0
+              = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0);
+          op.y_memory.push_back(dst_layer_memory_0);
+
+          mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H};  //  ldsnc
+          auto dst_iter_md_0 = mkldnn::memory::desc(
+              { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+          DType* dst_iter_0 = dst_layer_0 + T * N * D * H;  //  D * nstates * N * H
+          auto dst_iter_memory_0 =
+              mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0);
+          op.hcy_memory.push_back(dst_iter_memory_0);
+          workptr = dst_iter_0 + D * nstates * N * H;
+
+          //  next L - 1 layers
+          if (L > 1 && D == 1) {
+            auto user_src_layer_md = mkldnn::memory::desc(
+                { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+            auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine });
+            op.x_memory.push_back(user_src_layer_memory);
+
+            mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H};  //  ldigo
+            mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H};  //  ldigo
+            mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H};
+            auto user_weight_layer_md = mkldnn::memory::desc(
+                { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+            auto user_weight_iter_md = mkldnn::memory::desc(
+                { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+            auto user_bias_md = mkldnn::memory::desc({ bias_tz },
+                mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+            DType* weight_layer_n = workptr;  //  (L - 1) * H * ngates * H
+            auto user_weight_layer_memory_n
+                = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
+            op.wx_memory.push_back(user_weight_layer_memory_n);
+
+            DType* weight_iter_n = weight_layer_n +
+                (L - 1) * H * ngates * H;  //  (L - 1) * H * ngates * H
+            auto user_weight_iter_memory_n
+                = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
+            op.wh_memory.push_back(user_weight_iter_memory_n);
+
+            DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H;  //  (L - 1) * ngates * H
+            auto user_bias_memory_n =
+                mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
+            op.bias_memory.push_back(user_bias_memory_n);
+
+            auto wx_md_n = mkldnn::memory::desc(
+                { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+            DType* wx_n = bias_n + (L - 1) * ngates * H;  //  (L - 1) * ngates * H * H
+            auto wx_memory_n =
+                mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
+            DType* wh_n = wx_n + (L - 1) * ngates * H * H;  //  (L - 1) * ngates * H * H
+            auto wh_md_n = mkldnn::memory::desc(
+                { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+            auto wh_memory_n =
+                mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
+
+            op.concat_weight_memory.push_back(wx_memory_n);
+            op.concat_weight_memory.push_back(wh_memory_n);
+            workptr = wh_n + (L - 1) * ngates * H * H;
+
+            mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H};  //  ldsnc
+            auto src_iter_md_n1 = mkldnn::memory::desc(
+                { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            for (int l = 0; l < L - 1; l++) {
+              DType* src_iter_n1 = workptr;  //  nstates * N * H
+              auto src_iter_memory_n1 =
+                  mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1);
+              op.concat_iter_memory.push_back(src_iter_memory_n1);
+              workptr = src_iter_n1 + nstates * N * H;
+            }
+            mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H};  //  ldsnc
+            auto src_iter_md_n = mkldnn::memory::desc(
+                { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            DType* src_iter_n = workptr;  //  (L - 1) * nstates * N * H
+            auto src_iter_memory_n =
+                mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n);
+            op.concat_iter_memory.push_back(src_iter_memory_n);
+            op.hcx_memory.push_back(src_iter_memory_n);
+
+            DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H;  //  T * N * D * H
+            auto dst_layer_memory_n
+                = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
+            op.y_memory.push_back(dst_layer_memory_n);
+
+            mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H};  //  ldsnc
+            auto dst_iter_md_n = mkldnn::memory::desc(
+                { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            DType* dst_iter_n = dst_layer_n + T * N * D * H;  //  (L - 1) * nstates * N * H
+            auto dst_iter_memory_n =
+                mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
+            op.hcy_memory.push_back(dst_iter_memory_n);
+          }
+
+          if (L > 1 && D == 2) {
+            mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H};  //  ldigo
+            mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H};  //  ldigo
+            mkldnn::memory::dims bias_tz = {1, D, ngates, H};
+            auto user_weight_layer_md = mkldnn::memory::desc(
+                { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+            auto user_weight_iter_md = mkldnn::memory::desc(
+                { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo);
+            auto user_bias_md = mkldnn::memory::desc({ bias_tz },
+                mkldnn_dtype, mkldnn::memory::format::ldgo);
+
+            auto user_src_layer_md = mkldnn::memory::desc(
+                { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc);
+            auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine });
+            op.x_memory.push_back(user_src_layer_memory);
+
+            auto wx_md_n = mkldnn::memory::desc(
+                { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+            auto wh_md_n = mkldnn::memory::desc(
+                { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi);
+
+            for (int l = 0; l < L; l++) {
+              DType* weight_layer_n = workptr;  //  D * (H * D) * ngates * H
+              auto user_weight_layer_memory_n
+                  = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n);
+              op.wx_memory.push_back(user_weight_layer_memory_n);
+
+              DType* weight_iter_n = weight_layer_n +
+                  D * (H * D) * ngates * H;  //  D * H * ngates * H
+              auto user_weight_iter_memory_n
+                  = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
+              op.wh_memory.push_back(user_weight_iter_memory_n);
+
+              DType* bias_n = weight_iter_n + D * H * ngates * H;  //  D * ngates * H
+              auto user_bias_memory_n =
+                  mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
+              op.bias_memory.push_back(user_bias_memory_n);
+              workptr = bias_n + D * ngates * H;
+            }
+
+            DType* wx_n = workptr;  //  D * ngates * (D * H) * H
+            DType* wh_n = wx_n + D * ngates * (D * H) * H;  //  D * ngates * H * H
+            auto wx_memory_n =
+                mkldnn::memory({ wx_md_n, cpu_engine }, wx_n);
+            auto wh_memory_n =
+                mkldnn::memory({ wh_md_n, cpu_engine }, wh_n);
+            op.concat_weight_memory.push_back(wx_memory_n);
+            op.concat_weight_memory.push_back(wh_memory_n);
+
+            mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H};  //  ldsnc
+            auto src_iter_undi_md = mkldnn::memory::desc(
+                { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            DType* src_iter_undi = wh_n + D * ngates * H * H;  //  nstates * N * H
+            auto src_iter_undi_memory =
+                mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi);
+            op.concat_iter_memory.push_back(src_iter_undi_memory_0);
+
+            DType* src_iter_undi2 = src_iter_undi + nstates * N * H;  //  nstates * N * H
+            auto src_iter_undi2_memory =
+                mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2);
+            op.concat_iter_memory.push_back(src_iter_undi2_memory);
+
+            mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H};  //  ldsnc
+            auto src_iter_md = mkldnn::memory::desc(
+                { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            DType* src_iter = src_iter_undi2 + nstates * N * H;  //  D * nstates * N * H
+            auto src_iter_memory =
+                mkldnn::memory({ src_iter_md, cpu_engine }, src_iter);
+            op.concat_iter_memory.push_back(src_iter_memory);
+            op.hcx_memory.push_back(src_iter_memory);
+
+            DType* dst_layer_n = src_iter + D * nstates * N * H;  //  T * N * D * H
+            auto dst_layer_memory_n
+                = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n);
+            op.y_memory.push_back(dst_layer_memory_n);
+
+            mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H};  //  ldsnc
+            auto dst_iter_md_n = mkldnn::memory::desc(
+                { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc);
+            DType* dst_iter_n = dst_layer_n + T * N * D * H;  //  D * nstates * N * H
+            auto dst_iter_memory_n =
+                mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n);
+            op.hcy_memory.push_back(dst_iter_memory_n);
+          }
+        }
+      }
+      op.Forward(ctx, in_blobs, req, out_blobs);
+    });
+  });
+}
+#endif
+
 NNVM_REGISTER_OP(RNN)
 .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are
 implemented, with both multi-layer and bidirectional support.
@@ -269,8 +695,13 @@ The definition of GRU here is slightly different from paper but compatible with
 })
 .set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
 .set_attr<nnvm::FInferType>("FInferType", RNNType)
+.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
 .set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
 .set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<bool>("TIsMKLDNN", true)
+.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeCPU)
+#endif
 .set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
 .set_attr<FResourceRequestEx>("FResourceRequestEx",
   [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h
index e1b4a2b..425ea4a 100644
--- a/src/operator/rnn_impl.h
+++ b/src/operator/rnn_impl.h
@@ -44,6 +44,13 @@
 namespace mxnet {
 namespace op {
 
+namespace rnn_enum {
+  enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength};
+  enum RNNOpOutputs {kOut, kStateOut, kStateCellOut};
+  enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru};
+  enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace};
+}
+
 template<typename DType>
 inline DType sigmoid(DType x) {
   return 1.0f / (1.0f + exp(-x));