You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/22 19:50:46 UTC

[GitHub] azai91 closed pull request #12746: Add env flag to disable MKLDNN cache (MXNET_MKLDNN_CACHE_ENABLED)

azai91 closed pull request #12746: Add env flag to disable MKLDNN cache (MXNET_MKLDNN_CACHE_ENABLED)
URL: https://github.com/apache/incubator-mxnet/pull/12746
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index e373377ee8d..2f8489b2ab9 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -195,6 +195,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
   - Values: 0, 1 ```(default=1)```
   - Flag to enable or disable MKLDNN accelerator. On by default.
   - Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```)
+  
+* MXNET_MKLDNN_CACHE_SIZE
+  - Values: Int ```(default=-1)```
+  - Flag to set MKLDNN cache size. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded.
 
 * MXNET_ENFORCE_DETERMINISM
   - Values: 0(false) or 1(true) ```(default=0)```
diff --git a/src/common/utils.h b/src/common/utils.h
index 26889792e53..87b9bedb58a 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -474,6 +474,10 @@ inline void LogStorageFallback(const nnvm::NodeAttrs& attrs,
 #if MXNET_USE_MKLDNN == 1
   if (!MKLDNNEnvSet()) common::LogOnce("MXNET_MKLDNN_ENABLED flag is off. "
                                        "You can re-enable by setting MXNET_MKLDNN_ENABLED=1");
+  if (GetMKLDNNSize() != -1) common::LogOnce("MXNET_MKLDNN_CACHE_SIZE is set."
+                                       "Should only be set if "
+                                       "your model has variable input shapes, "
+                                       "as cache size may grow unbounded");
 #endif
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index c914b38b542..473423658b3 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -147,8 +147,7 @@ static MKLDNNActForward &GetActForward(const ActivationParam& param,
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem);
-    auto ins_ret = fwds.insert(std::pair<MKLDNNActSignature, MKLDNNActForward>(
-            key, fwd));
+    auto ins_ret = AddToCache(fwds, std::pair<MKLDNNActSignature, MKLDNNActForward>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
@@ -261,8 +260,7 @@ static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
-    auto ins_ret =
-        bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
+    auto ins_ret = AddToCache(bwds, std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index d8651c83d0c..ca0122a118b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -58,6 +58,7 @@
 #include "mxnet/ndarray.h"
 #include "mxnet/resource.h"
 #include "mxnet/op_attr_types.h"
+#include "../../operator_common.h"
 using namespace mkldnn;
 namespace mxnet {
 extern bool EnableMkldnnWarnGenerated();
@@ -147,6 +148,22 @@ static inline bool MKLDNNEnvSet() {
   return is_mkldnn_enabled;
 }
 
+static inline int GetMKLDNNCacheSize() {
+  static int mkldnn_cache_size = dmlc::GetEnv("MXNET_MKLDNN_CACHE_SIZE", -1);
+  return mkldnn_cache_size;
+}
+
+// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size during runtime
+template<class S, class I>
+static std::pair<S, I> AddToCache(
+    const std::unordered_map<S, I, op::OpHash> &cache, const std::pair<S, I> &item) {
+  int mkldnn_cache_size = GetMKLDNNCacheSize();
+  if (mkldnn_cache_size == -1) return;
+  if (static_cast<int>(cache.size()) > mkldnn_cache_size)
+    cache.erase(cache.begin());
+  return cache.insert(item);
+}
+
 /*
  * This is to align address to a certain alignment.
  */
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index e605c9bb19c..9556e332bc8 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -198,8 +198,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
     auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
                           (DType) param.eps, flags);
     MKLDNNBNForward fwd(fwd_pd, ctx.is_train);
-    auto ins_ret = fwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNForward>(
-            key, fwd));
+    auto ins_ret = AddToCache(fwds, std::pair<MKLDNNBNSignature, MKLDNNBNForward>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
@@ -360,8 +359,7 @@ static MKLDNNBNBackward &GetBNBackward(
   if (it == bwds.end()) {
     auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
     MKLDNNBNBackward bwd(bwd_pd);
-    auto ins_ret =
-        bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
+    auto ins_ret = AddToCache(bwds, std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index af81e1fe3ee..cf434c0717e 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -87,8 +87,7 @@ static MKLDNNConcatFwd &GetConcatForward(
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNConcatFwd fwd(concat_dim, data_md);
-    auto ins_ret = fwds.insert(std::pair<OpSignature, MKLDNNConcatFwd>(
-            key, fwd));
+    auto ins_ret = AddToCache(fwds, std::pair<OpSignature, MKLDNNConcatFwd>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 9cf1b71880a..fff6db41cc1 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -338,8 +338,7 @@ MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
     full_param.conv_param = param;
     full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
     MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output);
-    auto ins_ret = fwds.insert(
-        std::pair<MKLDNNConvSignature, MKLDNNConvForward>(key, fwd));
+    auto ins_ret = AddToCache(fwds, std::pair<MKLDNNConvSignature, MKLDNNConvForward>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
@@ -557,11 +556,11 @@ static inline MKLDNNConvBackward &GetConvBwd(
   if (bias)
     key.AddSign(*bias);
 
+
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
-    auto ins_ret = bwds.insert(
-        std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
+    auto ins_ret = AddToCache(bwds, std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 93032f7c92d..cd348dfd219 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -328,9 +328,7 @@ static inline MKLDNNDeconvForward &GetDeconvFwd(
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     bool has_bias = (bias != nullptr);
-    MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
-    auto ins_ret = fwds.insert(
-        std::pair<DeconvSignature, MKLDNNDeconvForward>(key, fwd));
+    auto ins_ret = AddToCache(fwds, std::pair<DeconvSignature, MKLDNNDeconvForward>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
@@ -417,8 +415,8 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNDeconvBackwardData bwd(param, data, weights, output);
-    auto ins_ret = bwds.insert(
-        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
+    auto ins_ret = AddToCache(
+        bwds, std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index bc386bedde1..a39060f5750 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -189,7 +189,7 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
   auto it = lrn_fwds.find(key);
   if (it == lrn_fwds.end()) {
     MKLDNNLRNFwd fwd(param, ctx.is_train, in_data);
-    auto ins_ret = lrn_fwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd>
+    auto ins_ret = AddToCache(lrn_fwds, std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd>
                                    (key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
@@ -284,8 +284,7 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
     const mkldnn::memory::desc diff_md =
         out_grad.GetMKLDNNData()->get_primitive_desc().desc();
     MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
-    auto ins_ret =
-        lrn_bwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
+    auto ins_ret = AddToCache(lrn_bwds, std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 1610944304e..9c70dfc58e0 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -243,8 +243,8 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
     const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
     MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, stride_w_,
                          pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, is_train);
-    auto ins_ret = pooling_fwds.insert(
-        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingFwd>(key, fwd));
+    auto ins_ret = AddToCache(
+        pooling_fwds, std::pair<MKLDNNPoolingSignature, MKLDNNPoolingFwd>(key, fwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }
@@ -364,8 +364,8 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
         mkldnn::padding_kind::zero);
     const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
     MKLDNNPoolingBwd bwd(pdesc, with_workspace);
-    auto ins_ret = pooling_bwds.insert(
-        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
+    auto ins_ret = AddToCache(
+        pooling_bwds, std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
     CHECK(ins_ret.second);
     it = ins_ret.first;
   }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services