You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/06/30 08:39:02 UTC
[incubator-mxnet] branch v1.9.x updated: Fix oneDNN RNN weights reorder (#21065)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.9.x by this push:
new a0def52a3e Fix oneDNN RNN weights reorder (#21065)
a0def52a3e is described below
commit a0def52a3ef15ac9118f8015e13be09f1ad5d6ff
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Jun 30 10:38:45 2022 +0200
Fix oneDNN RNN weights reorder (#21065)
* fix rnn weights reorder
* fix sanity
* fix sanity
* fix windows build
---
src/operator/nn/mkldnn/mkldnn_rnn.cc | 63 ++++++++++++++++++--------
tests/python/quantization/test_quantization.py | 4 --
2 files changed, 44 insertions(+), 23 deletions(-)
diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index 114cff57a8..590eea6e43 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -443,12 +443,11 @@ void MKLDNNRnnForward::ReorderWeights() {
}
void AdjustGruGateOrder(char* weight,
- const size_t input_size,
const size_t hidden_size,
const int dtype) {
// mxnet gru gate order is reset, update and new gates
// mkldnn gru gate order is update, reset and new gates
- size_t single_weight_bytes = input_size * hidden_size * mshadow::mshadow_sizeof(dtype);
+ size_t single_weight_bytes = hidden_size * mshadow::mshadow_sizeof(dtype);
char* weight_reset = weight;
char* weight_update = weight + single_weight_bytes;
std::swap_ranges(weight_reset, weight_update, weight_update);
@@ -521,6 +520,10 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
const int dtype) {
using format_tag = mkldnn::memory::format_tag;
const auto mkldnn_dtype = get_mkldnn_type(dtype);
+ const size_t input_size = param_.input_size;
+ const size_t state_size = param_.state_size;
+ const size_t directions = param_.bidirectional + 1U;
+
// Get the weights' memory for RNN forward primitive
if (weights_layer_ == nullptr) {
weights_layer_ = mgr->Alloc(fwd_inf_.GetLayerDesc());
@@ -534,13 +537,14 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
}
// Get the intermediate memory for weights concat & reorder
+ // use ldigo format instead ldgoi to support weights reorder for BRGEMM implementation
if (weights_layer_r_ == nullptr) {
weights_layer_r_ = mgr->Alloc(
- {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi});
+ {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldigo});
}
if (weights_iter_r_ == nullptr) {
weights_iter_r_ = mgr->Alloc(
- {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi});
+ {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldigo});
}
// Get the bytes of a real type
@@ -548,10 +552,10 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
// convert void* to char* for arithmetic operations
char *weights_ptr = static_cast<char *>(w_ptr);
- size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
- param_.input_size * dtype_bytes; //* DIMS: ngates x state_size x input_size
- size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
- param_.state_size * dtype_bytes; //* DIMS: ngates x state_size x state_size
+ size_t wx_bytes = GetRnnGatesNum(param_.mode) * state_size *
+ input_size * dtype_bytes; //* DIMS: ngates x state_size x input_size
+ size_t wh_bytes = GetRnnGatesNum(param_.mode) * state_size *
+ state_size * dtype_bytes; //* DIMS: ngates x state_size x state_size
char *l2r_wx = weights_ptr;
char *l2r_wh = l2r_wx + wx_bytes; //* DIMS: ngates x state_size * state_size
@@ -562,9 +566,16 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
ConcatWeights(*weights_layer_r_, 1, {l2r_wx, r2l_wx}, format_tag::ldgoi);
ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi);
} else if (param_.num_layer == 1 && !param_.bidirectional) {
- //* single uni-directional layer, no concatenate operator needed
- std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes);
- std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes);
+ // single uni-directional layer, no concatenate operator needed
+ // reorder from ldgoi to ldigo
+ mkldnn::memory l2r_wx_mem = mkldnn::memory(
+ {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi},
+ CpuEngine::Get()->get_engine(), l2r_wx);
+ mkldnn::memory l2r_wh_mem = mkldnn::memory(
+ {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi},
+ CpuEngine::Get()->get_engine(), l2r_wh);
+ ReorderTo(&l2r_wx_mem, weights_layer_r_);
+ ReorderTo(&l2r_wh_mem, weights_iter_r_);
} else if (param_.num_layer > 1 && !param_.bidirectional) {
//* concat fused multi-layer weights on layer axis
std::vector<void *> l2r_wx_ptrs;
@@ -582,16 +593,30 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
<< ", and bidirectional is " << param_.bidirectional;
}
+ // recalculate weight bytes after reorder (ldgoi => ldigo)
+ wx_bytes = GetRnnGatesNum(param_.mode) * state_size * dtype_bytes; //* DIMS: ngates x state_size
+ wh_bytes = GetRnnGatesNum(param_.mode) * state_size * dtype_bytes; //* DIMS: ngates x state_size
+
// Adjust gates order of LBR-GRU among concatenated memory inplace.
char* fused_wx = static_cast<char*>(weights_layer_r_->get_data_handle());
char* fused_wh = static_cast<char*>(weights_iter_r_->get_data_handle());
if (param_.mode == rnn_enum::kGru) {
+ const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
for (size_t lyr = 0; lyr < static_cast<size_t>(param_.num_layer); ++lyr) {
- for (size_t d = 0; d < param_.bidirectional + 1U; ++d) {
- AdjustGruGateOrder(fused_wx, param_.input_size, param_.state_size, dtype);
- AdjustGruGateOrder(fused_wh, param_.state_size, param_.state_size, dtype);
- fused_wx += wx_bytes;
- fused_wh += wh_bytes;
+ for (size_t d = 0; d < directions; ++d) {
+ #pragma omp parallel num_threads(omp_threads)
+ {
+ #pragma omp for
+ for (int i = 0; i < static_cast<int>(input_size); ++i) {
+ int offset_fused_wx = i + d * input_size + lyr * directions * input_size;
+ AdjustGruGateOrder(fused_wx + wx_bytes * offset_fused_wx, state_size, dtype);
+ }
+ #pragma omp for
+ for (int s = 0; s < static_cast<int>(state_size); ++s) {
+ int offset_fused_wh = s + d * state_size + lyr * directions * state_size;
+ AdjustGruGateOrder(fused_wh + wh_bytes * offset_fused_wh, state_size, dtype);
+ }
+ }
}
}
}
@@ -600,9 +625,9 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType* native_b_ptr = static_cast<DType*>(b_ptr);
DType* fused_bias = static_cast<DType*>(bias_->get_data_handle());
- for (int lyr = 0; lyr < param_.num_layer; ++lyr) {
- for (int d = 0; d < param_.bidirectional + 1; ++d) {
- FuseBias<DType>(fused_bias, native_b_ptr, param_.mode, param_.state_size);
+ for (size_t lyr = 0; lyr < static_cast<size_t>(param_.num_layer); ++lyr) {
+ for (size_t d = 0; d < directions; ++d) {
+ FuseBias<DType>(fused_bias, native_b_ptr, param_.mode, state_size);
fused_bias += param_.single_b_size;
native_b_ptr += param_.native_single_b_size;
}
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index f30fc2f938..dea189c54c 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -603,10 +603,6 @@ def test_quantized_rnn():
if is_test_for_native_cpu():
print('skipped testing test_quantized_rnn for native cpu since it is not supported yet')
return
- # skip test for mkldnn, flakey and failing - tracked in https://github.com/apache/incubator-mxnet/issues/21061
- if is_test_for_mkldnn():
- print('skipped flakey test test_quantized_rnn for mkldnn, see issue #21061')
- return
data_shape = (seq_len, batch_size, input_dim)
data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')