You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/12 17:19:42 UTC

[incubator-mxnet] branch v1.2.0 updated: [MXNET-491] Use depthwise convolution by cuDNNv7 if available, updated version (#11076) (#11233)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch v1.2.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.2.0 by this push:
     new 546a233  [MXNET-491] Use depthwise convolution by cuDNNv7 if available, updated version (#11076) (#11233)
546a233 is described below

commit 546a2333385592e8df24a0b85072ac40c558b7c2
Author: Anirudh Subramanian <an...@apache.org>
AuthorDate: Tue Jun 12 10:19:35 2018 -0700

    [MXNET-491] Use depthwise convolution by cuDNNv7 if available, updated version (#11076) (#11233)
    
    * Use group convolution by cuDNNv7 if available
    
    * Fix coding style
    
    * ident-- for #if statements
    
    * more ident--
    
    * more ident--
    
    * prefer cudnnv7 depthwise convolution
---
 src/operator/nn/convolution.cu                | 10 ++-
 src/operator/nn/cudnn/cudnn_convolution-inl.h | 92 +++++++++++++++++++++++++++
 2 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/src/operator/nn/convolution.cu b/src/operator/nn/convolution.cu
index 045e570..65a320d 100644
--- a/src/operator/nn/convolution.cu
+++ b/src/operator/nn/convolution.cu
@@ -97,7 +97,9 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
       op.Forward(ctx, inputs, req, outputs);
     })
     return;
-  } else if (param.num_filter == param.num_group &&
+  }
+#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7
+  if (param.num_filter == param.num_group &&
       param.layout.value() == mshadow::kNCHW &&
       param.num_filter == inputs[conv::kData].shape_[1] &&
       param.kernel.ndim() == 2 &&
@@ -112,6 +114,7 @@ void ConvolutionCompute<gpu>(const nnvm::NodeAttrs& attrs,
     op.Forward(ctx, inputs, req, outputs);
     return;
   }
+#endif
 
 #if MXNET_USE_CUDNN == 1
   // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16).
@@ -167,7 +170,9 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
       op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
     })
     return;
-  } else if (param.num_filter == param.num_group &&
+  }
+#if MXNET_USE_CUDNN == 0 || CUDNN_MAJOR < 7
+  if (param.num_filter == param.num_group &&
       param.layout.value() == mshadow::kNCHW &&
       param.num_filter == in_data[conv::kData].shape_[1] &&
       param.kernel.ndim() == 2 &&
@@ -183,6 +188,7 @@ void ConvolutionGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
     op.Backward(ctx, std::vector<TBlob>{out_grad}, in_data, req, in_grad);
     return;
   }
+#endif
 
 #if MXNET_USE_CUDNN == 1
   // On fp16-I/O instances, use fp32 compute (i.e. pseudo-fp16).
diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h
index ca60c99..4b1cbbe 100644
--- a/src/operator/nn/cudnn/cudnn_convolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h
@@ -137,6 +137,35 @@ class CuDNNConvolutionOp {
     DType *wmat_ptr = GetNdPtr(in_data[conv::kWeight], param_.kernel.ndim() + 2, s);
     DType *out_ptr = GetNdPtr(out_data[conv::kOut], param_.kernel.ndim() + 2, s);
 
+    #if CUDNN_MAJOR >= 7
+    typename DataType<DType>::ScaleType alpha = 1.0f;
+    typename DataType<DType>::ScaleType beta = 0.0f;
+    typename DataType<DType>::ScaleType beta_add = 1.0f;
+    CUDNN_CALL(cudnnConvolutionForward(s->dnn_handle_,
+                    &alpha,
+                    in_desc_,
+                    data_ptr,
+                    filter_desc_,
+                    wmat_ptr,
+                    forward_conv_desc_,
+                    forward_algo_.AlgoNumber(),
+                    workspace.dptr_,
+                    workspace_size,
+                    req[conv::kOut] == kAddTo? &beta_add : &beta,
+                    out_desc_,
+                      out_ptr));
+
+    if (!param_.no_bias) {
+      Tensor<gpu, 1, DType> bias = in_data[conv::kBias].get<gpu, 1, DType>(s);
+      CUDNN_CALL(cudnnAddTensor(s->dnn_handle_,
+                              &alpha,
+                              bias_desc_,
+                              bias.dptr_,
+                              &beta_add,
+                              out_desc_,
+                              out_ptr));
+    }
+    #else
     for (uint32_t g = 0; g < param_.num_group; ++g) {
       typename DataType<DType>::ScaleType alpha = 1.0f;
       typename DataType<DType>::ScaleType beta = 0.0f;
@@ -177,6 +206,7 @@ class CuDNNConvolutionOp {
         #endif
       }
     }
+    #endif  // CUDNN_MAJOR >= 7
   }
 
   void Backward(const OpContext &ctx,
@@ -202,6 +232,51 @@ class CuDNNConvolutionOp {
     GetTempSize(ctx);
     Tensor<gpu, 1, DType> workspace = AllocateTempWorkspace(ctx, backward_workspace_byte_);
     size_t workspace_size = TensorSizeBytes(workspace);
+    #if CUDNN_MAJOR >= 7
+    typename DataType<DType>::ScaleType alpha = 1.0f;
+    typename DataType<DType>::ScaleType beta = 0.0f;
+    typename DataType<DType>::ScaleType beta_add = 1.0f;
+    if (!param_.no_bias && (req[conv::kBias] != kNullOp)) {
+        Tensor<gpu, 1, DType> gbias = in_grad[conv::kBias].get<gpu, 1, DType>(s);
+        CUDNN_CALL(cudnnConvolutionBackwardBias(s->dnn_handle_,
+                                            &alpha,
+                                            out_desc_,
+                                            grad_ptr,
+                                            req[conv::kBias] == kAddTo ? &beta_add : &beta,
+                                            bias_desc_,
+                                            gbias.dptr_));
+    }
+    if (req[conv::kWeight] != kNullOp) {
+        CUDNN_CALL(cudnnConvolutionBackwardFilter(s->dnn_handle_,
+            &alpha,
+            in_desc_,
+            data_ptr,
+            out_desc_,
+            grad_ptr,
+            back_conv_desc_w_,
+            back_algo_w_.AlgoNumber(),
+            workspace.dptr_,
+            workspace_size,
+            req[conv::kWeight] == kAddTo? &beta_add : &beta,
+            filter_desc_,
+            gwmat_ptr));
+    }
+    if (req[conv::kData] != kNullOp) {
+        CUDNN_CALL(cudnnConvolutionBackwardData(s->dnn_handle_,
+            &alpha,
+            filter_desc_,
+            wmat_ptr,
+            out_desc_,
+            grad_ptr,
+            back_conv_desc_,
+            back_algo_.AlgoNumber(),
+            workspace.dptr_,
+            workspace_size,
+            req[conv::kData] == kAddTo? &beta_add : &beta,
+            in_desc_,
+            gdata_ptr));
+    }
+    #else
     for (uint32_t g = 0; g < param_.num_group; ++g) {
       typename DataType<DType>::ScaleType alpha = 1.0f;
       typename DataType<DType>::ScaleType beta = 0.0f;
@@ -279,6 +354,7 @@ class CuDNNConvolutionOp {
         #endif
       }
     }
+    #endif  // CUDNN_MAJOR >= 7
   }
 
 /*!
@@ -342,7 +418,10 @@ class CuDNNConvolutionOp {
     TShape wshape = in_shape[conv::kWeight];
     TShape oshape = out_shape[conv::kOut];
     TShape dstride, ostride;
+#if CUDNN_MAJOR <= 6
     wshape[0] /= param_.num_group;
+#endif
+
 #if CUDNN_MAJOR <= 5
       // As of cuDNN_v6, the unsuffixed version of cudnnSetConvolution2dDescriptor()
       // takes an additional 'computeType' parameter to set the precision of the
@@ -464,9 +543,15 @@ class CuDNNConvolutionOp {
       CUDNN_CALL(cudnnSetConvolutionMathType(forward_conv_desc_, math_type));
       CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_, math_type));
       CUDNN_CALL(cudnnSetConvolutionMathType(back_conv_desc_w_, math_type));
+      CUDNN_CALL(cudnnSetConvolutionGroupCount(forward_conv_desc_, param_.num_group));
+      CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_, param_.num_group));
+      CUDNN_CALL(cudnnSetConvolutionGroupCount(back_conv_desc_w_, param_.num_group));
     #endif
+
+  #if CUDNN_MAJOR <= 6
     dshape[1] /= param_.num_group;
     oshape[1] /= param_.num_group;
+  #endif
     weight_offset_ = wshape.Size();
     data_offset_ = dstride[1] * dshape[1];
     out_offset_ = ostride[1] * oshape[1];
@@ -494,10 +579,17 @@ class CuDNNConvolutionOp {
 
     if (!param_.no_bias) {
       TShape bias = in_shape[conv::kBias];
+      #if CUDNN_MAJOR >= 7
+      bias_offset_ = bias[0];
+      std::vector<int> bias_shape = {1,
+                                     static_cast<int>(bias[0]),
+                                     1, 1};
+      #else
       bias_offset_ = bias[0] / param_.num_group;
       std::vector<int> bias_shape = {1,
                                      static_cast<int>(bias[0] / param_.num_group),
                                      1, 1};
+      #endif
       std::vector<int> bias_stride = {static_cast<int>(bias_offset_), 1, 1, 1};
       if (param_.kernel.ndim() == 3) {
         bias_shape.push_back(1);

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.