You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/05/20 18:55:32 UTC

[incubator-mxnet] branch master updated: Only allocate cudnn-rnn dropout memory if dropout p > 0 and acquire descriptors during initialization (#11004)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 10ac529  Only allocate cudnn-rnn dropout memory if dropout p > 0 and acquire descriptors during initialization (#11004)
10ac529 is described below

commit 10ac52993aeaf2fa589a6b3636c8e23a65c8e639
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Sun May 20 11:55:25 2018 -0700

    Only allocate cudnn-rnn dropout memory if dropout p > 0 and acquire descriptors during initialization (#11004)
    
    * cudnn-rnn: Only allocate dropout memory if dropout p > 0
    
    Also request cudnn descriptors during class initialization
    
    * Don't call cudnnDropoutGetStatesSize when not allocating states
    
    * Fixes
---
 src/operator/cudnn_rnn-inl.h | 87 ++++++++++++++++++++++++--------------------
 1 file changed, 47 insertions(+), 40 deletions(-)

diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
index 033d30e..b33a717 100644
--- a/src/operator/cudnn_rnn-inl.h
+++ b/src/operator/cudnn_rnn-inl.h
@@ -76,9 +76,39 @@ class CuDNNRNNOp : public Operator{
       param_.lstm_q_ = true;
     else
       param_.lstm_q_ = false;
+
+    // Create descriptors
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));
+
+    CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
+    CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));
+
+    CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
+    CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
   }
 
   ~CuDNNRNNOp() {
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));
+
+    CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
+    CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
+    CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
+
     if (init_cudnn_) {
       for (size_t i = 0; i < x_desc_vec_.size(); ++i) {
         CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i]));
@@ -86,27 +116,16 @@ class CuDNNRNNOp : public Operator{
         CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i]));
         CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]));
       }
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
-      CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));
-
-      CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
-      CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
-      CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
-      CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
-      Storage::Get()->Free(dropout_states_);
-      Storage::Get()->Free(reserve_space_);
       init_cudnn_ = false;
+
+      Storage::Get()->Free(reserve_space_);
+      if (param_.p > 0) {
+        Storage::Get()->Free(dropout_states_);
+      }
     }
   }
 
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
+  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
                        const std::vector<OpReqType> &req,
                        const std::vector<TBlob> &out_data,
                        const std::vector<TBlob> &aux_args) {
@@ -395,15 +414,6 @@ class CuDNNRNNOp : public Operator{
       strideA[1] = dimA[2];
       strideA[2] = 1;
 
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
-      CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));
-
       CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
                                             dtype_,
                                             3,
@@ -446,20 +456,19 @@ class CuDNNRNNOp : public Operator{
                                             strideA));
 
       // Create Dropout descriptors
-      CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
-      CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_,
-                                           &dropout_byte_));
-      dropout_size_ = dropout_byte_ / sizeof(DType);
-      dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU());
-      CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_,
-                                           s->dnn_handle_,
-                                           param_.p,  // keep probability
-                                           dropout_states_.dptr,
-                                           dropout_byte_,
+      if (param_.p > 0) {
+        CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
+        dropout_size_ = dropout_byte_ / sizeof(DType);
+        dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU());
+      } else {
+        dropout_states_ = {};
+        dropout_byte_ = 0;
+      }
+      CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
+                                           param_.p,  // discard probability
+                                           dropout_states_.dptr, dropout_byte_,
                                            seed_));
       // RNN descriptors
-      CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
-
       #if CUDNN_MAJOR >= 6
         cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
         CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
@@ -514,8 +523,6 @@ class CuDNNRNNOp : public Operator{
       CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size);
 
       // Set param descriptors
-      CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
-      CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));
       int dim_w[3] = {1, 1, 1};
       dim_w[0] = w.shape_[0];
       CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_,

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.