You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2018/01/17 12:49:14 UTC
commons-rng git commit: RNG-47: Sampling from discrete probability
distribution.
Repository: commons-rng
Updated Branches:
refs/heads/master 01a2c09ce -> 6ec1d323e
RNG-47: Sampling from discrete probability distribution.
Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/6ec1d323
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/6ec1d323
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/6ec1d323
Branch: refs/heads/master
Commit: 6ec1d323e1747152411fa0d52128614bc8ea0f30
Parents: 01a2c09
Author: Gilles <er...@apache.org>
Authored: Wed Jan 17 13:47:24 2018 +0100
Committer: Gilles <er...@apache.org>
Committed: Wed Jan 17 13:47:24 2018 +0100
----------------------------------------------------------------------
.../DiscreteProbabilityCollectionSampler.java | 185 +++++++++++++++++++
...iscreteProbabilityCollectionSamplerTest.java | 87 +++++++++
2 files changed, 272 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-rng/blob/6ec1d323/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
----------------------------------------------------------------------
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
new file mode 100644
index 0000000..8f87c15
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.sampling;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Sampling from a {@link Collection} of items with user-defined
+ * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
+ * probabilities</a>.
+ * Note that if all unique items are assigned the same probability,
+ * it is much more efficient to use {@link CollectionSampler}.
+ *
+ * @param <T> Type of items in the collection.
+ *
+ * @since 1.1
+ */
+public class DiscreteProbabilityCollectionSampler<T> {
+ /** Collection to be sampled from. */
+ private final List<T> items;
+ /** RNG. */
+ private final UniformRandomProvider rng;
+ /** Cumulative probabilities. */
+ private final double[] cumulativeProbabilities;
+
+ /**
+ * Creates a sampler.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param collection Collection to be sampled, with the probabilities
+ * associated to each of its items.
+ * A (shallow) copy of the items will be stored in the created instance.
+ * The probabilities must be non-negative, but zero values are allowed
+ * and their sum does not have to equal one (input will be normalized
+ * to make the probabilities sum to one).
+ * @throws IllegalArgumentException if {@code collection} is empty, a
+ * probability is negative, infinite or {@code NaN}, or the sum of all
+ * probabilities is not strictly positive.
+ */
+ public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
+ Map<T, Double> collection) {
+ if (collection.isEmpty()) {
+ throw new IllegalArgumentException("Empty collection");
+ }
+
+ this.rng = rng;
+ final int size = collection.size();
+ items = new ArrayList<T>(size);
+ cumulativeProbabilities = new double[size];
+
+ double sumProb = 0;
+ int count = 0;
+ for (Map.Entry<T, Double> e : collection.entrySet()) {
+ items.add(e.getKey());
+
+ final double prob = e.getValue();
+ if (prob < 0 ||
+ Double.isInfinite(prob) ||
+ Double.isNaN(prob)) {
+ throw new IllegalArgumentException("Invalid probability: " +
+ prob);
+ }
+
+ // Temporarily store probability.
+ cumulativeProbabilities[count++] = prob;
+ sumProb += prob;
+ }
+
+ if (!(sumProb > 0)) {
+ throw new IllegalArgumentException("Invalid sum of probabilities");
+ }
+
+ // Compute and store cumulative probability.
+ for (int i = 0; i < size; i++) {
+ cumulativeProbabilities[i] /= sumProb;
+ if (i > 0) {
+ cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
+ }
+ }
+ }
+
+ /**
+ * Creates a sampler.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param collection Collection to be sampled.
+ * A (shallow) copy of the items will be stored in the created instance.
+ * @param probabilities Probability associated to each item of the
+ * {@code collection}.
+ * The probabilities must be non-negative, but zero values are allowed
+ * and their sum does not have to equal one (input will be normalized
+ * to make the probabilities sum to one).
+ * @throws IllegalArgumentException if {@code collection} is empty or
+ * a probability is negative, infinite or {@code NaN}, or if the number
+ * of items in the {@code collection} is not equal to the number of
+ * provided {@code probabilities}.
+ */
+ public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
+ List<T> collection,
+ double[] probabilities) {
+ this(rng, consolidate(collection, probabilities));
+ }
+
+ /**
+ * Picks one of the items from the collection passed to the constructor.
+ *
+ * @return a random sample.
+ */
+ public T sample() {
+ final double rand = rng.nextDouble();
+
+ int index = Arrays.binarySearch(cumulativeProbabilities, rand);
+ if (index < 0) {
+ index = -index - 1;
+ }
+
+ if (index >= 0 &&
+ index < cumulativeProbabilities.length &&
+ rand < cumulativeProbabilities[index]) {
+ return items.get(index);
+ }
+
+ // This should never happen, but it ensures we will return a correct
+ // object in case there is some floating point inequality problem
+ // wrt the cumulative probabilities.
+ return items.get(items.size() - 1);
+ }
+
+ /**
+ * @param collection Collection to be sampled.
+ * @param probabilities Probability associated to each item of the
+ * {@code collection}.
+ * @return a consolidated map (where probabilities of equal items
+ * have been summed).
+ * @throws IllegalArgumentException if the number of items in the
+ * {@code collection} is not equal to the number of provided
+ * {@code probabilities}.
+ */
+ private static <T> Map<T, Double> consolidate(List<T> collection,
+ double[] probabilities) {
+ final int len = probabilities.length;
+ if (len != collection.size()) {
+ throw new IllegalArgumentException("Size mismatch: " +
+ len + " != " +
+ collection.size());
+ }
+
+ final Map<T, Double> map = new HashMap<T, Double>();
+ for (int i = 0; i < len; i++) {
+ final T item = collection.get(i);
+ final Double prob = probabilities[i];
+
+ Double currentProb = map.get(item);
+ if (currentProb == null) {
+ currentProb = 0d;
+ }
+
+ map.put(item, currentProb + prob);
+ }
+
+ return map;
+ }
+}
http://git-wip-us.apache.org/repos/asf/commons-rng/blob/6ec1d323/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
----------------------------------------------------------------------
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
new file mode 100644
index 0000000..757d44e
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.sampling;
+
+import java.util.Arrays;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.simple.RandomSource;
+
+/**
+ * Test class for {@link DiscreteProbabilityCollectionSampler}.
+ */
+public class DiscreteProbabilityCollectionSamplerTest {
+ /** RNG. */
+ private static final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A);
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testPrecondition1() {
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Arrays.asList(new Double[] {1d, 2d}),
+ new double[] {0d});
+ }
+ @Test(expected=IllegalArgumentException.class)
+ public void testPrecondition2() {
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Arrays.asList(new Double[] {1d, 2d}),
+ new double[] {0d, -1d});
+ }
+ @Test(expected=IllegalArgumentException.class)
+ public void testPrecondition3() {
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Arrays.asList(new Double[] {1d, 2d}),
+ new double[] {0d, 0d});
+ }
+ @Test(expected=IllegalArgumentException.class)
+ public void testPrecondition4() {
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Arrays.asList(new Double[] {1d, 2d}),
+ new double[] {0d, Double.NaN});
+ }
+ @Test(expected=IllegalArgumentException.class)
+ public void testPrecondition5() {
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Arrays.asList(new Double[] {1d, 2d}),
+ new double[] {0d, Double.POSITIVE_INFINITY});
+ }
+
+ @Test
+ public void testSample() {
+ final DiscreteProbabilityCollectionSampler<Double> sampler =
+ new DiscreteProbabilityCollectionSampler<Double>(rng,
+ Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}),
+ new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
+ final double expectedMean = 3.4;
+ final double expectedVariance = 7.84;
+
+ final int n = 100000000;
+ double sum = 0;
+ double sumOfSquares = 0;
+ for (int i = 0; i < n; i++) {
+ final double rand = sampler.sample();
+ sum += rand;
+ sumOfSquares += rand * rand;
+ }
+
+ final double mean = sum / n;
+ Assert.assertEquals(expectedMean, mean, 1e-3);
+ final double variance = sumOfSquares / n - mean * mean;
+ Assert.assertEquals(expectedVariance, variance, 1e-3);
+ }
+}