You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/09/10 19:32:52 UTC

[commons-rng] 02/03: RNG-114: Update ListSampler to detect RandomAccess lists.

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 2d2716523e8de53f5da221b2425058743562627d
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Sun Sep 8 22:04:06 2019 +0100

    RNG-114: Update ListSampler to detect RandomAccess lists.
    
    RandomAccess lists can be shuffled in-place. Non-RandomAccess lists
    should be shuffle as an extracted array and updated using their list
    iterator.
---
 .../apache/commons/rng/sampling/ListSampler.java   |  79 +++++++++++--
 .../commons/rng/sampling/ListSamplerTest.java      | 130 ++++++++++++++++++++-
 2 files changed, 196 insertions(+), 13 deletions(-)

diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ListSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ListSampler.java
index 88e9477..b837615 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ListSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/ListSampler.java
@@ -18,6 +18,8 @@
 package org.apache.commons.rng.sampling;
 
 import java.util.List;
+import java.util.ListIterator;
+import java.util.RandomAccess;
 import java.util.ArrayList;
 
 import org.apache.commons.rng.UniformRandomProvider;
@@ -31,6 +33,12 @@ import org.apache.commons.rng.UniformRandomProvider;
  */
 public final class ListSampler {
     /**
+     * The size threshold for using the random access algorithm
+     * when the list does not implement java.util.RandomAccess.
+     */
+    private static final int RANDOM_ACCESS_SIZE_THRESHOLD = 4;
+
+    /**
      * Class contains only static methods.
      */
     private ListSampler() {}
@@ -73,25 +81,51 @@ public final class ListSampler {
     }
 
     /**
-     * Shuffles the entries of the given array.
+     * Shuffles the entries of the given array, using the
+     * <a href="http://en.wikipedia.org/wiki/Fisher-Yates_shuffle#The_modern_algorithm">
+     * Fisher-Yates</a> algorithm.
      *
-     * @see #shuffle(UniformRandomProvider,List,int,boolean)
+     * <p>
+     * Sampling uses {@link UniformRandomProvider#nextInt(int)}.
+     * </p>
      *
      * @param <T> Type of the list items.
      * @param rng Random number generator.
      * @param list List whose entries will be shuffled (in-place).
      */
+    @SuppressWarnings({"rawtypes", "unchecked"})
     public static <T> void shuffle(UniformRandomProvider rng,
                                    List<T> list) {
-        shuffle(rng, list, 0, false);
+        if (list instanceof RandomAccess || list.size() < RANDOM_ACCESS_SIZE_THRESHOLD) {
+            // Shuffle list in-place
+            for (int i = list.size(); i > 1; i--) {
+                swap(list, i - 1, rng.nextInt(i));
+            }
+        } else {
+            // Shuffle as an array
+            final Object[] array = list.toArray();
+            for (int i = array.length; i > 1; i--) {
+                swap(array, i - 1, rng.nextInt(i));
+            }
+
+            // Copy back. Use raw types.
+            final ListIterator it = list.listIterator();
+            for (final Object item : array) {
+                it.next();
+                it.set(item);
+            }
+        }
     }
 
     /**
      * Shuffles the entries of the given array, using the
      * <a href="http://en.wikipedia.org/wiki/Fisher-Yates_shuffle#The_modern_algorithm">
      * Fisher-Yates</a> algorithm.
+     *
+     * <p>
      * The {@code start} and {@code pos} parameters select which part
      * of the array is randomized and which is left untouched.
+     * </p>
      *
      * <p>
      * Sampling uses {@link UniformRandomProvider#nextInt(int)}.
@@ -109,13 +143,38 @@ public final class ListSampler {
                                    List<T> list,
                                    int start,
                                    boolean towardHead) {
-        final int len = list.size();
-        final int[] indices = PermutationSampler.natural(len);
-        PermutationSampler.shuffle(rng, indices, start, towardHead);
-
-        final ArrayList<T> items = new ArrayList<T>(list);
-        for (int i = 0; i < len; i++) {
-            list.set(i, items.get(indices[i]));
+        // Shuffle in-place as a sub-list.
+        if (towardHead) {
+            shuffle(rng, list.subList(0, start + 1));
+        } else {
+            shuffle(rng, list.subList(start, list.size()));
         }
     }
+
+    /**
+     * Swaps the two specified elements in the list.
+     *
+     * @param <T> Type of the list items.
+     * @param list List.
+     * @param i First index.
+     * @param j Second index.
+     */
+    private static <T> void swap(List<T> list, int i, int j) {
+        final T tmp = list.get(i);
+        list.set(i, list.get(j));
+        list.set(j, tmp);
+    }
+
+    /**
+     * Swaps the two specified elements in the array.
+     *
+     * @param array Array.
+     * @param i First index.
+     * @param j Second index.
+     */
+    private static void swap(Object[] array, int i, int j) {
+        final Object tmp = array[i];
+        array[i] = array[j];
+        array[j] = tmp;
+    }
 }
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ListSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ListSamplerTest.java
index 3f1c91c..7dec2da 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ListSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/ListSamplerTest.java
@@ -18,7 +18,9 @@ package org.apache.commons.rng.sampling;
 
 import java.util.Set;
 import java.util.HashSet;
+import java.util.LinkedList;
 import java.util.List;
+import java.util.ListIterator;
 import java.util.ArrayList;
 import java.util.Collection;
 
@@ -99,11 +101,18 @@ public class ListSamplerTest {
         for (int i = 0; i < 10; i++) {
             orig.add((i + 1) * rng.nextInt());
         }
-        final List<Integer> list = new ArrayList<Integer>(orig);
 
-        ListSampler.shuffle(rng, list);
+        final List<Integer> arrayList = new ArrayList<Integer>(orig);
+
+        ListSampler.shuffle(rng, arrayList);
         // Ensure that at least one entry has moved.
-        Assert.assertTrue(compare(orig, list, 0, orig.size(), false));
+        Assert.assertTrue("ArrayList", compare(orig, arrayList, 0, orig.size(), false));
+
+        final List<Integer> linkedList = new LinkedList<Integer>(orig);
+
+        ListSampler.shuffle(rng, linkedList);
+        // Ensure that at least one entry has moved.
+        Assert.assertTrue("LinkedList", compare(orig, linkedList, 0, orig.size(), false));
     }
 
     @Test
@@ -142,6 +151,61 @@ public class ListSamplerTest {
         Assert.assertTrue(compare(orig, list, 0, start + 1, false));
     }
 
+    /**
+     * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
+     * The implementation may be different but the result is a Fisher-Yates shuffle so the
+     * output order should match.
+     */
+    @Test
+    public void testShuffleMatchesPermutationSamplerShuffle() {
+        final List<Integer> orig = new ArrayList<Integer>();
+        for (int i = 0; i < 10; i++) {
+            orig.add((i + 1) * rng.nextInt());
+        }
+
+        assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig));
+        assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig));
+    }
+
+    /**
+     * Test shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
+     * The implementation may be different but the result is a Fisher-Yates shuffle so the
+     * output order should match.
+     */
+    @Test
+    public void testShuffleMatchesPermutationSamplerShuffleDirectional() {
+        final List<Integer> orig = new ArrayList<Integer>();
+        for (int i = 0; i < 10; i++) {
+            orig.add((i + 1) * rng.nextInt());
+        }
+
+        assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, true);
+        assertShuffleMatchesPermutationSamplerShuffle(new ArrayList<Integer>(orig), 4, false);
+        assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, true);
+        assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), 4, false);
+    }
+
+    /**
+     * This test hits the edge case when a LinkedList is small enough that the algorithm
+     * using a RandomAccess list is faster than the one with an iterator.
+     */
+    @Test
+    public void testShuffleWithSmallLinkedList() {
+        final int size = 3;
+        final List<Integer> orig = new ArrayList<Integer>();
+        for (int i = 0; i < size; i++) {
+            orig.add((i + 1) * rng.nextInt());
+        }
+
+        // When the size is small there is a chance that the list has no entries that move.
+        // E.g. The number of permutations of 3 items is only 6 giving a 1/6 chance of no change.
+        // So repeat test that the small shuffle matches the PermutationSampler.
+        // 10 times is (1/6)^10 or 1 in 60,466,176 of no change.
+        for (int i = 0; i < 10; i++) {
+            assertShuffleMatchesPermutationSamplerShuffle(new LinkedList<Integer>(orig), size - 1, true);
+        }
+    }
+
     //// Support methods.
 
     /**
@@ -180,4 +244,64 @@ public class ListSamplerTest {
                     samp[0] + ", " + samp[1] + " }");
         return -1;
     }
+
+    /**
+     * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[])}.
+     *
+     * @param list Array whose entries will be shuffled (in-place).
+     */
+    private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list) {
+        final int[] array = new int[list.size()];
+        ListIterator<Integer> it = list.listIterator();
+        for (int i = 0; i < array.length; i++) {
+            array[i] = it.next();
+        }
+
+        // Identical RNGs
+        final long seed = RandomSource.createLong();
+        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+
+        ListSampler.shuffle(rng1, list);
+        PermutationSampler.shuffle(rng2, array);
+
+        final String msg = "Type=" + list.getClass().getSimpleName();
+        it = list.listIterator();
+        for (int i = 0; i < array.length; i++) {
+            Assert.assertEquals(msg, array[i], it.next().intValue());
+        }
+    }
+    /**
+     * Assert the shuffle matches {@link PermutationSampler#shuffle(UniformRandomProvider, int[], int, boolean)}.
+     *
+     * @param list Array whose entries will be shuffled (in-place).
+     * @param start Index at which shuffling begins.
+     * @param towardHead Shuffling is performed for index positions between
+     * {@code start} and either the end (if {@code false}) or the beginning
+     * (if {@code true}) of the array.
+     */
+    private static void assertShuffleMatchesPermutationSamplerShuffle(List<Integer> list,
+                                                                    int start,
+                                                                    boolean towardHead) {
+        final int[] array = new int[list.size()];
+        ListIterator<Integer> it = list.listIterator();
+        for (int i = 0; i < array.length; i++) {
+            array[i] = it.next();
+        }
+
+        // Identical RNGs
+        final long seed = RandomSource.createLong();
+        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+
+        ListSampler.shuffle(rng1, list, start, towardHead);
+        PermutationSampler.shuffle(rng2, array, start, towardHead);
+
+        final String msg = String.format("Type=%s start=%d towardHead=%b",
+                list.getClass().getSimpleName(), start, towardHead);
+        it = list.listIterator();
+        for (int i = 0; i < array.length; i++) {
+            Assert.assertEquals(msg, array[i], it.next().intValue());
+        }
+    }
 }