You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/06/30 08:39:02 UTC

[incubator-mxnet] branch v1.9.x updated: Fix oneDNN RNN weights reorder (#21065)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.9.x by this push:
     new a0def52a3e Fix oneDNN RNN weights reorder (#21065)
a0def52a3e is described below

commit a0def52a3ef15ac9118f8015e13be09f1ad5d6ff
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Jun 30 10:38:45 2022 +0200

    Fix oneDNN RNN weights reorder (#21065)
    
    * fix rnn weights reorder
    
    * fix sanity
    
    * fix sanity
    
    * fix windows build
---
 src/operator/nn/mkldnn/mkldnn_rnn.cc           | 63 ++++++++++++++++++--------
 tests/python/quantization/test_quantization.py |  4 --
 2 files changed, 44 insertions(+), 23 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_rnn.cc b/src/operator/nn/mkldnn/mkldnn_rnn.cc
index 114cff57a8..590eea6e43 100644
--- a/src/operator/nn/mkldnn/mkldnn_rnn.cc
+++ b/src/operator/nn/mkldnn/mkldnn_rnn.cc
@@ -443,12 +443,11 @@ void MKLDNNRnnForward::ReorderWeights() {
 }
 
 void AdjustGruGateOrder(char* weight,
-                        const size_t input_size,
                         const size_t hidden_size,
                         const int dtype) {
   // mxnet gru gate order is reset, update and new gates
   // mkldnn gru gate order is update, reset and new gates
-  size_t single_weight_bytes = input_size * hidden_size * mshadow::mshadow_sizeof(dtype);
+  size_t single_weight_bytes = hidden_size * mshadow::mshadow_sizeof(dtype);
   char* weight_reset = weight;
   char* weight_update = weight + single_weight_bytes;
   std::swap_ranges(weight_reset, weight_update, weight_update);
@@ -521,6 +520,10 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
                                      const int dtype) {
   using format_tag  = mkldnn::memory::format_tag;
   const auto mkldnn_dtype = get_mkldnn_type(dtype);
+  const size_t input_size = param_.input_size;
+  const size_t state_size = param_.state_size;
+  const size_t directions = param_.bidirectional + 1U;
+
   // Get the weights' memory for RNN forward primitive
   if (weights_layer_ == nullptr) {
     weights_layer_ = mgr->Alloc(fwd_inf_.GetLayerDesc());
@@ -534,13 +537,14 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
   }
 
   // Get the intermediate memory for weights concat & reorder
+  // use ldigo format instead ldgoi to support weights reorder for BRGEMM implementation
   if (weights_layer_r_ == nullptr) {
     weights_layer_r_ = mgr->Alloc(
-        {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi});
+        {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldigo});
   }
   if (weights_iter_r_ == nullptr) {
     weights_iter_r_ = mgr->Alloc(
-        {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi});
+        {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldigo});
   }
 
   // Get the bytes of a real type
@@ -548,10 +552,10 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
 
   // convert void* to char* for arithmetic operations
   char *weights_ptr = static_cast<char *>(w_ptr);
-  size_t wx_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
-        param_.input_size * dtype_bytes;  //* DIMS: ngates x state_size x input_size
-  size_t wh_bytes = GetRnnGatesNum(param_.mode) * param_.state_size *
-        param_.state_size * dtype_bytes;  //* DIMS: ngates x state_size x state_size
+  size_t wx_bytes = GetRnnGatesNum(param_.mode) * state_size *
+        input_size * dtype_bytes;  //* DIMS: ngates x state_size x input_size
+  size_t wh_bytes = GetRnnGatesNum(param_.mode) * state_size *
+        state_size * dtype_bytes;  //* DIMS: ngates x state_size x state_size
   char *l2r_wx = weights_ptr;
   char *l2r_wh = l2r_wx + wx_bytes;       //* DIMS: ngates x state_size * state_size
 
@@ -562,9 +566,16 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
     ConcatWeights(*weights_layer_r_, 1, {l2r_wx, r2l_wx}, format_tag::ldgoi);
     ConcatWeights(*weights_iter_r_, 1, {l2r_wh, r2l_wh}, format_tag::ldgoi);
   } else if (param_.num_layer == 1 && !param_.bidirectional) {
-    //* single uni-directional layer, no concatenate operator needed
-    std::memcpy(weights_layer_r_->get_data_handle(), l2r_wx, wx_bytes);
-    std::memcpy(weights_iter_r_->get_data_handle(), l2r_wh, wh_bytes);
+    // single uni-directional layer, no concatenate operator needed
+    // reorder from ldgoi to ldigo
+    mkldnn::memory l2r_wx_mem = mkldnn::memory(
+        {param_.weight_layer_dims, mkldnn_dtype, format_tag::ldgoi},
+        CpuEngine::Get()->get_engine(), l2r_wx);
+    mkldnn::memory l2r_wh_mem = mkldnn::memory(
+        {param_.weight_iter_dims, mkldnn_dtype, format_tag::ldgoi},
+        CpuEngine::Get()->get_engine(), l2r_wh);
+    ReorderTo(&l2r_wx_mem, weights_layer_r_);
+    ReorderTo(&l2r_wh_mem, weights_iter_r_);
   } else if (param_.num_layer > 1 && !param_.bidirectional) {
     //* concat fused multi-layer weights on layer axis
     std::vector<void *> l2r_wx_ptrs;
@@ -582,16 +593,30 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
                << ", and bidirectional is " << param_.bidirectional;
   }
 
+  // recalculate weight bytes after reorder (ldgoi => ldigo)
+  wx_bytes = GetRnnGatesNum(param_.mode) * state_size * dtype_bytes;  //* DIMS: ngates x state_size
+  wh_bytes = GetRnnGatesNum(param_.mode) * state_size * dtype_bytes;  //* DIMS: ngates x state_size
+
   // Adjust gates order of LBR-GRU among concatenated memory inplace.
   char* fused_wx = static_cast<char*>(weights_layer_r_->get_data_handle());
   char* fused_wh = static_cast<char*>(weights_iter_r_->get_data_handle());
   if (param_.mode == rnn_enum::kGru) {
+    const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     for (size_t lyr = 0; lyr < static_cast<size_t>(param_.num_layer); ++lyr) {
-      for (size_t d = 0; d < param_.bidirectional + 1U; ++d) {
-        AdjustGruGateOrder(fused_wx, param_.input_size, param_.state_size, dtype);
-        AdjustGruGateOrder(fused_wh, param_.state_size, param_.state_size, dtype);
-        fused_wx += wx_bytes;
-        fused_wh += wh_bytes;
+      for (size_t d = 0; d < directions; ++d) {
+        #pragma omp parallel num_threads(omp_threads)
+        {
+          #pragma omp for
+          for (int i = 0; i < static_cast<int>(input_size); ++i) {
+            int offset_fused_wx = i + d * input_size + lyr * directions * input_size;
+            AdjustGruGateOrder(fused_wx + wx_bytes * offset_fused_wx, state_size, dtype);
+          }
+          #pragma omp for
+          for (int s = 0; s < static_cast<int>(state_size); ++s) {
+            int offset_fused_wh = s + d * state_size + lyr * directions * state_size;
+            AdjustGruGateOrder(fused_wh + wh_bytes * offset_fused_wh, state_size, dtype);
+          }
+        }
       }
     }
   }
@@ -600,9 +625,9 @@ void MKLDNNRnnForward::SetWeightsMem(MKLDNNRnnMemMgr* mgr,
   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
     DType* native_b_ptr = static_cast<DType*>(b_ptr);
     DType* fused_bias = static_cast<DType*>(bias_->get_data_handle());
-    for (int lyr = 0; lyr < param_.num_layer; ++lyr) {
-      for (int d = 0; d < param_.bidirectional + 1; ++d) {
-        FuseBias<DType>(fused_bias, native_b_ptr, param_.mode, param_.state_size);
+    for (size_t lyr = 0; lyr < static_cast<size_t>(param_.num_layer); ++lyr) {
+      for (size_t d = 0; d < directions; ++d) {
+        FuseBias<DType>(fused_bias, native_b_ptr, param_.mode, state_size);
         fused_bias += param_.single_b_size;
         native_b_ptr += param_.native_single_b_size;
       }
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index f30fc2f938..dea189c54c 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -603,10 +603,6 @@ def test_quantized_rnn():
         if is_test_for_native_cpu():
             print('skipped testing test_quantized_rnn for native cpu since it is not supported yet')
             return
-        # skip test for mkldnn, flakey and failing - tracked in https://github.com/apache/incubator-mxnet/issues/21061
-        if is_test_for_mkldnn():
-            print('skipped flakey test test_quantized_rnn for mkldnn, see issue #21061')
-            return
 
         data_shape = (seq_len, batch_size, input_dim)
         data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')