You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/04/23 05:35:36 UTC

[GitHub] [incubator-mxnet] roywei commented on a change in pull request #17547: Fix cudnn Dropout reproducibility

roywei commented on a change in pull request #17547:
URL: https://github.com/apache/incubator-mxnet/pull/17547#discussion_r413521892



##########
File path: src/operator/nn/dropout-inl.h
##########
@@ -255,8 +255,13 @@ class DropoutOp {
       Stream<xpu> *s = ctx.get_stream<xpu>();
 
       // set dropout state.
-      ctx.requested[0].get_cudnn_dropout_desc(&dropout_desc_, s, 1.0f - this->pkeep_, seed_);
-
+      Random<xpu, unsigned> *prnd = ctx.requested[1].get_random<xpu, unsigned>(s);
+      uint64_t rng_seed = prnd->GetSeed();
+      // reset dropout descriptor if rng seed changed.
+      bool reset = seed_ != rng_seed;
+      seed_ = rng_seed;
+      ctx.requested[0].get_cudnn_dropout_desc(&dropout_desc_, s, 1.0f - this->pkeep_,
+          seed_, reset);

Review comment:
       I spent some time looking into it and actually the problem is dropout's forward is entering the [re-init cudnn dropout desc logic](https://github.com/apache/incubator-mxnet/pull/17547/files#diff-cc7bb408eba92cdd6fad0e590b76fabeL438) every time during forward even without my change. This is true for nd/np and gluon, false for symbol dropout. So originally it was already reinitializing in every forward. The reason it does not cause any performance regression is because it's always using the `seed_` defined in `uint64_t seed_ = 17 + rand() % 4096; ` and never changed, it won't listen to MXNet's PRNG (which is the problem this PR is trying to fix). So if the seed didn't change, event if you re-init cudnn dropout descriptor, it won't take any time. My PR changed the seed, so it was actually re-init every forward, thus the regression.
   
   However, it works fine under symbol case. I guess is somehow for symbolic, every forward is using the same Dropout node and the state check took effect the didn't go into re-init logic. For imperative case the state check is always empty and went into the re-init logic.
   
   Compare the following two code, one is gluon and one is symbol, if I print some log during reinitialization in the original code without my change [here](https://github.com/apache/incubator-mxnet/pull/17547/files#diff-cc7bb408eba92cdd6fad0e590b76fabeL438). It will print out every forward in ND/NP/Gluon, not the case in Symbol.
   ```
   import mxnet as mx
   data = mx.nd.ones((10, 200, 300, 500), ctx=mx.gpu(0))
   dropout = mx.gluon.nn.Dropout(0.5)
   # with or without hybridize is the same result
   dropout.hybridize()
   with mx.autograd.record():
           result1 = dropout(data)
           result2 = dropout(result1)
   ```
   print 2 times
   ```
   re-init dropout desc
   re-init dropout desc
   ```
   
   Symbol:
   ```
   import mxnet as mx
   data = mx.nd.ones((10, 200, 300, 500), ctx=mx.gpu(0))
   net = mx.sym.Variable("data")
   net = mx.sym.Dropout(data=net, p=0.5, cudnn_off=False)
   exe = net.simple_bind(mx.gpu(0), data=data.shape)
   result1 = exe.forward(is_train=True, data=data)
   result2 = exe.forward(is_train=True, data=result1[0])
   ```
   
   print 1 time
   ```
   re-init dropout desc
   ```
   
   Given this situation, I don't have a good solution as checking the state handle size does not work in imperative mode, checking PRNG seed will also not work. I will revert this PR for now as it's causing regressions for models using dropout.
   
   cc @szha @sxjscience @apeforest 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org