You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/09 23:54:31 UTC

[incubator-mxnet] branch master updated: Tensorcore conv deconv support (#7347)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e61fa7c  Tensorcore conv deconv support (#7347)
e61fa7c is described below

commit e61fa7c86db06ac960e76bbe577480d446084502
Author: Dick Carter <di...@comcast.net>
AuthorDate: Wed Aug 9 16:54:24 2017 -0700

    Tensorcore conv deconv support (#7347)
    
    * Adds support for TensorCore in conv and deconv.
    
    * Style correction: Adding '_' to cudnn_tensor_core member variable in conv and deconv.
    
    * Style correction: Adding '_' to cudnn_tensor_core member variable in rnn.
    
    * Adding missing includes needed for compile on Windows.
    
    * Empty commit to test CI failure repeatability.
    
    * Changed cached algo selections to be per sm_arch, not device_id.
---
 src/common/cuda_utils.h                | 126 ++++++++++
 src/operator/convolution.cu            |  14 +-
 src/operator/cudnn_algoreg-inl.h       |  51 +++-
 src/operator/cudnn_convolution-inl.h   | 436 ++++++++++++++++++++++-----------
 src/operator/cudnn_deconvolution-inl.h | 422 +++++++++++++++++++++----------
 src/operator/cudnn_rnn-inl.h           |  46 +++-
 src/operator/deconvolution.cu          |   6 +-
 7 files changed, 806 insertions(+), 295 deletions(-)

diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index 3c4d1a8..2879ab3 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -25,6 +25,8 @@
 #define MXNET_COMMON_CUDA_UTILS_H_
 
 #include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <dmlc/optional.h>
 #include <mshadow/base.h>
 
 /*! \brief Macros/inlines to assist CLion to parse Cuda files (*.cu, *.cuh) */
@@ -175,6 +177,79 @@ inline const char* CurandGetErrorString(curandStatus_t status) {
         << "cuRAND: " << common::cuda::CurandGetErrorString(e); \
   }
 
+/*!
+ * \brief Determine major version number of the gpu's cuda compute architecture.
+ * \param device_id The device index of the cuda-capable gpu of interest.
+ * \return the major version number of the gpu's cuda compute architecture.
+ */
+inline int ComputeCapabilityMajor(int device_id) {
+  int major = 0;
+  CUDA_CALL(cudaDeviceGetAttribute(&major,
+                                   cudaDevAttrComputeCapabilityMajor, device_id));
+  return major;
+}
+
+/*!
+ * \brief Determine minor version number of the gpu's cuda compute architecture.
+ * \param device_id The device index of the cuda-capable gpu of interest.
+ * \return the minor version number of the gpu's cuda compute architecture.
+ */
+inline int ComputeCapabilityMinor(int device_id) {
+  int minor = 0;
+  CUDA_CALL(cudaDeviceGetAttribute(&minor,
+                                   cudaDevAttrComputeCapabilityMinor, device_id));
+  return minor;
+}
+
+/*!
+ * \brief Return the integer SM architecture (e.g. Volta = 70).
+ * \param device_id The device index of the cuda-capable gpu of interest.
+ * \return the gpu's cuda compute architecture as an int.
+ */
+inline int SMArch(int device_id) {
+  auto major = ComputeCapabilityMajor(device_id);
+  auto minor = ComputeCapabilityMinor(device_id);
+  return 10 * major + minor;
+}
+
+/*!
+ * \brief Determine whether a cuda-capable gpu's architecture supports float16 math.
+ * \param device_id The device index of the cuda-capable gpu of interest.
+ * \return whether the gpu's architecture supports float16 math.
+ */
+inline bool SupportsFloat16Compute(int device_id) {
+  // Kepler and most Maxwell GPUs do not support fp16 compute
+  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
+  int computeCapabilityMinor = ComputeCapabilityMinor(device_id);
+  return (computeCapabilityMajor > 5) ||
+      (computeCapabilityMajor == 5 && computeCapabilityMinor >= 3);
+}
+
+/*!
+ * \brief Determine whether a cuda-capable gpu's architecture supports Tensor Core math.
+ * \param device_id The device index of the cuda-capable gpu of interest.
+ * \return whether the gpu's architecture supports Tensor Core math.
+ */
+inline bool SupportsTensorCore(int device_id) {
+  // Volta (sm_70) supports TensorCore algos
+  int computeCapabilityMajor = ComputeCapabilityMajor(device_id);
+  return (computeCapabilityMajor >= 7);
+}
+
+// The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE
+#define MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT true
+
+/*!
+ * \brief Returns global policy for TensorCore algo use.
+ * \return whether to allow TensorCore algo (if not specified by the Operator locally).
+ */
+inline bool GetEnvAllowTensorCore() {
+  // Use of optional<bool> here permits: "0", "1", "true" and "false" to all be legal.
+  bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT;
+  return dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
+                      dmlc::optional<bool>(default_value)).value();
+}
+
 #endif  // MXNET_USE_CUDA
 
 #if MXNET_USE_CUDNN
@@ -187,6 +262,57 @@ inline const char* CurandGetErrorString(curandStatus_t status) {
     CHECK_EQ(e, CUDNN_STATUS_SUCCESS) << "cuDNN: " << cudnnGetErrorString(e); \
   }
 
+/*!
+ * \brief Return max number of perf structs cudnnFindConvolutionForwardAlgorithm()
+ *        may want to populate.
+ * \param cudnn_handle cudnn handle needed to perform the inquiry.
+ * \return max number of perf structs cudnnFindConvolutionForwardAlgorithm() may
+ *         want to populate.
+ */
+inline int MaxForwardAlgos(cudnnHandle_t cudnn_handle) {
+#if CUDNN_MAJOR >= 7
+  int max_algos = 0;
+  CUDNN_CALL(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_algos));
+  return max_algos;
+#else
+  return 10;
+#endif
+}
+
+/*!
+ * \brief Return max number of perf structs cudnnFindConvolutionBackwardFilterAlgorithm()
+ *        may want to populate.
+ * \param cudnn_handle cudnn handle needed to perform the inquiry.
+ * \return max number of perf structs cudnnFindConvolutionBackwardFilterAlgorithm() may
+ *         want to populate.
+ */
+inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
+#if CUDNN_MAJOR >= 7
+  int max_algos = 0;
+  CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(cudnn_handle, &max_algos));
+  return max_algos;
+#else
+  return 10;
+#endif
+}
+
+/*!
+ * \brief Return max number of perf structs cudnnFindConvolutionBackwardDataAlgorithm()
+ *        may want to populate.
+ * \param cudnn_handle cudnn handle needed to perform the inquiry.
+ * \return max number of perf structs cudnnFindConvolutionBackwardDataAlgorithm() may
+ *         want to populate.
+ */
+inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) {
+#if CUDNN_MAJOR >= 7
+  int max_algos = 0;
+  CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithmMaxCount(cudnn_handle, &max_algos));
+  return max_algos;
+#else
+  return 10;
+#endif
+}
+
 #endif  // MXNET_USE_CUDNN
 
 // Overload atomicAdd to work for floats on all architectures
diff --git a/src/operator/convolution.cu b/src/operator/convolution.cu
index bf5f305..ab35484 100644
--- a/src/operator/convolution.cu
+++ b/src/operator/convolution.cu
@@ -71,14 +71,14 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
       int backward_compute_type = desired_backward_compute_type;
       bool convolutionIsSupported = CuDNNConvolutionOp<DType>::Supports(param,
                                           forward_compute_type,
-                                          backward_compute_type);
+                                          backward_compute_type, ctx);
 
       // If cuDNN can't handle this case with fp16 backprop kernels, try fp32 backprop.
       if (!convolutionIsSupported && backward_compute_type == mshadow::kFloat16) {
         backward_compute_type = mshadow::kFloat32;
         convolutionIsSupported = CuDNNConvolutionOp<DType>::Supports(param,
                                           forward_compute_type,
-                                          backward_compute_type);
+                                          backward_compute_type, ctx);
       }
 
       // If cuDNN can't handle this case with fp16 forward kernels, try fp32
@@ -86,16 +86,16 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
         forward_compute_type = mshadow::kFloat32;
         convolutionIsSupported = CuDNNConvolutionOp<DType>::Supports(param,
                                           forward_compute_type,
-                                          backward_compute_type);
+                                          backward_compute_type, ctx);
       }
       if (!convolutionIsSupported) {
         LOG(WARNING) << "This convolution is not supported by cudnn, MXNET convolution is applied.";
         op = new ConvolutionOp<gpu, DType>(param);
       } else {
-        if ((forward_compute_type != desired_forward_compute_type) ||
-            (backward_compute_type != desired_backward_compute_type))
-          LOG(WARNING) << "True fp16 convolution by cudnn not supported in this configuration.  " <<
-                       "Falling back to pseudo fp16.";
+        if (forward_compute_type != desired_forward_compute_type)
+          LOG(WARNING) << "Requested forward compute precision not supported, using fp32.";
+        if (backward_compute_type != desired_backward_compute_type)
+          LOG(WARNING) << "Requested backward compute precision not supported, using fp32.";
         op = new CuDNNConvolutionOp<DType>(param,
                                          forward_compute_type,
                                          backward_compute_type,
diff --git a/src/operator/cudnn_algoreg-inl.h b/src/operator/cudnn_algoreg-inl.h
index 1078d65..dc5db6b 100644
--- a/src/operator/cudnn_algoreg-inl.h
+++ b/src/operator/cudnn_algoreg-inl.h
@@ -32,11 +32,35 @@
 #include "../common/cuda_utils.h"
 #include "./convolution-inl.h"
 #include "./deconvolution-inl.h"
-
 namespace mxnet {
 namespace op {
 #if MXNET_USE_CUDNN == 1
 
+/*!
+ * \brief A cuDNN algorithm: an algo number and whether it should be run in TENSOR CORE mode.
+ */
+template <typename CuDNNAlgoType>
+class CuDNNAlgo {
+ public:
+  CuDNNAlgo() :
+      algo_number_(static_cast<CuDNNAlgoType>(0)),
+      is_tensor_core_algo_(false) { }
+  void Set(CuDNNAlgoType algo, bool is_tensor_core) {
+    algo_number_ = algo;
+    is_tensor_core_algo_ = is_tensor_core;
+  }
+  CuDNNAlgoType AlgoNumber() const { return algo_number_; }
+  bool IsTensorCoreAlgo() const { return is_tensor_core_algo_; }
+  #if CUDNN_MAJOR >= 7
+  cudnnMathType_t MathType() {
+    return IsTensorCoreAlgo() ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
+  }
+  #endif
+ private:
+  CuDNNAlgoType algo_number_;
+  bool is_tensor_core_algo_;
+};
+
 class CuDNNAlgoReg {
  public:
   template <typename Param>
@@ -44,7 +68,8 @@ class CuDNNAlgoReg {
                      const std::vector<TShape> &out_shape,
                      cudnnDataType_t cudnn_data_type,
                      cudnnDataType_t cudnn_forward_compute_type,
-                     cudnnDataType_t cudnn_backward_compute_type) {
+                     cudnnDataType_t cudnn_backward_compute_type,
+                     int sm_arch) {
     std::ostringstream oss;
     oss << "inputs=";
     for (auto &i : in_shape)
@@ -58,12 +83,15 @@ class CuDNNAlgoReg {
     oss << "cudnn_data_type=" << cudnn_data_type << ";";
     oss << "cudnn_forward_compute_type=" << cudnn_forward_compute_type << ";";
     oss << "cudnn_backward_compute_type=" << cudnn_backward_compute_type << ";";
+    // All GPUs of the same compute capability (SM arch) share an algo selection.
+    oss << "sm_arch=" << sm_arch << ";";
     return oss.str();
   }
 
-  bool Find(std::string key, cudnnConvolutionFwdAlgo_t *fwd,
-            cudnnConvolutionBwdDataAlgo_t *bwd,
-            cudnnConvolutionBwdFilterAlgo_t *flt) {
+  bool Find(std::string key,
+            CuDNNAlgo<cudnnConvolutionFwdAlgo_t> *fwd,
+            CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> *bwd,
+            CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> *flt) {
     std::lock_guard<std::mutex> guard(lock_);
     auto i = reg_.find(key);
     if (i != reg_.end()) {
@@ -75,9 +103,10 @@ class CuDNNAlgoReg {
     return false;
   }
 
-  void Register(std::string key, cudnnConvolutionFwdAlgo_t fwd,
-                cudnnConvolutionBwdDataAlgo_t bwd,
-                cudnnConvolutionBwdFilterAlgo_t flt) {
+  void Register(std::string key,
+                const CuDNNAlgo<cudnnConvolutionFwdAlgo_t> &fwd,
+                const CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> &bwd,
+                const CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> &flt) {
     std::lock_guard<std::mutex> guard(lock_);
     if (reg_.size() % 50 == 0) {
       LOG(INFO) << "Running performance tests to find the best convolution "
@@ -100,9 +129,9 @@ class CuDNNAlgoReg {
 
  private:
   struct CudnnAlgorithms {
-    cudnnConvolutionFwdAlgo_t fwd;
-    cudnnConvolutionBwdDataAlgo_t bwd;
-    cudnnConvolutionBwdFilterAlgo_t flt;
+    CuDNNAlgo<cudnnConvolutionFwdAlgo_t> fwd;
+    CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> bwd;
+    CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> flt;
   };
 
   std::mutex lock_;
diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h
index e966b56..4282784 100644
--- a/src/operator/cudnn_convolution-inl.h
+++ b/src/operator/cudnn_convolution-inl.h
@@ -59,6 +59,8 @@ class CuDNNConvolutionOp : public Operator {
     init_cudnn_ = false;
     init_temp_size_ = false;
     dtype_ = DataType<DType>::kCudnnFlag;
+    // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
+    cudnn_tensor_core_ = DataType<DType>::kFlag == kFloat16 && GetEnvAllowTensorCore();
 
 #if CUDNN_MAJOR >= 5
     MSHADOW_LAYOUT_SWITCH(param_.layout.value(), Layout, {
@@ -69,7 +71,7 @@ class CuDNNConvolutionOp : public Operator {
       << "Need CuDNN > 5.0 for layout support";
 #endif
     // Double check to make sure this class supports the operation
-    if (!Supports(param, forward_compute_type, backward_compute_type))
+    if (!Supports(param, forward_compute_type, backward_compute_type, ctx))
       LOG(FATAL) << "Need CuDNN >= 6.0 for dilated convolution.";
 
     InitDescriptors(ctx, in_shape, out_shape,
@@ -95,7 +97,8 @@ class CuDNNConvolutionOp : public Operator {
       CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc_));
       CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc_));
       CUDNN_CALL(cudnnDestroyConvolutionDescriptor(forward_conv_desc_));
-      CUDNN_CALL(cudnnDestroyConvolutionDescriptor(backward_conv_desc_));
+      CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_));
+      CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_w_));
     }
   }
 
@@ -148,7 +151,7 @@ class CuDNNConvolutionOp : public Operator {
                                        filter_desc_,
                                        wmat_ptr + weight_offset_ * g,
                                        forward_conv_desc_,
-                                       algo_,
+                                       forward_algo_.AlgoNumber(),
                                        workspace.dptr_,
                                        workspace_size,
                                        req[conv::kOut] == kAddTo? &beta_add : &beta,
@@ -244,8 +247,8 @@ class CuDNNConvolutionOp : public Operator {
                data_ptr + data_offset_ * g,
                out_desc_,
                grad_ptr + out_offset_ * g,
-               backward_conv_desc_,
-               back_algo_w_,
+               back_conv_desc_w_,
+               back_algo_w_.AlgoNumber(),
                workspace.dptr_,
                workspace_size,
                req[conv::kWeight] == kAddTo? &beta_add : &beta,
@@ -258,8 +261,8 @@ class CuDNNConvolutionOp : public Operator {
                data_ptr + data_offset_ * g,
                out_desc_,
                grad_ptr + out_offset_ * g,
-               backward_conv_desc_,
-               back_algo_w_,
+               back_conv_desc_w_,
+               back_algo_w_.AlgoNumber(),
                workspace.dptr_,
                workspace_size,
                req[conv::kWeight] == kAddTo? &beta_add : &beta,
@@ -275,8 +278,8 @@ class CuDNNConvolutionOp : public Operator {
                wmat_ptr + weight_offset_ * g,
                out_desc_,
                grad_ptr + out_offset_ * g,
-               backward_conv_desc_,
-               back_algo_,
+               back_conv_desc_,
+               back_algo_.AlgoNumber(),
                workspace.dptr_,
                workspace_size,
                req[conv::kData] == kAddTo? &beta_add : &beta,
@@ -289,8 +292,8 @@ class CuDNNConvolutionOp : public Operator {
                wmat_ptr + weight_offset_ * g,
                out_desc_,
                grad_ptr + out_offset_ * g,
-               backward_conv_desc_,
-               back_algo_,
+               back_conv_desc_,
+               back_algo_.AlgoNumber(),
                workspace.dptr_,
                workspace_size,
                req[conv::kData] == kAddTo? &beta_add : &beta,
@@ -308,7 +311,8 @@ class CuDNNConvolutionOp : public Operator {
  */
   static bool Supports(ConvolutionParam param,
                        int forward_compute_type,
-                       int backward_compute_type) {
+                       int backward_compute_type,
+                       const Context &ctx) {
     using namespace mshadow;
 
     // NDHWC not supported, NHWC not supported in true fp16
@@ -318,6 +322,12 @@ class CuDNNConvolutionOp : public Operator {
     if (layout_val == kNDHWC || layout_val == kNHWC && true_fp16)
       return false;
 
+    // Permits graceful fallback to pseudo-fp16 on heterogenous systems
+    if (!SupportsFloat16Compute(ctx.dev_id) &&
+        (forward_compute_type == kFloat16 || backward_compute_type == kFloat16)) {
+      return false;
+    }
+
     // The factor by which the effective filter size grows based on dilation.
     auto filterDilationFactor = param.dilate.Size();
 
@@ -355,7 +365,8 @@ class CuDNNConvolutionOp : public Operator {
     CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc_));
     CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc_));
     CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_));
-    CUDNN_CALL(cudnnCreateConvolutionDescriptor(&backward_conv_desc_));
+    CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_));
+    CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_w_));
 
     TShape dshape = in_shape[conv::kData];
     TShape wshape = in_shape[conv::kWeight];
@@ -379,7 +390,16 @@ class CuDNNConvolutionOp : public Operator {
                                                param_.dilate[1],
                                                CUDNN_CROSS_CORRELATION,
                                                cudnn_forward_compute_type));
-      CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_,
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_,
+                                               param_.pad[0],
+                                               param_.pad[1],
+                                               param_.stride[0],
+                                               param_.stride[1],
+                                               param_.dilate[0],
+                                               param_.dilate[1],
+                                               CUDNN_CROSS_CORRELATION,
+                                               cudnn_backward_compute_type));
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_,
                                                param_.pad[0],
                                                param_.pad[1],
                                                param_.stride[0],
@@ -397,7 +417,15 @@ class CuDNNConvolutionOp : public Operator {
                                                param_.dilate[0],
                                                param_.dilate[1],
                                                CUDNN_CROSS_CORRELATION));
-      CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_,
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_,
+                                               param_.pad[0],
+                                               param_.pad[1],
+                                               param_.stride[0],
+                                               param_.stride[1],
+                                               param_.dilate[0],
+                                               param_.dilate[1],
+                                               CUDNN_CROSS_CORRELATION));
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_,
                                                param_.pad[0],
                                                param_.pad[1],
                                                param_.stride[0],
@@ -460,7 +488,15 @@ class CuDNNConvolutionOp : public Operator {
                                                CUDNN_CROSS_CORRELATION,
                                                cudnn_forward_compute_type));
 
-      CUDNN_CALL(cudnnSetConvolutionNdDescriptor(backward_conv_desc_,
+      CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_,
+                                               3,
+                                               param_pad_.data(),
+                                               param_stride_.data(),
+                                               param_dilate_.data(),
+                                               CUDNN_CROSS_CORRELATION,
+                                               cudnn_backward_compute_type));
+
+      CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_w_,
                                                3,
                                                param_pad_.data(),
                                                param_stride_.data(),
@@ -484,6 +520,14 @@ class CuDNNConvolutionOp : public Operator {
                               param_.layout.value(), kNCDHW);
       oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW);
     }
+    // Set "allow tensor core" flag in convolution descriptors, if available.
+    #if CUDNN_MAJOR >= 7
+      cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH
+                                                    : CUDNN_DEFAULT_MATH;
+      CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type));
+      CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type));
+      CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type));
+    #endif
     dshape[1] /= param_.num_group;
     oshape[1] /= param_.num_group;
     weight_offset_ = wshape.Size();
@@ -538,122 +582,234 @@ class CuDNNConvolutionOp : public Operator {
                   cudnnDataType_t cudnn_backward_compute_type) {
     std::string key = CuDNNAlgoReg::Get()->GetKey(param_, in_shape, out_shape, dtype_,
                                                   cudnn_forward_compute_type,
-                                                  cudnn_backward_compute_type);
-    if (CuDNNAlgoReg::Get()->Find(key, &algo_, &back_algo_, &back_algo_w_))
-      return;
-
-    Engine::VarHandle var = Engine::Get()->NewVariable();
-    Engine::Get()->PushSync([=](RunContext rctx) {
-      mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
-      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-      size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
-      if (!param_.cudnn_tune.value()) {
-        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
-        if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-          algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-        } else {
-          CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
-                 in_desc_,
-                 filter_desc_,
-                 forward_conv_desc_,
-                 out_desc_,
-                 CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
-                 workspace_byte,
-                 &(this->algo_)));
-        }
-        CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                 in_desc_,
-                 out_desc_,
-                 backward_conv_desc_,
-                 filter_desc_,
-                 CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
-                 workspace_byte,
-                 &(this->back_algo_w_)));
-          CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                 filter_desc_,
-                 out_desc_,
-                 backward_conv_desc_,
-                 in_desc_,
-                 CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
-                 workspace_byte,
-                 &(this->back_algo_)));
-      } else {
-        const int kMaxAlgos = 10;
-        int nalgo = kMaxAlgos;
-        int i;
-
-        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
-        if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-          algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-        } else {
-          cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
-          CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
-                 in_desc_,
-                 filter_desc_,
-                 forward_conv_desc_,
-                 out_desc_,
-                 kMaxAlgos,
-                 &nalgo,
-                 fwd_algo));
-          i = 0;
-          while (i < nalgo
-               && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
-               || (param_.cudnn_tune.value() == conv::kLimited
-               && fwd_algo[i].memory > workspace_byte))) ++i;
-          if (i == nalgo) {
-            LOG(FATAL) << "Failed to find a forward convolution algorithm.";
+                                                  cudnn_backward_compute_type,
+                                                  SMArch(ctx.dev_id));
+    if (!CuDNNAlgoReg::Get()->Find(key, &forward_algo_, &back_algo_, &back_algo_w_)) {
+      // Not in algo registry, must determine via *Get*() or *Find*()
+      Engine::VarHandle var = Engine::Get()->NewVariable();
+      Engine::Get()->PushSync([=](RunContext rctx) {
+        mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
+        CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+        size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
+        #if CUDNN_MAJOR >= 7
+          // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
+          // story: the notion of whether the algo ran in Tensor Core mode is not known.
+          // Since we want to report the Tensor Core mode in the verbose output, we switch
+          // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
+          // that of *Find*(), we can unify the find-vs-get logic by using function pointers.
+
+          // Forward Algorithm Find/Get() v7
+          std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
+          int actual_fwd_algos = 0;
+          auto fwd_algo_discoverer =
+            param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
+                                                    : cudnnFindConvolutionForwardAlgorithm;
+          CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+                                            in_desc_,
+                                            filter_desc_,
+                                            forward_conv_desc_,
+                                            out_desc_,
+                                            fwd_results.size(),
+                                            &actual_fwd_algos,
+                                            fwd_results.data()));
+          fwd_results.resize(actual_fwd_algos);
+          AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+                          cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+                                                     workspace_byte, &forward_algo_);
+
+          // Backprop-to-Filter Algorithm Find/Get() v7
+          auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+          std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
+          int actual_bwd_filter_algos = 0;
+          auto bwd_filter_algo_discoverer =
+            param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
+                                                    : cudnnFindConvolutionBackwardFilterAlgorithm;
+          CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                                                   in_desc_,
+                                                   out_desc_,
+                                                   back_conv_desc_w_,
+                                                   filter_desc_,
+                                                   bwd_filt_results.size(),
+                                                   &actual_bwd_filter_algos,
+                                                   bwd_filt_results.data()));
+          bwd_filt_results.resize(actual_bwd_filter_algos);
+          AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+                          cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
+                                       workspace_byte, &back_algo_w_);
+
+          // Backprop-to-Data Algorithm Find/Get() v7
+          auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+          std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
+          int actual_bwd_data_algos = 0;
+          auto bwd_data_algo_discoverer =
+            param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
+                                                    : cudnnFindConvolutionBackwardDataAlgorithm;
+          CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+                                                 filter_desc_,
+                                                 out_desc_,
+                                                 back_conv_desc_,
+                                                 in_desc_,
+                                                 bwd_data_results.size(),
+                                                 &actual_bwd_data_algos,
+                                                 bwd_data_results.data()));
+          bwd_data_results.resize(actual_bwd_data_algos);
+          AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+                          cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
+                                        workspace_byte, &back_algo_);
+        #else
+          // CUDNN_MAJOR < 7
+          const int kMaxAlgos = 10;
+          int nalgo = kMaxAlgos;
+          int i = 0;
+          // Forward Algorithm Find/Get, v6 and earlier
+          if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
+            // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
+            // supported.  Hard-coded this since the algo find() or get() throws an FPE.
+            forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+          } else if (!param_.cudnn_tune.value()) {
+            cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+            CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                     in_desc_,
+                                                     filter_desc_,
+                                                     forward_conv_desc_,
+                                                     out_desc_,
+                                                     CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+                                                     workspace_byte,
+                                                     &fastest_fwd_algo));
+            forward_algo_.Set(fastest_fwd_algo, false);
           } else {
-            this->algo_ = fwd_algo[i].algo;
+            cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
+            CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                            in_desc_,
+                                                            filter_desc_,
+                                                            forward_conv_desc_,
+                                                            out_desc_,
+                                                            kMaxAlgos,
+                                                            &nalgo,
+                                                            fwd_algo));
+            i = 0;
+            while (i < nalgo
+                   && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+                       || (param_.cudnn_tune.value() == conv::kLimited
+                           && fwd_algo[i].memory > workspace_byte)))
+              ++i;
+            if (i == nalgo) {
+              LOG(FATAL) << "Failed to find a forward convolution algorithm.";
+            } else {
+              forward_algo_.Set(fwd_algo[i].algo, false);
+            }
           }
-        }
-
-        cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                 in_desc_,
-                 out_desc_,
-                 backward_conv_desc_,
-                 filter_desc_,
-                 kMaxAlgos,
-                 &nalgo,
-                 bwd_filter_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
-               || (param_.cudnn_tune.value() == conv::kLimited
-               && bwd_filter_algo[i].memory > workspace_byte))) ++i;
-        if (i == nalgo) {
-          LOG(FATAL) << "Failed to find a backward filter convolution algorithm.";
-        } else {
-          this->back_algo_w_ = bwd_filter_algo[i].algo;
-        }
-
-        cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                 filter_desc_,
-                 out_desc_,
-                 backward_conv_desc_,
-                 in_desc_,
-                 kMaxAlgos,
-                 &nalgo,
-                 bwd_data_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
-               || (param_.cudnn_tune.value() == conv::kLimited
-               && bwd_data_algo[i].memory > workspace_byte))) ++i;
-        if (i == nalgo) {
-          LOG(FATAL) << "Failed to find a backward data convolution algorithm.";
-        } else {
-          this->back_algo_ = bwd_data_algo[i].algo;
-        }
-        CuDNNAlgoReg::Get()->Register(key, this->algo_, this->back_algo_,
+          // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
+          if (!param_.cudnn_tune.value()) {
+            cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
+            CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                              in_desc_,
+                                              out_desc_,
+                                              back_conv_desc_w_,
+                                              filter_desc_,
+                                              CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+                                              workspace_byte,
+                                              &fastest_bwd_filt_algo));
+            back_algo_w_.Set(fastest_bwd_filt_algo, false);
+          } else {
+            cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
+            CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                                                   in_desc_,
+                                                                   out_desc_,
+                                                                   back_conv_desc_w_,
+                                                                   filter_desc_,
+                                                                   kMaxAlgos,
+                                                                   &nalgo,
+                                                                   bwd_filter_algo));
+            i = 0;
+            while (i < nalgo
+                   && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
+                       || (param_.cudnn_tune.value() == conv::kLimited
+                           && bwd_filter_algo[i].memory > workspace_byte)))
+              ++i;
+            if (i == nalgo) {
+              LOG(FATAL) << "Failed to find a backward filter convolution algorithm.";
+            } else {
+              back_algo_w_.Set(bwd_filter_algo[i].algo, false);
+            }
+          }
+          // Backprop-to-Data Algorithm Get(), v6 and earlier
+          if (!param_.cudnn_tune.value()) {
+            cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
+            CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                                filter_desc_,
+                                                out_desc_,
+                                                back_conv_desc_,
+                                                in_desc_,
+                                                CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+                                                workspace_byte,
+                                                &fastest_bwd_data_algo));
+            back_algo_.Set(fastest_bwd_data_algo, false);
+          } else {
+            cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
+            CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                                                 filter_desc_,
+                                                                 out_desc_,
+                                                                 back_conv_desc_,
+                                                                 in_desc_,
+                                                                 kMaxAlgos,
+                                                                 &nalgo,
+                                                                 bwd_data_algo));
+            i = 0;
+            while (i < nalgo
+                   && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
+                       || (param_.cudnn_tune.value() == conv::kLimited
+                           && bwd_data_algo[i].memory > workspace_byte)))
+              ++i;
+            if (i == nalgo) {
+              LOG(FATAL) << "Failed to find a backward data convolution algorithm.";
+            } else {
+              back_algo_.Set(bwd_data_algo[i].algo, false);
+            }
+          }
+        #endif  // CUDNN_MAJOR < 7
+        // An algo specification by the user may be cached here, but another
+        // convolution will match only if identically specified.
+        // We're caching results of *Get* as well as *Find*, but these records
+        // will be held distinctly because param_.cudnn_tune is part of the key.
+        CuDNNAlgoReg::Get()->Register(key, this->forward_algo_, this->back_algo_,
                                       this->back_algo_w_);
+      }, ctx, {}, {var});
+      Engine::Get()->WaitForVar(var);
+      Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
+    }
+    // If we're allowing Tensor Core variants of the algos to be considered in
+    // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
+    // we must change the descriptor to preclude Tensor Core.  Simplest is to
+    // once again set the mathType in all cases.
+    #if CUDNN_MAJOR >= 7
+      CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, forward_algo_.MathType()));
+      CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, back_algo_.MathType()));
+      CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType()));
+    #endif
+  }
+
+  // Look over the results from *Find*() or *Get*() and pick the fastest algo given possible
+  // workspace constraints.
+  template <typename PerfType, typename AlgoType>
+  void AlgoFinalSelect(const std::vector<PerfType> &perf_results, std::string kernel_name,
+                       size_t workspace_byte, CuDNNAlgo<AlgoType> *algo) {
+    // Determine the fastest acceptable algo that matches the algo_preference (-1 = any),
+    // regardless of mathType.
+    for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) {
+      const auto &result = perf_results[i];
+      bool algo_is_tensor_core = false;
+      #if CUDNN_MAJOR >= 7
+        algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH;
+      #endif
+      if (result.status == CUDNN_STATUS_SUCCESS &&
+          (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) {
+        algo->Set(result.algo, algo_is_tensor_core);
+        return;
       }
-    }, ctx, {}, {var});
-    Engine::Get()->WaitForVar(var);
-    Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
+    }
+    auto mode = param_.cudnn_tune.value() == conv::kOff ? " get " : " find ";
+    LOG(FATAL) << "Failed to" << mode << "any " << kernel_name << " convolution algorithm.";
   }
 
   void GetTempSize(const OpContext& ctx) {
@@ -663,16 +819,16 @@ class CuDNNConvolutionOp : public Operator {
     CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_,
                filter_desc_,
                out_desc_,
-               backward_conv_desc_,
+               back_conv_desc_,
                in_desc_,
-               back_algo_,
+               back_algo_.AlgoNumber(),
                &back_size));
     CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_,
                in_desc_,
                out_desc_,
-               backward_conv_desc_,
+               back_conv_desc_w_,
                filter_desc_,
-               back_algo_w_,
+               back_algo_w_.AlgoNumber(),
                &back_size_w));
     backward_workspace_byte_ = std::max(back_size, back_size_w);
     CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_,
@@ -680,7 +836,7 @@ class CuDNNConvolutionOp : public Operator {
                filter_desc_,
                forward_conv_desc_,
                out_desc_,
-               algo_,
+               forward_algo_.AlgoNumber(),
                &forward_workspace_byte_));
 
     init_temp_size_ = true;
@@ -733,15 +889,19 @@ class CuDNNConvolutionOp : public Operator {
   cudnnFilterDescriptor_t filter_desc_;
   // Convolution descriptor for forward inference operation
   cudnnConvolutionDescriptor_t forward_conv_desc_;
-  // Convolution descriptor for back-prop operations to data and filter
-  cudnnConvolutionDescriptor_t backward_conv_desc_;
+  // Convolution descriptor for back-prop operations to the data
+  cudnnConvolutionDescriptor_t back_conv_desc_;
+  // Convolution descriptor for back-prop operations to the weights
+  cudnnConvolutionDescriptor_t back_conv_desc_w_;
   // Algorithm for the forward inference operation
-  cudnnConvolutionFwdAlgo_t algo_;
+  CuDNNAlgo<cudnnConvolutionFwdAlgo_t> forward_algo_;
   // Algorithm for the back-prop operation to the data
-  cudnnConvolutionBwdDataAlgo_t back_algo_;
+  CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> back_algo_;
   // Algorithm for the back-prop operation to the weights
-  cudnnConvolutionBwdFilterAlgo_t back_algo_w_;
+  CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> back_algo_w_;
   cudnnTensorFormat_t format_;
+  // Allow TensorCore algo policy
+  bool cudnn_tensor_core_;
   ConvolutionParam param_;
 };
 #endif  // __CUDACC__ && CUDNN
diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h
index 8c8f055..de3e70c 100644
--- a/src/operator/cudnn_deconvolution-inl.h
+++ b/src/operator/cudnn_deconvolution-inl.h
@@ -56,6 +56,8 @@ class CuDNNDeconvolutionOp : public Operator {
     init_cudnn_ = false;
     init_temp_size_ = false;
     dtype_ = mshadow::DataType<DType>::kCudnnFlag;
+    // TensorCore algos only allowed on fp16-I/O deconvolutions if permitted by the global policy.
+    cudnn_tensor_core_ = DataType<DType>::kFlag == kFloat16 && GetEnvAllowTensorCore();
 
 #if CUDNN_MAJOR >= 5
     MSHADOW_LAYOUT_SWITCH(param_.layout.value(), Layout, {
@@ -66,7 +68,7 @@ class CuDNNDeconvolutionOp : public Operator {
       << "Need CuDNN > 5.0 for layout support";
 #endif
     // Double check to make sure this class supports the operation
-    if (!Supports(param, forward_compute_type, backward_compute_type))
+    if (!Supports(param, forward_compute_type, backward_compute_type, ctx))
       LOG(FATAL) << "Need CuDNN >= 6.0 for dilated convolution.";
 
     InitDescriptors(ctx, in_shape, out_shape,
@@ -92,7 +94,8 @@ class CuDNNDeconvolutionOp : public Operator {
       CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc_));
       CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc_));
       CUDNN_CALL(cudnnDestroyConvolutionDescriptor(forward_conv_desc_));
-      CUDNN_CALL(cudnnDestroyConvolutionDescriptor(backward_conv_desc_));
+      CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_));
+      CUDNN_CALL(cudnnDestroyConvolutionDescriptor(back_conv_desc_w_));
     }
   }
 
@@ -146,7 +149,7 @@ class CuDNNDeconvolutionOp : public Operator {
                  in_desc_,
                  data_ptr + data_offset_ * g,
                  forward_conv_desc_,  // this backward algorithm used for inference
-                 back_algo_,
+                 back_algo_.AlgoNumber(),
                  workspace.dptr_,
                  workspace_size,
                  &beta,
@@ -160,7 +163,7 @@ class CuDNNDeconvolutionOp : public Operator {
                  in_desc_,
                  data_ptr + data_offset_ * g,
                  forward_conv_desc_,  // this backward algorithm used for inference
-                 back_algo_,
+                 back_algo_.AlgoNumber(),
                  workspace.dptr_,
                  workspace_size,
                  &beta,
@@ -270,8 +273,8 @@ class CuDNNDeconvolutionOp : public Operator {
           grad_ptr + out_offset_ * g,
           in_desc_,
           data_ptr + data_offset_ * g,
-          backward_conv_desc_,
-          back_algo_w_,
+          back_conv_desc_,
+          back_algo_w_.AlgoNumber(),
           workspace.dptr_,
           workspace_size,
           &weight_beta,
@@ -285,8 +288,8 @@ class CuDNNDeconvolutionOp : public Operator {
           grad_ptr + out_offset_ * g,
           in_desc_,
           data_ptr + data_offset_ * g,
-          backward_conv_desc_,
-          back_algo_w_,
+          back_conv_desc_,
+          back_algo_w_.AlgoNumber(),
           workspace.dptr_,
           workspace_size,
           &weight_beta,
@@ -301,8 +304,8 @@ class CuDNNDeconvolutionOp : public Operator {
                                            grad_ptr + out_offset_ * g,
                                            filter_desc_,
                                            wmat_ptr + weight_offset_ * g,
-                                           backward_conv_desc_,
-                                           algo_,
+                                           back_conv_desc_,
+                                           forward_algo_.AlgoNumber(),
                                            workspace.dptr_,
                                            workspace_size,
                                            &data_beta,
@@ -319,7 +322,8 @@ class CuDNNDeconvolutionOp : public Operator {
  */
   static bool Supports(DeconvolutionParam param,
                        int forward_compute_type,
-                       int backward_compute_type) {
+                       int backward_compute_type,
+                       const Context &ctx) {
     using namespace mshadow;
 
     // NDHWC not supported, NHWC not supported in true fp16
@@ -329,6 +333,12 @@ class CuDNNDeconvolutionOp : public Operator {
     if (layout_val == kNDHWC || layout_val == kNHWC && true_fp16)
       return false;
 
+    // Permits graceful fallback to pseudo-fp16 on heterogenous systems
+    if (!SupportsFloat16Compute(ctx.dev_id) &&
+        (forward_compute_type == kFloat16 || backward_compute_type == kFloat16)) {
+      return false;
+    }
+
     // The factor by which the effective filter size grows based on dilation.
     auto filterDilationFactor = param.dilate.Size();
 
@@ -374,7 +384,8 @@ class CuDNNDeconvolutionOp : public Operator {
     CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc_));
     CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc_));
     CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_));
-    CUDNN_CALL(cudnnCreateConvolutionDescriptor(&backward_conv_desc_));
+    CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_));
+    CUDNN_CALL(cudnnCreateConvolutionDescriptor(&back_conv_desc_w_));
 
     TShape dshape = in_shape[deconv::kData];
     TShape wshape = in_shape[deconv::kWeight];
@@ -398,7 +409,16 @@ class CuDNNDeconvolutionOp : public Operator {
                                                  param_.dilate[1],
                                                  CUDNN_CROSS_CORRELATION,
                                                  cudnn_forward_compute_type));
-      CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_,
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_,
+                                                 o_pad[0],
+                                                 o_pad[1],
+                                                 param_.stride[0],
+                                                 param_.stride[1],
+                                                 param_.dilate[0],
+                                                 param_.dilate[1],
+                                                 CUDNN_CROSS_CORRELATION,
+                                                 cudnn_backward_compute_type));
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_,
                                                  o_pad[0],
                                                  o_pad[1],
                                                  param_.stride[0],
@@ -416,7 +436,15 @@ class CuDNNDeconvolutionOp : public Operator {
                                                  param_.dilate[0],
                                                  param_.dilate[1],
                                                  CUDNN_CROSS_CORRELATION));
-      CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_,
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_,
+                                                 o_pad[0],
+                                                 o_pad[1],
+                                                 param_.stride[0],
+                                                 param_.stride[1],
+                                                 param_.dilate[0],
+                                                 param_.dilate[1],
+                                                 CUDNN_CROSS_CORRELATION));
+      CUDNN_CALL(cudnnSetConvolution2dDescriptor(back_conv_desc_w_,
                                                  o_pad[0],
                                                  o_pad[1],
                                                  param_.stride[0],
@@ -483,7 +511,15 @@ class CuDNNDeconvolutionOp : public Operator {
                                                  CUDNN_CROSS_CORRELATION,
                                                  cudnn_forward_compute_type));
 
-      CUDNN_CALL(cudnnSetConvolutionNdDescriptor(backward_conv_desc_,
+      CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_,
+                                                 3,
+                                                 reinterpret_cast<int*>(&o_pad[0]),
+                                                 param_stride_.data(),
+                                                 param_dilate_.data(),
+                                                 CUDNN_CROSS_CORRELATION,
+                                                 cudnn_backward_compute_type));
+
+      CUDNN_CALL(cudnnSetConvolutionNdDescriptor(back_conv_desc_w_,
                                                  3,
                                                  reinterpret_cast<int*>(&o_pad[0]),
                                                  param_stride_.data(),
@@ -507,6 +543,14 @@ class CuDNNDeconvolutionOp : public Operator {
                               param_.layout.value(), kNCDHW);
       oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW);
     }
+    // Set "allow tensor core" flag in convolution descriptors, if available.
+#if CUDNN_MAJOR >= 7
+    cudnnMathType_t math_type = cudnn_tensor_core_ ? CUDNN_TENSOR_OP_MATH
+                                                  : CUDNN_DEFAULT_MATH;
+    CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type));
+    CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type));
+    CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type));
+#endif
     dshape[1] /= param_.num_group;
     oshape[1] /= param_.num_group;
     weight_offset_ = wshape.Size();
@@ -556,125 +600,242 @@ class CuDNNDeconvolutionOp : public Operator {
                   cudnnDataType_t cudnn_backward_compute_type) {
     std::string key = CuDNNAlgoReg::Get()->GetKey(param_, in_shape, out_shape, dtype_,
                                                   cudnn_forward_compute_type,
-                                                  cudnn_backward_compute_type);
-    if (CuDNNAlgoReg::Get()->Find(key, &algo_, &back_algo_, &back_algo_w_))
-      return;
-
-    Engine::VarHandle var = Engine::Get()->NewVariable();
-    Engine::Get()->PushSync([=](RunContext rctx) {
-      mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
-      CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
-      size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
-      if (!param_.cudnn_tune.value()) {
-        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
-        if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-          algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-        } else {
-          CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
-                     out_desc_,
-                     filter_desc_,
-                     backward_conv_desc_,  // forward algorithm used to backprop-to-data
-                     in_desc_,
-                     CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
-                     workspace_byte,
-                     &(this->algo_)));
-        }
-        CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                   out_desc_,
-                   in_desc_,
-                   backward_conv_desc_,
-                   filter_desc_,
-                   CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
-                   workspace_byte,
-                   &(this->back_algo_w_)));
-        CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                   filter_desc_,
-                   in_desc_,
-                   forward_conv_desc_,  // this backward algorithm used for inference
-                   out_desc_,
-                   CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
-                   workspace_byte,
-                   &(this->back_algo_)));
-      } else {
+                                                  cudnn_backward_compute_type,
+                                                  SMArch(ctx.dev_id));
+    if (!CuDNNAlgoReg::Get()->Find(key, &forward_algo_, &back_algo_, &back_algo_w_)) {
+      // Not in algo registry, must determine via *Get*() or *Find*()
+      Engine::VarHandle var = Engine::Get()->NewVariable();
+      Engine::Get()->PushSync([=](RunContext rctx) {
+        mshadow::Stream <gpu> *s = rctx.get_stream<gpu>();
+        CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
+        size_t workspace_byte = static_cast<size_t>(param_.workspace * sizeof(DType));
+        #if CUDNN_MAJOR >= 7
+          // Starting with cuDNNv7, the algo number returned by *Get*() is not the entire
+          // story: the notion of whether the algo ran in Tensor Core mode is not known.
+          // Since we want to report the Tensor Core mode in the verbose output, we switch
+          // to using the new *Get*_v7() call.  Since the function signature of *Get*_v7() matches
+          // that of *Find*(), we can unify the find-vs-get logic by using function pointers.
+
+          // Forward Algorithm Find/Get() v7
+          std::vector<cudnnConvolutionFwdAlgoPerf_t> fwd_results(MaxForwardAlgos(s->dnn_handle_));
+          int actual_fwd_algos = 0;
+          auto fwd_algo_discoverer =
+            param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7
+                                                    : cudnnFindConvolutionForwardAlgorithm;
+          CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_,
+                                            out_desc_,
+                                            filter_desc_,
+                                            back_conv_desc_,  // fwd algo used to backprop-to-data
+                                            in_desc_,
+                                            fwd_results.size(),
+                                            &actual_fwd_algos,
+                                            fwd_results.data()));
+          fwd_results.resize(actual_fwd_algos);
+          AlgoFinalSelect<cudnnConvolutionFwdAlgoPerf_t,
+                          cudnnConvolutionFwdAlgo_t>(fwd_results, "forward",
+                                                     workspace_byte, &forward_algo_);
+
+          // Backprop-to-Filter Algorithm Find/Get() v7
+          auto max_bwd_filt_algos = MaxBackwardFilterAlgos(s->dnn_handle_);
+          std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(max_bwd_filt_algos);
+          int actual_bwd_filter_algos = 0;
+          auto bwd_filter_algo_discoverer =
+            param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7
+                                                    : cudnnFindConvolutionBackwardFilterAlgorithm;
+          CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_,
+                                                   out_desc_,
+                                                   in_desc_,
+                                                   back_conv_desc_,
+                                                   filter_desc_,
+                                                   bwd_filt_results.size(),
+                                                   &actual_bwd_filter_algos,
+                                                   bwd_filt_results.data()));
+          bwd_filt_results.resize(actual_bwd_filter_algos);
+          AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
+                          cudnnConvolutionBwdFilterAlgo_t>(bwd_filt_results, "backprop-to-filter",
+                                                           workspace_byte, &back_algo_w_);
+
+          // Backprop-to-Data Algorithm Find/Get() v7
+          auto max_bwd_data_algos = MaxBackwardDataAlgos(s->dnn_handle_);
+          std::vector<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_results(max_bwd_data_algos);
+          int actual_bwd_data_algos = 0;
+          auto bwd_data_algo_discoverer =
+            param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7
+                                                    : cudnnFindConvolutionBackwardDataAlgorithm;
+          CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_,
+                                                 filter_desc_,
+                                                 in_desc_,
+                                                 forward_conv_desc_,  // bwd algo used in inference
+                                                 out_desc_,
+                                                 bwd_data_results.size(),
+                                                 &actual_bwd_data_algos,
+                                                 bwd_data_results.data()));
+          bwd_data_results.resize(actual_bwd_data_algos);
+          AlgoFinalSelect<cudnnConvolutionBwdDataAlgoPerf_t,
+                          cudnnConvolutionBwdDataAlgo_t>(bwd_data_results, "backprop-to-data",
+                                                         workspace_byte, &back_algo_);
+        #else
+        // CUDNN_MAJOR < 7
         const int kMaxAlgos = 10;
         int nalgo = kMaxAlgos;
-        int i;
-
-        // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
-        // supported.  Hard-coded this since the algo find() or get() throws an FPE.
+        int i = 0;
+        // Forward Algorithm Find/Get, v6 and earlier
         if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) {
-          algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+          // In cuDNNv6, for kNHWC, only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM is
+          // supported.  Hard-coded this since the algo find() or get() throws an FPE.
+          forward_algo_.Set(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, false);
+        } else if (!param_.cudnn_tune.value()) {
+          cudnnConvolutionFwdAlgo_t fastest_fwd_algo;
+          CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_,
+                                                     out_desc_,
+                                                     filter_desc_,
+                                                     back_conv_desc_,  // fwd algo used in dgrad
+                                                     in_desc_,
+                                                     CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+                                                     workspace_byte,
+                                                     &fastest_fwd_algo));
+          forward_algo_.Set(fastest_fwd_algo, false);
         } else {
           cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos];
           CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_,
-                     out_desc_,
-                     filter_desc_,
-                     backward_conv_desc_,  // forward algorithm used to backprop-to-data
-                     in_desc_,
-                     kMaxAlgos,
-                     &nalgo,
-                     fwd_algo));
+                                                        out_desc_,
+                                                        filter_desc_,
+                                                        back_conv_desc_,  // fwd algo used in dgrad
+                                                        in_desc_,
+                                                        kMaxAlgos,
+                                                        &nalgo,
+                                                        fwd_algo));
           i = 0;
           while (i < nalgo
-               && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
-               || (param_.cudnn_tune.value() == deconv::kLimited
-               && fwd_algo[i].memory > workspace_byte))) ++i;
+                 && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS
+                     || (param_.cudnn_tune.value() == deconv::kLimited
+                         && fwd_algo[i].memory > workspace_byte)))
+            ++i;
           if (i == nalgo) {
             LOG(FATAL) << "Failed to find a 'forward' convolution algorithm " <<
-              "(for use in deconvolution operator backprop-to-data).";
+                       "(for use in deconvolution operator backprop-to-data).";
           } else {
-            this->algo_ = fwd_algo[i].algo;
+            forward_algo_.Set(fwd_algo[i].algo, false);
           }
         }
-
-        cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
-                   out_desc_,
-                   in_desc_,
-                   backward_conv_desc_,
-                   filter_desc_,
-                   kMaxAlgos,
-                   &nalgo,
-                   bwd_filter_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
-               || (param_.cudnn_tune.value() == deconv::kLimited
-               && bwd_filter_algo[i].memory > workspace_byte))) ++i;
-        if (i == nalgo) {
-          LOG(FATAL) << "Failed to find a backward filter convolution algorithm " <<
-              "(for use in deconvolution operator backprop-to-filter).";
+        // Backprop-to-Filter Algorithm Find/Get, v6 and earlier
+        if (!param_.cudnn_tune.value()) {
+          cudnnConvolutionBwdFilterAlgo_t fastest_bwd_filt_algo;
+          CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                              out_desc_,
+                                              in_desc_,
+                                              back_conv_desc_,
+                                              filter_desc_,
+                                              CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+                                              workspace_byte,
+                                              &fastest_bwd_filt_algo));
+          back_algo_w_.Set(fastest_bwd_filt_algo, false);
         } else {
-          this->back_algo_w_ = bwd_filter_algo[i].algo;
+          cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos];
+          CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_,
+                                                                 out_desc_,
+                                                                 in_desc_,
+                                                                 back_conv_desc_,
+                                                                 filter_desc_,
+                                                                 kMaxAlgos,
+                                                                 &nalgo,
+                                                                 bwd_filter_algo));
+          i = 0;
+          while (i < nalgo
+                 && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS
+                     || (param_.cudnn_tune.value() == deconv::kLimited
+                         && bwd_filter_algo[i].memory > workspace_byte)))
+            ++i;
+          if (i == nalgo) {
+            LOG(FATAL) << "Failed to find a backward filter convolution algorithm " <<
+                       "(for use in deconvolution operator backprop-to-filter).";
+          } else {
+            back_algo_w_.Set(bwd_filter_algo[i].algo, false);
+          }
         }
-
-        cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
-        CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
-                   filter_desc_,
-                   in_desc_,
-                   forward_conv_desc_,  // this backward algorithm used for inference
-                   out_desc_,
-                   kMaxAlgos,
-                   &nalgo,
-                   bwd_data_algo));
-        i = 0;
-        while (i < nalgo
-               && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
-               || (param_.cudnn_tune.value() == deconv::kLimited
-               && bwd_data_algo[i].memory > workspace_byte))) ++i;
-        if (i == nalgo) {
-          LOG(FATAL) << "Failed to find a backward data convolution algorithm." <<
-              "(for use in deconvolution operator forward inference).";
+        // Backprop-to-Data Algorithm Get(), v6 and earlier
+        if (!param_.cudnn_tune.value()) {
+          cudnnConvolutionBwdDataAlgo_t fastest_bwd_data_algo;
+          CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                                filter_desc_,
+                                                in_desc_,
+                                                forward_conv_desc_,  // bwd algo used for inference
+                                                out_desc_,
+                                                CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+                                                workspace_byte,
+                                                &fastest_bwd_data_algo));
+          back_algo_.Set(fastest_bwd_data_algo, false);
         } else {
-          this->back_algo_ = bwd_data_algo[i].algo;
+          cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos];
+          CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_,
+                                                 filter_desc_,
+                                                 in_desc_,
+                                                 forward_conv_desc_,  // bwd algo used in inference
+                                                 out_desc_,
+                                                 kMaxAlgos,
+                                                 &nalgo,
+                                                 bwd_data_algo));
+          i = 0;
+          while (i < nalgo
+                 && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS
+                     || (param_.cudnn_tune.value() == deconv::kLimited
+                         && bwd_data_algo[i].memory > workspace_byte)))
+            ++i;
+          if (i == nalgo) {
+            LOG(FATAL) << "Failed to find a backward data convolution algorithm." <<
+                       "(for use in deconvolution operator forward inference).";
+          } else {
+            back_algo_.Set(bwd_data_algo[i].algo, false);
+          }
         }
-        CuDNNAlgoReg::Get()->Register(key, this->algo_, this->back_algo_,
+        #endif  // CUDNN_MAJOR < 7
+        // An algo specification by the user may be cached here, but another
+        // convolution will match only if identically specified.
+        // We're caching results of *Get* as well as *Find*, but these records
+        // will be held distinctly because param_.cudnn_tune is part of the key.
+        CuDNNAlgoReg::Get()->Register(key, this->forward_algo_, this->back_algo_,
                                       this->back_algo_w_);
+      }, ctx, {}, {var});
+      Engine::Get()->WaitForVar(var);
+      Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
+    }
+    // If we're allowing Tensor Core variants of the algos to be considered in
+    // *Find*() or *Get*(), but a non-Tensor-Core algo variant is the fastest,
+    // we must change the descriptor to preclude Tensor Core.  Simplest is to
+    // once again set the mathType in all cases.
+    #if CUDNN_MAJOR >= 7
+      // The next two code lines will look like they have typos, but they don't!
+      // The forward_conv_desc_ is used during inference, which invokes the back_algo_.
+      // Thus, the mathType of the back_algo_ should be stored in the forward_conv_desc_.
+      // Conversely, the back_conv_desc_ is used during training backprop, which invokes
+      // the forward_algo_.  Thus, the mathType of the forward_algo_ should be stored
+      // in the back_conv_desc_.
+      CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, back_algo_.MathType()));
+      CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, forward_algo_.MathType()));
+      CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, back_algo_w_.MathType()));
+    #endif
+  }
+
+  // Look over the results from *Find*() or *Get*() and pick the fastest algo given possible
+  // workspace constraints and a possible user algo preference.
+  template <typename PerfType, typename AlgoType>
+  void AlgoFinalSelect(const std::vector<PerfType> &perf_results, std::string kernel_name,
+                       size_t workspace_byte, CuDNNAlgo<AlgoType> *algo) {
+    // Determine the fastest acceptable algo regardless of mathType.
+    for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) {
+      const auto &result = perf_results[i];
+      bool algo_is_tensor_core = false;
+      #if CUDNN_MAJOR >= 7
+        algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH;
+      #endif
+      if (result.status == CUDNN_STATUS_SUCCESS &&
+          (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) {
+        algo->Set(result.algo, algo_is_tensor_core);
+        return;
       }
-    }, ctx, {}, {var});
-    Engine::Get()->WaitForVar(var);
-    Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var);
+    }
+    auto mode = param_.cudnn_tune.value() == conv::kOff ? " get " : " find ";
+    LOG(FATAL) << "Failed to" << mode << "any " << kernel_name << " deconvolution algorithm.";
   }
 
   void GetTempSize(const OpContext& ctx) {
@@ -688,21 +849,21 @@ class CuDNNDeconvolutionOp : public Operator {
                in_desc_,
                forward_conv_desc_,
                out_desc_,
-               back_algo_,
+               back_algo_.AlgoNumber(),
                &back_data_algo_workspace_size));
     CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_,
                out_desc_,
                in_desc_,
-               backward_conv_desc_,
+               back_conv_desc_,
                filter_desc_,
-               back_algo_w_,
+               back_algo_w_.AlgoNumber(),
                &back_filter_algo_workspace_size));
     CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_,
                out_desc_,
                filter_desc_,
-               backward_conv_desc_,
+               back_conv_desc_,
                in_desc_,
-               algo_,
+               forward_algo_.AlgoNumber(),
                &forward_algo_workspace_size));
 
     forward_workspace_byte_ = back_data_algo_workspace_size;
@@ -761,19 +922,24 @@ class CuDNNDeconvolutionOp : public Operator {
   // Note that in deconvolution, the forward operation is handled
   // by the cuDNN backprop-to-data kernel.
   cudnnConvolutionDescriptor_t forward_conv_desc_;
-  // Convolution descriptor for "back-prop" operations to data and filter.
+  // Convolution descriptor for "back-prop" operations to data .
+  // Note that in deconvolution, the backprop-to-data operation is handled
+  // by the cuDNN forward kernel.
+  cudnnConvolutionDescriptor_t back_conv_desc_;
+  // Convolution descriptor for "back-prop" operations to filter.
   // Note that in deconvolution, the backprop-to-data operation is handled
-  // by the cuDNN forward kernel, while the backprop-to-filter operation
-  // stays consistent with the convolution operator and is handled by
-  // the backprop-to-filter kernel.
-  cudnnConvolutionDescriptor_t backward_conv_desc_;
+  // by the backprop-to-filter kernel (so consistent with the treatment
+  // in convolution).
+  cudnnConvolutionDescriptor_t back_conv_desc_w_;
   // Algorithm for the cuDNN forward kernel (used in gradient backprop to input)
-  cudnnConvolutionFwdAlgo_t algo_;
+  CuDNNAlgo<cudnnConvolutionFwdAlgo_t> forward_algo_;
   // Algorithm for the cuDNN backprop-to-data kernel (used in inference)
-  cudnnConvolutionBwdDataAlgo_t back_algo_;
+  CuDNNAlgo<cudnnConvolutionBwdDataAlgo_t> back_algo_;
   // Algorithm for the cuDNN backprop-to-filter kernel
-  cudnnConvolutionBwdFilterAlgo_t back_algo_w_;
+  CuDNNAlgo<cudnnConvolutionBwdFilterAlgo_t> back_algo_w_;
   cudnnTensorFormat_t format_;
+  // Allow TensorCore algo policy
+  bool cudnn_tensor_core_;
   DeconvolutionParam param_;
 };
 #endif  // CUDNN
diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
index 1122aff..a260cb4 100644
--- a/src/operator/cudnn_rnn-inl.h
+++ b/src/operator/cudnn_rnn-inl.h
@@ -43,6 +43,12 @@ class CuDNNRNNOp : public Operator {
     this->param_ = param;
     init_cudnn_ = false;
     dtype_ = mshadow::DataType<DType>::kCudnnFlag;
+    // TensorCore algos only allowed on fp16-I/O convolutions if permitted by the global policy.
+    // No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
+    cudnn_tensor_core_ = false;
+    // When fp16 RNN tests are introduced, we can enable TensorCore as follows:
+//    cudnn_tensor_core =
+//        mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore();
     // Defaults
     input_mode_ = CUDNN_LINEAR_INPUT;  // Don't support this yet
     // RNN Mode
@@ -450,14 +456,36 @@ class CuDNNRNNOp : public Operator {
                                            seed_));
       // RNN descriptors
       CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
-      CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
-                                       param_.state_size,
-                                       param_.num_layers,
-                                       dropout_desc_,
-                                       input_mode_,
-                                       direction_,
-                                       mode_,
-                                       dtype_));
+
+      #if CUDNN_MAJOR >= 6
+        cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
+        CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
+                                            rnn_desc_,
+                                            param_.state_size,
+                                            param_.num_layers,
+                                            dropout_desc_,
+                                            input_mode_,
+                                            direction_,
+                                            mode_,
+                                            rnn_algo,
+                                            dtype_));
+      #else
+        CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
+                                         param_.state_size,
+                                         param_.num_layers,
+                                         dropout_desc_,
+                                         input_mode_,
+                                         direction_,
+                                         mode_,
+                                         dtype_));
+      #endif
+      #if CUDNN_MAJOR >= 7
+        cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
+        if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) {
+          math_type = CUDNN_TENSOR_OP_MATH;
+        }
+        CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
+      #endif
       // Get temp space sizes
       CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
                                           rnn_desc_,
@@ -554,6 +582,8 @@ class CuDNNRNNOp : public Operator {
   cudnnTensorDescriptor_t dhy_desc_, dcy_desc_;
 
   cudnnFilterDescriptor_t w_desc_, dw_desc_;
+  // Allow TensorCore algo policy
+  bool cudnn_tensor_core_;
 
   #if CUDNN_MAJOR >= 5
   cudnnTensorFormat_t format_;
diff --git a/src/operator/deconvolution.cu b/src/operator/deconvolution.cu
index b9dd1c1..e9b5cb8 100644
--- a/src/operator/deconvolution.cu
+++ b/src/operator/deconvolution.cu
@@ -70,14 +70,14 @@ Operator* CreateOp<gpu>(DeconvolutionParam param, int dtype,
       int backward_compute_type = desired_backward_compute_type;
       bool deconvolutionIsSupported = CuDNNDeconvolutionOp<DType>::Supports(param,
                                           forward_compute_type,
-                                          backward_compute_type);
+                                          backward_compute_type, ctx);
 
       // If cuDNN can't handle this case with fp16 backprop kernels, try fp32 backprop.
       if (!deconvolutionIsSupported && backward_compute_type == mshadow::kFloat16) {
         backward_compute_type = mshadow::kFloat32;
         deconvolutionIsSupported = CuDNNDeconvolutionOp<DType>::Supports(param,
                                           forward_compute_type,
-                                          backward_compute_type);
+                                          backward_compute_type, ctx);
       }
 
       // If cuDNN can't handle this case with fp16 forward kernels, try fp32
@@ -85,7 +85,7 @@ Operator* CreateOp<gpu>(DeconvolutionParam param, int dtype,
         forward_compute_type = mshadow::kFloat32;
         deconvolutionIsSupported = CuDNNDeconvolutionOp<DType>::Supports(param,
                                           forward_compute_type,
-                                          backward_compute_type);
+                                          backward_compute_type, ctx);
       }
       if (!deconvolutionIsSupported) {
         LOG(WARNING) <<

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].