You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/31 09:28:23 UTC
[incubator-mxnet] branch v1.9.x updated: Port BRGEMM (#20910)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.9.x by this push:
new 76fc3ef Port BRGEMM (#20910)
76fc3ef is described below
commit 76fc3efe02ef55be6fcf55f3de1ecdb6e704b626
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Mar 31 11:26:38 2022 +0200
Port BRGEMM (#20910)
---
docs/static_site/src/pages/api/faq/env_var.md | 4 ++++
src/operator/nn/mkldnn/mkldnn_base-inl.h | 17 ++++++++++++++---
src/operator/nn/mkldnn/mkldnn_fully_connected.cc | 16 +++++++++-------
3 files changed, 27 insertions(+), 10 deletions(-)
diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md
index 831f7ee..1f5debd 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -304,6 +304,10 @@ If ctypes is used, it must be `mxnet._ctypes.ndarray.NDArrayBase`.
If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice
of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details.
+* MXNET_MKLDNN_FORCE_FC_AB_FORMAT
+ - Values: 0, 1 ```(default=0)```
+ - If set to true, FullyConnected will use only AB format for weights, thus MXNet won't use BRGEMM implementation of FC on machines with AVX512-VNNI support which requires special weights format.
+
* MXNET_CPU_PARALLEL_SIZE
- Values: Int ```(default=200000)```
- The minimum size to call parallel operations by OpenMP for CPU context.
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 3eb42b4..57cae5b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -285,15 +285,26 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr, int dtype = -1
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
}
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray &arr, int dtype = -1) {
+inline static bool ChooseBRGEMMImpl(const mkldnn::memory::dims& weight_dims, size_t batch_size) {
+ // Conditions based on measurement results done on CLX8280
+ // https://github.com/apache/incubator-mxnet/pull/20533
+ return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
+ weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
+}
+
+inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
+ size_t batch_size,
+ int dtype = -1) {
int ndim = arr.shape().ndim();
mkldnn::memory::dims dims(ndim);
dtype = (dtype == -1) ? arr.dtype() : dtype;
for (size_t i = 0; i < dims.size(); i++) dims[i] = arr.shape()[i];
auto format = mkldnn::memory::format_tag::any;
- // for batch 256 alexnet benchmark test
+ const bool force_fc_ab_format = dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
- format = mkldnn::memory::format_tag::ab;
+ if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
+ format = mkldnn::memory::format_tag::ab;
+ }
}
return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), format};
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 4205dd4..fef0d28 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -35,12 +35,14 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight, const NDArray *bias,
const mkldnn::memory::desc &out_md) {
- auto data_md = GetMemDesc(data);
- auto weight_md = full_param.mkldnn_param.quantized ?
- GetFCWeightDesc(weight, mshadow::kInt8) : GetFCWeightDesc(weight);
auto engine = CpuEngine::Get()->get_engine();
- auto propagation =
- is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
+ auto data_md = GetMemDesc(data);
+ auto weight_md =
+ full_param.mkldnn_param.quantized
+ ? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
+ : GetFCWeightDesc(weight, data.shape()[0]);
+ auto propagation = is_train ? mkldnn::prop_kind::forward_training
+ : mkldnn::prop_kind::forward_scoring;
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
@@ -91,7 +93,7 @@ inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
const NDArray &data, const NDArray &weight, const NDArray &output,
mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
- auto weight_md = GetFCWeightDesc(weight);
+ auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
@@ -102,7 +104,7 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
const NDArray &data, const NDArray &weight, const NDArray *bias,
const NDArray &output, mkldnn::inner_product_forward::primitive_desc fwd_pd) {
auto data_md = GetMemDesc(data);
- auto weight_md = GetFCWeightDesc(weight);
+ auto weight_md = GetFCWeightDesc(weight, data.shape()[0]);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
if (bias) {