You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/08/07 12:12:33 UTC
[commons-rng] 02/03: RNG-109: Delegate sampling in probability
collection sampler.
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 341af54ecbef6988b28be8ca4ec1db246362f4d5
Author: aherbert <ah...@apache.org>
AuthorDate: Tue Aug 6 14:54:16 2019 +0100
RNG-109: Delegate sampling in probability collection sampler.
---
.../DiscreteProbabilityCollectionSampler.java | 119 ++++++---------------
.../DiscreteProbabilityCollectionSamplerTest.java | 52 ++++++---
2 files changed, 70 insertions(+), 101 deletions(-)
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
index 4bc50b4..69b45bc 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
@@ -19,11 +19,11 @@ package org.apache.commons.rng.sampling;
import java.util.List;
import java.util.Map;
-import java.util.HashMap;
import java.util.ArrayList;
-import java.util.Arrays;
import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
/**
* Sampling from a collection of items with user-defined
@@ -40,12 +40,12 @@ import org.apache.commons.rng.UniformRandomProvider;
*/
public class DiscreteProbabilityCollectionSampler<T>
implements SharedStateSampler<DiscreteProbabilityCollectionSampler<T>> {
+ /** The error message for an empty collection. */
+ private static final String EMPTY_COLLECTION = "Empty collection";
/** Collection to be sampled from. */
private final List<T> items;
- /** RNG. */
- private final UniformRandomProvider rng;
- /** Cumulative probabilities. */
- private final double[] cumulativeProbabilities;
+ /** Sampler for the probabilities. */
+ private final SharedStateDiscreteSampler sampler;
/**
* Creates a sampler.
@@ -64,43 +64,22 @@ public class DiscreteProbabilityCollectionSampler<T>
public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
Map<T, Double> collection) {
if (collection.isEmpty()) {
- throw new IllegalArgumentException("Empty collection");
+ throw new IllegalArgumentException(EMPTY_COLLECTION);
}
- this.rng = rng;
+ // Extract the items and probabilities
final int size = collection.size();
items = new ArrayList<T>(size);
- cumulativeProbabilities = new double[size];
+ final double[] probabilities = new double[size];
- double sumProb = 0;
int count = 0;
for (final Map.Entry<T, Double> e : collection.entrySet()) {
items.add(e.getKey());
-
- final double prob = e.getValue();
- if (prob < 0 ||
- Double.isInfinite(prob) ||
- Double.isNaN(prob)) {
- throw new IllegalArgumentException("Invalid probability: " +
- prob);
- }
-
- // Temporarily store probability.
- cumulativeProbabilities[count++] = prob;
- sumProb += prob;
+ probabilities[count++] = e.getValue();
}
- if (sumProb <= 0) {
- throw new IllegalArgumentException("Invalid sum of probabilities");
- }
-
- // Compute and store cumulative probability.
- for (int i = 0; i < size; i++) {
- cumulativeProbabilities[i] /= sumProb;
- if (i > 0) {
- cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
- }
- }
+ // Delegate sampling
+ sampler = createSampler(rng, probabilities);
}
/**
@@ -122,7 +101,19 @@ public class DiscreteProbabilityCollectionSampler<T>
public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
List<T> collection,
double[] probabilities) {
- this(rng, consolidate(collection, probabilities));
+ if (collection.isEmpty()) {
+ throw new IllegalArgumentException(EMPTY_COLLECTION);
+ }
+ final int len = probabilities.length;
+ if (len != collection.size()) {
+ throw new IllegalArgumentException("Size mismatch: " +
+ len + " != " +
+ collection.size());
+ }
+ // Shallow copy the list
+ items = new ArrayList<T>(collection);
+ // Delegate sampling
+ sampler = createSampler(rng, probabilities);
}
/**
@@ -131,9 +122,8 @@ public class DiscreteProbabilityCollectionSampler<T>
*/
private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
DiscreteProbabilityCollectionSampler<T> source) {
- this.rng = rng;
this.items = source.items;
- this.cumulativeProbabilities = source.cumulativeProbabilities;
+ this.sampler = source.sampler.withUniformRandomProvider(rng);
}
/**
@@ -142,22 +132,7 @@ public class DiscreteProbabilityCollectionSampler<T>
* @return a random sample.
*/
public T sample() {
- final double rand = rng.nextDouble();
-
- int index = Arrays.binarySearch(cumulativeProbabilities, rand);
- if (index < 0) {
- index = -index - 1;
- }
-
- if (index < cumulativeProbabilities.length &&
- rand < cumulativeProbabilities[index]) {
- return items.get(index);
- }
-
- // This should never happen, but it ensures we will return a correct
- // object in case there is some floating point inequality problem
- // wrt the cumulative probabilities.
- return items.get(items.size() - 1);
+ return items.get(sampler.sample());
}
/** {@inheritDoc} */
@@ -167,38 +142,14 @@ public class DiscreteProbabilityCollectionSampler<T>
}
/**
- * @param collection Collection to be sampled.
- * @param probabilities Probability associated to each item of the
- * {@code collection}.
- * @return a consolidated map (where probabilities of equal items
- * have been summed).
- * @throws IllegalArgumentException if the number of items in the
- * {@code collection} is not equal to the number of provided
- * {@code probabilities}.
- * @param <T> Type of items in the collection.
+ * Creates the sampler of the enumerated probability distribution.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param probabilities Probability associated to each item.
+ * @return the sampler
*/
- private static <T> Map<T, Double> consolidate(List<T> collection,
- double[] probabilities) {
- final int len = probabilities.length;
- if (len != collection.size()) {
- throw new IllegalArgumentException("Size mismatch: " +
- len + " != " +
- collection.size());
- }
-
- final Map<T, Double> map = new HashMap<T, Double>();
- for (int i = 0; i < len; i++) {
- final T item = collection.get(i);
- final Double prob = probabilities[i];
-
- Double currentProb = map.get(item);
- if (currentProb == null) {
- currentProb = 0d;
- }
-
- map.put(item, currentProb + prob);
- }
-
- return map;
+ private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
+ double[] probabilities) {
+ return GuideTableDiscreteSampler.of(rng, probabilities);
}
}
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
index 78a4391..e5ffc6a 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
@@ -18,8 +18,11 @@
package org.apache.commons.rng.sampling;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
import org.junit.Assert;
import org.junit.Test;
@@ -74,6 +77,13 @@ public class DiscreteProbabilityCollectionSamplerTest {
new DiscreteProbabilityCollectionSampler<Double>(rng,
new HashMap<Double, Double>());
}
+ @Test(expected = IllegalArgumentException.class)
+ public void testPrecondition7() {
+ // Empty List<T> not allowed
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Collections.<Double>emptyList(),
+ new double[0]);
+ }
@Test
public void testSample() {
@@ -99,30 +109,38 @@ public class DiscreteProbabilityCollectionSamplerTest {
Assert.assertEquals(expectedVariance, variance, 2e-3);
}
- /**
- * Edge-case test:
- * Create a sampler that will return 1 for nextDouble() forcing the binary search to
- * identify the end item of the cumulative probability array.
- */
+
@Test
- public void testSampleWithProbabilityAtLastItem() {
- sampleWithProbabilityForLastItem(false);
+ public void testSampleUsingMap() {
+ final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+ final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+ final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
+ final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
+ final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
+ new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities);
+
+ // Create a map version. The map iterator must be ordered so use a TreeMap.
+ final Map<Integer, Double> map = new TreeMap<Integer, Double>();
+ for (int i = 0; i < probabilities.length; i++) {
+ map.put(items.get(i), probabilities[i]);
+ }
+ final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
+ new DiscreteProbabilityCollectionSampler<Integer>(rng2, map);
+
+ for (int i = 0; i < 50; i++) {
+ Assert.assertEquals(sampler1.sample(), sampler2.sample());
+ }
}
/**
* Edge-case test:
- * Create a sampler that will return over 1 for nextDouble() forcing the binary search to
- * identify insertion at the end of the cumulative probability array.
+ * Create a sampler that will return 1 for nextDouble() forcing the search to
+ * identify the end item of the cumulative probability array.
*/
@Test
- public void testSampleWithProbabilityPastLastItem() {
- sampleWithProbabilityForLastItem(true);
- }
-
- private static void sampleWithProbabilityForLastItem(boolean pastLast) {
+ public void testSampleWithProbabilityAtLastItem() {
// Ensure the samples pick probability 0 (the first item) and then
// a probability (for the second item) that hits an edge case.
- final double probability = pastLast ? 1.1 : 1;
final UniformRandomProvider dummyRng = new UniformRandomProvider() {
private int count;
// CHECKSTYLE: stop all
@@ -132,7 +150,7 @@ public class DiscreteProbabilityCollectionSamplerTest {
public int nextInt() { return 0; }
public float nextFloat() { return 0; }
// Return 0 then the given probability
- public double nextDouble() { return (count++ == 0) ? 0 : probability; }
+ public double nextDouble() { return (count++ == 0) ? 0 : 1.0; }
public void nextBytes(byte[] bytes, int start, int len) {}
public void nextBytes(byte[] bytes) {}
public boolean nextBoolean() { return false; }
@@ -164,7 +182,7 @@ public class DiscreteProbabilityCollectionSamplerTest {
final DiscreteProbabilityCollectionSampler<Double> sampler1 =
new DiscreteProbabilityCollectionSampler<Double>(rng1,
items,
- new double[] {0.1, 0.2, 0.3, 04});
+ new double[] {0.1, 0.2, 0.3, 0.4});
final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
RandomAssert.assertProduceSameSequence(
new RandomAssert.Sampler<Double>() {