You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/21 11:21:13 UTC

[incubator-mxnet] branch v1.6.x updated: [v1.6.x] Quantized LSTMP operator (#18107)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new 0bdde3e  [v1.6.x] Quantized LSTMP operator (#18107)
0bdde3e is described below

commit 0bdde3e35f098a6e8e09720a6c3c3960ea6d8638
Author: Zixuan Wei <zi...@intel.com>
AuthorDate: Tue Apr 21 19:20:15 2020 +0800

    [v1.6.x] Quantized LSTMP operator (#18107)
    
    * Quantized LSTMP operator
    
    * Checkout onednn dev-lstmp-int8 branch
    
    * Fix wrong size of projection weights
    
    * Add unit test for INT8 LSTMP
---
 3rdparty/mkldnn                                    |   2 +-
 src/operator/nn/mkldnn/mkldnn_rnn-inl.h            |   3 +-
 src/operator/nn/mkldnn/mkldnn_rnn.cc               |  25 ++++-
 .../quantization/mkldnn/mkldnn_quantized_rnn.cc    | 123 +++++++++++++++------
 src/operator/quantization/quantized_rnn.cc         |  11 +-
 tests/python/quantization/test_quantization.py     |  11 +-
 6 files changed, 129 insertions(+), 46 deletions(-)

diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn
index f7c41dc..33aca9b 160000
--- a/3rdparty/mkldnn
+++ b/3rdparty/mkldnn
@@ -1 +1 @@
-Subproject commit f7c41dc7b5471ad8bf7905e459bbed27f9094caa
+Subproject commit 33aca9ba23a977282bf1d64c34cecc2abb4e019a
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
index f039184..7e90e67 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
@@ -322,7 +322,8 @@ class MKLDNNRnnForward {
 
   mkldnn::memory *weights_layer_r_ = nullptr;
   mkldnn::memory *weights_iter_r_ = nullptr;
-  mkldnn::memory *weights_proj_r_ = nullptr;
+  mkldnn::memory *weights_proj_r_ = nullptr;    // format_tag::ldoi
+  mkldnn::memory *weights_proj_io_ = nullptr;   // format_tag::ldio, used in quantization
 
   /*
    * net_args must contain some keys as below:
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index 3775e9c..2435a6a 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -81,7 +81,11 @@ void MKLDNNRnnLayerParam::SetDims() {
 
   workspace_size = tz_volume(weight_layer_dims) + tz_volume(weight_iter_dims) +
       tz_volume(bias_dims);
-  if (proj_size > 0) workspace_size += tz_volume(weight_proj_dims);
+  if (proj_size > 0) {
+    workspace_size += tz_volume(weight_proj_dims);
+    //* NOTE: Quantized Op needs one more projection weights in ldio format.
+    if (quantized) workspace_size += tz_volume(weight_proj_dims);
+  }
   reserve_size = 0;
 }
 
@@ -134,10 +138,10 @@ MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const NodeAttrs& attrs, const int se
 
   // Set dims, workspace size, state_outputs, quantized and enable_u8_output flag
   for (auto& layer_param : layer_params) {
-    layer_param.SetDims();
     layer_param.state_outputs = rnn_param.state_outputs;
     layer_param.quantized = full_param.mkldnn_param.quantized;
     layer_param.enable_u8_output = true;
+    layer_param.SetDims();
   }
   // Quantized RNN operator produces kFloat32 outputs.
   if (full_param.mkldnn_param.quantized) layer_params.back().enable_u8_output = false;
@@ -211,7 +215,6 @@ RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam &layer_param,
   auto dst_cell_desc = layer_param.state_outputs ? memory::desc(
       layer_param.cell_dims, iter_dtype, tag::ldnc) : memory::desc();
 
-
   auto fwd = RnnPrimitive();
   switch (mode) {
     case rnn_enum::kLstm:
@@ -433,6 +436,10 @@ void MKLDNNRnnForward::ReorderWeights() {
       };
     ReorderWithAttr(*weights_layer_r_, *weights_layer_);
     ReorderWithAttr(*weights_iter_r_, *weights_iter_);
+    if (param_.proj_size > 0) {
+      MKLDNNMemoryReorder(*weights_proj_r_, *weights_proj_io_);
+      ReorderWithAttr(*weights_proj_io_, *weights_proj_);
+    }
   } else {
     MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_);
     MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_);
@@ -518,12 +525,14 @@ void MKLDNNRnnForward::SetWeightsMem(void *w_ptr, void *b_ptr,
   const auto mkldnn_dtype = get_mkldnn_type(dtype);
   const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
 
-  const size_t buffer_bytes = this->GetSize()  // byte number of the buffer
+  size_t buffer_bytes = this->GetSize()  // byte number of the buffer
       + (param_.workspace_size + param_.reserve_size) * dtype_bytes
       + kMKLDNNAlign * 7;     // Add margin for alignment of seven times allocation for the
                               // dnnl memory handlers, i.e. weights_layer_, weights_iter_,
                               // weights_proj_, bias_, weights_layer_r_, weights_iter_r_,
                               // and weights_proj_r_.
+  if (param_.quantized && param_.proj_size > 0)
+    buffer_bytes += kMKLDNNAlign;  // Quantized Op needs another one for weights_proj_io_.
   if (mem_mgr_.Size() < buffer_bytes) mem_mgr_.Init(buffer_bytes, this->ctx_);
 
   const bool use_proj = (param_.proj_size > 0);
@@ -555,6 +564,10 @@ void MKLDNNRnnForward::SetWeightsMem(void *w_ptr, void *b_ptr,
     weights_proj_r_ = mem_mgr_.Alloc(
         {param_.weight_proj_dims, mkldnn_dtype, format_tag::ldoi});
   }
+  if (param_.quantized && use_proj && weights_proj_io_ == nullptr) {
+    weights_proj_io_ = mem_mgr_.Alloc(
+        {param_.weight_proj_dims, mkldnn_dtype, format_tag::ldio});
+  }
 
   // convert void* to char* for arithmetic operations
   const size_t iter_size = use_proj ? param_.proj_size : param_.state_size;
@@ -1034,9 +1047,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
   }
 
   if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
-    LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during "
+    common::LogOnce("The current weight of RNN is assumed to be fixed and cached during "
         "the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if "
-        "the weight changed at runtime.";
+        "the weight changed at runtime.");
   }
   // Check if weights NDArray was changed. If so, reset initialized_
   if (!is_training && fwd_inf_vec_.size() > 0
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
index 9ae9bb2..ebf4cce 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
@@ -32,31 +32,46 @@
 namespace mxnet {
 namespace op {
 
-std::vector<float> GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_param,
-                                              float* w_ptr) {
+/*!
+ * \brief Quantization parameters of rnn weights' scales in an order of weights_qparams,
+ *        weights_projection_qparams.
+ */
+typedef std::tuple<std::vector<float>, std::vector<float> > rnn_weights_qparams_t;
+
+rnn_weights_qparams_t GetMKLDNNRnnWeightsQParams(
+    const MKLDNNRnnFullParam& full_param, float* w_ptr) {
   const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
   const RNNParam& default_param = full_param.default_param;
   const LayerParamVector& layer_params = full_param.layer_params;
+  const bool use_proj = default_param.projection_size.has_value();
+  const size_t iter_size = use_proj ?
+      default_param.projection_size.value() : default_param.state_size;
 
   const MKLDNNRnnLayerParam& layer_param0 = layer_params.at(0);
   const size_t w_size0 = layer_param0.single_w_size;
   const size_t wx_size0 = 4 * layer_param0.state_size * layer_param0.input_size;
-  const size_t wh_size0 = 4 * layer_param0.state_size * layer_param0.state_size;
+  const size_t wh_size0 = 4 * layer_param0.state_size * iter_size;
+  const size_t wr_size = default_param.state_size * iter_size;
 
   int directions = 1;
   float* wx = w_ptr;
   float* wh = wx + wx_size0;
+  float* wr = wh + wh_size0;
   float* fake_wx = wx;
   float* fake_wh = wh;
+  float* fake_wr = wr;
 
   std::vector<float> wx_goi_max;
   std::vector<float> wh_goi_max;
+  std::vector<float> wr_oi_max;
   if (default_param.bidirectional) {
     directions = 2;
     wx_goi_max.resize(wx_size0);
     wh_goi_max.resize(wh_size0);
+    wr_oi_max.resize(wr_size);
     fake_wx = wx_goi_max.data();
     fake_wh = wh_goi_max.data();
+    fake_wr = wr_oi_max.data();
     #pragma omp parallel for num_threads(nthreads)
     for (index_t i = 0; i < static_cast<index_t>(wx_size0); ++i) {
       fake_wx[i] = MaxAbs(wx[i], wx[i + w_size0]);
@@ -65,8 +80,15 @@ std::vector<float> GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_par
     for (index_t i = 0; i < static_cast<index_t>(wh_size0); ++i) {
       fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]);
     }
+    if (use_proj) {
+      #pragma omp parallel for num_threads(nthreads)
+      for (index_t i = 0; i < static_cast<index_t>(wr_size); ++i) {
+        fake_wr[i] = MaxAbs(wr[i], wr[i + w_size0]);
+      }
+    }
   }
   std::vector<float> w_max(4 * layer_param0.state_size, 0.0);
+  std::vector<float> proj_max(iter_size, 0.0);
   const index_t input_size = layer_param0.input_size;          // input
   const index_t state_size = layer_param0.state_size;          // state
   const index_t gates_nblks = 4 * layer_param0.state_size;     // gates * state
@@ -75,44 +97,68 @@ std::vector<float> GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_par
     for (index_t i = 0; i < input_size; ++i) {
       tmp_max = MaxAbs(fake_wx[go * input_size + i], tmp_max);
     }
-    for (index_t i = 0; i < state_size; ++i) {
-      tmp_max = MaxAbs(fake_wh[go * state_size + i], tmp_max);
+    for (index_t i = 0; i < static_cast<index_t>(iter_size); ++i) {
+      tmp_max = MaxAbs(fake_wh[go * iter_size + i], tmp_max);
     }
     w_max[go] = tmp_max;
   }
+  if (use_proj) {
+    for (index_t i = 0; i < static_cast<index_t>(iter_size); ++i) {
+      for (index_t s = 0; s < state_size; ++s) {
+        proj_max[i] = MaxAbs(fake_wr[iter_size * state_size + s], proj_max[i]);
+      }
+    }
+  }
   wx += layer_param0.single_w_size * directions;
   wh += layer_param0.single_w_size * directions;
+  wr += layer_param0.single_w_size * directions;
 
-  std::vector<float> goi_max(wh_size0, 0.0);
+  const size_t wx_size1 = 4 * default_param.state_size * default_param.state_size;
+  const size_t wh_size1 = wh_size0;
+  std::vector<float> go_max(gates_nblks, 0.0);
   for (size_t lyr = 1; lyr < layer_params.size(); ++lyr) {
     const MKLDNNRnnLayerParam& layer_param = layer_params.at(lyr);
     const int weight_nblks = layer_param.num_layer * directions;
     for (int blk = 0; blk < weight_nblks; ++blk) {
-      #pragma omp parallel for num_threads(nthreads)
-      for (index_t i = 0; i < static_cast<index_t>(wh_size0); ++i) {
-        goi_max[i] = MaxAbs(wx[i], wh[i]);
+      for (index_t go = 0; go < gates_nblks; ++go) {
+        float tmp = Abs(wx[0]);
+        for (index_t i = 1; i < layer_param.input_size; ++i) {
+          tmp = MaxAbs(wx[go * layer_param.input_size + i], tmp);
+        }
+        go_max[go] = Max(tmp, go_max[go]);
+      }
+      for (index_t go = 0; go < gates_nblks; ++go) {
+        float tmp = Abs(wh[0]);
+        for (index_t i = 1; i < static_cast<index_t>(iter_size); ++i) {
+          tmp = MaxAbs(wh[go * iter_size + i], tmp);
+        }
+        go_max[go] = Max(tmp, go_max[go]);
       }
+      #pragma omp parallel for num_threads(nthreads)
       for (index_t go = 0; go < gates_nblks; ++go) {
-        float tmp = w_max[go];
-        //* NOTES: min/max reductions were supported since OpenMP 3.1, which was released in
-        //  Jul 2011 (hence the version number).
-        #if _OPENMP >= 201107
-        #pragma omp parallel for reduction(max : tmp) num_threads(nthreads)
-        #endif
-        for (index_t i = 0; i < state_size; ++i) {
-          tmp = Max(goi_max[go * state_size + i], tmp);
+        w_max[go] = Max(go_max[go], w_max[go]);
+      }
+      if (use_proj) {
+        for (index_t i = 0; i < static_cast<index_t>(iter_size); ++i) {
+          for (index_t s = 0; s < state_size; ++s) {
+            proj_max[i] = MaxAbs(fake_wr[iter_size * state_size + s], proj_max[i]);
+          }
         }
-        w_max[go] = tmp;
       }
+      wx += layer_param.single_w_size;
+      wh = wx + wx_size1;
+      wr = wh + wh_size1;
     }
-    wx += layer_param.single_w_size * directions;
-    wh = wx + wh_size0;
   }
   #pragma omp parallel for num_threads(nthreads)
   for (index_t i = 0; i < static_cast<index_t>(w_max.size()); ++i) {
     w_max[i] = mshadow::red::limits::MaxValue<int8_t>() / w_max[i];
   }
-  return w_max;
+  #pragma omp parallel for num_threads(nthreads)
+  for (index_t i = 0; i < static_cast<index_t>(proj_max.size()); ++i) {
+    proj_max[i] = mshadow::red::limits::MaxValue<int8_t>() / proj_max[i];
+  }
+  return std::make_tuple(w_max, proj_max);
 }
 
 void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
@@ -133,9 +179,9 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
   float *bias_ptr = weights_ptr + weights_size;
 
   if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
-    LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during "
+    common::LogOnce("The current weight of RNN is assumed to be fixed and cached during "
         "the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if "
-        "the weight changed at runtime.";
+        "the weight changed at runtime.");
   }
   const bool need_reset_weight = (!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) &&
       weights_ver_ != inputs[rnn_enum::kParams].version()) ? true : false;
@@ -154,9 +200,14 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
     cached_data_scale_ = data_scale;
     cached_data_shift_ = data_shift;
     rnn_attr_->set_rnn_data_qparams(data_scale, data_shift);
-    if (need_reset_weight || rnn_layers_.empty())
-      rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4),
-          GetMKLDNNRnnWeightsQParams(full_param_, weights_ptr));
+    if (need_reset_weight || rnn_layers_.empty()) {
+      rnn_weights_qparams_t weights_qparams =
+          GetMKLDNNRnnWeightsQParams(full_param_, weights_ptr);
+      rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4), std::get<0>(weights_qparams));
+      if (default_param.projection_size.has_value()) {
+        rnn_attr_->set_rnn_weights_projection_qparams(0 + (1 << 3), std::get<1>(weights_qparams));
+      }
+    }
   }
 
   // Get data type
@@ -167,10 +218,14 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
   const int seq_length = default_param.seq_length_;
   const int batch_size = default_param.batch_size_;
   const int state_size = default_param.state_size;
+  const int iter_size  = default_param.projection_size.has_value() ?
+      default_param.projection_size.value() : state_size;
   const int directions = default_param.bidirectional ? 2 : 1;
-  mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * state_size},
+  mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * iter_size},
       get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::tnc);
-  mkldnn::memory::desc state_desc({num_layers, directions, batch_size, state_size},
+  mkldnn::memory::desc state_desc({num_layers, directions, batch_size, iter_size},
+      get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
+  mkldnn::memory::desc cell_desc({num_layers, directions, batch_size, state_size},
       get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
   auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]);
   mkldnn_output_t stateout_mem;
@@ -183,8 +238,10 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
   char *dst_state = nullptr;          // Output state
   char *src_state_cell = nullptr;     // Used in LSTM for cell state
   char *dst_state_cell = nullptr;     // Used in LSTM for cell state
-  const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
-      default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
+  const size_t state_bytes = (default_param.bidirectional + 1) * batch_size *
+      iter_size * mshadow::mshadow_sizeof(data_dtype);
+  const size_t cell_bytes = (default_param.bidirectional + 1) * batch_size *
+      state_size * mshadow::mshadow_sizeof(data_dtype);
 
   const LayerParamVector& layer_params = full_param_.layer_params;
   for (size_t lyr = 0; lyr < layer_params.size(); ++lyr) {
@@ -218,7 +275,7 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
       src_state_cell = static_cast<char *>(inputs[rnn_enum::kStateCell].data().dptr_);
       if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) {
         statecellout_mem = CreateMKLDNNMem(
-            outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]);
+            outputs[rnn_enum::kStateCellOut], cell_desc, req[rnn_enum::kStateCellOut]);
         dst_state_cell = static_cast<char *>(statecellout_mem.second->get_data_handle());
       }
     }
@@ -229,9 +286,9 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
     MKLDNNStream::Get()->RegisterPrimArgs(rnn_layer.GetFwd(), rnn_layer.GetArgsMap());
 
     if (lyr < default_param.num_layers - 1U) {
-      src_state += cell_bytes;
+      src_state += state_bytes;
+      if (dst_state) dst_state += state_bytes;
       if (src_state_cell) src_state_cell += cell_bytes;
-      if (dst_state) dst_state += cell_bytes;
       if (dst_state_cell) dst_state_cell += cell_bytes;
     }
   }
diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc
index b2864ff..31944c9 100644
--- a/src/operator/quantization/quantized_rnn.cc
+++ b/src/operator/quantization/quantized_rnn.cc
@@ -92,7 +92,9 @@ bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs,
   const dim_t directions = param.bidirectional ? 2 : 1;
   const dim_t total_lyrs = directions * param.num_layers;
   const dim_t state_size = param.state_size;
-  SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, state_size));
+  const dim_t iter_size  = param.projection_size.has_value() ?
+      param.projection_size.value() : state_size;
+  SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, iter_size));
   if (param.mode == rnn_enum::kLstm)
     SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kStateCell,
         Shape3(total_lyrs, batch_size, state_size));
@@ -109,9 +111,9 @@ bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs,
     SHAPE_ASSIGN_CHECK(*in_shape, i, Shape1(1));
 
   out_shape->clear();
-  out_shape->push_back({dshape[0], batch_size, directions * state_size});  // output dim: [T, N, C]
+  out_shape->push_back({dshape[0], batch_size, directions * iter_size});  // output dim: [T, N, C]
   if (param.state_outputs) {
-    out_shape->push_back({total_lyrs, batch_size, state_size});    // state dim: [L*D, N, C]
+    out_shape->push_back({total_lyrs, batch_size, iter_size});    // state dim: [L*D, N, C]
     if (param.mode == rnn_enum::kLstm)
       out_shape->push_back({total_lyrs, batch_size, state_size});  // cell dim: [L*D, N, C]
   }
@@ -216,6 +218,9 @@ OpStatePtr CreateQuantizedRnnState(const nnvm::NodeAttrs& attrs,
     const mxnet::TShape& data_shape = in_shapes[quantized_rnn::kData];
     state = OpStatePtr::Create<MKLDNNQuantizedRnnOp>(attrs, data_shape[0],
         data_shape[1], data_shape[2]);
+  } else {
+    LOG(FATAL) << "MKLDNN quantized rnn operator only supports inputs in U8 type and weight"
+        " in FP32 type.";
   }
 #else
   LOG(FATAL) << "Quantized RNN operator relies on MKL-DNN library."
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 48968c2..bbe3008 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -519,7 +519,7 @@ def test_quantized_fc():
 
 @with_seed()
 def test_quantized_rnn():
-    def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_dim):
+    def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_dim, projection_dim=None):
         if is_test_for_gpu():
             print('skipped testing test_quantized_rnn for gpu since it is not supported yet')
             return
@@ -534,6 +534,7 @@ def test_quantized_rnn():
                               bidirectional=bidirectional,
                               state_outputs=True,
                               state_size=state_dim,
+                              projection_size=projection_dim,
                               mode='lstm',
                               name='rnn')
         arg_shapes, _, _ = rnn_fp32.infer_shape(data=data_shape)
@@ -559,6 +560,7 @@ def test_quantized_rnn():
                                                 bidirectional=bidirectional,
                                                 state_outputs=True,
                                                 state_size=state_dim,
+                                                projection_size=projection_dim,
                                                 mode='lstm',
                                                 name='qrnn')
         qarg_names = rnn_int8.list_arguments()
@@ -575,10 +577,15 @@ def test_quantized_rnn():
         qoutput = rnn_int8_exe.forward()[0]
 
         mse = np.mean((output.asnumpy() - qoutput.asnumpy())**2)
-        assert mse < 0.001
+        if projection_dim:
+            assert mse < 2
+        else:
+            assert mse < 0.001
 
     check_quantized_rnn(1, False, 5, 2, 16, 16)
     check_quantized_rnn(1, True, 5, 2, 16, 16)
+    check_quantized_rnn(1, False, 5, 2, 16, 16, 8)
+    check_quantized_rnn(1, True, 5, 2, 16, 16, 8)
 
 
 @with_seed()