You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/22 20:01:16 UTC

[GitHub] KellenSunderland closed pull request #12824: WIP Test batchnorm conv.

KellenSunderland closed pull request #12824: WIP Test batchnorm conv.
URL: https://github.com/apache/incubator-mxnet/pull/12824
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/base.h b/include/mxnet/base.h
index dfe18996aec..86114323f14 100644
--- a/include/mxnet/base.h
+++ b/include/mxnet/base.h
@@ -225,11 +225,11 @@ struct Context {
   /*!
    * \brief get the free and total available memory on a GPU
    * \param dev the GPU number to query
-   * \param free_mem pointer to the integer holding free GPU memory
-   * \param total_mem pointer to the integer holding total GPU memory
+   * \param free_mem pointer to the uint64_t holding free GPU memory
+   * \param total_mem pointer to the uint64_t holding total GPU memory
    * \return No return value
    */
-  inline static void GetGPUMemoryInformation(int dev, int *free, int *total);
+  inline static void GetGPUMemoryInformation(int dev, uint64_t *free, uint64_t *total);
   /*!
    * Create a pinned CPU context.
    * \param dev_id the device id for corresponding GPU.
@@ -334,8 +334,8 @@ inline int32_t Context::GetGPUCount() {
 #endif
 }
 
-inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
-                                             int *total_mem) {
+inline void Context::GetGPUMemoryInformation(int dev, uint64_t *free_mem,
+                                             uint64_t *total_mem) {
 #if MXNET_USE_CUDA
 
   size_t memF, memT;
@@ -354,8 +354,8 @@ inline void Context::GetGPUMemoryInformation(int dev, int *free_mem,
   e = cudaSetDevice(curDevice);
   CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
 
-  *free_mem = static_cast<int>(memF);
-  *total_mem = static_cast<int>(memT);
+  *free_mem = static_cast<uint64_t>(memF);
+  *total_mem = static_cast<uint64_t>(memT);
 
 #else
   LOG(FATAL)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index dc33c95437f..1e6f73bc80d 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -442,11 +442,11 @@ MXNET_DLL int MXGetGPUCount(int* out);
 /*!
  * \brief get the free and total available memory on a GPU
  * \param dev the GPU number to query
- * \param free_mem pointer to the integer holding free GPU memory
- * \param total_mem pointer to the integer holding total GPU memory
+ * \param free_mem pointer to the uint64_t holding free GPU memory
+ * \param total_mem pointer to the uint64_t holding total GPU memory
  * \return 0 when success, -1 when failure happens
  */
-MXNET_DLL int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem);
+MXNET_DLL int MXGetGPUMemoryInformation(int dev, uint64_t *free_mem, uint64_t *total_mem);
 
 /*!
  * \brief get the MXNet library version as an integer
diff --git a/perl-package/AI-MXNetCAPI/mxnet.i b/perl-package/AI-MXNetCAPI/mxnet.i
index 38665748a0b..3176a86959a 100644
--- a/perl-package/AI-MXNetCAPI/mxnet.i
+++ b/perl-package/AI-MXNetCAPI/mxnet.i
@@ -349,7 +349,7 @@ int MXGetGPUCount(int* out);
  * \param total_mem pointer to the integer holding total GPU memory
  * \return 0 when success, -1 when failure happens
  */
-int MXGetGPUMemoryInformation(int dev, int *out, int *out);
+int MXGetGPUMemoryInformation(int dev, uint64_t *out, uint64_t *out);
 
 
 //-------------------------------------
diff --git a/python/mxnet/context.py b/python/mxnet/context.py
index 61b70532dd7..176f11d5c37 100644
--- a/python/mxnet/context.py
+++ b/python/mxnet/context.py
@@ -258,6 +258,30 @@ def num_gpus():
     check_call(_LIB.MXGetGPUCount(ctypes.byref(count)))
     return count.value
 
+def gpu_memory_info(device_id=0):
+    """Query CUDA for the free and total bytes of GPU global memory.
+
+    Parameters
+    ----------
+    device_id : int, optional
+        The device id of the GPU device.
+
+    Raises
+    ------
+    Will raise an exception on any CUDA error.
+
+    Returns
+    -------
+    (free, total) : (int, int)
+        The number of GPUs.
+
+    """
+    free = ctypes.c_uint64()
+    total = ctypes.c_uint64()
+    dev_id = ctypes.c_int(device_id)
+    check_call(_LIB.MXGetGPUMemoryInformation(dev_id, ctypes.byref(free), ctypes.byref(total)))
+    return (free.value, total.value)
+
 def current_context():
     """Returns the current context.
 
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 56e318097a3..72ec341b5a4 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -122,7 +122,7 @@ int MXGetGPUCount(int* out) {
   API_END();
 }
 
-int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) {
+int MXGetGPUMemoryInformation(int dev, uint64_t *free_mem, uint64_t *total_mem) {
   API_BEGIN();
   Context::GetGPUMemoryInformation(dev, free_mem, total_mem);
   API_END();
diff --git a/src/operator/nn/cudnn/cudnn_algoreg-inl.h b/src/operator/nn/cudnn/cudnn_algoreg-inl.h
index 3b59fd1c3ce..ce4af6a089d 100644
--- a/src/operator/nn/cudnn/cudnn_algoreg-inl.h
+++ b/src/operator/nn/cudnn/cudnn_algoreg-inl.h
@@ -30,6 +30,8 @@
 #include <mutex>
 #include <string>
 #include <vector>
+#include <functional>
+#include <utility>
 #include "../../../common/cuda_utils.h"
 #include "../convolution-inl.h"
 #include "../deconvolution-inl.h"
@@ -65,7 +67,11 @@ class CuDNNAlgo {
 template<typename ParamType>
 class CuDNNAlgoReg {
  public:
-  bool Find(const ParamType &param,
+  using AlgoSetter_t = std::function<void(CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *,
+                            CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *,
+                            CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *)>;
+
+  void FindOrElseRegister(const ParamType &param,
             const std::vector<TShape> &in_shape,
             const std::vector<TShape> &out_shape,
             cudnnDataType_t cudnn_data_type,
@@ -75,7 +81,7 @@ class CuDNNAlgoReg {
             bool add_to_weight,
             CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
             CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
-            CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+            CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt, AlgoSetter_t algo_setter) {
     CHECK(in_shape.size() == 2 || in_shape.size() == 3);
     ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type,
                  cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch, add_to_weight};
@@ -85,45 +91,28 @@ class CuDNNAlgoReg {
       *fwd = i->second.fwd;
       *bwd = i->second.bwd;
       *flt = i->second.flt;
-      return true;
-    }
-    return false;
-  }
-
-  void Register(const ParamType &param,
-                const std::vector<TShape> &in_shape,
-                const std::vector<TShape> &out_shape,
-                cudnnDataType_t cudnn_data_type,
-                cudnnDataType_t cudnn_forward_compute_type,
-                cudnnDataType_t cudnn_backward_compute_type,
-                int sm_arch,
-                bool add_to_weight,
-                const CuDNNAlgo<cudnnConvolutionFwdAlgo_t> &fwd,
-                const CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> &bwd,
-                const CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> &flt) {
-    CHECK(in_shape.size() == 2 || in_shape.size() == 3);
-    ParamKey key{param, in_shape[0], in_shape[1], out_shape[0], cudnn_data_type,
-                 cudnn_forward_compute_type, cudnn_backward_compute_type, sm_arch, add_to_weight};
-    std::lock_guard<std::mutex> guard(lock_);
-    if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
-      LOG(INFO) << "Running performance tests to find the best convolution "
-                   "algorithm, "
-                   "this can take a while... (setting env variable "
-                   "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)";
-      if (reg_.size() >= 1000) {
-        // Many people are very concerned about this warning, so change the warning once.
-        if (!is_warning_autotune_) {
-          LOG(INFO)
-            << "If you see this message in the middle of training, you are "
-            "probably using bucketing. Consider setting env variable "
-            "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning.";
-          is_warning_autotune_ = true;
+    } else {
+      if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
+        LOG(INFO) << "Running performance tests to find the best convolution "
+            "algorithm, "
+            "this can take a while... (setting env variable "
+            "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)";
+        if (reg_.size() >= 1000) {
+          // Many people are very concerned about this warning, so change the warning once.
+          if (!is_warning_autotune_) {
+            LOG(INFO)
+                << "If you see this message in the middle of training, you are "
+                    "probably using bucketing. Consider setting env variable "
+                    "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning.";
+            is_warning_autotune_ = true;
+          }
         }
       }
+      // Call provided function to determine the algos- likely uses cudnnFind() or cudnnGet()
+      algo_setter(fwd, bwd, flt);
+      // Save result so future lookups hit in this registry
+      reg_.insert(std::pair<ParamKey, CudnnAlgorithms>(key, CudnnAlgorithms{*fwd, *bwd, *flt}));
     }
-    reg_[key].fwd = fwd;
-    reg_[key].bwd = bwd;
-    reg_[key].flt = flt;
   }
 
   static CuDNNAlgoReg *Get();
diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h
index 53bd76c9c3e..51abbde8060 100644
--- a/src/operator/nn/cudnn/cudnn_convolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h
@@ -26,6 +26,7 @@
 #ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_
 #define MXNET_OPERATOR_NN_CUDNN_CUDNN_CONVOLUTION_INL_H_
 
+#include <mxnet/storage.h>
 #include <algorithm>
 #include <vector>
 #include <mutex>
@@ -611,236 +612,274 @@ class CuDNNConvolutionOp {
     }
   }
 
-  void SelectAlgo(const RunContext& rctx,
+  void CuDNNAlgoSetter(const RunContext& rctx,
                   const std::vector<TShape>& in_shape,
                   const std::vector<TShape>& out_shape,
                   cudnnDataType_t cudnn_forward_compute_type,
-                  cudnnDataType_t cudnn_backward_compute_type) {
-    if (!CuDNNConvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
-                                       cudnn_forward_compute_type, cudnn_backward_compute_type,
-                                       SMArch(rctx.ctx.dev_id), add_to_weight_,
-                                       &forward_algo_, &back_algo_, &back_algo_w_)) {
-      mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
-      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-      size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
-      #if CUDNN_MAJOR >= 7
-      // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
-      // story: the notion of whether the algo ran in Tensor Core mode is not known.
-      // Since we want to report the Tensor Core mode in the verbose output, we switch
-      // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
-      // that of *Find*(), we can unify the find-vs-get logic by using function pointers.
-
-      // Forward Algorithm Find/Get() v7
-      std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
-      int actual_fwd_algos = 0;
-      auto fwd_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
-                                                : cudnnFindConvolutionForwardAlgorithm;
-      CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
-                                        in_desc_,
-                                        filter_desc_,
-                                        forward_conv_desc_,
-                                        out_desc_,
-                                        fwd_results.size(),
-                                        &actual_fwd_algos,
-                                        fwd_results.data()));
-      fwd_results.resize(actual_fwd_algos);
-      AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
-                      cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
-                                                 workspace_byte, &forward_algo_);
-
-      // Backprop-to-Filter Algorithm Find/Get() v7
-      auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
-      std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
-      int actual_bwd_filter_algos = 0;
-      // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
-      // were summing into the output (i.e. beta != 0).  Get() returned OK algos though.
-      auto bwd_filter_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
-                                                : cudnnFindConvolutionBackwardFilterAlgorithm;
-      CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                  cudnnDataType_t cudnn_backward_compute_type,
+                  CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+                  CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+                  CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+    // Not in algo registry, must determine via *Get*() or *Find*()
+    mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+    CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+    size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
+#if CUDNN_MAJOR >= 7
+    // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
+    // story: the notion of whether the algo ran in Tensor Core mode is not known.
+    // Since we want to report the Tensor Core mode in the verbose output, we switch
+    // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
+    // that of *Find*(), we can unify the find-vs-get logic by using function pointers.
+
+    // Forward Algorithm Find/Get() v7
+    std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
+    int actual_fwd_algos = 0;
+    auto fwd_algo_discoverer =
+      param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
+                                              : cudnnFindConvolutionForwardAlgorithm;
+    CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+                                      in_desc_,
+                                      filter_desc_,
+                                      forward_conv_desc_,
+                                      out_desc_,
+                                      fwd_results.size(),
+                                      &actual_fwd_algos,
+                                      fwd_results.data()));
+    fwd_results.resize(actual_fwd_algos);
+    AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+                    cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+                                               workspace_byte, fwd);
+
+    // Backprop-to-Filter Algorithm Find/Get() v7
+    auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+    std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
+    int actual_bwd_filter_algos = 0;
+    // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
+    // were summing into the output (i.e. beta != 0).  Get() returned OK algos though.
+    auto bwd_filter_algo_discoverer =
+      param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
+                                              : cudnnFindConvolutionBackwardFilterAlgorithm;
+    CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                                             in_desc_,
+                                             out_desc_,
+                                             back_conv_desc_w_,
+                                             filter_desc_,
+                                             bwd_filt_results.size(),
+                                             &actual_bwd_filter_algos,
+                                             bwd_filt_results.data()));
+    bwd_filt_results.resize(actual_bwd_filter_algos);
+    AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+                    cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
+                                                     workspace_byte, flt);
+
+    // Backprop-to-Data Algorithm Find/Get() v7
+    auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+    std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
+    int actual_bwd_data_algos = 0;
+    auto bwd_data_algo_discoverer =
+      param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
+                                              : cudnnFindConvolutionBackwardDataAlgorithm;
+    CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+                                           filter_desc_,
+                                           out_desc_,
+                                           back_conv_desc_,
+                                           in_desc_,
+                                           bwd_data_results.size(),
+                                           &actual_bwd_data_algos,
+                                           bwd_data_results.data()));
+    bwd_data_results.resize(actual_bwd_data_algos);
+    AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+                    cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
+                                                   workspace_byte, bwd);
+#else
+    // CUDNN_MAJOR < 7
+    const int kMaxAlgos = 10;
+    int nalgo = kMaxAlgos;
+    int i = 0;
+    size_t min_memory_needs = 0;
+    // Forward Algorithm Find/Get, v6 and earlier
+    if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
+      // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
+      // supported.  Hard-coded this since the algo find() or get() throws an FPE.
+      fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+    } else if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+      CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
                                                in_desc_,
-                                               out_desc_,
-                                               back_conv_desc_w_,
                                                filter_desc_,
-                                               bwd_filt_results.size(),
-                                               &actual_bwd_filter_algos,
-                                               bwd_filt_results.data()));
-      bwd_filt_results.resize(actual_bwd_filter_algos);
-      AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
-                      cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
-                                   workspace_byte, &back_algo_w_);
-
-      // Backprop-to-Data Algorithm Find/Get() v7
-      auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
-      std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
-      int actual_bwd_data_algos = 0;
-      auto bwd_data_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
-                                                : cudnnFindConvolutionBackwardDataAlgorithm;
-      CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
-                                             filter_desc_,
-                                             out_desc_,
-                                             back_conv_desc_,
-                                             in_desc_,
-                                             bwd_data_results.size(),
-                                             &actual_bwd_data_algos,
-                                             bwd_data_results.data()));
-      bwd_data_results.resize(actual_bwd_data_algos);
-      AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
-                      cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
-                                    workspace_byte, &back_algo_);
-      #else
-      // CUDNN_MAJOR < 7
-      const int kMaxAlgos = 10;
-      int nalgo = kMaxAlgos;
-      int i = 0;
-      size_t min_memory_needs = 0;
-      // Forward Algorithm Find/Get, v6 and earlier
-      if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
-        forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
-      } else if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
-        CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
-                                                 in_desc_,
-                                                 filter_desc_,
-                                                 forward_conv_desc_,
-                                                 out_desc_,
-                                                 CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
-                                                 workspace_byte,
-                                                 &fastest_fwd_algo));
-        forward_algo_.Set(fastest_fwd_algo, false);
-      } else {
-        cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
-                                                        in_desc_,
-                                                        filter_desc_,
-                                                        forward_conv_desc_,
-                                                        out_desc_,
-                                                        kMaxAlgos,
-                                                        &nalgo,
-                                                        fwd_algo));
-        i = 0;
-        while (i < nalgo
-               && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == conv::kLimited
-                       && fwd_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs =
-            (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, fwd_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " forward algorithms with minimum memory requirement "
-                     << min_memory_needs << " bytes have been tried. Workspace size is set to "
-                     << workspace_byte << " bytes, please consider reducing the batch/model size, "
-                     << "or increasing workspace size.";
-        } else {
-          forward_algo_.Set(fwd_algo[i].algo, false);
-        }
+                                               forward_conv_desc_,
+                                               out_desc_,
+                                               CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+                                               workspace_byte,
+                                               &fastest_fwd_algo));
+      fwd->Set(fastest_fwd_algo, false);
+    } else {
+      cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                      in_desc_,
+                                                      filter_desc_,
+                                                      forward_conv_desc_,
+                                                      out_desc_,
+                                                      kMaxAlgos,
+                                                      &nalgo,
+                                                      fwd_algo));
+      i = 0;
+      while (i < nalgo
+             && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == conv::kLimited
+                     && fwd_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs =
+          (i == 0) ? fwd_algo[i].memory : std::min(min_memory_needs, fwd_algo[i].memory);
       }
-      // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
-      if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
-        CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                                          in_desc_,
-                                          out_desc_,
-                                          back_conv_desc_w_,
-                                          filter_desc_,
-                                          CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
-                                          workspace_byte,
-                                          &fastest_bwd_filt_algo));
-        back_algo_w_.Set(fastest_bwd_filt_algo, false);
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " forward algorithms with minimum memory requirement "
+                   << min_memory_needs << " bytes have been tried. Workspace size is set to "
+                   << workspace_byte << " bytes, please consider reducing the batch/model size, "
+                   << "or increasing workspace size.";
       } else {
-        cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                                                               in_desc_,
-                                                               out_desc_,
-                                                               back_conv_desc_w_,
-                                                               filter_desc_,
-                                                               kMaxAlgos,
-                                                               &nalgo,
-                                                               bwd_filter_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == conv::kLimited
-                       && bwd_filter_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs = (i == 0) ?
-                             bwd_filter_algo[i].memory :
-                             std::min(min_memory_needs, bwd_filter_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " backward filter algorithms with minimum memory requirement "
-                     << min_memory_needs << " bytes have been tried. Workspace size is set to "
-                     << workspace_byte << " bytes, please consider reducing the batch/model size, "
-                     << "or increasing workspace size.";
-        } else {
-          back_algo_w_.Set(bwd_filter_algo[i].algo, false);
-        }
+        fwd->Set(fwd_algo[i].algo, false);
       }
-      // Backprop-to-Data Algorithm Get(), v6 and earlier
-      if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
-        CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                                            filter_desc_,
-                                            out_desc_,
-                                            back_conv_desc_,
-                                            in_desc_,
-                                            CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
-                                            workspace_byte,
-                                            &fastest_bwd_data_algo));
-        back_algo_.Set(fastest_bwd_data_algo, false);
-      } else {
-        cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                                                             filter_desc_,
-                                                             out_desc_,
-                                                             back_conv_desc_,
+    }
+    // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
+    if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
+      CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                        in_desc_,
+                                        out_desc_,
+                                        back_conv_desc_w_,
+                                        filter_desc_,
+                                        CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+                                        workspace_byte,
+                                        &fastest_bwd_filt_algo));
+      flt->Set(fastest_bwd_filt_algo, false);
+    } else {
+      cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
                                                              in_desc_,
+                                                             out_desc_,
+                                                             back_conv_desc_w_,
+                                                             filter_desc_,
                                                              kMaxAlgos,
                                                              &nalgo,
-                                                             bwd_data_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == conv::kLimited
-                       && bwd_data_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs = (i == 0) ?
-                             bwd_data_algo[i].memory :
-                             std::min(min_memory_needs, bwd_data_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " backward data algorithms with minimum memory requirement "
-                     << min_memory_needs << " bytes have been tried. Workspace size is set to "
-                     << workspace_byte << " bytes, please consider reducing the batch/model size, "
-                     << "or increasing workspace size.";
-        } else {
-          back_algo_.Set(bwd_data_algo[i].algo, false);
-        }
+                                                             bwd_filter_algo));
+      i = 0;
+      while (i < nalgo
+             && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == conv::kLimited
+                     && bwd_filter_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs = (i == 0) ?
+                           bwd_filter_algo[i].memory :
+                           std::min(min_memory_needs, bwd_filter_algo[i].memory);
       }
-      #endif  // CUDNN_MAJOR < 7
-
-      // Fix for issue #11241
-      int cudnn_find_issue_max_features = 64 * 1024;
-      if (add_to_weight_ && Features(in_shape[conv::kData]) >= cudnn_find_issue_max_features) {
-        this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " backward filter algorithms with minimum memory requirement "
+                   << min_memory_needs << " bytes have been tried. Workspace size is set to "
+                   << workspace_byte << " bytes, please consider reducing the batch/model size, "
+                   << "or increasing workspace size.";
+      } else {
+        flt->Set(bwd_filter_algo[i].algo, false);
       }
+    }
+    // Backprop-to-Data Algorithm Get(), v6 and earlier
+    if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
+      CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                          filter_desc_,
+                                          out_desc_,
+                                          back_conv_desc_,
+                                          in_desc_,
+                                          CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+                                          workspace_byte,
+                                          &fastest_bwd_data_algo));
+      bwd->Set(fastest_bwd_data_algo, false);
+    } else {
+      cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                                           filter_desc_,
+                                                           out_desc_,
+                                                           back_conv_desc_,
+                                                           in_desc_,
+                                                           kMaxAlgos,
+                                                           &nalgo,
+                                                           bwd_data_algo));
+      i = 0;
+      while (i < nalgo
+             && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == conv::kLimited
+                     && bwd_data_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs = (i == 0) ?
+                           bwd_data_algo[i].memory :
+                           std::min(min_memory_needs, bwd_data_algo[i].memory);
+      }
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " backward data algorithms with minimum memory requirement "
+                   << min_memory_needs << " bytes have been tried. Workspace size is set to "
+                   << workspace_byte << " bytes, please consider reducing the batch/model size, "
+                   << "or increasing workspace size.";
+      } else {
+        bwd->Set(bwd_data_algo[i].algo, false);
+      }
+    }
+#endif  // CUDNN_MAJOR < 7
 
-      // An algo specification by the user may be cached here, but another
-      // convolution will match only if identically specified.
-      // We're caching results of *Get* as well as *Find*, but these records
-      // will be held distinctly because param_.cudnn_tune is part of the key.
-      CuDNNConvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
-                                        cudnn_forward_compute_type,
-                                        cudnn_backward_compute_type,
-                                        SMArch(rctx.ctx.dev_id), this->add_to_weight_,
-                                        this->forward_algo_,
-                                        this->back_algo_, this->back_algo_w_);
+    // Fix for issue #11241
+    int cudnn_find_issue_max_features = 64 * 1024;
+    if (add_to_weight_ && Features(in_shape[conv::kData]) >= cudnn_find_issue_max_features) {
+      flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
     }
+  }
+
+  void SelectAlgo(const RunContext& rctx,
+                  const std::vector<TShape>& in_shape,
+                  const std::vector<TShape>& out_shape,
+                  cudnnDataType_t cudnn_forward_compute_type,
+                  cudnnDataType_t cudnn_backward_compute_type) {
+    auto algo_setter = [&](CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+                           CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+                           CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+      if (param_.cudnn_tune.value() == conv::kOff) {
+        // The routine will only be calling cudnnGet, so no need to grab the Storage lock.
+        this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+                              cudnn_forward_compute_type,
+                              cudnn_backward_compute_type,
+                              fwd, bwd, flt);
+      } else {
+        // One potential problem is that cudnnFind() uses cudaMalloc() to directly allocate
+        // I/O and workspace areas, and these allocations may result in an out-of-memory
+        // error even though the StorageMangager free pool is not empty.  Ideally, cudnnFind
+        // would use MXNet's storage allocator for its I/O and workspace areas, instead of using
+        // the area carved out by MXNET_GPU_MEM_POOL_RESERVE.
+        // To get somewhat the same effect as this, we can pre-allocate the areas needed for the
+        // I/Os (possibly triggering a desirable StorageManager::ReleaseAll()), followed by a
+        // DirectFree(), which makes these areas available for cudnn's subsequent cudaMalloc().
+
+        // Allocate for x (or dx), w (or dw) and y (or dy).
+        ReserveElements({in_shape[conv::kData].Size(),
+                         in_shape[conv::kWeight].Size(),
+                         out_shape[conv::kOut].Size()});
+
+        // We're about to call cudnnFind so we need to quiet the system by grabbing
+        // the Storage lock.  Concurrent cudaMalloc's can disrupt the accurate timing
+        // measurements of the algos, and can prevent the cuda driver's proper freeing
+        // of cudnnFind's internal temporary allocations.  Grabbing the lock might also
+        // impede other threads from launching work on the GPU.
+        std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
+        this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+                              cudnn_forward_compute_type,
+                              cudnn_backward_compute_type,
+                              fwd, bwd, flt);
+      }
+    };
+
+    CuDNNConvAlgoReg::Get()->FindOrElseRegister(param_, in_shape, out_shape, dtype_,
+                                       cudnn_forward_compute_type,
+                                       cudnn_backward_compute_type,
+                                       SMArch(rctx.ctx.dev_id), add_to_weight_,
+                                       &forward_algo_, &back_algo_, &back_algo_w_, algo_setter);
+
     // If we're allowing Tensor Core variants of the algos to be considered in
     // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
     // we must change the descriptor to preclude Tensor Core.  Simplest is to
@@ -877,6 +916,7 @@ class CuDNNConvolutionOp {
                << " please consider reducing batch/model size or increasing the workspace size";
   }
 
+
   void GetTempSize(const OpContext& ctx) {
     mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
     size_t back_size = 0, back_size_w = 0;
@@ -975,6 +1015,16 @@ class CuDNNConvolutionOp {
     return c;
   }
 
+  // Make a number of allocations and directly free them, ensuring room for an equivalent set of
+  // cudaMalloc() calls by (say) cudnnFind().  `elements` spec the alloc size in DTypes, not bytes.
+  void ReserveElements(const std::vector<size_t> &elements) {
+    std::vector<Storage::Handle> handles;
+    for (size_t alloc_element : elements)
+        handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU()));
+    for (auto &handle : handles)
+        Storage::Get()->DirectFree(handle);
+  }
+
   std::vector<int> param_stride_;
   std::vector<int> param_dilate_;
   std::vector<int> param_pad_;
diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
index 041bea66f7b..eca92d13295 100644
--- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
@@ -26,6 +26,7 @@
 #ifndef MXNET_OPERATOR_NN_CUDNN_CUDNN_DECONVOLUTION_INL_H_
 #define MXNET_OPERATOR_NN_CUDNN_CUDNN_DECONVOLUTION_INL_H_
 
+#include <mxnet/storage.h>
 #include <algorithm>
 #include <vector>
 #include <mutex>
@@ -538,245 +539,285 @@ class CuDNNDeconvolutionOp {
     }
   }
 
-  void SelectAlgo(const RunContext& rctx,
-                  const std::vector<TShape>& in_shape,
-                  const std::vector<TShape>& out_shape,
-                  cudnnDataType_t cudnn_forward_compute_type,
-                  cudnnDataType_t cudnn_backward_compute_type) {
-    if (!CuDNNDeconvAlgoReg::Get()->Find(param_, in_shape, out_shape, dtype_,
-                                         cudnn_forward_compute_type,
-                                         cudnn_backward_compute_type,
-                                         SMArch(rctx.ctx.dev_id), add_to_weight_,
-                                         &forward_algo_, &back_algo_, &back_algo_w_)) {
-      mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
-      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-      size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
-      #if CUDNN_MAJOR >= 7
-      // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
-      // story: the notion of whether the algo ran in Tensor Core mode is not known.
-      // Since we want to report the Tensor Core mode in the verbose output, we switch
-      // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
-      // that of *Find*(), we can unify the find-vs-get logic by using function pointers.
-
-      // Forward Algorithm Find/Get() v7
-      std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
-      int actual_fwd_algos = 0;
-      auto fwd_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
+  void CuDNNAlgoSetter(const RunContext& rctx,
+                       const std::vector<TShape>& in_shape,
+                       const std::vector<TShape>& out_shape,
+                       cudnnDataType_t cudnn_forward_compute_type,
+                       cudnnDataType_t cudnn_backward_compute_type,
+                       CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+                       CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+                       CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+    // Not in algo registry, must determine via *Get*() or *Find*()
+    mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
+    CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+    size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
+#if CUDNN_MAJOR >= 7
+    // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
+    // story: the notion of whether the algo ran in Tensor Core mode is not known.
+    // Since we want to report the Tensor Core mode in the verbose output, we switch
+    // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
+    // that of *Find*(), we can unify the find-vs-get logic by using function pointers.
+
+    // Forward Algorithm Find/Get() v7
+    std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
+    int actual_fwd_algos = 0;
+    auto fwd_algo_discoverer =
+        param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
                                                 : cudnnFindConvolutionForwardAlgorithm;
-      CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
-                                        out_desc_,
-                                        filter_desc_,
-                                        back_conv_desc_,  // fwd algo used to backprop-to-data
-                                        in_desc_,
-                                        fwd_results.size(),
-                                        &actual_fwd_algos,
-                                        fwd_results.data()));
-      fwd_results.resize(actual_fwd_algos);
-      AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
-                      cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
-                                                 workspace_byte, &forward_algo_);
-
-      // Backprop-to-Filter Algorithm Find/Get() v7
-      auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
-      std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
-      int actual_bwd_filter_algos = 0;
-      // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
-      // were summing into the output (i.e. beta != 0).  Get() returned OK algos though.
-      auto bwd_filter_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
-                                                : cudnnFindConvolutionBackwardFilterAlgorithm;
-      CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
-                                               out_desc_,
-                                               in_desc_,
-                                               back_conv_desc_,
-                                               filter_desc_,
-                                               bwd_filt_results.size(),
-                                               &actual_bwd_filter_algos,
-                                               bwd_filt_results.data()));
-      bwd_filt_results.resize(actual_bwd_filter_algos);
-      AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
-                      cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
-                                                       workspace_byte, &back_algo_w_);
-
-      // Backprop-to-Data Algorithm Find/Get() v7
-      auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
-      std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
-      int actual_bwd_data_algos = 0;
-      auto bwd_data_algo_discoverer =
-        param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
+    CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+                                      out_desc_,
+                                      filter_desc_,
+                                      back_conv_desc_,  // fwd algo used to backprop-to-data
+                                      in_desc_,
+                                      fwd_results.size(),
+                                      &actual_fwd_algos,
+                                      fwd_results.data()));
+    fwd_results.resize(actual_fwd_algos);
+    AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+        cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+                                   workspace_byte, fwd);
+
+    // Backprop-to-Filter Algorithm Find/Get() v7
+    auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+    std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
+    int actual_bwd_filter_algos = 0;
+    // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we
+    // were summing into the output (i.e. beta != 0).  Get() returned OK algos though.
+    auto bwd_filter_algo_discoverer =
+        param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
+                                                  : cudnnFindConvolutionBackwardFilterAlgorithm;
+    CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                                             out_desc_,
+                                             in_desc_,
+                                             back_conv_desc_,
+                                             filter_desc_,
+                                             bwd_filt_results.size(),
+                                             &actual_bwd_filter_algos,
+                                             bwd_filt_results.data()));
+    bwd_filt_results.resize(actual_bwd_filter_algos);
+    AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+        cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
+                                         workspace_byte, flt);
+    // Backprop-to-Data Algorithm Find/Get() v7
+    auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+    std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
+    int actual_bwd_data_algos = 0;
+    auto bwd_data_algo_discoverer =
+        param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
                                                 : cudnnFindConvolutionBackwardDataAlgorithm;
-      CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+    CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+                                           filter_desc_,
+                                           in_desc_,
+                                           forward_conv_desc_,  // bwd algo used in inference
+                                           out_desc_,
+                                           bwd_data_results.size(),
+                                           &actual_bwd_data_algos,
+                                           bwd_data_results.data()));
+    bwd_data_results.resize(actual_bwd_data_algos);
+    AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+        cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
+                                       workspace_byte, bwd);
+#else
+    // CUDNN_MAJOR < 7
+    const int kMaxAlgos = 10;
+    int nalgo = kMaxAlgos;
+    int i = 0;
+    size_t min_memory_needs = 0;
+    // Forward Algorithm Find/Get, v6 and earlier
+    if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
+      // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
+      // supported.  Hard-coded this since the algo find() or get() throws an FPE.
+      fwd->Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+    } else if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+      CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                 out_desc_,
+                                                 filter_desc_,
+                                                 back_conv_desc_,  // fwd algo used in dgrad
+                                                 in_desc_,
+                                                 CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+                                                 workspace_byte,
+                                                 &fastest_fwd_algo));
+      fwd->Set(fastest_fwd_algo, false);
+    } else {
+      cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                    out_desc_,
+                                                    filter_desc_,
+                                                    back_conv_desc_,  // fwd algo used in dgrad
+                                                    in_desc_,
+                                                    kMaxAlgos,
+                                                    &nalgo,
+                                                    fwd_algo));
+      i = 0;
+      while (i < nalgo
+             && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == deconv::kLimited
+                     && fwd_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs = (i == 0) ?
+                           fwd_algo[i].memory :
+                           std::min(min_memory_needs, fwd_algo[i].memory);
+      }
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " forward algorithms"
+                   << " (for use in deconvolution operator backprop-to-data)"
+                   << " with minimum memory requirement " << min_memory_needs
+                   << " bytes have been tried. Workspace size is set to " << workspace_byte
+                   << " bytes, please consider reducing the batch/model size,"
+                   << " or increasing workspace size.";
+      } else {
+        fwd->Set(fwd_algo[i].algo, false);
+      }
+    }
+    // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
+    if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
+      CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                          out_desc_,
+                                          in_desc_,
+                                          back_conv_desc_,
+                                          filter_desc_,
+                                          CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+                                          workspace_byte,
+                                          &fastest_bwd_filt_algo));
+      flt->Set(fastest_bwd_filt_algo, false);
+    } else {
+      cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                                             out_desc_,
+                                                             in_desc_,
+                                                             back_conv_desc_,
+                                                             filter_desc_,
+                                                             kMaxAlgos,
+                                                             &nalgo,
+                                                             bwd_filter_algo));
+      i = 0;
+      while (i < nalgo
+             && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == deconv::kLimited
+                     && bwd_filter_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs = (i == 0) ?
+                           bwd_filter_algo[i].memory :
+                           std::min(min_memory_needs, bwd_filter_algo[i].memory);
+      }
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " backward filter algorithms"
+                   << " (for use in deconvolution operator backprop-to-filter)"
+                   << " with minimum memory requirement " << min_memory_needs
+                   << " bytes have been tried. Workspace size is set to " << workspace_byte
+                   << " bytes, please consider reducing the batch/model size,"
+                   << " or increasing workspace size.";
+      } else {
+        flt->Set(bwd_filter_algo[i].algo, false);
+      }
+    }
+    // Backprop-to-Data Algorithm Get(), v6 and earlier
+    if (!param_.cudnn_tune.value()) {
+      cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
+      CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                            filter_desc_,
+                                            in_desc_,
+                                            forward_conv_desc_,  // bwd algo used for inference
+                                            out_desc_,
+                                            CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+                                            workspace_byte,
+                                            &fastest_bwd_data_algo));
+      bwd->Set(fastest_bwd_data_algo, false);
+    } else {
+      cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
+      CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
                                              filter_desc_,
                                              in_desc_,
                                              forward_conv_desc_,  // bwd algo used in inference
                                              out_desc_,
-                                             bwd_data_results.size(),
-                                             &actual_bwd_data_algos,
-                                             bwd_data_results.data()));
-      bwd_data_results.resize(actual_bwd_data_algos);
-      AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
-                      cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
-                                                     workspace_byte, &back_algo_);
-      #else
-      // CUDNN_MAJOR < 7
-      const int kMaxAlgos = 10;
-      int nalgo = kMaxAlgos;
-      int i = 0;
-      size_t min_memory_needs = 0;
-      // Forward Algorithm Find/Get, v6 and earlier
-      if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
-        forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
-      } else if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
-        CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
-                                                   out_desc_,
-                                                   filter_desc_,
-                                                   back_conv_desc_,  // fwd algo used in dgrad
-                                                   in_desc_,
-                                                   CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
-                                                   workspace_byte,
-                                                   &fastest_fwd_algo));
-        forward_algo_.Set(fastest_fwd_algo, false);
-      } else {
-        cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
-                                                      out_desc_,
-                                                      filter_desc_,
-                                                      back_conv_desc_,  // fwd algo used in dgrad
-                                                      in_desc_,
-                                                      kMaxAlgos,
-                                                      &nalgo,
-                                                      fwd_algo));
-        i = 0;
-        while (i < nalgo
-               && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == deconv::kLimited
-                       && fwd_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs = (i == 0) ?
-                             fwd_algo[i].memory :
-                             std::min(min_memory_needs, fwd_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " forward algorithms"
-                     << " (for use in deconvolution operator backprop-to-data)"
-                     << " with minimum memory requirement " << min_memory_needs
-                     << " bytes have been tried. Workspace size is set to " << workspace_byte
-                     << " bytes, please consider reducing the batch/model size,"
-                     << " or increasing workspace size.";
-        } else {
-          forward_algo_.Set(fwd_algo[i].algo, false);
-        }
+                                             kMaxAlgos,
+                                             &nalgo,
+                                             bwd_data_algo));
+      i = 0;
+      while (i < nalgo
+             && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
+                 || (param_.cudnn_tune.value() == deconv::kLimited
+                     && bwd_data_algo[i].memory > workspace_byte))) {
+        ++i;
+        min_memory_needs = (i == 0) ?
+                           bwd_data_algo[i].memory :
+                           std::min(min_memory_needs, bwd_data_algo[i].memory);
       }
-      // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
-      if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
-        CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                                            out_desc_,
-                                            in_desc_,
-                                            back_conv_desc_,
-                                            filter_desc_,
-                                            CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
-                                            workspace_byte,
-                                            &fastest_bwd_filt_algo));
-        back_algo_w_.Set(fastest_bwd_filt_algo, false);
+      if (i == nalgo) {
+        LOG(FATAL) << nalgo << " backward data algorithms"
+                   << " (for use in deconvolution operator forward inference) with"
+                   << " minimum memory requirement " << min_memory_needs
+                   << " bytes have been tried. Workspace size is set to " << workspace_byte
+                   << " bytes, please consider reducing the batch/model size,"
+                   << " or increasing workspace size.";
       } else {
-        cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                                                               out_desc_,
-                                                               in_desc_,
-                                                               back_conv_desc_,
-                                                               filter_desc_,
-                                                               kMaxAlgos,
-                                                               &nalgo,
-                                                               bwd_filter_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == deconv::kLimited
-                       && bwd_filter_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs = (i == 0) ?
-                             bwd_filter_algo[i].memory :
-                             std::min(min_memory_needs, bwd_filter_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " backward filter algorithms"
-                     << " (for use in deconvolution operator backprop-to-filter)"
-                     << " with minimum memory requirement " << min_memory_needs
-                     << " bytes have been tried. Workspace size is set to " << workspace_byte
-                     << " bytes, please consider reducing the batch/model size,"
-                     << " or increasing workspace size.";
-        } else {
-          back_algo_w_.Set(bwd_filter_algo[i].algo, false);
-        }
+        bwd->Set(bwd_data_algo[i].algo, false);
       }
-      // Backprop-to-Data Algorithm Get(), v6 and earlier
-      if (!param_.cudnn_tune.value()) {
-        cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
-        CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                                              filter_desc_,
-                                              in_desc_,
-                                              forward_conv_desc_,  // bwd algo used for inference
-                                              out_desc_,
-                                              CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
-                                              workspace_byte,
-                                              &fastest_bwd_data_algo));
-        back_algo_.Set(fastest_bwd_data_algo, false);
+    }
+#endif  // CUDNN_MAJOR < 7
+
+    // Fix for issue #11241
+    int cudnn_find_issue_max_features = 64 * 1024;
+    // With deconvolution, the algo sensitivity is to a large number of output features
+    if (add_to_weight_ && Features(out_shape[deconv::kOut]) >= cudnn_find_issue_max_features) {
+      flt->Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+    }
+  }
+
+  void SelectAlgo(const RunContext& rctx,
+                  const std::vector<TShape>& in_shape,
+                  const std::vector<TShape>& out_shape,
+                  cudnnDataType_t cudnn_forward_compute_type,
+                  cudnnDataType_t cudnn_backward_compute_type) {
+    auto algo_setter = [&](CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+                           CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+                           CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
+      if (param_.cudnn_tune.value() == deconv::kOff) {
+        // The routine will only be calling cudnnGet, so no need to grab the Storage lock.
+        this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+                              cudnn_forward_compute_type,
+                              cudnn_backward_compute_type,
+                              fwd, bwd, flt);
       } else {
-        cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                                               filter_desc_,
-                                               in_desc_,
-                                               forward_conv_desc_,  // bwd algo used in inference
-                                               out_desc_,
-                                               kMaxAlgos,
-                                               &nalgo,
-                                               bwd_data_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
-                   || (param_.cudnn_tune.value() == deconv::kLimited
-                       && bwd_data_algo[i].memory > workspace_byte))) {
-          ++i;
-          min_memory_needs = (i == 0) ?
-                             bwd_data_algo[i].memory :
-                             std::min(min_memory_needs, bwd_data_algo[i].memory);
-        }
-        if (i == nalgo) {
-          LOG(FATAL) << nalgo << " backward data algorithms"
-                     << " (for use in deconvolution operator forward inference) with"
-                     << " minimum memory requirement " << min_memory_needs
-                     << " bytes have been tried. Workspace size is set to " << workspace_byte
-                     << " bytes, please consider reducing the batch/model size,"
-                     << " or increasing workspace size.";
-        } else {
-          back_algo_.Set(bwd_data_algo[i].algo, false);
-        }
-      }
-      #endif  // CUDNN_MAJOR < 7
+        // One potential problem is that cudnnFind() uses cudaMalloc() to directly allocate
+        // I/O and workspace areas, and these allocations may result in an out-of-memory
+        // error even though the StorageMangager free pool is not empty.  Ideally, cudnnFind
+        // would use MXNet's storage allocator for its I/O and workspace areas, instead of using
+        // the area carved out by MXNET_GPU_MEM_POOL_RESERVE.
+        // To get somewhat the same effect as this, we can pre-allocate the areas needed for the
+        // I/Os (possibly triggering a desirable StorageManager::ReleaseAll()), followed by a
+        // DirectFree(), which makes these areas available for cudnn's subsequent cudaMalloc().
 
-      // Fix for issue #11241
-      int cudnn_find_issue_max_features = 64 * 1024;
-      // With deconvolution, the algo sensitivity is to a large number of output features
-      if (add_to_weight_ && Features(out_shape[deconv::kOut]) >= cudnn_find_issue_max_features) {
-        this->back_algo_w_.Set(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true);
+        // Allocate for x (or dx), w (or dw) and y (or dy).
+        ReserveElements({in_shape[conv::kData].Size(),
+                         in_shape[conv::kWeight].Size(),
+                         out_shape[conv::kOut].Size()});
+
+        // We're about to call cudnnFind so we need to quiet the system by grabbing
+        // the Storage lock.  Concurrent cudaMalloc's can disrupt the accurate timing
+        // measurements of the algos, and can prevent the cuda driver's proper freeing
+        // of cudnnFind's internal temporary allocations.  Grabbing the lock might also
+        // impede other threads from launching work on the GPU.
+        std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
+        this->CuDNNAlgoSetter(rctx, in_shape, out_shape,
+                              cudnn_forward_compute_type,
+                              cudnn_backward_compute_type,
+                              fwd, bwd, flt);
       }
+    };
+
+    // An algo specification by the user may be cached here, but another
+    // convolution will match only if identically specified.
+    // We're caching results of *Get* as well as *Find*, but these records
+    // will be held distinctly because param_.cudnn_tune is part of the key.
+    CuDNNDeconvAlgoReg::Get()->FindOrElseRegister(param_, in_shape, out_shape, dtype_,
+                                         cudnn_forward_compute_type,
+                                         cudnn_backward_compute_type,
+                                         SMArch(rctx.ctx.dev_id), add_to_weight_,
+                                         &forward_algo_, &back_algo_, &back_algo_w_, algo_setter);
 
-      // An algo specification by the user may be cached here, but another
-      // convolution will match only if identically specified.
-      // We're caching results of *Get* as well as *Find*, but these records
-      // will be held distinctly because param_.cudnn_tune is part of the key.
-      CuDNNDeconvAlgoReg::Get()->Register(param_, in_shape, out_shape, dtype_,
-                                          cudnn_forward_compute_type,
-                                          cudnn_backward_compute_type,
-                                          SMArch(rctx.ctx.dev_id), this->add_to_weight_,
-                                          this->forward_algo_,
-                                          this->back_algo_, this->back_algo_w_);
-    }
     // If we're allowing Tensor Core variants of the algos to be considered in
     // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
     // we must change the descriptor to preclude Tensor Core.  Simplest is to
@@ -818,6 +859,7 @@ class CuDNNDeconvolutionOp {
                << " please consider reducing batch/model size or increasing the workspace size";
   }
 
+
   void GetTempSize(const OpContext& ctx) {
     mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
     size_t back_data_algo_workspace_size = 0;
@@ -921,6 +963,16 @@ class CuDNNDeconvolutionOp {
     return c;
   }
 
+  // Make a number of allocations and directly free them, ensuring room for an equivalent set of
+  // cudaMalloc() calls by (say) cudnnFind().  `elements` spec the alloc size in DTypes, not bytes.
+  void ReserveElements(const std::vector<size_t> &elements) {
+    std::vector<Storage::Handle> handles;
+    for (size_t alloc_element : elements)
+        handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU()));
+    for (auto &handle : handles)
+        Storage::Get()->DirectFree(handle);
+  }
+
   std::vector<int> param_stride_;
   std::vector<int> param_dilate_;
 
diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h
index f3a9b16cdd8..8694ab23a87 100644
--- a/src/storage/pooled_storage_manager.h
+++ b/src/storage/pooled_storage_manager.h
@@ -57,6 +57,7 @@ class GPUPooledStorageManager final : public StorageManager {
   GPUPooledStorageManager() {
     reserve_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_RESERVE", 5);
     page_size_ = dmlc::GetEnv("MXNET_GPU_MEM_POOL_PAGE_SIZE", 4096);
+    large_alloc_round_size_ = dmlc::GetEnv("MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE", 2 * 1024 * 1024);
     if (page_size_ < NDEV) {
       LOG(FATAL) << "MXNET_GPU_MEM_POOL_PAGE_SIZE cannot be set to a value smaller than " << NDEV \
                  << ". Got " << page_size_ << ".";
@@ -80,7 +81,7 @@ class GPUPooledStorageManager final : public StorageManager {
  private:
   void DirectFreeNoLock(Storage::Handle handle) {
     cudaError_t err = cudaFree(handle.dptr);
-    size_t size = std::max(handle.size, page_size_);
+    size_t size = RoundAllocSize(handle.size);
     // ignore unloading error, as memory has already been recycled
     if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
       LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
@@ -88,12 +89,31 @@ class GPUPooledStorageManager final : public StorageManager {
     used_memory_ -= size;
   }
 
+  // Round a value 'x' up to the next multiple of 'multiple'
+  size_t RoundToMultiple(size_t x, size_t multiple) {
+    size_t retVal = ((x + multiple - 1) / multiple) * multiple;
+    return retVal;
+  }
+
+  size_t RoundAllocSize(size_t size) {
+    // Round up small allocs to the page_size_ to consolidate the pool lookups
+    size = std::max(size, page_size_);
+    // To ensure proper freeing under some driver variants, make sure
+    // large allocs entirely occupy their slabs, which cannot then be
+    // locked by smaller permanent allocations sharing the slab.
+    if (size > large_alloc_round_size_)
+      size = RoundToMultiple(size, large_alloc_round_size_);
+    return size;
+  }
+
  private:
   void ReleaseAll();
   // used memory
   size_t used_memory_ = 0;
   // page size
   size_t page_size_;
+  // size that large allocations should be rounded to, for proper freeing.
+  size_t large_alloc_round_size_;
   // percentage of reserved memory
   int reserve_;
   // number of devices
@@ -105,7 +125,7 @@ class GPUPooledStorageManager final : public StorageManager {
 
 void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
   std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
-  size_t size = std::max(handle->size, page_size_);
+  size_t size = RoundAllocSize(handle->size);
   auto&& reuse_it = memory_pool_.find(size);
   if (reuse_it == memory_pool_.end() || reuse_it->second.size() == 0) {
     size_t free, total;
@@ -130,7 +150,7 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
 
 void GPUPooledStorageManager::Free(Storage::Handle handle) {
   std::lock_guard<std::mutex> lock(Storage::Get()->GetMutex(Context::kGPU));
-  size_t size = std::max(handle.size, page_size_);
+  size_t size = RoundAllocSize(handle.size);
   auto&& reuse_pool = memory_pool_[size];
   reuse_pool.push_back(handle.dptr);
 }
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 8394276c8ef..8ada95b3bf2 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -25,12 +25,14 @@
 import mxnet as mx
 import numpy as np
 import unittest
+import math
 from nose.tools import assert_raises
 from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal
 from mxnet.base import MXNetError
 from mxnet import autograd
 from numpy.testing import assert_allclose
 
+
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled
@@ -57,7 +59,7 @@ def check_rnn_layer(layer):
     for g, c in zip(gs, cs):
         assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
 
-
+@with_seed()
 def check_rnn_layer_w_rand_inputs(layer):
     layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
     x = mx.nd.uniform(shape=(10, 16, 30))
@@ -186,7 +188,7 @@ def _syncParameters(bn1, bn2, ctx):
     input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0)
     assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3)
 
-
+@with_seed()
 def test_sync_batchnorm():
     def get_num_devices():
         for i in range(100):
@@ -203,6 +205,7 @@ def get_num_devices():
         _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
                                 num_devices=ndev, cuda=True)
 
+
 @with_seed()
 def test_symbol_block_fp16():
     # Test case to verify if initializing the SymbolBlock from a model with params
@@ -233,6 +236,45 @@ def test_symbol_block_fp16():
             break
     assert np.dtype(net_fp16.params[param_name].dtype) == np.dtype(np.float16)
 
+
+@with_seed()
+def test_large_models():
+    ctx = default_context()
+    # Create model
+    net = gluon.nn.HybridSequential()
+
+    largest_num_features = 256
+    with net.name_scope():
+        net.add(nn.Conv2D(largest_num_features, 3))
+
+    net.hybridize()
+    net.initialize(mx.init.Normal(sigma=0.01), ctx=ctx)
+
+    # Compute the height (=width) of the square tensor of the given size in bytes
+    def tensor_size(big_tensor_bytes):
+        bytes_per_float = 4
+        sz = int(math.sqrt(big_tensor_bytes / largest_num_features / bytes_per_float))
+        return (sz // 100) * 100
+
+    # The idea is to create models with large tensors of (say) 20% of the total memory.
+    # This in the past has given cudnnFind() trouble when it needed to allocate similar I/O's
+    # from the area carved out by the MXNET_GPU_MEM_POOL_RESERVE setting (by default 5%).
+    (free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id)
+    start_size = tensor_size(0.20 * total_mem_bytes)
+    num_trials = 10
+    sys.stderr.write(' testing global memory of size {} ... '.format(total_mem_bytes))
+    sys.stderr.flush()
+    for i in range(num_trials):
+        sz = start_size - 10 * i
+        (height, width) = (sz,sz)
+        sys.stderr.write(" {}x{} ".format(height,width))
+        sys.stderr.flush()
+        data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height, width),
+                                    ctx=ctx, dtype="float32")
+        # Evaluate model
+        net(data_in).asnumpy()
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index a6932d252dd..02dc6cee4a6 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -1931,7 +1931,6 @@ def hybrid_forward(self, F, x):
 
 
 @with_seed()
-@unittest.skip('Test failing, tracked by https://github.com/apache/incubator-mxnet/issues/12715')
 def test_slice_batchnorm():
     class Net(gluon.HybridBlock):
         def __init__(self, slice, **kwargs):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services