You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2018/01/17 12:49:14 UTC

commons-rng git commit: RNG-47: Sampling from discrete probability distribution.

Repository: commons-rng
Updated Branches:
  refs/heads/master 01a2c09ce -> 6ec1d323e


RNG-47: Sampling from discrete probability distribution.


Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/6ec1d323
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/6ec1d323
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/6ec1d323

Branch: refs/heads/master
Commit: 6ec1d323e1747152411fa0d52128614bc8ea0f30
Parents: 01a2c09
Author: Gilles <er...@apache.org>
Authored: Wed Jan 17 13:47:24 2018 +0100
Committer: Gilles <er...@apache.org>
Committed: Wed Jan 17 13:47:24 2018 +0100

----------------------------------------------------------------------
 .../DiscreteProbabilityCollectionSampler.java   | 185 +++++++++++++++++++
 ...iscreteProbabilityCollectionSamplerTest.java |  87 +++++++++
 2 files changed, 272 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-rng/blob/6ec1d323/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
----------------------------------------------------------------------
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
new file mode 100644
index 0000000..8f87c15
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.sampling;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Sampling from a {@link Collection} of items with user-defined
+ * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
+ * probabilities</a>.
+ * Note that if all unique items are assigned the same probability,
+ * it is much more efficient to use {@link CollectionSampler}.
+ *
+ * @param <T> Type of items in the collection.
+ *
+ * @since 1.1
+ */
+public class DiscreteProbabilityCollectionSampler<T> {
+    /** Collection to be sampled from. */
+    private final List<T> items;
+    /** RNG. */
+    private final UniformRandomProvider rng;
+    /** Cumulative probabilities. */
+    private final double[] cumulativeProbabilities;
+
+    /**
+     * Creates a sampler.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param collection Collection to be sampled, with the probabilities
+     * associated to each of its items.
+     * A (shallow) copy of the items will be stored in the created instance.
+     * The probabilities must be non-negative, but zero values are allowed
+     * and their sum does not have to equal one (input will be normalized
+     * to make the probabilities sum to one).
+     * @throws IllegalArgumentException if {@code collection} is empty, a
+     * probability is negative, infinite or {@code NaN}, or the sum of all
+     * probabilities is not strictly positive.
+     */
+    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
+                                                Map<T, Double> collection) {
+        if (collection.isEmpty()) {
+            throw new IllegalArgumentException("Empty collection");
+        }
+
+        this.rng = rng;
+        final int size = collection.size();
+        items = new ArrayList<T>(size);
+        cumulativeProbabilities = new double[size];
+
+        double sumProb = 0;
+        int count = 0;
+        for (Map.Entry<T, Double> e : collection.entrySet()) {
+            items.add(e.getKey());
+
+            final double prob = e.getValue();
+            if (prob < 0 ||
+                Double.isInfinite(prob) ||
+                Double.isNaN(prob)) {
+                throw new IllegalArgumentException("Invalid probability: " +
+                                                   prob);
+            }
+
+            // Temporarily store probability.
+            cumulativeProbabilities[count++] = prob;
+            sumProb += prob;
+        }
+
+        if (!(sumProb > 0)) {
+            throw new IllegalArgumentException("Invalid sum of probabilities");
+        }
+
+        // Compute and store cumulative probability.
+        for (int i = 0; i < size; i++) {
+            cumulativeProbabilities[i] /= sumProb;
+            if (i > 0) {
+                cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
+            }
+        }
+    }
+
+    /**
+     * Creates a sampler.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param collection Collection to be sampled.
+     * A (shallow) copy of the items will be stored in the created instance.
+     * @param probabilities Probability associated to each item of the
+     * {@code collection}.
+     * The probabilities must be non-negative, but zero values are allowed
+     * and their sum does not have to equal one (input will be normalized
+     * to make the probabilities sum to one).
+     * @throws IllegalArgumentException if {@code collection} is empty or
+     * a probability is negative, infinite or {@code NaN}, or if the number
+     * of items in the {@code collection} is not equal to the number of
+     * provided {@code probabilities}.
+     */
+    public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
+                                                List<T> collection,
+                                                double[] probabilities) {
+        this(rng, consolidate(collection, probabilities));
+    }
+
+    /**
+     * Picks one of the items from the collection passed to the constructor.
+     *
+     * @return a random sample.
+     */
+    public T sample() {
+        final double rand = rng.nextDouble();
+
+        int index = Arrays.binarySearch(cumulativeProbabilities, rand);
+        if (index < 0) {
+            index = -index - 1;
+        }
+
+        if (index >= 0 &&
+            index < cumulativeProbabilities.length &&
+            rand < cumulativeProbabilities[index]) {
+            return items.get(index);
+        }
+
+        // This should never happen, but it ensures we will return a correct
+        // object in case there is some floating point inequality problem
+        // wrt the cumulative probabilities.
+        return items.get(items.size() - 1);
+    }
+
+    /**
+     * @param collection Collection to be sampled.
+     * @param probabilities Probability associated to each item of the
+     * {@code collection}.
+     * @return a consolidated map (where probabilities of equal items
+     * have been summed).
+     * @throws IllegalArgumentException if the number of items in the
+     * {@code collection} is not equal to the number of provided
+     * {@code probabilities}.
+     */
+    private static <T> Map<T, Double> consolidate(List<T> collection,
+                                                  double[] probabilities) {
+        final int len = probabilities.length;
+        if (len != collection.size()) {
+            throw new IllegalArgumentException("Size mismatch: " +
+                                               len + " != " +
+                                               collection.size());
+        }
+
+        final Map<T, Double> map = new HashMap<T, Double>();
+        for (int i = 0; i < len; i++) {
+            final T item = collection.get(i);
+            final Double prob = probabilities[i];
+
+            Double currentProb = map.get(item);
+            if (currentProb == null) {
+                currentProb = 0d;
+            }
+
+            map.put(item, currentProb + prob);
+        }
+
+        return map;
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/6ec1d323/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
----------------------------------------------------------------------
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
new file mode 100644
index 0000000..757d44e
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.sampling;
+
+import java.util.Arrays;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.simple.RandomSource;
+
+/**
+ * Test class for {@link DiscreteProbabilityCollectionSampler}.
+ */
+public class DiscreteProbabilityCollectionSamplerTest {
+    /** RNG. */
+    private static final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_1024_A);
+
+    @Test(expected=IllegalArgumentException.class)
+    public void testPrecondition1() {
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Arrays.asList(new Double[] {1d, 2d}),
+                                                         new double[] {0d});
+    }
+    @Test(expected=IllegalArgumentException.class)
+    public void testPrecondition2() {
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Arrays.asList(new Double[] {1d, 2d}),
+                                                         new double[] {0d, -1d});
+    }
+    @Test(expected=IllegalArgumentException.class)
+    public void testPrecondition3() {
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Arrays.asList(new Double[] {1d, 2d}),
+                                                         new double[] {0d, 0d});
+    }
+    @Test(expected=IllegalArgumentException.class)
+    public void testPrecondition4() {
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Arrays.asList(new Double[] {1d, 2d}),
+                                                         new double[] {0d, Double.NaN});
+    }
+    @Test(expected=IllegalArgumentException.class)
+    public void testPrecondition5() {
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Arrays.asList(new Double[] {1d, 2d}),
+                                                         new double[] {0d, Double.POSITIVE_INFINITY});
+    }
+
+    @Test
+    public void testSample() {
+        final DiscreteProbabilityCollectionSampler<Double> sampler =
+            new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                             Arrays.asList(new Double[] {3d, -1d, 3d, 7d, -2d, 8d}),
+                                                             new double[] {0.2, 0.2, 0.3, 0.3, 0, 0});
+        final double expectedMean = 3.4;
+        final double expectedVariance = 7.84;
+
+        final int n = 100000000;
+        double sum = 0;
+        double sumOfSquares = 0;
+        for (int i = 0; i < n; i++) {
+            final double rand = sampler.sample();
+            sum += rand;
+            sumOfSquares += rand * rand;
+        }
+
+        final double mean = sum / n;
+        Assert.assertEquals(expectedMean, mean, 1e-3);
+        final double variance = sumOfSquares / n - mean * mean;
+        Assert.assertEquals(expectedVariance, variance, 1e-3);
+    }
+}