You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/11/24 17:56:06 UTC

[incubator-mxnet] branch master updated: Fix/env disable mkldnn cache map (#13324)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d64d15  Fix/env disable mkldnn cache map (#13324)
3d64d15 is described below

commit 3d64d15e69ce6afba728a92b18753a868b6c3298
Author: Alexander Zai <az...@gmail.com>
AuthorDate: Sat Nov 24 09:55:53 2018 -0800

    Fix/env disable mkldnn cache map (#13324)
    
    * add flag to disable mkldnn cache
    
    * update docs
    
    * fix typos
    
    * update var name
    
    * fix ordering
    
    * set cache size
    
    * fix log message
    
    * update docs
    
    * fix lint
    
    * fix lint
    
    * fix comparison
    
    * update method name
    
    * fix missing
    
    * fix logging
    
    * remove random item when cache exceeded
    
    * update helper name
    
    * update hash namespace
    
    * make ophash template
    
    * udpate function params
    
    * fix return
    
    * fix return
    
    * update return for helper
    
    * chagne class to typename
    
    * add typename
    
    * fix lint
    
    * update doc
    
    * pass ptr to cache
    
    * retrigger
    
    * retrigger
    
    * retrigger
    
    * change env var name to MXNET_MKLDNN_CACHE_NUM
    
    * fix log env name
    
    * retrigger
---
 docs/faq/env_var.md                            |  4 ++++
 src/common/utils.h                             |  4 ++++
 src/operator/nn/mkldnn/mkldnn_act.cc           | 10 ++--------
 src/operator/nn/mkldnn/mkldnn_base-inl.h       | 17 +++++++++++++++++
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 10 ++--------
 src/operator/nn/mkldnn/mkldnn_concat.cc        |  5 +----
 src/operator/nn/mkldnn/mkldnn_convolution.cc   | 11 +++--------
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc | 10 ++--------
 src/operator/nn/mkldnn/mkldnn_lrn-inl.h        | 10 ++--------
 src/operator/nn/mkldnn/mkldnn_pooling.cc       | 10 ++--------
 10 files changed, 39 insertions(+), 52 deletions(-)

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index e373377..c7d3b28 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -195,6 +195,10 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
   - Values: 0, 1 ```(default=1)```
   - Flag to enable or disable MKLDNN accelerator. On by default.
   - Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```)
+  
+* MXNET_MKLDNN_CACHE_NUM
+  - Values: Int ```(default=-1)```
+  - Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape.
 
 * MXNET_ENFORCE_DETERMINISM
   - Values: 0(false) or 1(true) ```(default=0)```
diff --git a/src/common/utils.h b/src/common/utils.h
index 2688979..92b7c20 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -474,6 +474,10 @@ inline void LogStorageFallback(const nnvm::NodeAttrs& attrs,
 #if MXNET_USE_MKLDNN == 1
   if (!MKLDNNEnvSet()) common::LogOnce("MXNET_MKLDNN_ENABLED flag is off. "
                                        "You can re-enable by setting MXNET_MKLDNN_ENABLED=1");
+  if (GetMKLDNNCacheSize() != -1) common::LogOnce("MXNET_MKLDNN_CACHE_NUM is set."
+                                       "Should only be set if "
+                                       "your model has variable input shapes, "
+                                       "as cache size may grow unbounded");
 #endif
 }
 
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index c914b38..4407058 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -147,10 +147,7 @@ static MKLDNNActForward &GetActForward(const ActivationParam& param,
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem);
-    auto ins_ret = fwds.insert(std::pair<MKLDNNActSignature, MKLDNNActForward>(
-            key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -261,10 +258,7 @@ static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNActBackward bwd(param, in_data, in_mem, *out_grad.GetMKLDNNData());
-    auto ins_ret =
-        bwds.insert(std::pair<MKLDNNActSignature, MKLDNNActBackward>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index d8651c8..17e7409 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -147,6 +147,23 @@ static inline bool MKLDNNEnvSet() {
   return is_mkldnn_enabled;
 }
 
+static inline int GetMKLDNNCacheSize() {
+  static int mkldnn_cache_size = dmlc::GetEnv("MXNET_MKLDNN_CACHE_NUM", -1);
+  return mkldnn_cache_size;
+}
+
+// TODO(alex): (MXNET-1075) Will remove env variable and calculate cache size during runtime
+template<typename S, typename I, typename H>
+static typename std::unordered_map<S, I, H>::iterator AddToCache(
+    std::unordered_map<S, I, H>* cache, const S &key, const I &item) {
+  int mkldnn_cache_size = GetMKLDNNCacheSize();
+  if (mkldnn_cache_size != -1 && static_cast<int>(cache->size()) > mkldnn_cache_size)
+    cache->erase(cache->begin());
+  auto ins_return = cache->insert(std::pair<S, I>(key, item));
+  CHECK(ins_return.second);
+  return ins_return.first;
+}
+
 /*
  * This is to align address to a certain alignment.
  */
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index e605c9b..403baaa 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -198,10 +198,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
     auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
                           (DType) param.eps, flags);
     MKLDNNBNForward fwd(fwd_pd, ctx.is_train);
-    auto ins_ret = fwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNForward>(
-            key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -360,10 +357,7 @@ static MKLDNNBNBackward &GetBNBackward(
   if (it == bwds.end()) {
     auto bwd_pd = _GetBwd(in_mem, diff_mem, param.eps, flags);
     MKLDNNBNBackward bwd(bwd_pd);
-    auto ins_ret =
-        bwds.insert(std::pair<MKLDNNBNSignature, MKLDNNBNBackward>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index af81e1f..03eeb61 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -87,10 +87,7 @@ static MKLDNNConcatFwd &GetConcatForward(
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNConcatFwd fwd(concat_dim, data_md);
-    auto ins_ret = fwds.insert(std::pair<OpSignature, MKLDNNConcatFwd>(
-            key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index 985a965..dd1f3ec 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -338,10 +338,7 @@ MKLDNNConvForward &GetConvFwd(const ConvolutionParam &param,
     full_param.conv_param = param;
     full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
     MKLDNNConvForward fwd(full_param, is_train, data, weights, bias, output);
-    auto ins_ret = fwds.insert(
-        std::pair<MKLDNNConvSignature, MKLDNNConvForward>(key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -566,13 +563,11 @@ static inline MKLDNNConvBackward &GetConvBwd(
   if (bias)
     key.AddSign(*bias);
 
+
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
-    auto ins_ret = bwds.insert(
-        std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index 577fae0..a6d6b24 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -328,10 +328,7 @@ static inline MKLDNNDeconvForward &GetDeconvFwd(
   if (it == fwds.end()) {
     bool has_bias = (bias != nullptr);
     MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
-    auto ins_ret = fwds.insert(
-        std::pair<DeconvSignature, MKLDNNDeconvForward>(key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
@@ -425,10 +422,7 @@ static inline MKLDNNDeconvBackwardData &GetDeconvBwdData(
   auto it = bwds.find(key);
   if (it == bwds.end()) {
     MKLDNNDeconvBackwardData bwd(param, data, weights, output);
-    auto ins_ret = bwds.insert(
-        std::pair<MKLDNNDeconvSignature, MKLDNNDeconvBackwardData>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
index bc386be..31b293a 100644
--- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h
@@ -189,10 +189,7 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
   auto it = lrn_fwds.find(key);
   if (it == lrn_fwds.end()) {
     MKLDNNLRNFwd fwd(param, ctx.is_train, in_data);
-    auto ins_ret = lrn_fwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNFwd>
-                                   (key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&lrn_fwds, key, fwd);
   }
   return it->second;
 }
@@ -284,10 +281,7 @@ static MKLDNNLRNBwd &GetLRNBwd(const LRNParam &param, const NDArray &in_data,
     const mkldnn::memory::desc diff_md =
         out_grad.GetMKLDNNData()->get_primitive_desc().desc();
     MKLDNNLRNBwd bwd(param, in_data_md, diff_md);
-    auto ins_ret =
-        lrn_bwds.insert(std::pair<MKLDNNLRNSignature, MKLDNNLRNBwd>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&lrn_bwds, key, bwd);
   }
   return it->second;
 }
diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc
index 18dc835..f4d681d 100644
--- a/src/operator/nn/mkldnn/mkldnn_pooling.cc
+++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc
@@ -261,10 +261,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
     const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
     MKLDNNPoolingFwd fwd(data, output, kernel_h_, kernel_w_, stride_h_, stride_w_,
                          pad_t_, pad_b_, pad_l_, pad_r_, alg, with_workspace, is_train);
-    auto ins_ret = pooling_fwds.insert(
-        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingFwd>(key, fwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&pooling_fwds, key, fwd);
   }
   return it->second;
 }
@@ -388,10 +385,7 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
         mkldnn::padding_kind::zero);
     const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
     MKLDNNPoolingBwd bwd(pdesc, with_workspace);
-    auto ins_ret = pooling_bwds.insert(
-        std::pair<MKLDNNPoolingSignature, MKLDNNPoolingBwd>(key, bwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    it = AddToCache(&pooling_bwds, key, bwd);
   }
   return it->second;
 }