You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/16 18:28:23 UTC

[incubator-mxnet] branch master updated: Refactor dropout operator to use ParallelRandom generator and also react deterministically when seeding (#9366)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 16746c1  Refactor dropout operator to use ParallelRandom generator and also react deterministically when seeding (#9366)
16746c1 is described below

commit 16746c177d557a1c6774cf4da07c70e4045e7599
Author: Chris Olivier <cj...@gmail.com>
AuthorDate: Tue Jan 16 10:28:19 2018 -0800

    Refactor dropout operator to use ParallelRandom generator and also react deterministically when seeding (#9366)
    
    * Refactor dropout operator to use ParallelRandom generator and also react deterministically when seeding
    
    * lint fix
    
    * Add more dropout unit testing
    
    * Reintroduced deterministic version of mkl dropout implementation
    
    * Fix a couple of unused variable warnings
    
    * MKL mode handle types smaller than int
    
    * Rearrange MKL code forward and backward passes into separate functions
    
    * fix typo
---
 src/operator/nn/dropout-inl.h          | 264 +++++++++++++++++++++++----------
 src/operator/nn/dropout.cc             |   1 +
 src/operator/optimizer_op-inl.h        |   2 +-
 tests/cpp/include/test_legacy_op.h     |   7 +
 tests/cpp/operator/dropout_perf.cc     | 104 +++++++++++++
 tests/python/unittest/test_operator.py | 118 +++++++++++----
 6 files changed, 391 insertions(+), 105 deletions(-)

diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h
index 4c8a5ee..715a6f4 100644
--- a/src/operator/nn/dropout-inl.h
+++ b/src/operator/nn/dropout-inl.h
@@ -34,9 +34,9 @@
 #include <string>
 #include <utility>
 #include <algorithm>
-#include "../../engine/openmp.h"
-#include "../operator_common.h"
+#include "../mxnet_op.h"
 #include "../mshadow_op.h"
+#include "../random/sampler.h"
 
 #if defined(USE_MKL) && defined(_OPENMP)
 #include <omp.h>
@@ -55,28 +55,6 @@ enum DropoutOpMode {kTraining, kAlways};
 namespace mxnet {
 namespace op {
 
-#if defined(USE_MKL) && defined(_OPENMP)
-static void bernoulli_generate(int n, double p, int* r) {
-  const int seed = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
-  const int nthr = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-# pragma omp parallel num_threads(nthr)
-  {
-    const int ithr = omp_get_thread_num();
-    const int avg_amount = (n + nthr - 1) / nthr;
-    const int my_offset = ithr * avg_amount;
-    const int my_amount = std::min(my_offset + avg_amount, n) - my_offset;
-    if (my_amount > 0) {
-      VSLStreamStatePtr stream;
-      vslNewStream(&stream, VSL_BRNG_MCG31, seed);
-      vslSkipAheadStream(stream, my_offset);
-      viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, my_amount,
-        r + my_offset, p);
-      vslDeleteStream(&stream);
-    }
-  }
-}
-#endif  // USE_MKL && _OPENMP
-
 struct DropoutParam : public dmlc::Parameter<DropoutParam> {
   float p;
   int mode;
@@ -94,10 +72,143 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
 
 template<typename xpu, typename DType>
 class DropoutOp : public Operator {
+#if defined(USE_MKL) && defined(_OPENMP)
+  static void BernoulliGenerate(common::random::RandGenerator<cpu, DType> gen,
+                                int n, double p, int* r) {
+    typename RandGenerator<xpu, DType>::Impl genImpl(&gen, 1);
+    const int seed = 17 + genImpl.rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
+    const int nthr = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+#pragma omp parallel num_threads(nthr)
+    {
+      const int ithr = omp_get_thread_num();
+      const int avg_amount = (n + nthr - 1) / nthr;
+      const int my_offset = ithr * avg_amount;
+      const int my_amount = std::min(my_offset + avg_amount, n) - my_offset;
+      if (my_amount > 0) {
+        VSLStreamStatePtr stream;
+        vslNewStream(&stream, VSL_BRNG_MCG31, seed + my_offset);
+        vslSkipAheadStream(stream, my_offset);
+        viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, stream, my_amount, r + my_offset, p);
+        vslDeleteStream(&stream);
+      }
+    }
+  }
+
+  // MKL forward pass
+  static bool MSHADOW_CINLINE MKLForward(mshadow::Stream<cpu> *s, RandGenerator<cpu, DType> *pgen,
+                                         const double pkeep,
+                                         const std::vector<TBlob> &in_data,
+                                         const std::vector<TBlob> &out_data) {
+    // BernoulliGenerate expects an array int, so for types smaller than int, the mask buffer
+    // will be too small, so we can;t use MKL in those cases
+    if (sizeof(DType) >= sizeof(int)) {
+      Tensor<xpu, 2, DType> mask = out_data[dropout::kMask].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> data = in_data[dropout::kData].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> out = out_data[dropout::kOut].FlatTo2D<xpu, DType>(s);
+      DType *outptr = out.dptr_;
+      DType *dataptr = data.dptr_;
+      auto maskptr = reinterpret_cast<int *>(mask.dptr_);
+      int count = mask.shape_[0] * mask.shape_[1];
+      BernoulliGenerate(*pgen, count, pkeep, maskptr);
+      const float pk_1 = 1.0f / pkeep;
+#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+      for (int i = 0; i < count; ++i) {
+        outptr[i] = dataptr[i] * maskptr[i] * pk_1;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // MKL backward pass
+  static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream<cpu> *s, const double pkeep,
+                                          const std::vector<TBlob> &in_grad,
+                                          const std::vector<TBlob> &out_data,
+                                          const std::vector<TBlob> &out_grad) {
+    if (sizeof(DType) >= sizeof(int)) {
+      Tensor<xpu, 2, DType> grad = out_grad[dropout::kOut].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> mask = out_data[dropout::kMask].FlatTo2D<xpu, DType>(s);
+      Tensor<xpu, 2, DType> gdata = in_grad[dropout::kData].FlatTo2D<xpu, DType>(s);
+      DType *ingradptr = gdata.dptr_;
+      const DType *outgradptr = grad.dptr_;
+      auto maskptr = reinterpret_cast<int *>(mask.dptr_);
+      int count = mask.shape_[0] * mask.shape_[1];
+      const float pk_1 = 1.0f / pkeep;
+#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
+      for (int i = 0; i < count; ++i) {
+        ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1;
+      }
+      return true;
+    }
+    return false;
+  }
+
+#ifdef __CUDACC__
+  // GPU never uses MKL
+  static bool MSHADOW_CINLINE MKLForward(mshadow::Stream<gpu> *s, RandGenerator<gpu, DType> *pgen,
+                                         const double pkeep,
+                                         const std::vector<TBlob> &in_data,
+                                         const std::vector<TBlob> &out_data) {
+    return false;
+  }
+  static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream<gpu> *s, const double pkeep,
+                                          const std::vector<TBlob> &in_grad,
+                                          const std::vector<TBlob> &out_data,
+                                          const std::vector<TBlob> &out_grad) {
+    return false;
+  }
+#endif  // __CUDACC__
+
+#else  // #if defined(USE_MKL) && defined(_OPENMP)
+  static bool MSHADOW_CINLINE MKLForward(mshadow::Stream<xpu> *s, RandGenerator<xpu, DType> *pgen,
+                                const double pkeep,
+                                const std::vector<TBlob> &in_data,
+                                const std::vector<TBlob> &out_data) {
+    return false;
+  }
+  static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream<xpu> *s, const double pkeep,
+                                          const std::vector<TBlob> &in_grad,
+                                          const std::vector<TBlob> &out_data,
+                                          const std::vector<TBlob> &out_grad) {
+    return false;
+  }
+#endif  // #if defined(USE_MKL) && defined(_OPENMP)
+
  public:
+  /*!
+   * \brief Dropout kernel, compute dropout tensor
+   */
+  struct DropoutKernel {
+    /*!
+     * \brief Dropout kernel function
+     * \param id Thread number (0-based representing count)
+     * \param gen Random number generator
+     * \param N Total number of items in the output
+     * \param step Step between items, related to parallelism
+     * \param dropout_out Output dropout values
+     * \param mask_out  Output mask (is multiplied to create dropout output, may be 0)
+     * \param input_data Input data to perform the dropout on
+     * \param pkeep Dropout rate (keep when the generated random number is less than this value)
+     */
+    MSHADOW_XINLINE static void Map(int id,
+                                    RandGenerator<xpu, DType> gen,
+                                    const int N,
+                                    const int step,
+                                    DType *dropout_out,
+                                    DType *mask_out,
+                                    const DType *input_data,
+                                    const real_t pkeep) {
+      RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
+        const real_t rand_num = static_cast<real_t>(genImpl.uniform());
+        mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep);
+        dropout_out[i] = input_data[i] * mask_out[i];
+      });
+    }
+  };
+
   explicit DropoutOp(DropoutParam param) {
     this->pkeep_ = 1.0f - param.p;
-    this->mode_ = param.mode;
+    this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
   }
 
   virtual void Forward(const OpContext &ctx,
@@ -105,36 +216,36 @@ class DropoutOp : public Operator {
                        const std::vector<OpReqType> &req,
                        const std::vector<TBlob> &out_data,
                        const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 1U);
-    if (ctx.is_train) {
-      CHECK_EQ(out_data.size(), 2U);
-    }
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 2, DType> data = in_data[dropout::kData].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> out = out_data[dropout::kOut].FlatTo2D<xpu, DType>(s);
-    if (ctx.is_train || mode_ == dropout::kAlways) {
-      Tensor<xpu, 2, DType> mask = out_data[dropout::kMask].FlatTo2D<xpu, DType>(s);
-#if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP)
-      DType* outptr = out.dptr_;
-      DType* dataptr = data.dptr_;
-      auto maskptr = reinterpret_cast<int*>(mask.dptr_);
-      int count = mask.shape_[0]*mask.shape_[1];
-      bernoulli_generate(count, this->pkeep_, maskptr);
-      const float pk_1 = 1.0f / pkeep_;
-      #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-      for (int i = 0; i < count; ++i) {
-        outptr[i] = dataptr[i] * maskptr[i] * pk_1;
+    if (req[dropout::kOut] != kNullOp) {
+      CHECK_EQ(in_data.size(), 1U);
+      if (ctx.is_train) {
+        CHECK_EQ(out_data.size(), 2U);
+      }
+      Stream<xpu> *s = ctx.get_stream<xpu>();
+      const TBlob &out = out_data[dropout::kOut];
+      if (ctx.is_train || this->mode_ == dropout::kAlways) {
+        RandGenerator<xpu, DType> *pgen = ctx.requested[0].get_parallel_random<xpu, DType>();
+        CHECK_NOTNULL(pgen);
+        if (!MKLForward(s, pgen, this->pkeep_, in_data, out_data)) {
+          const TBlob &mask = out_data[dropout::kMask];
+          CHECK(req[dropout::kOut] != kAddTo);
+          LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
+                                        out.dptr<DType>(),
+                                        mask.dptr<DType>(),
+                                        in_data[dropout::kData].dptr<DType>(),
+                                        this->pkeep_);
+        }
+      } else {
+        const TBlob& data = in_data[dropout::kData];
+        if (req[dropout::kOut] == kWriteTo) {
+          mxnet_op::copy(s, out, data);
+        } else {
+          MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
+              s, out.Size(), out.dptr<DType>(), data.dptr<DType>());
+          });
+        }
       }
-#else
-      Random<xpu> *prnd = ctx.requested[dropout::kRandom].get_random<xpu, real_t>(s);
-      mask = tcast<DType>(F<mshadow_op::threshold>(
-             prnd->uniform(mask.shape_), pkeep_) * (1.0f / pkeep_));
-      Assign(out, req[dropout::kOut], data * mask);
-#endif  // USE_MKL && _OPENMP
-    } else {
-      Assign(out, req[dropout::kOut], F<mshadow_op::identity>(data));
     }
   }
 
@@ -150,32 +261,36 @@ class DropoutOp : public Operator {
     CHECK_EQ(out_grad.size(), 1U);
     CHECK_EQ(in_grad.size(), 1U);
     Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 2, DType> grad = out_grad[dropout::kOut].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> mask = out_data[dropout::kMask].FlatTo2D<xpu, DType>(s);
-    Tensor<xpu, 2, DType> gdata = in_grad[dropout::kData].FlatTo2D<xpu, DType>(s);
     if (ctx.is_train || mode_ == dropout::kAlways) {
-#if !defined(__CUDACC__) && defined(USE_MKL) && defined(_OPENMP)
-      DType* ingradptr = gdata.dptr_;
-      DType* outgradptr = grad.dptr_;
-      auto maskptr = reinterpret_cast<int*>(mask.dptr_);
-      int count = mask.shape_[0]*mask.shape_[1];
-      const float pk_1 = 1.0f / pkeep_;
-      #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-      for (int i = 0; i < count; ++i) {
-        ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1;
+      if (!MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) {
+        const TBlob &gdata = in_grad[dropout::kData];
+        const TBlob &grad = out_grad[dropout::kOut];
+        const TBlob &mask = out_data[dropout::kMask];
+        CHECK_EQ(grad.Size(), mask.Size());
+        MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
+            s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
+        });
       }
-#else  // USE_MKL && _OPENMP
-      CHECK_EQ(grad.shape_.Size(), mask.shape_.Size());
-      Assign(gdata, req[dropout::kData], grad * mask);
-#endif  // USE_MKL && _OPENMP
     } else {
-      Assign(gdata, req[dropout::kData], F<mshadow_op::identity>(grad));
+      const TBlob& gdata = in_grad[dropout::kData];
+      const TBlob& grad = out_grad[dropout::kOut];
+      if (req[dropout::kData] == kWriteTo) {
+        mxnet_op::copy(s, gdata, grad);
+      } else {
+        MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
+          mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, xpu>::Launch(
+            s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>());
+        });
+      }
     }
   }
 
  private:
+  /*! \brief Dropout rate (keep when the generated random number is less than this value) */
   real_t pkeep_;
-  int mode_;
+  /*! \brief Dropout mode */
+  dropout::DropoutOpMode mode_;
 };  // class DropoutOp
 
 
@@ -254,9 +369,8 @@ class DropoutProp : public OperatorProperty {
     return {{in_data[dropout::kData], out_data[dropout::kOut]}};
   }
 
-  std::vector<ResourceRequest> ForwardResource(
-    const std::vector<TShape> &in_shape) const override {
-    return {ResourceRequest::kRandom};
+  std::vector<ResourceRequest> ForwardResource(const std::vector<TShape> &in_shape) const override {
+    return { ResourceRequest::kParallelRandom };
   }
 
   int NumVisibleOutputs() const override {
diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc
index bbf5e2d..3aa832a 100644
--- a/src/operator/nn/dropout.cc
+++ b/src/operator/nn/dropout.cc
@@ -25,6 +25,7 @@
 */
 
 #include "./dropout-inl.h"
+#include "../operator_common.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 5c3cab9..42721a9 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -463,7 +463,7 @@ inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param,
                                  mom.data(), req, &out_blob);
 }
 
-/*! 
+/*!
  * \brief Storge type inference function in optimizer.
  * \param n_rsp     The number of inputs that should be of row_sparse storage type
  *                  if kFComputeEx is dispatched
diff --git a/tests/cpp/include/test_legacy_op.h b/tests/cpp/include/test_legacy_op.h
index 6d326fc..498fa06 100644
--- a/tests/cpp/include/test_legacy_op.h
+++ b/tests/cpp/include/test_legacy_op.h
@@ -503,6 +503,13 @@ class LegacyOperatorExecutor : public OperatorDataInitializer<DType>
         }
       } else if (req.type == ResourceRequest::kRandom) {
         opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req));
+      } else if (req.type == ResourceRequest::kParallelRandom) {
+        Resource rm = ResourceManager::Get()->Request(ctx, req);
+        if (ctx.dev_mask() == Context::kCPU) {
+          common::random::RandGenerator<cpu, DType>::AllocState(
+            rm.get_parallel_random<cpu, DType>());
+        }
+        opContext_.requested.emplace_back(rm);
       } else {
         LOG(FATAL) << "resource type not yet supported";
       }
diff --git a/tests/cpp/operator/dropout_perf.cc b/tests/cpp/operator/dropout_perf.cc
new file mode 100644
index 0000000..90bf6eb
--- /dev/null
+++ b/tests/cpp/operator/dropout_perf.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  \file dropout_perf.cc
+ *  \brief Perf/profile run of DropoutOp
+ *  \author Chris Olivier
+ */
+
+#include <gtest/gtest.h>
+#include <mxnet/tensor_blob.h>
+#include "../include/test_op_runner.h"
+#include "../include/test_legacy_op.h"
+#include "../../src/operator/nn/dropout-inl.h"
+
+using namespace mxnet;
+
+typedef std::vector<std::pair<std::string, std::string> > kwargs_t;
+const kwargs_t basic_dropout_args = { };
+
+/*!
+ * \brief Generic bidirectional sanity test
+ */
+TEST(DROPOUT_PERF, ExecuteBidirectional) {
+  TShape shape({5, 5});
+  kwargs_t kwargs = basic_dropout_args;
+  kwargs.push_back({"mode", "always"});
+  test::op::LegacyOpRunner<mxnet::op::DropoutProp, float, float> runner;
+  runner.RunBidirectional(false, { shape }, kwargs, 1);
+}
+
+/*!
+ * \brief DropoutOp timing test for CPU
+ */
+TEST(DROPOUT_PERF, TimingCPU) {
+  kwargs_t kwargs = basic_dropout_args;
+// Which math function is arbitrary since it will have roughly constant timing among approaches
+  kwargs.push_back({"mode", "always"});
+  test::op::LegacyOpRunner<mxnet::op::DropoutProp, float, float> runner;
+  runner.RunBidirectional(false,
+                          { TShape({10, 10, 10, 10}) },
+                          kwargs, 1);  // prime code and cache
+  std::vector <TShape> shapes;
+  if (test::performance_run) {
+    shapes = {
+      {1,  1, 28,  28},
+      {1,  3, 28,  28},
+      {50, 1, 18,  32},
+      {50, 3, 18,  32},
+      {20, 3, 128, 128}
+    };
+  } else {
+    shapes = {
+      {1,  1, 28,  28},
+      {50, 3, 18,  32},
+    };
+  }
+  for (const TShape &shape : shapes) {
+    runner.TimingTest("Dropout Operator CPU", false, false, kwargs, 2, 10, { shape });
+  }
+}
+
+#if MXNET_USE_CUDA == 1
+/*!
+ * \brief DropoutOp timing test for GPU
+ */
+TEST(DROPOUT_PERF, TimingGPU) {
+  kwargs_t kwargs = basic_dropout_args;
+  // Which math function is arbitrary since it will have roughly constant timing among approaches
+  kwargs.push_back({"mode", "always"});
+  test::OperatorRunner<mxnet::op::DropoutProp,
+    test::op::LegacyOperatorExecutor<float, float>> runner;
+  runner.RunBidirectional(true,
+                          { TShape({10, 10, 10, 10}) },
+                          kwargs, 1);  // prime code and cache
+  std::vector <TShape> shapes = {
+    {1,  1, 28,  28},
+    {1,  3, 28,  28},
+    {50, 1, 18,  32},
+    {50, 3, 18,  32},
+    {20, 3, 128, 128}
+  };
+  for (const TShape &shape : shapes) {
+    runner.TimingTest("Dropout Operator GPU", true, false, kwargs, 2, 10, { shape });
+  }
+}
+#endif  // MXNET_USE_CUDA == 1
+
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 56dc27c..966a955 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -19,6 +19,7 @@
 from __future__ import print_function
 import numpy as np
 import mxnet as mx
+import math
 import random
 import itertools
 from numpy.testing import assert_allclose, assert_array_equal
@@ -4344,41 +4345,100 @@ def test_stack():
 
 
 def test_dropout():
-    # test dropout
-    x = mx.sym.var('data')
-    y = mx.sym.Dropout(x, p=0.5)
-    exe = y.simple_bind(ctx=default_context(), data=(10, 10))
+    def zero_count(array, ratio):
+        zeros = 0
+        for i in array:
+            if i == 0:
+                zeros += 1
+            elif math.isnan(i):
+                assert ratio == 1  # Only valid for ratio = 1
+                zeros += 1
+        return zeros
+
+    def check_correctness(executor, input, ratio):
+        input = input.ravel()
+        output = executor.outputs[0].asnumpy().ravel()
+        input_sum = np.sum(input)
+        output_sum = np.sum(output)
+
+        # Make sure input zeroes are none (test data setup check)
+        assert zero_count(input, ratio) == 0
+
+        # count number of zeroes in output
+        output_zeroes = zero_count(output, ratio)
+
+        # Hopefully should be within ratio/2 %
+        error = abs(output_sum - input_sum) / input_sum
+        if ratio == 1.0:
+            assert output_zeroes == len(input)
+        elif ratio > 0.2:
+            assert output_zeroes > 0
+            assert error < (ratio/2)
+        elif ratio == 0:
+            assert output_zeroes == 0
+
+    def check_dropout_ratio(ratio, shape):
+        # test dropout
+        x = mx.sym.var('data')
+        y = mx.sym.Dropout(x, p=ratio)
+        exe = y.simple_bind(ctx=default_context(), data=shape)
+
+        if ratio == 1:
+            max_value = float('nan')
+        else:
+            max_value = 1 if ratio == 0 else 1/ratio
 
-    exe.arg_arrays[0][:] = 1
-    exe.forward(is_train=True)
-    assert exe.outputs[0].asnumpy().max() == 2
-    assert exe.outputs[0].asnumpy().min() == 0
-    exe.backward([mx.nd.ones((10, 10))])
-    assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all()
+        if ratio == 1:
+            min_value = float('nan')
+        else:
+            min_value = 1 if ratio == 0 else 0
 
-    exe.forward(is_train=False)
-    assert (exe.outputs[0].asnumpy() == exe.arg_arrays[0].asnumpy()).all()
-    exe.backward([mx.nd.ones((10, 10))], is_train=False)
-    assert (exe.grad_arrays[0].asnumpy() == exe.arg_arrays[0].asnumpy()).all()
+        exe.arg_arrays[0][:] = 1
+        exe.forward(is_train=True)
+        if not math.isnan(max_value):
+            assert exe.outputs[0].asnumpy().max() > 0
+        else:
+            assert math.isnan(exe.outputs[0].asnumpy().max())
+        if not math.isnan(min_value):
+            assert exe.outputs[0].asnumpy().min() == min_value
+        else:
+            assert math.isnan(exe.outputs[0].asnumpy().min())
 
-    # test permanent dropout
-    x = mx.sym.var('data')
-    y = mx.sym.Dropout(x, p=0.5, mode='always')
-    exe = y.simple_bind(ctx=default_context(), data=(10, 10))
+        check_correctness(exe, exe.arg_arrays[0].asnumpy(), ratio)
 
-    exe.arg_arrays[0][:] = 1
-    exe.forward(is_train=True)
-    assert exe.outputs[0].asnumpy().max() == 2
-    assert exe.outputs[0].asnumpy().min() == 0
-    exe.backward([mx.nd.ones((10, 10))])
-    assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all()
+        if ratio == 0.5:
+            exe.backward([mx.nd.ones(shape)])
+            assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all()
 
-    exe.forward(is_train=False)
-    assert exe.outputs[0].asnumpy().max() == 2
-    assert exe.outputs[0].asnumpy().min() == 0
-    exe.backward([mx.nd.ones((10, 10))], is_train=False)
-    assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all()
+            exe.forward(is_train=False)
+            assert (exe.outputs[0].asnumpy() == exe.arg_arrays[0].asnumpy()).all()
+            exe.backward([mx.nd.ones(shape)], is_train=False)
+            assert (exe.grad_arrays[0].asnumpy() == exe.arg_arrays[0].asnumpy()).all()
 
+            # test permanent dropout
+            x = mx.sym.var('data')
+            y = mx.sym.Dropout(x, p=ratio, mode='always')
+            exe = y.simple_bind(ctx=default_context(), data=shape)
+
+            exe.arg_arrays[0][:] = 1
+            exe.forward(is_train=True)
+            assert exe.outputs[0].asnumpy().max() == max_value
+            assert exe.outputs[0].asnumpy().min() == min_value
+            exe.backward([mx.nd.ones(shape)])
+            assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all()
+
+            exe.forward(is_train=False)
+            assert exe.outputs[0].asnumpy().max() == max_value
+            assert exe.outputs[0].asnumpy().min() == min_value
+            exe.backward([mx.nd.ones(shape)], is_train=False)
+            assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all()
+
+    shape = (100, 100)
+    check_dropout_ratio(0.5, shape)
+    check_dropout_ratio(0.0, shape)
+    check_dropout_ratio(1.0, shape)
+    check_dropout_ratio(0.75, shape)
+    check_dropout_ratio(0.25, shape)
 
 def test_scatter_gather_nd():
     def check(data, idx):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].