You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2020/04/21 11:21:13 UTC
[incubator-mxnet] branch v1.6.x updated: [v1.6.x] Quantized LSTMP
operator (#18107)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.6.x by this push:
new 0bdde3e [v1.6.x] Quantized LSTMP operator (#18107)
0bdde3e is described below
commit 0bdde3e35f098a6e8e09720a6c3c3960ea6d8638
Author: Zixuan Wei <zi...@intel.com>
AuthorDate: Tue Apr 21 19:20:15 2020 +0800
[v1.6.x] Quantized LSTMP operator (#18107)
* Quantized LSTMP operator
* Checkout onednn dev-lstmp-int8 branch
* Fix wrong size of projection weights
* Add unit test for INT8 LSTMP
---
3rdparty/mkldnn | 2 +-
src/operator/nn/mkldnn/mkldnn_rnn-inl.h | 3 +-
src/operator/nn/mkldnn/mkldnn_rnn.cc | 25 ++++-
.../quantization/mkldnn/mkldnn_quantized_rnn.cc | 123 +++++++++++++++------
src/operator/quantization/quantized_rnn.cc | 11 +-
tests/python/quantization/test_quantization.py | 11 +-
6 files changed, 129 insertions(+), 46 deletions(-)
diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn
index f7c41dc..33aca9b 160000
--- a/3rdparty/mkldnn
+++ b/3rdparty/mkldnn
@@ -1 +1 @@
-Subproject commit f7c41dc7b5471ad8bf7905e459bbed27f9094caa
+Subproject commit 33aca9ba23a977282bf1d64c34cecc2abb4e019a
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
index f039184..7e90e67 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_rnn-inl.h
@@ -322,7 +322,8 @@ class MKLDNNRnnForward {
mkldnn::memory *weights_layer_r_ = nullptr;
mkldnn::memory *weights_iter_r_ = nullptr;
- mkldnn::memory *weights_proj_r_ = nullptr;
+ mkldnn::memory *weights_proj_r_ = nullptr; // format_tag::ldoi
+ mkldnn::memory *weights_proj_io_ = nullptr; // format_tag::ldio, used in quantization
/*
* net_args must contain some keys as below:
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index 3775e9c..2435a6a 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -81,7 +81,11 @@ void MKLDNNRnnLayerParam::SetDims() {
workspace_size = tz_volume(weight_layer_dims) + tz_volume(weight_iter_dims) +
tz_volume(bias_dims);
- if (proj_size > 0) workspace_size += tz_volume(weight_proj_dims);
+ if (proj_size > 0) {
+ workspace_size += tz_volume(weight_proj_dims);
+ //* NOTE: Quantized Op needs one more projection weights in ldio format.
+ if (quantized) workspace_size += tz_volume(weight_proj_dims);
+ }
reserve_size = 0;
}
@@ -134,10 +138,10 @@ MKLDNNRnnFullParam MKLDNNRnnFullParamParser(const NodeAttrs& attrs, const int se
// Set dims, workspace size, state_outputs, quantized and enable_u8_output flag
for (auto& layer_param : layer_params) {
- layer_param.SetDims();
layer_param.state_outputs = rnn_param.state_outputs;
layer_param.quantized = full_param.mkldnn_param.quantized;
layer_param.enable_u8_output = true;
+ layer_param.SetDims();
}
// Quantized RNN operator produces kFloat32 outputs.
if (full_param.mkldnn_param.quantized) layer_params.back().enable_u8_output = false;
@@ -211,7 +215,6 @@ RnnPrimitive GetRnnFwdPrim(const MKLDNNRnnLayerParam &layer_param,
auto dst_cell_desc = layer_param.state_outputs ? memory::desc(
layer_param.cell_dims, iter_dtype, tag::ldnc) : memory::desc();
-
auto fwd = RnnPrimitive();
switch (mode) {
case rnn_enum::kLstm:
@@ -433,6 +436,10 @@ void MKLDNNRnnForward::ReorderWeights() {
};
ReorderWithAttr(*weights_layer_r_, *weights_layer_);
ReorderWithAttr(*weights_iter_r_, *weights_iter_);
+ if (param_.proj_size > 0) {
+ MKLDNNMemoryReorder(*weights_proj_r_, *weights_proj_io_);
+ ReorderWithAttr(*weights_proj_io_, *weights_proj_);
+ }
} else {
MKLDNNMemoryReorder(*weights_layer_r_, *weights_layer_);
MKLDNNMemoryReorder(*weights_iter_r_, *weights_iter_);
@@ -518,12 +525,14 @@ void MKLDNNRnnForward::SetWeightsMem(void *w_ptr, void *b_ptr,
const auto mkldnn_dtype = get_mkldnn_type(dtype);
const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype);
- const size_t buffer_bytes = this->GetSize() // byte number of the buffer
+ size_t buffer_bytes = this->GetSize() // byte number of the buffer
+ (param_.workspace_size + param_.reserve_size) * dtype_bytes
+ kMKLDNNAlign * 7; // Add margin for alignment of seven times allocation for the
// dnnl memory handlers, i.e. weights_layer_, weights_iter_,
// weights_proj_, bias_, weights_layer_r_, weights_iter_r_,
// and weights_proj_r_.
+ if (param_.quantized && param_.proj_size > 0)
+ buffer_bytes += kMKLDNNAlign; // Quantized Op needs another one for weights_proj_io_.
if (mem_mgr_.Size() < buffer_bytes) mem_mgr_.Init(buffer_bytes, this->ctx_);
const bool use_proj = (param_.proj_size > 0);
@@ -555,6 +564,10 @@ void MKLDNNRnnForward::SetWeightsMem(void *w_ptr, void *b_ptr,
weights_proj_r_ = mem_mgr_.Alloc(
{param_.weight_proj_dims, mkldnn_dtype, format_tag::ldoi});
}
+ if (param_.quantized && use_proj && weights_proj_io_ == nullptr) {
+ weights_proj_io_ = mem_mgr_.Alloc(
+ {param_.weight_proj_dims, mkldnn_dtype, format_tag::ldio});
+ }
// convert void* to char* for arithmetic operations
const size_t iter_size = use_proj ? param_.proj_size : param_.state_size;
@@ -1034,9 +1047,9 @@ void MKLDNNRnnOp::Forward(const OpContext &ctx,
}
if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
- LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during "
+ common::LogOnce("The current weight of RNN is assumed to be fixed and cached during "
"the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if "
- "the weight changed at runtime.";
+ "the weight changed at runtime.");
}
// Check if weights NDArray was changed. If so, reset initialized_
if (!is_training && fwd_inf_vec_.size() > 0
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
index 9ae9bb2..ebf4cce 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_rnn.cc
@@ -32,31 +32,46 @@
namespace mxnet {
namespace op {
-std::vector<float> GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_param,
- float* w_ptr) {
+/*!
+ * \brief Quantization parameters of rnn weights' scales in an order of weights_qparams,
+ * weights_projection_qparams.
+ */
+typedef std::tuple<std::vector<float>, std::vector<float> > rnn_weights_qparams_t;
+
+rnn_weights_qparams_t GetMKLDNNRnnWeightsQParams(
+ const MKLDNNRnnFullParam& full_param, float* w_ptr) {
const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const RNNParam& default_param = full_param.default_param;
const LayerParamVector& layer_params = full_param.layer_params;
+ const bool use_proj = default_param.projection_size.has_value();
+ const size_t iter_size = use_proj ?
+ default_param.projection_size.value() : default_param.state_size;
const MKLDNNRnnLayerParam& layer_param0 = layer_params.at(0);
const size_t w_size0 = layer_param0.single_w_size;
const size_t wx_size0 = 4 * layer_param0.state_size * layer_param0.input_size;
- const size_t wh_size0 = 4 * layer_param0.state_size * layer_param0.state_size;
+ const size_t wh_size0 = 4 * layer_param0.state_size * iter_size;
+ const size_t wr_size = default_param.state_size * iter_size;
int directions = 1;
float* wx = w_ptr;
float* wh = wx + wx_size0;
+ float* wr = wh + wh_size0;
float* fake_wx = wx;
float* fake_wh = wh;
+ float* fake_wr = wr;
std::vector<float> wx_goi_max;
std::vector<float> wh_goi_max;
+ std::vector<float> wr_oi_max;
if (default_param.bidirectional) {
directions = 2;
wx_goi_max.resize(wx_size0);
wh_goi_max.resize(wh_size0);
+ wr_oi_max.resize(wr_size);
fake_wx = wx_goi_max.data();
fake_wh = wh_goi_max.data();
+ fake_wr = wr_oi_max.data();
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < static_cast<index_t>(wx_size0); ++i) {
fake_wx[i] = MaxAbs(wx[i], wx[i + w_size0]);
@@ -65,8 +80,15 @@ std::vector<float> GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_par
for (index_t i = 0; i < static_cast<index_t>(wh_size0); ++i) {
fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]);
}
+ if (use_proj) {
+ #pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(wr_size); ++i) {
+ fake_wr[i] = MaxAbs(wr[i], wr[i + w_size0]);
+ }
+ }
}
std::vector<float> w_max(4 * layer_param0.state_size, 0.0);
+ std::vector<float> proj_max(iter_size, 0.0);
const index_t input_size = layer_param0.input_size; // input
const index_t state_size = layer_param0.state_size; // state
const index_t gates_nblks = 4 * layer_param0.state_size; // gates * state
@@ -75,44 +97,68 @@ std::vector<float> GetMKLDNNRnnWeightsQParams(const MKLDNNRnnFullParam& full_par
for (index_t i = 0; i < input_size; ++i) {
tmp_max = MaxAbs(fake_wx[go * input_size + i], tmp_max);
}
- for (index_t i = 0; i < state_size; ++i) {
- tmp_max = MaxAbs(fake_wh[go * state_size + i], tmp_max);
+ for (index_t i = 0; i < static_cast<index_t>(iter_size); ++i) {
+ tmp_max = MaxAbs(fake_wh[go * iter_size + i], tmp_max);
}
w_max[go] = tmp_max;
}
+ if (use_proj) {
+ for (index_t i = 0; i < static_cast<index_t>(iter_size); ++i) {
+ for (index_t s = 0; s < state_size; ++s) {
+ proj_max[i] = MaxAbs(fake_wr[iter_size * state_size + s], proj_max[i]);
+ }
+ }
+ }
wx += layer_param0.single_w_size * directions;
wh += layer_param0.single_w_size * directions;
+ wr += layer_param0.single_w_size * directions;
- std::vector<float> goi_max(wh_size0, 0.0);
+ const size_t wx_size1 = 4 * default_param.state_size * default_param.state_size;
+ const size_t wh_size1 = wh_size0;
+ std::vector<float> go_max(gates_nblks, 0.0);
for (size_t lyr = 1; lyr < layer_params.size(); ++lyr) {
const MKLDNNRnnLayerParam& layer_param = layer_params.at(lyr);
const int weight_nblks = layer_param.num_layer * directions;
for (int blk = 0; blk < weight_nblks; ++blk) {
- #pragma omp parallel for num_threads(nthreads)
- for (index_t i = 0; i < static_cast<index_t>(wh_size0); ++i) {
- goi_max[i] = MaxAbs(wx[i], wh[i]);
+ for (index_t go = 0; go < gates_nblks; ++go) {
+ float tmp = Abs(wx[0]);
+ for (index_t i = 1; i < layer_param.input_size; ++i) {
+ tmp = MaxAbs(wx[go * layer_param.input_size + i], tmp);
+ }
+ go_max[go] = Max(tmp, go_max[go]);
+ }
+ for (index_t go = 0; go < gates_nblks; ++go) {
+ float tmp = Abs(wh[0]);
+ for (index_t i = 1; i < static_cast<index_t>(iter_size); ++i) {
+ tmp = MaxAbs(wh[go * iter_size + i], tmp);
+ }
+ go_max[go] = Max(tmp, go_max[go]);
}
+ #pragma omp parallel for num_threads(nthreads)
for (index_t go = 0; go < gates_nblks; ++go) {
- float tmp = w_max[go];
- //* NOTES: min/max reductions were supported since OpenMP 3.1, which was released in
- // Jul 2011 (hence the version number).
- #if _OPENMP >= 201107
- #pragma omp parallel for reduction(max : tmp) num_threads(nthreads)
- #endif
- for (index_t i = 0; i < state_size; ++i) {
- tmp = Max(goi_max[go * state_size + i], tmp);
+ w_max[go] = Max(go_max[go], w_max[go]);
+ }
+ if (use_proj) {
+ for (index_t i = 0; i < static_cast<index_t>(iter_size); ++i) {
+ for (index_t s = 0; s < state_size; ++s) {
+ proj_max[i] = MaxAbs(fake_wr[iter_size * state_size + s], proj_max[i]);
+ }
}
- w_max[go] = tmp;
}
+ wx += layer_param.single_w_size;
+ wh = wx + wx_size1;
+ wr = wh + wh_size1;
}
- wx += layer_param.single_w_size * directions;
- wh = wx + wh_size0;
}
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < static_cast<index_t>(w_max.size()); ++i) {
w_max[i] = mshadow::red::limits::MaxValue<int8_t>() / w_max[i];
}
- return w_max;
+ #pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(proj_max.size()); ++i) {
+ proj_max[i] = mshadow::red::limits::MaxValue<int8_t>() / proj_max[i];
+ }
+ return std::make_tuple(w_max, proj_max);
}
void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
@@ -133,9 +179,9 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
float *bias_ptr = weights_ptr + weights_size;
if (dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && !initialized_) {
- LOG(INFO) << "The current weight of RNN is assumed to be fixed and cached during "
+ common::LogOnce("The current weight of RNN is assumed to be fixed and cached during "
"the whole inference pipeline. Please set MXNET_RNN_USE_WEIGHT_CACHE=0, if "
- "the weight changed at runtime.";
+ "the weight changed at runtime.");
}
const bool need_reset_weight = (!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) &&
weights_ver_ != inputs[rnn_enum::kParams].version()) ? true : false;
@@ -154,9 +200,14 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
cached_data_scale_ = data_scale;
cached_data_shift_ = data_shift;
rnn_attr_->set_rnn_data_qparams(data_scale, data_shift);
- if (need_reset_weight || rnn_layers_.empty())
- rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4),
- GetMKLDNNRnnWeightsQParams(full_param_, weights_ptr));
+ if (need_reset_weight || rnn_layers_.empty()) {
+ rnn_weights_qparams_t weights_qparams =
+ GetMKLDNNRnnWeightsQParams(full_param_, weights_ptr);
+ rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4), std::get<0>(weights_qparams));
+ if (default_param.projection_size.has_value()) {
+ rnn_attr_->set_rnn_weights_projection_qparams(0 + (1 << 3), std::get<1>(weights_qparams));
+ }
+ }
}
// Get data type
@@ -167,10 +218,14 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
const int seq_length = default_param.seq_length_;
const int batch_size = default_param.batch_size_;
const int state_size = default_param.state_size;
+ const int iter_size = default_param.projection_size.has_value() ?
+ default_param.projection_size.value() : state_size;
const int directions = default_param.bidirectional ? 2 : 1;
- mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * state_size},
+ mkldnn::memory::desc dst_desc({seq_length, batch_size, directions * iter_size},
get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::tnc);
- mkldnn::memory::desc state_desc({num_layers, directions, batch_size, state_size},
+ mkldnn::memory::desc state_desc({num_layers, directions, batch_size, iter_size},
+ get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
+ mkldnn::memory::desc cell_desc({num_layers, directions, batch_size, state_size},
get_mkldnn_type(data_dtype), mkldnn::memory::format_tag::ldnc);
auto out_mem = CreateMKLDNNMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]);
mkldnn_output_t stateout_mem;
@@ -183,8 +238,10 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
char *dst_state = nullptr; // Output state
char *src_state_cell = nullptr; // Used in LSTM for cell state
char *dst_state_cell = nullptr; // Used in LSTM for cell state
- const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ *
- default_param.state_size * mshadow::mshadow_sizeof(data_dtype);
+ const size_t state_bytes = (default_param.bidirectional + 1) * batch_size *
+ iter_size * mshadow::mshadow_sizeof(data_dtype);
+ const size_t cell_bytes = (default_param.bidirectional + 1) * batch_size *
+ state_size * mshadow::mshadow_sizeof(data_dtype);
const LayerParamVector& layer_params = full_param_.layer_params;
for (size_t lyr = 0; lyr < layer_params.size(); ++lyr) {
@@ -218,7 +275,7 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
src_state_cell = static_cast<char *>(inputs[rnn_enum::kStateCell].data().dptr_);
if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) {
statecellout_mem = CreateMKLDNNMem(
- outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]);
+ outputs[rnn_enum::kStateCellOut], cell_desc, req[rnn_enum::kStateCellOut]);
dst_state_cell = static_cast<char *>(statecellout_mem.second->get_data_handle());
}
}
@@ -229,9 +286,9 @@ void MKLDNNQuantizedRnnOp::Forward(const OpContext &op_ctx,
MKLDNNStream::Get()->RegisterPrimArgs(rnn_layer.GetFwd(), rnn_layer.GetArgsMap());
if (lyr < default_param.num_layers - 1U) {
- src_state += cell_bytes;
+ src_state += state_bytes;
+ if (dst_state) dst_state += state_bytes;
if (src_state_cell) src_state_cell += cell_bytes;
- if (dst_state) dst_state += cell_bytes;
if (dst_state_cell) dst_state_cell += cell_bytes;
}
}
diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc
index b2864ff..31944c9 100644
--- a/src/operator/quantization/quantized_rnn.cc
+++ b/src/operator/quantization/quantized_rnn.cc
@@ -92,7 +92,9 @@ bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs,
const dim_t directions = param.bidirectional ? 2 : 1;
const dim_t total_lyrs = directions * param.num_layers;
const dim_t state_size = param.state_size;
- SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, state_size));
+ const dim_t iter_size = param.projection_size.has_value() ?
+ param.projection_size.value() : state_size;
+ SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, iter_size));
if (param.mode == rnn_enum::kLstm)
SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kStateCell,
Shape3(total_lyrs, batch_size, state_size));
@@ -109,9 +111,9 @@ bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs,
SHAPE_ASSIGN_CHECK(*in_shape, i, Shape1(1));
out_shape->clear();
- out_shape->push_back({dshape[0], batch_size, directions * state_size}); // output dim: [T, N, C]
+ out_shape->push_back({dshape[0], batch_size, directions * iter_size}); // output dim: [T, N, C]
if (param.state_outputs) {
- out_shape->push_back({total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C]
+ out_shape->push_back({total_lyrs, batch_size, iter_size}); // state dim: [L*D, N, C]
if (param.mode == rnn_enum::kLstm)
out_shape->push_back({total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C]
}
@@ -216,6 +218,9 @@ OpStatePtr CreateQuantizedRnnState(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& data_shape = in_shapes[quantized_rnn::kData];
state = OpStatePtr::Create<MKLDNNQuantizedRnnOp>(attrs, data_shape[0],
data_shape[1], data_shape[2]);
+ } else {
+ LOG(FATAL) << "MKLDNN quantized rnn operator only supports inputs in U8 type and weight"
+ " in FP32 type.";
}
#else
LOG(FATAL) << "Quantized RNN operator relies on MKL-DNN library."
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 48968c2..bbe3008 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -519,7 +519,7 @@ def test_quantized_fc():
@with_seed()
def test_quantized_rnn():
- def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_dim):
+ def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_dim, projection_dim=None):
if is_test_for_gpu():
print('skipped testing test_quantized_rnn for gpu since it is not supported yet')
return
@@ -534,6 +534,7 @@ def test_quantized_rnn():
bidirectional=bidirectional,
state_outputs=True,
state_size=state_dim,
+ projection_size=projection_dim,
mode='lstm',
name='rnn')
arg_shapes, _, _ = rnn_fp32.infer_shape(data=data_shape)
@@ -559,6 +560,7 @@ def test_quantized_rnn():
bidirectional=bidirectional,
state_outputs=True,
state_size=state_dim,
+ projection_size=projection_dim,
mode='lstm',
name='qrnn')
qarg_names = rnn_int8.list_arguments()
@@ -575,10 +577,15 @@ def test_quantized_rnn():
qoutput = rnn_int8_exe.forward()[0]
mse = np.mean((output.asnumpy() - qoutput.asnumpy())**2)
- assert mse < 0.001
+ if projection_dim:
+ assert mse < 2
+ else:
+ assert mse < 0.001
check_quantized_rnn(1, False, 5, 2, 16, 16)
check_quantized_rnn(1, True, 5, 2, 16, 16)
+ check_quantized_rnn(1, False, 5, 2, 16, 16, 8)
+ check_quantized_rnn(1, True, 5, 2, 16, 16, 8)
@with_seed()