You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/12 06:32:58 UTC

[incubator-mxnet] branch master updated: Support projection feature of LSTM (#17996)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0bff90d  Support projection feature of LSTM (#17996)
0bff90d is described below

commit 0bff90dcfe698c1a154b9536bae6a688f373cd70
Author: Zixuan Wei <zi...@intel.com>
AuthorDate: Sun Apr 12 14:31:54 2020 +0800

    Support projection feature of LSTM (#17996)
    
    * cpp unittest dependency
---
 3rdparty/mkldnn                               |   2 +-
 docs/static_site/src/pages/api/faq/env_var.md |   4 +
 src/operator/nn/mkldnn/mkldnn_rnn-inl.h       |  69 +++++----
 src/operator/nn/mkldnn/mkldnn_rnn.cc          | 196 +++++++++++++++++---------
 src/operator/rnn.cc                           |   6 +-
 tests/cpp/operator/mkldnn_test.cc             |   2 +-
 6 files changed, 179 insertions(+), 100 deletions(-)

diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn
index 07579e6..1b05a28 160000
--- a/3rdparty/mkldnn
+++ b/3rdparty/mkldnn
@@ -1 +1 @@
-Subproject commit 07579e6c0c6839a390a6f3040e05a2b2c71e628a
+Subproject commit 1b05a28eb9666efef83b281e4cc1936db5e6cf6c
diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md
index e0b70a6..7525521 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -362,6 +362,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
   - Values: 0(false) or 1(true) ```(default=0)```
   - If this variable is set to true, MXNet will perform fp16 accumulation when using cuBLAS and input datatype is set to float16. This could increase the speed of the computation, but might result in loss of accuracy. This makes this setting useful mainly for inference usecases.
 
+* MXNET_RNN_USE_WEIGHT_CACHE
+  - Values: 0(false) or 1(true) ```(default=0)```
+  - If this variable is set, MXNet will ignore the altering of the version of NDArray which is the input parameter of the RNN operator. In Gluon API, there is a `_rnn_param_concat` operator concatenating the weights and bias of RNN into a single parameter tensor that changes the version number. Since the values of the parameters are invariant in inference pass, the RNN operator could ignore the altering of the version to escape much overhead from re-initializing the parameters.
+
 Settings for Minimum Memory Usage
 ---------------------------------
 - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1```
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
index 420633f..f47801a 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
@@ -44,17 +44,20 @@ struct MKLDNNRnnLayerParam {
   bool bidirectional;
   bool state_outputs;
   int num_layer;
-  int batch_size;
-  int input_size;
+  index_t batch_size;
+  index_t input_size;
   int state_size;
-  int seq_len;
+  int proj_size;
+  index_t seq_len;
 
   dims src_dims;           // Dimensions of source input in format_tag::tnc
   dims weight_layer_dims;  // Dimensions of layer weights in format_tag::ldigo
   dims weight_iter_dims;   // Dimensions of iter weights in format_tag::ldigo
+  dims weight_proj_dims;   // Dimensions of projection weights in format_tag::ldio
   dims bias_dims;          // Dimensions of bias in format_tag::ldgo
   dims dst_dims;           // Dimensions of output in format_tag::tnc
   dims state_dims;         // Dimensions of the state cell in format_tag::ldnc
+  dims cell_dims;          // Dimensions of LSTM cell state in format_tag::ldnc
 
   size_t workspace_size;  // used for the cached mkl-dnn memory in Forward inference
   size_t reserve_size;    // used for the reserved cached memory in Backward
@@ -63,12 +66,12 @@ struct MKLDNNRnnLayerParam {
   size_t native_single_b_size;  // bias size of a single cell from framework
   size_t single_state_size;     // state size of a single cell, hy, cy
 
-  MKLDNNRnnLayerParam(int num_layer, int batch_size, int seq_len,
-                      int input_size, int state_size,
+  MKLDNNRnnLayerParam(int num_layer, index_t batch_size, index_t seq_len,
+                      index_t input_size, int state_size, int proj_size,
                       int mode, bool bidirectional = true)
       : mode(mode), bidirectional(bidirectional), state_outputs(true),
         num_layer(num_layer), batch_size(batch_size), input_size(input_size),
-        state_size(state_size), seq_len(seq_len) { }
+        state_size(state_size), proj_size(proj_size), seq_len(seq_len) { }
 
   void SetDims();
 };
@@ -79,8 +82,8 @@ struct MKLDNNRnnFullParam {
   LayerParamVector layer_params;
 };
 
-MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int seq_len,
-                                            const int batch_size, const int input_size);
+MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const index_t seq_len,
+                                            const index_t batch_size, const index_t input_size);
 
 /*
  * Use this to allocate memory from MKLDNNRnnOp temporary space.
@@ -100,7 +103,15 @@ class MKLDNNRnnMemMgr {
   std::vector<std::shared_ptr<const mkldnn::memory> > mem_holder;
 
  public:
-  void Init(dim_t size, const Context& ctx, int dtype = mshadow::kFloat32);
+  /*!
+   * \brief Initializer for RNN memory manager
+   * \param size byte number
+   * \param ctx Context of device enviroment
+   */
+  void Init(dim_t size, const Context& ctx);
+
+  // Return the bytes number of the buffer
+  const size_t Size() { return mem_size; }
 
   void RegisterMem(std::shared_ptr<const mkldnn::memory> mem) {
     mem_holder.push_back(mem);
@@ -129,6 +140,7 @@ class RnnPrimitive {
     auto fwd_pd = reinterpret_cast<typename rnn_fwd::primitive_desc*>(rnn_fwd_prim.fwd_pd_.get());
     rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc();
     rnn_fwd_prim.weights_iter_desc_  = fwd_pd->weights_iter_desc();
+    rnn_fwd_prim.weights_proj_desc_  = fwd_pd->weights_projection_desc();
     rnn_fwd_prim.workspace_desc_ = fwd_pd->workspace_desc();
 
     rnn_fwd_prim.primitive_ = std::shared_ptr<mkldnn::primitive>(new rnn_fwd(*fwd_pd));
@@ -141,6 +153,7 @@ class RnnPrimitive {
     this->primitive_ = nullptr;
     this->weights_layer_desc_ = mkldnn::memory::desc();
     this->weights_iter_desc_ = mkldnn::memory::desc();
+    this->weights_proj_desc_ = mkldnn::memory::desc();
     this->workspace_desc_ = mkldnn::memory::desc();
   }
 
@@ -149,6 +162,7 @@ class RnnPrimitive {
     this->primitive_ = rnn_fwd_prim.primitive_;
     this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
     this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
+    this->weights_proj_desc_ = rnn_fwd_prim.weights_proj_desc_;
     this->workspace_desc_ = rnn_fwd_prim.workspace_desc_;
   }
 
@@ -158,6 +172,7 @@ class RnnPrimitive {
       this->primitive_ = rnn_fwd_prim.primitive_;
       this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_;
       this->weights_iter_desc_ = rnn_fwd_prim.weights_iter_desc_;
+      this->weights_proj_desc_ = rnn_fwd_prim.weights_proj_desc_;
       this->workspace_desc_ = rnn_fwd_prim.workspace_desc_;
     }
 
@@ -175,6 +190,10 @@ class RnnPrimitive {
     return weights_iter_desc_;
   }
 
+  const mkldnn::memory::desc& GetProjDesc() const {
+    return weights_proj_desc_;
+  }
+
   const mkldnn::memory::desc& GetWorkspaceDesc() const {
     return workspace_desc_;
   }
@@ -184,6 +203,7 @@ class RnnPrimitive {
   std::shared_ptr<mkldnn::primitive> primitive_;
   mkldnn::memory::desc weights_layer_desc_;
   mkldnn::memory::desc weights_iter_desc_;
+  mkldnn::memory::desc weights_proj_desc_;
   mkldnn::memory::desc workspace_desc_;
 };
 
@@ -195,27 +215,29 @@ RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam &layer_param, const bool is
  */
 class MKLDNNRnnForward {
  public:
-  MKLDNNRnnForward(const MKLDNNRnnLayerParam &layer_param, const bool is_train,
-                   const NDArray &data, const NDArray &params)
-      : initialized_(false), param_(layer_param),
+  MKLDNNRnnForward(const Context ctx,
+                   const MKLDNNRnnLayerParam &layer_param,
+                   const bool is_train,
+                   const NDArray &data,
+                   const NDArray &params)
+      : ctx_(ctx), initialized_(false), param_(layer_param),
         fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) { }
 
   void SetNewDataMem(void* x, void* hx, void* cx,
                      void* y, void* hy, void* cy,
                      const int dtype = mshadow::kFloat32);
-  void SetWeightsMem(MKLDNNRnnMemMgr* mgr, void* w_ptr, void* b_ptr,
+  void SetWeightsMem(void* w_ptr, void* b_ptr,
                      const bool is_train = false,
                      const int dtype = mshadow::kFloat32);
   void ReorderWeights();
 
   const mkldnn::primitive& GetFwd() const { return fwd_inf_.GetPrim(); }
 
-  const size_t GetSize(int dtype) const {
-    size_t bytes = mshadow::mshadow_sizeof(dtype);
-    size_t size = 0;
-    size += fwd_inf_.GetLayerDesc().get_size();
-    size += fwd_inf_.GetIterDesc().get_size();
-    return size / bytes + 1;
+  const size_t GetSize() const {
+    const size_t size = fwd_inf_.GetLayerDesc().get_size()
+                        + fwd_inf_.GetIterDesc().get_size()
+                        + fwd_inf_.GetProjDesc().get_size();
+    return size;
   }
 
   const MKLDNNRnnLayerParam &GetParam() const { return param_; }
@@ -226,16 +248,20 @@ class MKLDNNRnnForward {
   void Reset() { initialized_ = false; }
 
  private:
+  Context ctx_;
   bool initialized_;
   MKLDNNRnnLayerParam param_;
   RnnPrimitive fwd_inf_;    // forward inference primitive
 
+  MKLDNNRnnMemMgr mem_mgr_;
   mkldnn::memory *weights_layer_ = nullptr;
   mkldnn::memory *weights_iter_ = nullptr;
+  mkldnn::memory *weights_proj_ = nullptr;
   mkldnn::memory *bias_ = nullptr;
 
   mkldnn::memory *weights_layer_r_ = nullptr;
   mkldnn::memory *weights_iter_r_ = nullptr;
+  mkldnn::memory *weights_proj_r_ = nullptr;
 
   /*
    * net_args must contain some keys as below:
@@ -447,11 +473,6 @@ inline bool SupportMKLDNNRnn(const int input_dtype) {
   return false;
 }
 
-inline bool SupportMKLDNNRnn(const RNNParam &param, const int input_dtype) {
-  if (param.projection_size.has_value()) return false;
-  return SupportMKLDNNRnn(input_dtype);
-}
-
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index 5d3857e..c830080 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -53,18 +53,22 @@ void MKLDNNRnnLayerParam::SetDims() {
   const int nbias = mode == rnn_enum::kGru ? (ngates + 1) : ngates;
   const int num_direction = bidirectional ? 2 : 1;
 
+  const int iter_size = proj_size < 0 ? state_size : proj_size;
   src_dims.assign({seq_len, batch_size, input_size});
   weight_layer_dims.assign({num_layer, num_direction, input_size, ngates, state_size});
-  weight_iter_dims.assign({num_layer, num_direction, state_size, ngates, state_size});
+  weight_iter_dims.assign({num_layer, num_direction, iter_size, ngates, state_size});
+  weight_proj_dims.assign({num_layer, num_direction, state_size, iter_size});
   bias_dims.assign({num_layer, num_direction, nbias, state_size});
-  dst_dims.assign({seq_len, batch_size, state_size * num_direction});
-  state_dims.assign({num_layer, num_direction, batch_size, state_size});
+  dst_dims.assign({seq_len, batch_size, iter_size * num_direction});
+  state_dims.assign({num_layer, num_direction, batch_size, iter_size});
+  cell_dims.assign({num_layer, num_direction, batch_size, state_size});
 
   // unidirectional size of a single cell
-  single_w_size = (input_size + state_size) * ngates * state_size;
+  single_w_size = (input_size + iter_size) * ngates * state_size;
+  if (proj_size > 0) single_w_size += state_size * proj_size;
   single_b_size = nbias * state_size;
   native_single_b_size = ngates * state_size * 2;  // native RNN variants have double bias
-  single_state_size = batch_size * state_size;
+  single_state_size = batch_size * iter_size;
 
   // Get workspace size for cached weights memory
   // multiplication of tensor dimensions
@@ -75,14 +79,19 @@ void MKLDNNRnnLayerParam::SetDims() {
 
   workspace_size = tz_volume(weight_layer_dims) + tz_volume(weight_iter_dims) +
       tz_volume(bias_dims);
+  if (proj_size > 0) workspace_size += tz_volume(weight_proj_dims);
   reserve_size = 0;
 }
 
-MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int seq_len,
-                                            const int batch_size, const int input_size) {
+MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const index_t seq_len,
+                                            const index_t batch_size, const index_t input_size) {
   MKLDNNRnnFullParam full_param;
   full_param.default_param = rnn_param;
-  size_t state_size = rnn_param.state_size;
+  const int state_size = rnn_param.state_size;
+  const int proj_size = rnn_param.projection_size.has_value() ?
+      rnn_param.projection_size.value() : -1;
+  const int iter_size = rnn_param.projection_size.has_value() ?
+      rnn_param.projection_size.value() : state_size;
   LayerParamVector &layer_params = full_param.layer_params;
 
   full_param.default_param.seq_length_ = seq_len;
@@ -90,20 +99,21 @@ MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int
   full_param.default_param.input_size_ = input_size;
   // Set basic size by constructing MKLDNNRnnLayerParam instance(s)
   if (rnn_param.bidirectional) {  // unfused bidirectional multi-layer RNN
-    layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, rnn_param.mode);
+    layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, proj_size,
+        rnn_param.mode);
     for (size_t layer = 1; layer < rnn_param.num_layers; ++layer) {
-      layer_params.emplace_back(1, batch_size, seq_len, state_size * 2, state_size,
+      layer_params.emplace_back(1, batch_size, seq_len, iter_size * 2, state_size, proj_size,
           rnn_param.mode);
     }
-  } else if (input_size == static_cast<int>(state_size)) {  // fused multi-layer RNN
+  } else if (input_size == iter_size) {  // fused multi-layer
     layer_params.emplace_back(rnn_param.num_layers, batch_size, seq_len, input_size,
-        state_size, rnn_param.mode, false);
-  } else {  // unfused 1st layer, plus fused 2-end layers
-    layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, rnn_param.mode,
-        false);
+        state_size, proj_size, rnn_param.mode, false);
+  } else {  // unfused 1st layer, plus fused 2~end layers
+    layer_params.emplace_back(1, batch_size, seq_len, input_size, state_size, proj_size,
+        rnn_param.mode, false);
     if (rnn_param.num_layers > 1)
-      layer_params.emplace_back(rnn_param.num_layers - 1, batch_size, seq_len, state_size,
-          state_size, rnn_param.mode, false);
+      layer_params.emplace_back(rnn_param.num_layers - 1, batch_size, seq_len, iter_size,
+          state_size, proj_size, rnn_param.mode, false);
   }
 
   // Set dims, workspace size, and state_outputs flag
@@ -114,11 +124,13 @@ MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const RNNParam& rnn_param, const int
   return full_param;
 }
 
-void MKLDNNRnnMemMgr::Init(dim_t size, const Context& ctx, int dtype) {
-  workspace_ = NDArray(TShape({size}), ctx, false, dtype);
+void MKLDNNRnnMemMgr::Init(dim_t size, const Context& ctx) {
+  workspace_ = NDArray(TShape({size}), ctx, false, mshadow::kUint8);
+  if (workspace_.data().dptr_ == nullptr)
+    LOG(FATAL) << "MKLDNN RNN operator memory allocation error.";
   curr_mem = static_cast<char *>(workspace_.data().dptr_);
-  mem_size = size * mshadow::mshadow_sizeof(dtype);
-  curr_size = size * mshadow::mshadow_sizeof(dtype);
+  mem_size = size;
+  curr_size = size;
 }
 
 mkldnn::memory *MKLDNNRnnMemMgr::Alloc(const mkldnn::memory::desc &md) {
@@ -162,16 +174,22 @@ RnnPrimitive GetRnnFwdPrim(
   auto bias_desc         = memory::desc(layer_param.bias_dims, data_type, tag::ldgo);
   auto dst_layer_desc    = memory::desc(layer_param.dst_dims, data_type, tag::tnc);
   auto src_state_desc    = memory::desc(layer_param.state_dims, data_type, tag::ldnc);
+  auto src_cell_desc     = memory::desc(layer_param.cell_dims, data_type, tag::ldnc);
+  auto weight_peep_desc  = memory::desc();
+  auto weight_proj_desc = layer_param.proj_size > 0 ? memory::desc(
+      layer_param.weight_proj_dims, weight_type, tag::any) : memory::desc();
   auto dst_state_desc = layer_param.state_outputs ? memory::desc(
       layer_param.state_dims, data_type, tag::ldnc) : memory::desc();
+  auto dst_cell_desc = layer_param.state_outputs ? memory::desc(
+      layer_param.cell_dims, data_type, tag::ldnc) : memory::desc();
 
   auto fwd = RnnPrimitive();
   switch (mode) {
     case rnn_enum::kLstm:
       fwd = RnnPrimitive::Create<lstm_forward>(prop, mkldnn_rnn_direction,
-          src_layer_desc, src_state_desc, src_state_desc, weight_layer_desc,
-          weight_iter_desc, bias_desc, dst_layer_desc, dst_state_desc,
-          dst_state_desc);
+          src_layer_desc, src_state_desc, src_cell_desc, weight_layer_desc,
+          weight_iter_desc, weight_peep_desc, weight_proj_desc, bias_desc,
+          dst_layer_desc, dst_state_desc, dst_cell_desc);
       break;
     case rnn_enum::kGru:
       fwd = RnnPrimitive::Create<lbr_gru_forward>(prop, mkldnn_rnn_direction,
@@ -287,7 +305,7 @@ static void ConcatWeights(const mkldnn::memory &dst,
   const memory::desc& dst_desc = dst.get_desc();
   // Use dst memory dims to initialize src memory dims, then set the concat
   // dim to 1. And Rnn weights are 5-dimension tensor.
-  memory::dims src_dims(dst_desc.data.dims, dst_desc.data.dims + 5);
+  memory::dims src_dims(dst_desc.data.dims, dst_desc.data.dims + dst_desc.data.ndims);
   src_dims.at(concat_dimension) = 1;
   std::vector<memory::desc> src_descs;
   std::unordered_map<int, memory> concat_args;
@@ -356,9 +374,9 @@ void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx,
   }
 
   if (param_.mode == rnn_enum::kLstm) {
-    RNN_FWD_SET(SRC_ITER_C, param_.state_dims, format_tag::ldnc, cx, dtype);
+    RNN_FWD_SET(SRC_ITER_C, param_.cell_dims, format_tag::ldnc, cx, dtype);
     if (param_.state_outputs) {
-      RNN_FWD_SET(DST_ITER_C, param_.state_dims, format_tag::ldnc, cy, dtype);
+      RNN_FWD_SET(DST_ITER_C, param_.cell_dims, format_tag::ldnc, cy, dtype);
     }
   }
 }
@@ -395,6 +413,7 @@ inline void MKLDNNMemoryReorder(const mkldnn::memory& src,
 void MKLDNNRnnForward::ReorderWeights() {
   MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_);
   MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_);
+  if (param_.proj_size > 0) MKLDNNMemoryReorder(*weights_proj_r_, *weights_proj_);
 }
 
 void AdjustGruGateOrder(char* weight,
@@ -469,66 +488,92 @@ inline void EmplaceNetArgs(mkldnn_args_map_t* net_args, const int arg_name,
  * memory with preferred format_tag. Finally, native bias is fused to MKLDNN
  * bias memory.
  */
-void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_ptr,
+void MKLDNNRnnForward::SetWeightsMem(void *w_ptr, void *b_ptr,
                                      const bool is_train, const int dtype) {
   using format_tag = mkldnn::memory::format_tag;
   auto mkldnn_dtype = get_mkldnn_type(dtype);
+  const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
+
+  const size_t buffer_bytes = this->GetSize()  // byte number of the buffer
+      + (param_.workspace_size + param_.reserve_size) * dtype_bytes
+      + kMKLDNNAlign * 7;     // Add margin for alignment of seven times allocation for the
+                              // dnnl memory handlers, i.e. weights_layer_, weights_iter_,
+                              // weights_proj_, bias_, weights_layer_r_, weights_iter_r_,
+                              // and weights_proj_r_.
+  if (mem_mgr_.Size() < buffer_bytes) mem_mgr_.Init(buffer_bytes, this->ctx_);
+
+  const bool use_proj = (param_.proj_size > 0);
   // Get the weights' memory for RNN forward primitive
   if (weights_layer_ == nullptr) {
-    weights_layer_ = mgr->Alloc(fwd_inf_.GetLayerDesc());
+    weights_layer_ = mem_mgr_.Alloc(fwd_inf_.GetLayerDesc());
   }
   if (weights_iter_ == nullptr) {
-    weights_iter_ = mgr->Alloc(fwd_inf_.GetIterDesc());
+    weights_iter_ = mem_mgr_.Alloc(fwd_inf_.GetIterDesc());
+  }
+  if (use_proj && weights_proj_ == nullptr) {
+    weights_proj_ = mem_mgr_.Alloc(fwd_inf_.GetProjDesc());
   }
   if (bias_ == nullptr) {
-    bias_ = mgr->Alloc(
+    bias_ = mem_mgr_.Alloc(
         {param_.bias_dims, mkldnn_dtype, format_tag::ldgo});
   }
 
   // Get the intermediate memory for weights concat & reorder
   if (weights_layer_r_ == nullptr) {
-    weights_layer_r_ = mgr->Alloc(
+    weights_layer_r_ = mem_mgr_.Alloc(
         {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi});
   }
   if (weights_iter_r_ == nullptr) {
-    weights_iter_r_ = mgr->Alloc(
+    weights_iter_r_ = mem_mgr_.Alloc(
         {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi});
   }
-
-  // Get the bytes of a real type
-  size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
+  if (use_proj && weights_proj_r_ == nullptr) {
+    weights_proj_r_ = mem_mgr_.Alloc(
+        {param_.weight_proj_dims, mkldnn_dtype, format_tag::ldoi});
+  }
 
   // convert void* to char* for arithmetic operations
+  const size_t iter_size = use_proj ? param_.proj_size : param_.state_size;
   char *weights_ptr = static_cast<char *>(w_ptr);
   size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
         param_.input_size * dtype_bytes;  //* DIMS: ngates x state_size x input_size
   size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
-        param_.state_size * dtype_bytes;  //* DIMS: ngates x state_size x state_size
+        iter_size * dtype_bytes;  //* DIMS: ngates x state_size x state_size, if not use projection.
+                                  // With projection, DIMS is ngates x state_size x projection_size
+  size_t wr_bytes = param_.state_size * iter_size * dtype_bytes;
   char *l2r_wx = weights_ptr;
   char *l2r_wh = l2r_wx + wx_bytes;       //* DIMS: ngates x state_size * state_size
+  char *l2r_wr = l2r_wh + wh_bytes;       //* DIMS: ngates x state_size * iter_size
 
   if (param_.num_layer == 1 && param_.bidirectional) {
     //* single bidirectinal layer, concat weights on direction axis
     char *r2l_wx = weights_ptr + param_.single_w_size * dtype_bytes;
-    char *r2l_wh = r2l_wx + wx_bytes;  //* DIMS: ngates x state_size * state_size
+    char *r2l_wh = r2l_wx + wx_bytes;  //* DIMS: ngates x state_size x state_size
+    char *r2l_wr = r2l_wh + wh_bytes;  //* DIMS: ngates x state_size x iter_size
     ConcatWeights(*weights_layer_r_, 1, {l2r_wx, r2l_wx}, format_tag::ldgoi);
     ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi);
+    if (use_proj) ConcatWeights(*weights_proj_r_, 1, {l2r_wr, r2l_wr}, format_tag::ldoi);
   } else if (param_.num_layer == 1 && !param_.bidirectional) {
     //* single uni-directional layer, no concatenate operator needed
     std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes);
     std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes);
+    if (use_proj) std::memcpy(weights_proj_r_->get_data_handle(), l2r_wr, wr_bytes);
   } else if (param_.num_layer > 1 && !param_.bidirectional) {
     //* concat fused multi-layer weights on layer axis
     std::vector<void *> l2r_wx_ptrs;
     std::vector<void *> l2r_wh_ptrs;
+    std::vector<void *> l2r_wr_ptrs;
     for (int lyr = 0; lyr < param_.num_layer; ++lyr) {
       char *lth_wx = l2r_wx + lyr * param_.single_w_size * dtype_bytes;
       char *lth_wh = lth_wx + wx_bytes;
+      char *lth_wr = lth_wh + wh_bytes;
       l2r_wx_ptrs.push_back(lth_wx);
       l2r_wh_ptrs.push_back(lth_wh);
+      if (use_proj) l2r_wr_ptrs.push_back(lth_wr);
     }
     ConcatWeights(*weights_layer_r_, 0, l2r_wx_ptrs, format_tag::ldgoi);
     ConcatWeights(*weights_iter_r_, 0, l2r_wh_ptrs, format_tag::ldgoi);
+    if (use_proj) ConcatWeights(*weights_proj_r_, 0, l2r_wr_ptrs, format_tag::ldoi);
   } else {
     LOG(FATAL) << "Undifined RNN fusion workflow for num_layer = " << param_.num_layer
                << ", and bidirectional is " << param_.bidirectional;
@@ -565,6 +610,7 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr, void *w_ptr, void *b_
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_LAYER, this->weights_layer_);
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_WEIGHTS_ITER,  this->weights_iter_);
   EmplaceNetArgs(&this->net_args_, MKLDNN_ARG_BIAS,          this->bias_);
+  if (use_proj) EmplaceNetArgs(&this->net_args_, DNNL_ARG_WEIGHTS_PROJECTION, this->weights_proj_);
 
   if (!is_train) {
     // Reorder after adjustment only when is_train == false. When is_train == true, i.e.
@@ -627,7 +673,7 @@ void MKLDNNRnnForwardTraining::FetchData(const MKLDNNRnnForward& fwd) {
   }
 }
 
-void MKLDNNRnnOp::Init(const OpContext &ctx,
+void MKLDNNRnnOp::Init(const OpContext &op_ctx,
                        const std::vector<NDArray> &inputs,
                        const std::vector<OpReqType> &req,
                        const std::vector<NDArray> &outputs) {
@@ -635,22 +681,14 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
 
   // In the `autograd.record()` context, RNNOp is required to run into
   // `forward_training` mode.
-  const bool is_training = (ctx.is_train || ctx.need_grad);
+  const bool is_training = (op_ctx.is_train || op_ctx.need_grad);
   const size_t num_fusion = full_param_.layer_params.size();
+  const Context& ctx = op_ctx.run_ctx.ctx;
   if (fwd_inf_vec_.size() < num_fusion) {
-    size_t buffer_size = 0;  // Element number, instead of bytes, in the buffer
     for (auto& layer_param : full_param_.layer_params) {
-      buffer_size += layer_param.workspace_size + layer_param.reserve_size;
-    }
-    buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1);
-    buffer_size += kMKLDNNAlign * num_fusion * 5;  // Add margin for alignment
-
-    for (auto& layer_param : full_param_.layer_params) {
-      fwd_inf_vec_.emplace_back(layer_param,
-          ctx.is_train, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]);
-      buffer_size += fwd_inf_vec_.back().GetSize(inputs[rnn_enum::kParams].dtype());
+      fwd_inf_vec_.emplace_back(ctx, layer_param, false, inputs[rnn_enum::kData],
+          inputs[rnn_enum::kParams]);
     }
-    mgr_.Init(buffer_size, ctx.run_ctx.ctx, inputs[rnn_enum::kParams].dtype());
   }
 
   if (is_training && fwd_trn_vec_.size() < num_fusion) {
@@ -678,7 +716,7 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
     size_t layer_bias_bytes = single_b_bytes * directions;  // Native MXNet has double bias
 
     if (!fwd_layer.IsInitialized() || is_training)
-      fwd_layer.SetWeightsMem(&(this->mgr_), weights_ptr, bias_ptr, is_training, dtype);
+      fwd_layer.SetWeightsMem(weights_ptr, bias_ptr, is_training, dtype);
     weights_ptr += layer_weights_bytes;
     bias_ptr += layer_bias_bytes;
   }
@@ -694,6 +732,10 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
       "Layer vector's size has a different value than the number of fusion.";
   if (dst_.size() < num_fusion - 1) {
     int data_dtype = outputs[rnn_enum::kOut].dtype();
+    const size_t data_dbytes = mshadow::mshadow_sizeof(data_dtype);
+    mgr_.Init(
+        (outputs[rnn_enum::kOut].data().Size() * data_dbytes + kMKLDNNAlign) * (num_fusion - 1),
+        op_ctx.run_ctx.ctx);
     // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate results of the multiple
     // fused layers. And for the result of the last fused layer, `outputs[rnn_enum::kOut]` could
     // provide the space. Hence, `forward_inf_vec_.back()` is excluded when allocates the spaces
@@ -958,6 +1000,8 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
   // forward_training mode.
   const bool is_training = (ctx.is_train || ctx.need_grad);
   const RNNParam& default_param = full_param_.default_param;
+  if (is_training && default_param.projection_size.has_value())
+    LOG(FATAL) << "Backward/Training mode is not implemented!";
 
   // Initialize weights version
   if (!initialized_ && weights_version_ == 0) {
@@ -972,7 +1016,13 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
     weights_version_ = inputs[rnn_enum::kParams].version();
   }
 
-  if (!initialized_ || is_training || fwd_inf_vec_.size() == 0) {
+  if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
+    LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during "
+        "the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if "
+        "the weight changed at runtime.";
+  }
+  if ((!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) ||
+      is_training || fwd_inf_vec_.size() == 0) {
     Init(ctx, inputs, req, outputs);
   }
 
@@ -983,10 +1033,14 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
   const int seq_length = default_param.seq_length_;
   const int batch_size = default_param.batch_size_;
   const int state_size = default_param.state_size;
+  const int iter_size  = default_param.projection_size.has_value() ?
+      default_param.projection_size.value() : default_param.state_size;
   const int directions = default_param.bidirectional ? 2 : 1;
-  mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * state_size},
+  mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * iter_size},
       get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::tnc);
-  mkldnn::memory::desc state_desc({num_layers, directions, batch_size, state_size},
+  mkldnn::memory::desc state_desc({num_layers, directions, batch_size, iter_size},
+      get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
+  mkldnn::memory::desc cell_desc({num_layers, directions, batch_size, state_size},
       get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
   auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]);
   mkldnn_output_t stateout_mem;
@@ -1010,7 +1064,7 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
     src_state_cell = static_cast<char *>(inputs[rnn_enum::kStateCell].data().dptr_);
     if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) {
       statecellout_mem = CreateMKLDNNMem(
-          outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]);
+          outputs[rnn_enum::kStateCellOut], cell_desc, req[rnn_enum::kStateCellOut]);
       dst_state_cell = static_cast<char *>(statecellout_mem.second->get_data_handle());
     }
   }
@@ -1023,8 +1077,10 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
     }
   } else {
     CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error.";
+    size_t state_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
+        iter_size * mshadow::mshadow_sizeof(data_dtype);
     size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
-        default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
+        state_size * mshadow::mshadow_sizeof(data_dtype);
 
     // Set input data memory for the first layer. This stores intermediate output
     // results in this->xxx, used as the source input of the next layer.
@@ -1035,9 +1091,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
     }
     // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ...
     for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) {
-      src_state += cell_bytes;
+      src_state += state_bytes;
+      if (dst_state) dst_state += state_bytes;
       if (src_state_cell) src_state_cell += cell_bytes;
-      if (dst_state) dst_state += cell_bytes;
       if (dst_state_cell) dst_state_cell += cell_bytes;
       fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(),
           src_state, src_state_cell,
@@ -1047,9 +1103,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
       }
     }
     // Set output data memory for the last layer.
-    src_state += cell_bytes;
+    src_state += state_bytes;
+    if (dst_state) dst_state += state_bytes;
     if (src_state_cell) src_state_cell += cell_bytes;
-    if (dst_state) dst_state += cell_bytes;
     if (dst_state_cell) dst_state_cell += cell_bytes;
     fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(),
         src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype);
@@ -1146,7 +1202,7 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
     bwd_vec_.back().SetDataGradsMem(dx, dhx, dcx, dy, dhy, dcy, data_dtype);
     RegisterMKLDNNRnn(bwd_vec_.back());
   } else {
-    const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
+    const size_t state_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
         default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
     if (diff_src == nullptr) {
       auto desc = mkldnn::memory::desc(full_param_.layer_params.back().src_dims,
@@ -1157,17 +1213,17 @@ void MKLDNNRnnOp::Backward(const OpContext& ctx,
     bwd_vec_.front().SetDataGradsMem(dx, dhx, dcx,
         diff_src->get_data_handle(), dhy, dcy, data_dtype);
     for (size_t lyr = 1; lyr < bwd_vec_.size() - 1; ++lyr) {
-      if (dhx) dhx += cell_bytes;
-      if (dcx) dcx += cell_bytes;
-      if (dhy) dhy += cell_bytes;
-      if (dcy) dcy += cell_bytes;
+      if (dhx) dhx += state_bytes;
+      if (dcx) dcx += state_bytes;
+      if (dhy) dhy += state_bytes;
+      if (dcy) dcy += state_bytes;
       bwd_vec_.at(lyr).SetDataGradsMem(diff_src->get_data_handle(), dhx, dcx,
           diff_src->get_data_handle(), dhy, dcy, data_dtype);
     }
-    if (dhx) dhx += cell_bytes;
-    if (dcx) dcx += cell_bytes;
-    if (dhy) dhy += cell_bytes;
-    if (dcy) dcy += cell_bytes;
+    if (dhx) dhx += state_bytes;
+    if (dcx) dcx += state_bytes;
+    if (dhy) dhy += state_bytes;
+    if (dcy) dcy += state_bytes;
     bwd_vec_.back().SetDataGradsMem(diff_src->get_data_handle(), dhx, dcx,
         dy, dhy, dcy, data_dtype);
 
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 31adfe3..ecd38a8 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -197,9 +197,7 @@ inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
                                   DispatchMode* dispatch_mode,
                                   std::vector<int> *in_attrs,
                                   std::vector<int> *out_attrs) {
-  const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
-  const bool support_mkldnn_rnn =
-      !param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
+  const bool support_mkldnn_rnn = dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
   return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn,
                            dispatch_mode, in_attrs, out_attrs);
 }
@@ -246,7 +244,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
   }
 
 #if MXNET_USE_MKLDNN == 1
-  if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) {
+  if (ctx.dev_type == kCPU && SupportMKLDNNRnn(in_types[rnn_enum::kData])) {
     const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
     state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
         data_shape[1], data_shape[2]);
diff --git a/tests/cpp/operator/mkldnn_test.cc b/tests/cpp/operator/mkldnn_test.cc
index 973c398..73b9d93 100644
--- a/tests/cpp/operator/mkldnn_test.cc
+++ b/tests/cpp/operator/mkldnn_test.cc
@@ -100,7 +100,7 @@ static void VerifyDefMem(const mkldnn::memory &mem) {
 
 TEST(MKLDNN_UTIL_FUNC, MemFormat) {
   // Check whether the number of format is correct.
-  CHECK_EQ(mkldnn_format_tag_last, 154);
+  CHECK_EQ(mkldnn_format_tag_last, 168);
   CHECK_EQ(mkldnn_nchw, 5);
   CHECK_EQ(mkldnn_oihw, 5);
 }