You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/31 09:28:23 UTC

[incubator-mxnet] branch v1.9.x updated: Port BRGEMM (#20910)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.9.x by this push:
     new 76fc3ef  Port BRGEMM (#20910)
76fc3ef is described below

commit 76fc3efe02ef55be6fcf55f3de1ecdb6e704b626
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Mar 31 11:26:38 2022 +0200

    Port BRGEMM (#20910)
---
 docs/static_site/src/pages/api/faq/env_var.md    |  4 ++++
 src/operator/nn/mkldnn/mkldnn_base-inl.h         | 17 ++++++++++++++---
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 16 +++++++++-------
 3 files changed, 27 insertions(+), 10 deletions(-)

diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md
index 831f7ee..1f5debd 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -304,6 +304,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
   If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice
   of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details.
 
+* MXNET_MKLDNN_FORCE_FC_AB_FORMAT
+  - Values: 0, 1 ```(default=0)```
+  - If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support which requires special weights format.
+
 * MXNET_CPU_PARALLEL_SIZE
   - Values: Int ```(default=200000)```
   - The minimum size to call parallel operations by OpenMP for CPU context.
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 3eb42b4..57cae5b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -285,15 +285,26 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1
   return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
 }
 
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype = -1) {
+inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, size_t batch_size) {
+  // Conditions based on measurement results done on CLX8280
+  // https://github.com/apache/incubator-mxnet/pull/20533
+  return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
+         weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
+}
+
+inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
+                                                   size_t batch_size,
+                                                   int dtype = -1) {
   int ndim = arr.shape().ndim();
   mkldnn::memory::dims dims(ndim);
   dtype = (dtype == -1) ? arr.dtype() : dtype;
   for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i];
   auto format = mkldnn::memory::format_tag::any;
-  // for batch 256 alexnet benchmark test
+  const bool force_fc_ab_format = dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);
   if (dims.size() == 2) {
-    format = mkldnn::memory::format_tag::ab;
+    if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
+      format = mkldnn::memory::format_tag::ab;
+    }
   }
 
   return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 4205dd4..fef0d28 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -35,12 +35,14 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
     const MKLDNNFCFullParam &full_param, const bool is_train,
     const NDArray &data, const NDArray &weight, const NDArray *bias,
     const mkldnn::memory::desc &out_md) {
-  auto data_md = GetMemDesc(data);
-  auto weight_md = full_param.mkldnn_param.quantized ?
-    GetFCWeightDesc(weight, mshadow::kInt8) : GetFCWeightDesc(weight);
   auto engine = CpuEngine::Get()->get_engine();
-  auto propagation =
-    is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
+  auto data_md = GetMemDesc(data);
+  auto weight_md =
+      full_param.mkldnn_param.quantized
+          ? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
+          : GetFCWeightDesc(weight, data.shape()[0]);
+  auto propagation = is_train ? mkldnn::prop_kind::forward_training
+                              : mkldnn::prop_kind::forward_scoring;
 
   mkldnn::primitive_attr attr;
   mkldnn::post_ops ops;
@@ -91,7 +93,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
     const NDArray &data, const NDArray &weight, const NDArray &output,
     mkldnn::inner_product_forward::primitive_desc fwd_pd) {
   auto data_md = GetMemDesc(data);
-  auto weight_md = GetFCWeightDesc(weight);
+  auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
@@ -102,7 +104,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
     const NDArray &data, const NDArray &weight, const NDArray *bias,
     const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) {
   auto data_md = GetMemDesc(data);
-  auto weight_md = GetFCWeightDesc(weight);
+  auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   if (bias) {