You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/24 17:55:55 UTC

[GitHub] anirudh2290 closed pull request #13324: Fix/env disable mkldnn cache map

anirudh2290 closed pull request #13324: Fix/env disable mkldnn cache map
URL: https://github.com/apache/incubator-mxnet/pull/13324
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index e373377ee8d..c7d3b284721 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -195,6 +195,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
   - Values: 0, 1 ```(default=1)```
   - Flag to enable or disable MKLDNN accelerator. On by default.
   - Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```)
+  
+* MXNET_MKLDNN_CACHE_NUM
+  - Values: Int ```(default=-1)```
+  - Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape.
 
 * MXNET_ENFORCE_DETERMINISM
   - Values: 0(false) or 1(true) ```(default=0)```
diff --git a/src/common/utils.h b/src/common/utils.h
index 26889792e53..92b7c209318 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -474,6 +474,10 @@ inline void LogStorageFallback(const nnvm::NodeAttrs& attrs,
 #if MXNET_USE_MKLDNN == 1
   if (!MKLDNNEnvSet()) common::LogOnce("MXNET_MKLDNN_ENABLED flag is off. "
                                        "You can re-enable by setting MXNET_MKLDNN_ENABLED=1");
+  if (GetMKLDNNCacheSize() != -1) common::LogOnce("MXNET_MKLDNN_CACHE_NUM is set."
+                                       "Should only be set if "
+                                       "your model has variable input shapes, "
+                                       "as cache size may grow unbounded");
 #endif
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index c914b38b542..440705884b3 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -147,10 +147,7 @@ static MKLDNNActForward &GetActForward(const ActivationParam& param,
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem);
-    auto ins_ret = fwds.insert(std::pair<MKLDNNActSignature, MKLDNNActForward>(
-            key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -261,10 +258,7 @@ static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
-    auto ins_ret =
-        bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index d8651c83d0c..17e74094c2b 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -147,6 +147,23 @@ static inline bool MKLDNNEnvSet() {
   return is_mkldnn_enabled;
 }
 
+static inline int GetMKLDNNCacheSize() {
+  static int mkldnn_cache_size = dmlc::GetEnv("MXNET_MKLDNN_CACHE_NUM", -1);
+  return mkldnn_cache_size;
+}
+
+// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size during runtime
+template<typename S, typename I, typename H>
+static typename std::unordered_map<S, I, H>::iterator AddToCache(
+    std::unordered_map<S, I, H>* cache, const S &key, const I &item) {
+  int mkldnn_cache_size = GetMKLDNNCacheSize();
+  if (mkldnn_cache_size != -1 && static_cast<int>(cache->size()) > mkldnn_cache_size)
+    cache->erase(cache->begin());
+  auto ins_return = cache->insert(std::pair<S, I>(key, item));
+  CHECK(ins_return.second);
+  return ins_return.first;
+}
+
 /*
  * This is to align address to a certain alignment.
  */
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index e605c9bb19c..403baaa94ab 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -198,10 +198,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
     auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
                           (DType) param.eps, flags);
     MKLDNNBNForward fwd(fwd_pd, ctx.is_train);
-    auto ins_ret = fwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNForward>(
-            key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -360,10 +357,7 @@ static MKLDNNBNBackward &GetBNBackward(
   if (it == bwds.end()) {
     auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
     MKLDNNBNBackward bwd(bwd_pd);
-    auto ins_ret =
-        bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index af81e1fe3ee..03eeb61eccb 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -87,10 +87,7 @@ static MKLDNNConcatFwd &GetConcatForward(
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNConcatFwd fwd(concat_dim, data_md);
-    auto ins_ret = fwds.insert(std::pair<OpSignature, MKLDNNConcatFwd>(
-            key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 985a9655b10..dd1f3ec07d7 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -338,10 +338,7 @@ MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
     full_param.conv_param = param;
     full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
     MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output);
-    auto ins_ret = fwds.insert(
-        std::pair<MKLDNNConvSignature, MKLDNNConvForward>(key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -566,13 +563,11 @@ static inline MKLDNNConvBackward &GetConvBwd(
   if (bias)
     key.AddSign(*bias);
 
+
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
-    auto ins_ret = bwds.insert(
-        std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 577fae0d716..a6d6b24235c 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -328,10 +328,7 @@ static inline MKLDNNDeconvForward &GetDeconvFwd(
   if (it == fwds.end()) {
     bool has_bias = (bias != nullptr);
     MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
-    auto ins_ret = fwds.insert(
-        std::pair<DeconvSignature, MKLDNNDeconvForward>(key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -425,10 +422,7 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNDeconvBackwardData bwd(param, data, weights, output);
-    auto ins_ret = bwds.insert(
-        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index bc386bedde1..31b293a14c2 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -189,10 +189,7 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
   auto it = lrn_fwds.find(key);
   if (it == lrn_fwds.end()) {
     MKLDNNLRNFwd fwd(param, ctx.is_train, in_data);
-    auto ins_ret = lrn_fwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd>
-                                   (key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&lrn_fwds, key, fwd);
   }
   return it->second;
 }
@@ -284,10 +281,7 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
     const mkldnn::memory::desc diff_md =
         out_grad.GetMKLDNNData()->get_primitive_desc().desc();
     MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
-    auto ins_ret =
-        lrn_bwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&lrn_bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 18dc835c0d0..f4d681ded78 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -261,10 +261,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
     const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
     MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, stride_w_,
                          pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, is_train);
-    auto ins_ret = pooling_fwds.insert(
-        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingFwd>(key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&pooling_fwds, key, fwd);
   }
   return it->second;
 }
@@ -388,10 +385,7 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
         mkldnn::padding_kind::zero);
     const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
     MKLDNNPoolingBwd bwd(pdesc, with_workspace);
-    auto ins_ret = pooling_bwds.insert(
-        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&pooling_bwds, key, bwd);
   }
   return it->second;
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services