You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/20 22:30:32 UTC

[GitHub] piiswrong closed pull request #10048: [MXNET-68] Random shuffle implementation

piiswrong closed pull request #10048: [MXNET-68] Random shuffle implementation
URL: https://github.com/apache/incubator-mxnet/pull/10048
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/ndarray/random.md b/docs/api/python/ndarray/random.md
index ae9e69f758f..4341a3ce2cd 100644
--- a/docs/api/python/ndarray/random.md
+++ b/docs/api/python/ndarray/random.md
@@ -35,6 +35,8 @@ In the rest of this document, we list routines provided by the `ndarray.random`
     normal
     poisson
     uniform
+    multinomial
+    shuffle
     mxnet.random.seed
 ```
 
diff --git a/docs/api/python/symbol/random.md b/docs/api/python/symbol/random.md
index a3492f6f840..22c686ff2fd 100644
--- a/docs/api/python/symbol/random.md
+++ b/docs/api/python/symbol/random.md
@@ -35,6 +35,8 @@ In the rest of this document, we list routines provided by the `symbol.random` p
     normal
     poisson
     uniform
+    multinomial
+    shuffle
     mxnet.random.seed
 ```
 
diff --git a/python/mxnet/ndarray/random.py b/python/mxnet/ndarray/random.py
index af125753e5e..93f97e80b47 100644
--- a/python/mxnet/ndarray/random.py
+++ b/python/mxnet/ndarray/random.py
@@ -24,7 +24,7 @@
 
 
 __all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial',
-           'negative_binomial', 'generalized_negative_binomial']
+           'negative_binomial', 'generalized_negative_binomial', 'shuffle']
 
 
 def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs):
@@ -431,3 +431,35 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, **kwargs):
     <NDArray 2 @cpu(0)>
     """
     return _internal._sample_multinomial(data, shape, get_prob, out=out, **kwargs)
+
+
+def shuffle(data, **kwargs):
+    """Shuffle the elements randomly.
+
+    This shuffles the array along the first axis.
+    The order of the elements in each subarray does not change.
+    For example, if a 2D array is given, the order of the rows randomly changes,
+    but the order of the elements in each row does not change.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data array.
+    out : NDArray
+        Array to store the result.
+
+    Examples
+    --------
+    >>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+    >>> mx.nd.random.shuffle(data)
+    [[ 0.  1.  2.]
+     [ 6.  7.  8.]
+     [ 3.  4.  5.]]
+    <NDArray 2x3 @cpu(0)>
+    >>> mx.nd.random.shuffle(data)
+    [[ 3.  4.  5.]
+     [ 0.  1.  2.]
+     [ 6.  7.  8.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    return _internal._shuffle(data, **kwargs)
diff --git a/python/mxnet/symbol/random.py b/python/mxnet/symbol/random.py
index f0d05ad0561..721a1daa95e 100644
--- a/python/mxnet/symbol/random.py
+++ b/python/mxnet/symbol/random.py
@@ -23,7 +23,7 @@
 
 
 __all__ = ['uniform', 'normal', 'poisson', 'exponential', 'gamma', 'multinomial',
-           'negative_binomial', 'generalized_negative_binomial']
+           'negative_binomial', 'generalized_negative_binomial', 'shuffle']
 
 
 def _random_helper(random, sampler, params, shape, dtype, kwargs):
@@ -247,3 +247,34 @@ def multinomial(data, shape=_Null, get_prob=True, **kwargs):
         reward as head gradient w.r.t. this array to estimate gradient.
     """
     return _internal._sample_multinomial(data, shape, get_prob, **kwargs)
+
+
+def shuffle(data, **kwargs):
+    """Shuffle the elements randomly.
+
+    This shuffles the array along the first axis.
+    The order of the elements in each subarray does not change.
+    For example, if a 2D array is given, the order of the rows randomly changes,
+    but the order of the elements in each row does not change.
+
+    Parameters
+    ----------
+    data : NDArray
+        Input data array.
+    Examples
+    --------
+    >>> data = mx.nd.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+    >>> a = mx.sym.Variable('a')
+    >>> b = mx.sym.random.shuffle(a)
+    >>> b.eval(a=data)
+    [[ 0.  1.  2.]
+     [ 6.  7.  8.]
+     [ 3.  4.  5.]]
+    <NDArray 2x3 @cpu(0)>
+    >>> b.eval(a=data)
+    [[ 3.  4.  5.]
+     [ 0.  1.  2.]
+     [ 6.  7.  8.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    return _internal._shuffle(data, **kwargs)
diff --git a/src/operator/random/shuffle_op.cc b/src/operator/random/shuffle_op.cc
new file mode 100644
index 00000000000..d2a3e2d3df0
--- /dev/null
+++ b/src/operator/random/shuffle_op.cc
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file shuffle_op.cc
+ * \brief Operator to shuffle elements of an NDArray
+ */
+#if (__GNUC__ > 4 && !defined(__clang__major__)) || (__clang_major__ > 4 && __linux__)
+  #define USE_GNU_PARALLEL_SHUFFLE
+#endif
+
+#include <mxnet/operator_util.h>
+#include <algorithm>
+#include <random>
+#include <vector>
+#ifdef USE_GNU_PARALLEL_SHUFFLE
+  #include <parallel/algorithm>
+#endif
+#include "../elemwise_op_common.h"
+
+namespace mxnet {
+namespace op {
+
+namespace {
+
+template<typename DType, typename Rand>
+void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) {
+  #ifdef USE_GNU_PARALLEL_SHUFFLE
+    auto rand_n = [prnd](index_t n) {
+      std::uniform_int_distribution<index_t> dist(0, n - 1);
+      return dist(*prnd);
+    };
+    __gnu_parallel::random_shuffle(out, out + size, rand_n);
+  #else
+    std::shuffle(out, out + size, *prnd);
+  #endif
+}
+
+template<typename DType, typename Rand>
+void ShuffleND(DType* const out, const index_t size, const index_t first_axis_len,
+                Rand* const prnd) {
+  // Fisher-Yates shuffling
+  const index_t stride = size / first_axis_len;
+  auto rand_n = [prnd](index_t n) {
+    std::uniform_int_distribution<index_t> dist(0, n - 1);
+    return dist(*prnd);
+  };
+  CHECK_GT(first_axis_len, 0U);
+  for (index_t i = first_axis_len - 1; i > 0; --i) {
+    const index_t j = rand_n(i + 1);
+    if (i != j) {
+      std::swap_ranges(out + stride * i, out + stride * (i + 1), out + stride * j);
+    }
+  }
+}
+
+}  // namespace
+
+void ShuffleForwardCPU(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  if (req[0] == kNullOp) {
+    return;
+  }
+  CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
+  const TShape& input_shape = inputs[0].shape_;
+  const index_t size = inputs[0].Size();
+  const index_t first_axis_len = input_shape[0];
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    Tensor<cpu, 1, DType> in = inputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s);
+    Tensor<cpu, 1, DType> out = outputs[0].get_with_shape<cpu, 1, DType>(Shape1(size), s);
+    auto& prnd = ctx.requested[0].get_random<cpu, index_t>(ctx.get_stream<cpu>())->GetRndEngine();
+    if (req[0] != kWriteInplace) {
+      std::copy(in.dptr_, in.dptr_ + size, out.dptr_);
+    }
+    if (input_shape.ndim() == 1) {
+      Shuffle1D(out.dptr_, size, &prnd);
+    } else {
+      ShuffleND(out.dptr_, size, first_axis_len, &prnd);
+    }
+  });
+}
+
+
+// No parameter is declared.
+// No backward computation is registered. Shuffling is not differentiable.
+
+NNVM_REGISTER_OP(_shuffle)
+.add_alias("shuffle")
+.describe(R"code(Randomly shuffle the elements.
+
+This shuffles the array along the first axis.
+The order of the elements in each subarray does not change.
+For example, if a 2D array is given, the order of the rows randomly changes,
+but the order of the elements in each row does not change.
+)code")
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const nnvm::NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kRandom, ResourceRequest::kTempSpace};
+  })
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs) {
+    return std::vector<std::pair<int, int>>{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", ShuffleForwardCPU)
+.add_argument("data", "NDArray-or-Symbol", "Data to be shuffled.");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/random/shuffle_op.cu b/src/operator/random/shuffle_op.cu
new file mode 100644
index 00000000000..5bf8320c078
--- /dev/null
+++ b/src/operator/random/shuffle_op.cu
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file shuffle_op.cc
+ * \brief Operator to shuffle elements of an NDArray
+ */
+#include <mxnet/operator_util.h>
+#include <algorithm>
+#include <random>
+#include <vector>
+#include "../elemwise_op_common.h"
+#include "../tensor/init_op.h"
+
+namespace mxnet {
+namespace op {
+
+namespace {
+
+struct CopyForShuffle {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType* const in, DType* out,
+                                  const index_t* indices, const index_t stride) {
+    out[i] = in[indices[i / stride] * stride + i % stride];
+  }
+};
+
+}  // namespace
+
+void ShuffleForwardGPU(const nnvm::NodeAttrs& attrs,
+                       const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  using namespace mxnet_op;
+  if (req[0] == kNullOp) {
+    return;
+  }
+  CHECK_NE(req[0], kAddTo) << "Shuffle does not support AddTo";
+  const TShape& input_shape = inputs[0].shape_;
+  const index_t size = inputs[0].Size();
+  const index_t first_axis_len = input_shape[0];
+  const index_t stride = size / first_axis_len;
+  Stream<gpu> *s = ctx.get_stream<gpu>();
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    using KeyType = index_t;
+    Tensor<gpu, 1, DType> in = inputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
+    Tensor<gpu, 1, DType> out = outputs[0].get_with_shape<gpu, 1, DType>(Shape1(size), s);
+    Random<gpu, KeyType> *prnd = ctx.requested[0].get_random<gpu, KeyType>(s);
+    if (input_shape.ndim() == 1) {
+      if (req[0] != kWriteInplace) {
+        Copy(out, in, s);
+      }
+      Tensor<gpu, 1, KeyType> keys =
+        ctx.requested[1].get_space_typed<gpu, 1, KeyType>(Shape1(size), s);
+      prnd->GetRandInt(keys);
+      SortByKey(keys, out, true);
+    } else {
+      const size_t tmp_space_size = req[0] == kWriteInplace ?
+        2 * first_axis_len * sizeof(index_t) + size * sizeof(DType) :
+        2 * first_axis_len * sizeof(index_t);
+      Tensor<gpu, 1, char> tmp_space =
+        ctx.requested[1].get_space_typed<gpu, 1, char>(Shape1(tmp_space_size), s);
+      char* tmp_space_ptr = tmp_space.dptr_;
+      Tensor<gpu, 1, index_t> indices(reinterpret_cast<index_t*>(tmp_space_ptr),
+                                      Shape1(first_axis_len), s);
+      tmp_space_ptr += sizeof(index_t) * first_axis_len;
+      Kernel<range_fwd, gpu>::Launch(s, first_axis_len, 1, 0U, 1U, kWriteTo, indices.dptr_);
+      Tensor<gpu, 1, KeyType> keys(reinterpret_cast<KeyType*>(tmp_space_ptr),
+                                   Shape1(first_axis_len), s);
+      tmp_space_ptr += sizeof(KeyType) * first_axis_len;
+      prnd->GetRandInt(keys);
+      SortByKey(keys, indices, true);
+      if (req[0] == kWriteInplace) {
+        Tensor<gpu, 1, DType> buf(reinterpret_cast<DType*>(tmp_space_ptr), Shape1(size), s);
+        Copy(buf, in, s);
+        Kernel<CopyForShuffle, gpu>::Launch(s, size, buf.dptr_, out.dptr_, indices.dptr_, stride);
+      } else {
+        Kernel<CopyForShuffle, gpu>::Launch(s, size, in.dptr_, out.dptr_, indices.dptr_, stride);
+      }
+    }
+  });
+}
+
+NNVM_REGISTER_OP(_shuffle)
+.set_attr<FCompute>("FCompute<gpu>", ShuffleForwardGPU);
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py
index f042f57c4e9..c8dc3c97a81 100644
--- a/tests/python/unittest/test_random.py
+++ b/tests/python/unittest/test_random.py
@@ -16,6 +16,8 @@
 # under the License.
 
 import os
+import math
+import itertools
 import mxnet as mx
 from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
 import numpy as np
@@ -552,6 +554,81 @@ def compute_expected_prob():
     mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
     mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)
 
+@with_seed()
+def test_shuffle():
+    def check_first_axis_shuffle(arr):
+        stride = int(arr.size / arr.shape[0])
+        column0 = arr.reshape((arr.size,))[::stride].sort()
+        seq = mx.nd.arange(0, arr.size - stride + 1, stride, ctx=arr.context)
+        assert (column0 == seq).prod() == 1
+        for i in range(arr.shape[0]):
+            subarr = arr[i].reshape((arr[i].size,))
+            start = subarr[0].asscalar()
+            seq = mx.nd.arange(start, start + stride, ctx=arr.context)
+            assert (subarr == seq).prod() == 1
+
+    # This tests that the shuffling is along the first axis with `repeat1` number of shufflings
+    # and the outcomes are uniformly distributed with `repeat2` number of shufflings.
+    # Note that the enough number of samples (`repeat2`) to verify the uniformity of the distribution
+    # of the outcomes grows factorially with the length of the first axis of the array `data`.
+    # So we have to settle down with small arrays in practice.
+    # `data` must be a consecutive sequence of integers starting from 0 if it is flattened.
+    def testSmall(data, repeat1, repeat2):
+        # Check that the shuffling is along the first axis.
+        # The order of the elements in each subarray must not change.
+        # This takes long time so `repeat1` need to be small.
+        for i in range(repeat1):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+        # Count the number of each different outcome.
+        # The sequence composed of the first elements of the subarrays is enough to discriminate
+        # the outcomes as long as the order of the elements in each subarray does not change.
+        count = {}
+        stride = int(data.size / data.shape[0])
+        for i in range(repeat2):
+            ret = mx.nd.random.shuffle(data)
+            h = str(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # Check the total number of possible outcomes.
+        # If `repeat2` is not large enough, this could fail with high probability.
+        assert len(count) == math.factorial(data.shape[0])
+        # The outcomes must be uniformly distributed.
+        # If `repeat2` is not large enough, this could fail with high probability.
+        for p in itertools.permutations(range(0, data.size - stride + 1, stride)):
+            assert abs(1. * count[str(mx.nd.array(p))] / repeat2 - 1. / math.factorial(data.shape[0])) < 0.01
+        # Check symbol interface
+        a = mx.sym.Variable('a')
+        b = mx.sym.random.shuffle(a)
+        c = mx.sym.random.shuffle(data=b, name='c')
+        d = mx.sym.sort(c, axis=0)
+        assert (d.eval(a=data, ctx=mx.current_context())[0] == data).prod() == 1
+
+    # This test is weaker than `testSmall` and to test larger arrays.
+    # `repeat` should be much smaller than the factorial of `len(x.shape[0])`.
+    # `data` must be a consecutive sequence of integers starting from 0 if it is flattened.
+    def testLarge(data, repeat):
+        # Check that the shuffling is along the first axis
+        # and count the number of different outcomes.
+        stride = int(data.size / data.shape[0])
+        count = {}
+        for i in range(repeat):
+            ret = mx.nd.random.shuffle(data)
+            check_first_axis_shuffle(ret)
+            h = str(ret.reshape((ret.size,))[::stride])
+            c = count.get(h, 0)
+            count[h] = c + 1
+        # The probability of duplicated outcomes is very low for large arrays.
+        assert len(count) == repeat
+
+    # Test small arrays with different shapes
+    testSmall(mx.nd.arange(0, 3), 100, 20000)
+    testSmall(mx.nd.arange(0, 9).reshape((3, 3)), 100, 20000)
+    testSmall(mx.nd.arange(0, 18).reshape((3, 2, 3)), 100, 20000)
+    # Test larger arrays
+    testLarge(mx.nd.arange(0, 100000).reshape((10, 10000)), 10)
+    testLarge(mx.nd.arange(0, 100000).reshape((10000, 10)), 10)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services