You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/08/07 12:12:33 UTC

[commons-rng] 02/03: RNG-109: Delegate sampling in probability collection sampler.

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 341af54ecbef6988b28be8ca4ec1db246362f4d5
Author: aherbert <ah...@apache.org>
AuthorDate: Tue Aug 6 14:54:16 2019 +0100

    RNG-109: Delegate sampling in probability collection sampler.
---
 .../DiscreteProbabilityCollectionSampler.java      | 119 ++++++---------------
 .../DiscreteProbabilityCollectionSamplerTest.java  |  52 ++++++---
 2 files changed, 70 insertions(+), 101 deletions(-)

diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
index 4bc50b4..69b45bc 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
@@ -19,11 +19,11 @@ package org.apache.commons.rng.sampling;
 
 import java.util.List;
 import java.util.Map;
-import java.util.HashMap;
 import java.util.ArrayList;
-import java.util.Arrays;
 
 import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
 
 /**
  * Sampling from a collection of items with user-defined
@@ -40,12 +40,12 @@ import org.apache.commons.rng.UniformRandomProvider;
  */
 public class DiscreteProbabilityCollectionSampler<T>
     implements SharedStateSampler<DiscreteProbabilityCollectionSampler<T>> {
+    /** The error message for an empty collection. */
+    private static final String EMPTY_COLLECTION = "Empty collection";
     /** Collection to be sampled from. */
     private final List<T> items;
-    /** RNG. */
-    private final UniformRandomProvider rng;
-    /** Cumulative probabilities. */
-    private final double[] cumulativeProbabilities;
+    /** Sampler for the probabilities. */
+    private final SharedStateDiscreteSampler sampler;
 
     /**
      * Creates a sampler.
@@ -64,43 +64,22 @@ public class DiscreteProbabilityCollectionSampler<T>
     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
                                                 Map<T, Double> collection) {
         if (collection.isEmpty()) {
-            throw new IllegalArgumentException("Empty collection");
+            throw new IllegalArgumentException(EMPTY_COLLECTION);
         }
 
-        this.rng = rng;
+        // Extract the items and probabilities
         final int size = collection.size();
         items = new ArrayList<T>(size);
-        cumulativeProbabilities = new double[size];
+        final double[] probabilities = new double[size];
 
-        double sumProb = 0;
         int count = 0;
         for (final Map.Entry<T, Double> e : collection.entrySet()) {
             items.add(e.getKey());
-
-            final double prob = e.getValue();
-            if (prob < 0 ||
-                Double.isInfinite(prob) ||
-                Double.isNaN(prob)) {
-                throw new IllegalArgumentException("Invalid probability: " +
-                                                   prob);
-            }
-
-            // Temporarily store probability.
-            cumulativeProbabilities[count++] = prob;
-            sumProb += prob;
+            probabilities[count++] = e.getValue();
         }
 
-        if (sumProb <= 0) {
-            throw new IllegalArgumentException("Invalid sum of probabilities");
-        }
-
-        // Compute and store cumulative probability.
-        for (int i = 0; i < size; i++) {
-            cumulativeProbabilities[i] /= sumProb;
-            if (i > 0) {
-                cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
-            }
-        }
+        // Delegate sampling
+        sampler = createSampler(rng, probabilities);
     }
 
     /**
@@ -122,7 +101,19 @@ public class DiscreteProbabilityCollectionSampler<T>
     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
                                                 List<T> collection,
                                                 double[] probabilities) {
-        this(rng, consolidate(collection, probabilities));
+        if (collection.isEmpty()) {
+            throw new IllegalArgumentException(EMPTY_COLLECTION);
+        }
+        final int len = probabilities.length;
+        if (len != collection.size()) {
+            throw new IllegalArgumentException("Size mismatch: " +
+                                               len + " != " +
+                                               collection.size());
+        }
+        // Shallow copy the list
+        items = new ArrayList<T>(collection);
+        // Delegate sampling
+        sampler = createSampler(rng, probabilities);
     }
 
     /**
@@ -131,9 +122,8 @@ public class DiscreteProbabilityCollectionSampler<T>
      */
     private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
                                                  DiscreteProbabilityCollectionSampler<T> source) {
-        this.rng = rng;
         this.items = source.items;
-        this.cumulativeProbabilities = source.cumulativeProbabilities;
+        this.sampler = source.sampler.withUniformRandomProvider(rng);
     }
 
     /**
@@ -142,22 +132,7 @@ public class DiscreteProbabilityCollectionSampler<T>
      * @return a random sample.
      */
     public T sample() {
-        final double rand = rng.nextDouble();
-
-        int index = Arrays.binarySearch(cumulativeProbabilities, rand);
-        if (index < 0) {
-            index = -index - 1;
-        }
-
-        if (index < cumulativeProbabilities.length &&
-            rand < cumulativeProbabilities[index]) {
-            return items.get(index);
-        }
-
-        // This should never happen, but it ensures we will return a correct
-        // object in case there is some floating point inequality problem
-        // wrt the cumulative probabilities.
-        return items.get(items.size() - 1);
+        return items.get(sampler.sample());
     }
 
     /** {@inheritDoc} */
@@ -167,38 +142,14 @@ public class DiscreteProbabilityCollectionSampler<T>
     }
 
     /**
-     * @param collection Collection to be sampled.
-     * @param probabilities Probability associated to each item of the
-     * {@code collection}.
-     * @return a consolidated map (where probabilities of equal items
-     * have been summed).
-     * @throws IllegalArgumentException if the number of items in the
-     * {@code collection} is not equal to the number of provided
-     * {@code probabilities}.
-     * @param <T> Type of items in the collection.
+     * Creates the sampler of the enumerated probability distribution.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param probabilities Probability associated to each item.
+     * @return the sampler
      */
-    private static <T> Map<T, Double> consolidate(List<T> collection,
-                                                  double[] probabilities) {
-        final int len = probabilities.length;
-        if (len != collection.size()) {
-            throw new IllegalArgumentException("Size mismatch: " +
-                                               len + " != " +
-                                               collection.size());
-        }
-
-        final Map<T, Double> map = new HashMap<T, Double>();
-        for (int i = 0; i < len; i++) {
-            final T item = collection.get(i);
-            final Double prob = probabilities[i];
-
-            Double currentProb = map.get(item);
-            if (currentProb == null) {
-                currentProb = 0d;
-            }
-
-            map.put(item, currentProb + prob);
-        }
-
-        return map;
+    private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
+                                                            double[] probabilities) {
+        return GuideTableDiscreteSampler.of(rng, probabilities);
     }
 }
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
index 78a4391..e5ffc6a 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
@@ -18,8 +18,11 @@
 package org.apache.commons.rng.sampling;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -74,6 +77,13 @@ public class DiscreteProbabilityCollectionSamplerTest {
         new DiscreteProbabilityCollectionSampler<Double>(rng,
                                                          new HashMap<Double, Double>());
     }
+    @Test(expected = IllegalArgumentException.class)
+    public void testPrecondition7() {
+        // Empty List<T> not allowed
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Collections.<Double>emptyList(),
+                                                         new double[0]);
+    }
 
     @Test
     public void testSample() {
@@ -99,30 +109,38 @@ public class DiscreteProbabilityCollectionSamplerTest {
         Assert.assertEquals(expectedVariance, variance, 2e-3);
     }
 
-    /**
-     * Edge-case test:
-     * Create a sampler that will return 1 for nextDouble() forcing the binary search to
-     * identify the end item of the cumulative probability array.
-     */
+
     @Test
-    public void testSampleWithProbabilityAtLastItem() {
-        sampleWithProbabilityForLastItem(false);
+    public void testSampleUsingMap() {
+        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+        final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
+        final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
+        final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
+            new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities);
+
+        // Create a map version. The map iterator must be ordered so use a TreeMap.
+        final Map<Integer, Double> map = new TreeMap<Integer, Double>();
+        for (int i = 0; i < probabilities.length; i++) {
+            map.put(items.get(i), probabilities[i]);
+        }
+        final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
+            new DiscreteProbabilityCollectionSampler<Integer>(rng2, map);
+
+        for (int i = 0; i < 50; i++) {
+            Assert.assertEquals(sampler1.sample(), sampler2.sample());
+        }
     }
 
     /**
      * Edge-case test:
-     * Create a sampler that will return over 1 for nextDouble() forcing the binary search to
-     * identify insertion at the end of the cumulative probability array.
+     * Create a sampler that will return 1 for nextDouble() forcing the search to
+     * identify the end item of the cumulative probability array.
      */
     @Test
-    public void testSampleWithProbabilityPastLastItem() {
-        sampleWithProbabilityForLastItem(true);
-    }
-
-    private static void sampleWithProbabilityForLastItem(boolean pastLast) {
+    public void testSampleWithProbabilityAtLastItem() {
         // Ensure the samples pick probability 0 (the first item) and then
         // a probability (for the second item) that hits an edge case.
-        final double probability = pastLast ? 1.1 : 1;
         final UniformRandomProvider dummyRng = new UniformRandomProvider() {
             private int count;
             // CHECKSTYLE: stop all
@@ -132,7 +150,7 @@ public class DiscreteProbabilityCollectionSamplerTest {
             public int nextInt() { return 0; }
             public float nextFloat() { return 0; }
             // Return 0 then the given probability
-            public double nextDouble() { return (count++ == 0) ? 0 : probability; }
+            public double nextDouble() { return (count++ == 0) ? 0 : 1.0; }
             public void nextBytes(byte[] bytes, int start, int len) {}
             public void nextBytes(byte[] bytes) {}
             public boolean nextBoolean() { return false; }
@@ -164,7 +182,7 @@ public class DiscreteProbabilityCollectionSamplerTest {
         final DiscreteProbabilityCollectionSampler<Double> sampler1 =
             new DiscreteProbabilityCollectionSampler<Double>(rng1,
                                                              items,
-                                                             new double[] {0.1, 0.2, 0.3, 04});
+                                                             new double[] {0.1, 0.2, 0.3, 0.4});
         final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
         RandomAssert.assertProduceSameSequence(
             new RandomAssert.Sampler<Double>() {