You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/08/19 06:37:11 UTC

[GitHub] [incubator-mxnet] bgawrych commented on a change in pull request #20533: [v1.x] Enabling BRGEMM FullyConnected based on shapes

bgawrych commented on a change in pull request #20533:
URL: https://github.com/apache/incubator-mxnet/pull/20533#discussion_r691824518



##########
File path: src/operator/nn/mkldnn/mkldnn_base-inl.h
##########
@@ -305,17 +305,26 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1
   return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
 }
 
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype = -1) {
+inline static bool SupportBRGEMMImpl(mkldnn::memory::dims weight_dims, size_t batch_size) {
+  return weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0 && weight_dims[0] >= 1024 &&
+         weight_dims[1] >= 1024 && batch_size >= 2 << 13;
+}
+
+inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr,
+                                                   size_t batch_size,
+                                                   int dtype = -1) {
   int ndim = arr.shape().ndim();
   mkldnn::memory::dims dims(ndim);
   dtype = (dtype == -1) ? arr.dtype() : dtype;
   for (size_t i = 0; i < dims.size(); i++)
     dims[i] = arr.shape()[i];
   auto format = mkldnn::memory::format_tag::any;
   // for batch 256 alexnet benchmark test
-  const bool brgemm_disabled = dmlc::GetEnv("MXNET_MKLDNN_DISABLE_BRGEMM_FC", true);
-  if (dims.size() == 2 && brgemm_disabled) {
-    format = mkldnn::memory::format_tag::ab;
+  const bool force_fc_ab_format = dmlc::GetEnv("MXNET_MKLDNN_FORCE_FC_AB_FORMAT", false);

Review comment:
       done

##########
File path: src/operator/nn/mkldnn/mkldnn_fully_connected.cc
##########
@@ -42,9 +42,10 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(const MKLDNNFCFullPar
                                                            const NDArray* bias,
                                                            const mkldnn::memory::desc& out_md) {
   auto data_md   = GetMemDesc(data);
-  auto weight_md = full_param.mkldnn_param.quantized ? GetFCWeightDesc(weight, mshadow::kInt8)
-                                                     : GetFCWeightDesc(weight);
-  auto engine    = CpuEngine::Get()->get_engine();
+  auto weight_md = full_param.mkldnn_param.quantized
+                       ? GetFCWeightDesc(weight, data.shape()[0], mshadow::kInt8)
+                       : GetFCWeightDesc(weight, data.shape()[0]);
+  auto engine = CpuEngine::Get()->get_engine();

Review comment:
       This is how clang-formatter works, should I ignore it and align manually?

##########
File path: src/operator/nn/mkldnn/mkldnn_base-inl.h
##########
@@ -305,17 +305,26 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1
   return mkldnn::memory::desc{dims, get_mkldnn_type(dtype), mkldnn::memory::format_tag::any};
 }
 
-inline static mkldnn::memory::desc GetFCWeightDesc(const NDArray& arr, int dtype = -1) {
+inline static bool SupportBRGEMMImpl(mkldnn::memory::dims weight_dims, size_t batch_size) {

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org