You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/08/07 12:12:31 UTC

[commons-rng] branch master updated (a85dd39 -> f6d59a1)

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git.


    from a85dd39  Exclude local maven meta-data from checkstyle.
     new 855915b  RNG-109: Benchmark enumerated probability distributed samplers.
     new 341af54  RNG-109: Delegate sampling in probability collection sampler.
     new f6d59a1  RNG-109: Track changes.

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../EnumeratedDistributionSamplersPerformance.java | 555 +++++++++++++++++++++
 .../DiscreteProbabilityCollectionSampler.java      | 119 ++---
 .../DiscreteProbabilityCollectionSamplerTest.java  |  52 +-
 src/changes/changes.xml                            |   4 +
 4 files changed, 629 insertions(+), 101 deletions(-)
 create mode 100644 commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java


[commons-rng] 01/03: RNG-109: Benchmark enumerated probability distributed samplers.

Posted by ah...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 855915bb0f5a2cb23158441869928ad7cd49a165
Author: aherbert <ah...@apache.org>
AuthorDate: Tue Aug 6 14:56:27 2019 +0100

    RNG-109: Benchmark enumerated probability distributed samplers.
---
 .../EnumeratedDistributionSamplersPerformance.java | 555 +++++++++++++++++++++
 1 file changed, 555 insertions(+)

diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java
new file mode 100644
index 0000000..fc28515
--- /dev/null
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java
@@ -0,0 +1,555 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.examples.jmh.distribution;
+
+import org.apache.commons.math3.distribution.BinomialDistribution;
+import org.apache.commons.math3.distribution.IntegerDistribution;
+import org.apache.commons.math3.distribution.PoissonDistribution;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.examples.jmh.RandomSources;
+import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
+import org.apache.commons.rng.simple.RandomSource;
+
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Level;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.util.Arrays;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Executes benchmark to compare the speed of generation of random numbers from an enumerated
+ * discrete probability distribution.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"})
+public class EnumeratedDistributionSamplersPerformance {
+    /**
+     * The {@link DiscreteSampler} samplers to use for testing. Creates the sampler for each
+     * {@link RandomSource} in the default {@link RandomSources}.
+     *
+     * <p>This class is abstract. The probability distribution is created by implementations.</p>
+     */
+    @State(Scope.Benchmark)
+    public abstract static class SamplerSources {
+        /**
+         * A factory for creating DiscreteSampler objects.
+         */
+        interface DiscreteSamplerFactory {
+            /**
+             * Creates the sampler.
+             *
+             * @return the sampler
+             */
+            DiscreteSampler create();
+        }
+
+        /**
+         * RNG providers.
+         *
+         * <p>Use different speeds.</p>
+         *
+         * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
+         *      Commons RNG user guide</a>
+         */
+        @Param({
+                //"WELL_44497_B",
+                //"ISAAC",
+                "XO_RO_SHI_RO_128_PLUS",
+                })
+        private String randomSourceName;
+
+        /**
+         * The sampler type.
+         */
+        @Param({"BinarySearchDiscreteSampler",
+                "AliasMethodDiscreteSampler",
+                "GuideTableDiscreteSampler",
+                "MarsagliaTsangWangDiscreteSampler",
+
+                // Uncomment to test non-default parameters
+                //"AliasMethodDiscreteSamplerNoPad", // Not optimal for sampling
+                //"AliasMethodDiscreteSamplerAlpha1",
+                //"AliasMethodDiscreteSamplerAlpha2",
+
+                // The AliasMethod memory requirement doubles for each alpha increment.
+                // A fair comparison is to use 2^alpha for the equivalent guide table method.
+                //"GuideTableDiscreteSamplerAlpha2",
+                //"GuideTableDiscreteSamplerAlpha4",
+                })
+        private String samplerType;
+
+        /** RNG. */
+        private UniformRandomProvider generator;
+
+        /** The factory. */
+        private DiscreteSamplerFactory factory;
+
+        /** The sampler. */
+        private DiscreteSampler sampler;
+
+        /**
+         * @return the RNG.
+         */
+        public UniformRandomProvider getGenerator() {
+            return generator;
+        }
+
+        /**
+         * Gets the sampler.
+         *
+         * @return the sampler.
+         */
+        public DiscreteSampler getSampler() {
+            return sampler;
+        }
+
+        /** Create the distribution (per iteration as it may vary) and instantiates sampler. */
+        @Setup(Level.Iteration)
+        public void setup() {
+            final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
+            generator = RandomSource.create(randomSource);
+
+            final double[] probabilities = createProbabilities();
+            createSamplerFactory(generator, probabilities);
+            sampler = factory.create();
+        }
+
+        /**
+         * Creates the probabilities for the distribution.
+         *
+         * @return The probabilities.
+         */
+        protected abstract double[] createProbabilities();
+
+        /**
+         * Creates the sampler factory.
+         *
+         * @param rng The random generator.
+         * @param probabilities The probabilities.
+         */
+        private void createSamplerFactory(final UniformRandomProvider rng,
+            final double[] probabilities) {
+            // This would benefit from Java 8 lambda functions
+            if ("BinarySearchDiscreteSampler".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return new BinarySearchDiscreteSampler(rng, probabilities);
+                    }
+                };
+            } else if ("AliasMethodDiscreteSampler".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return AliasMethodDiscreteSampler.of(rng, probabilities);
+                    }
+                };
+            } else if ("AliasMethodDiscreteSamplerNoPad".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return AliasMethodDiscreteSampler.of(rng, probabilities, -1);
+                    }
+                };
+            } else if ("AliasMethodDiscreteSamplerAlpha1".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return AliasMethodDiscreteSampler.of(rng, probabilities, 1);
+                    }
+                };
+            } else if ("AliasMethodDiscreteSamplerAlpha2".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return AliasMethodDiscreteSampler.of(rng, probabilities, 2);
+                    }
+                };
+            } else if ("GuideTableDiscreteSampler".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return GuideTableDiscreteSampler.of(rng, probabilities);
+                    }
+                };
+            } else if ("GuideTableDiscreteSamplerAlpha2".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return GuideTableDiscreteSampler.of(rng, probabilities, 2);
+                    }
+                };
+            } else if ("GuideTableDiscreteSamplerAlpha8".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return GuideTableDiscreteSampler.of(rng, probabilities, 8);
+                    }
+                };
+            } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
+                factory = new DiscreteSamplerFactory() {
+                    @Override
+                    public DiscreteSampler create() {
+                        return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities);
+                    }
+                };
+            } else {
+                throw new IllegalStateException();
+            }
+        }
+
+        /**
+         * Creates a new instance of the sampler.
+         *
+         * @return The sampler.
+         */
+        public DiscreteSampler createSampler() {
+            return factory.create();
+        }
+    }
+
+    /**
+     * Define known probability distributions for testing. These are expected to have well
+     * behaved cumulative probability functions.
+     */
+    @State(Scope.Benchmark)
+    public static class KnownDistributionSources extends SamplerSources {
+        /** The cumulative probability limit for unbounded distributions. */
+        private static final double CUMULATIVE_PROBABILITY_LIMIT = 1 - 1e-9;
+
+        /**
+         * The distribution.
+         */
+        @Param({"Binomial_N67_P0.7",
+                "Geometric_P0.2",
+                "4SidedLoadedDie",
+                "Poisson_Mean3.14",
+                "Poisson_Mean10_Mean20",
+                })
+        private String distribution;
+
+        /** {@inheritDoc} */
+        @Override
+        protected double[] createProbabilities() {
+            if ("Binomial_N67_P0.7".equals(distribution)) {
+                final int trials = 67;
+                final double probabilityOfSuccess = 0.7;
+                final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
+                return createProbabilities(dist, 0, trials);
+            } else if ("Geometric_P0.2".equals(distribution)) {
+                final double probabilityOfSuccess = 0.2;
+                final double probabilityOfFailure = 1 - probabilityOfSuccess;
+                // https://en.wikipedia.org/wiki/Geometric_distribution
+                // PMF = (1-p)^k * p
+                // k is number of failures before a success
+                double p = 1.0; // (1-p)^0
+                // Build until the cumulative function is big
+                double[] probabilities = new double[100];
+                double sum = 0;
+                int k = 0;
+                while (k < probabilities.length) {
+                    probabilities[k] = p * probabilityOfSuccess;
+                    sum += probabilities[k++];
+                    if (sum > CUMULATIVE_PROBABILITY_LIMIT) {
+                        break;
+                    }
+                    // For the next PMF
+                    p *= probabilityOfFailure;
+                }
+                return Arrays.copyOf(probabilities, k);
+            } else if ("4SidedLoadedDie".equals(distribution)) {
+                return new double[] {1.0 / 2, 1.0 / 3, 1.0 / 12, 1.0 / 12};
+            } else if ("Poisson_Mean3.14".equals(distribution)) {
+                final double mean = 3.14;
+                final IntegerDistribution dist = createPoissonDistribution(mean);
+                final int max = dist.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
+                return createProbabilities(dist, 0, max);
+            } else if ("Poisson_Mean10_Mean20".equals(distribution)) {
+                // Create a Bimodel using two Poisson distributions
+                final double mean1 = 10;
+                final double mean2 = 20;
+                final IntegerDistribution dist1 = createPoissonDistribution(mean2);
+                final int max = dist1.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT);
+                double[] p1 = createProbabilities(dist1, 0, max);
+                double[] p2 = createProbabilities(createPoissonDistribution(mean1), 0, max);
+                for (int i = 0; i < p1.length; i++) {
+                    p1[i] += p2[i];
+                }
+                // Leave to the distribution to normalise the sum
+                return p1;
+            }
+            throw new IllegalStateException();
+        }
+
+        /**
+         * Creates the poisson distribution.
+         *
+         * @param mean the mean
+         * @return the distribution
+         */
+        private static IntegerDistribution createPoissonDistribution(double mean) {
+            return new PoissonDistribution(null, mean,
+                PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
+        }
+
+        /**
+         * Creates the probabilities from the distribution.
+         *
+         * @param dist the distribution
+         * @param lower the lower bounds (inclusive)
+         * @param upper the upper bounds (inclusive)
+         * @return the probabilities
+         */
+        private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) {
+            double[] probabilities = new double[upper - lower + 1];
+            for (int i = 0, x = lower; x <= upper; i++, x++) {
+                probabilities[i] = dist.probability(x);
+            }
+            return probabilities;
+        }
+    }
+
+    /**
+     * Define random probability distributions of known size for testing. These are random but
+     * the average cumulative probability function will be straight line given the increment
+     * average is 0.5.
+     */
+    @State(Scope.Benchmark)
+    public static class RandomDistributionSources extends SamplerSources {
+        /**
+         * The distribution size.
+         * These are spaced half-way between powers-of-2 to minimise the advantage of
+         * padding by the Alias method sampler.
+         */
+        @Param({"6",
+                //"12",
+                //"24",
+                //"48",
+                "96",
+                //"192",
+                //"384",
+                // Above 2048 forces the Alias method to use more than 64-bits for sampling
+                "3072"
+                })
+        private int randomNonUniformSize;
+
+        /** {@inheritDoc} */
+        @Override
+        protected double[] createProbabilities() {
+            final double[] probabilities = new double[randomNonUniformSize];
+            final ThreadLocalRandom rng = ThreadLocalRandom.current();
+            for (int i = 0; i < probabilities.length; i++) {
+                probabilities[i] = rng.nextDouble();
+            }
+            return probabilities;
+        }
+    }
+
+    /**
+     * Compute a sample by binary search of the cumulative probability distribution..
+     */
+    static final class BinarySearchDiscreteSampler
+        implements DiscreteSampler {
+        /** Underlying source of randomness. */
+        private final UniformRandomProvider rng;
+        /**
+         * The cumulative probability table.
+         */
+        private final double[] cumulativeProbabilities;
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         * @param probabilities The probabilities.
+         * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
+         * probability is negative, infinite or {@code NaN}, or the sum of all
+         * probabilities is not strictly positive.
+         */
+        BinarySearchDiscreteSampler(UniformRandomProvider rng,
+                                    double[] probabilities) {
+            // Minimal set-up validation
+            if (probabilities == null || probabilities.length == 0) {
+                throw new IllegalArgumentException("Probabilities must not be empty.");
+            }
+
+            final int size = probabilities.length;
+            cumulativeProbabilities = new double[size];
+
+            double sumProb = 0;
+            int count = 0;
+            for (final double prob : probabilities) {
+                if (prob < 0 ||
+                    Double.isInfinite(prob) ||
+                    Double.isNaN(prob)) {
+                    throw new IllegalArgumentException("Invalid probability: " +
+                                                       prob);
+                }
+
+                // Compute and store cumulative probability.
+                sumProb += prob;
+                cumulativeProbabilities[count++] = sumProb;
+            }
+
+            if (Double.isInfinite(sumProb) || sumProb <= 0) {
+                throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
+            }
+
+            this.rng = rng;
+
+            // Normalise cumulative probability.
+            for (int i = 0; i < size; i++) {
+                final double norm = cumulativeProbabilities[i] / sumProb;
+                cumulativeProbabilities[i] = (norm < 1) ? norm : 1.0;
+            }
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        public int sample() {
+            final double u = rng.nextDouble();
+
+            // Java binary search
+            //int index = Arrays.binarySearch(cumulativeProbabilities, u);
+            //if (index < 0) {
+            //    index = -index - 1;
+            //}
+            //
+            //return index < cumulativeProbabilities.length ?
+            //    index :
+            //    cumulativeProbabilities.length - 1;
+
+            // Binary search within known cumulative probability table.
+            // Find x so that u > f[x-1] and u <= f[x].
+            // This is a looser search than Arrays.binarySearch:
+            // - The output is x = upper.
+            // - The table stores probabilities where f[0] is >= 0 and the max == 1.0.
+            // - u should be >= 0 and <= 1 (or the random generator is broken).
+            // - It avoids comparisons using Double.doubleToLongBits.
+            // - It avoids the low likelihood of equality between two doubles for fast exit
+            //   so uses only 1 compare per loop.
+            int lower = 0;
+            int upper = cumulativeProbabilities.length - 1;
+            while (lower < upper) {
+                final int mid = (lower + upper) >>> 1;
+                final double midVal = cumulativeProbabilities[mid];
+                if (u > midVal) {
+                    // Change lower such that
+                    // u > f[lower - 1]
+                    lower = mid + 1;
+                } else {
+                    // Change upper such that
+                    // u <= f[upper]
+                    upper = mid;
+                }
+            }
+            return upper;
+        }
+    }
+
+    /**
+     * The value for the baseline generation of an {@code int} value.
+     *
+     * <p>This must NOT be final!</p>
+     */
+    private int value;
+
+    // Benchmarks methods below.
+
+    /**
+     * Baseline for the JMH timing overhead for production of an {@code int} value.
+     *
+     * @return the {@code int} value
+     */
+    @Benchmark
+    public int baselineInt() {
+        return value;
+    }
+
+    /**
+     * Baseline for the production of a {@code double} value.
+     * This is used to assess the performance of the underlying random source.
+     *
+     * @param sources Source of randomness.
+     * @return the {@code int} value
+     */
+    @Benchmark
+    public int baselineNextDouble(SamplerSources sources) {
+        return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0;
+    }
+
+    /**
+     * Run the sampler.
+     *
+     * @param sources Source of randomness.
+     * @return the sample value
+     */
+    @Benchmark
+    public int sampleKnown(KnownDistributionSources sources) {
+        return sources.getSampler().sample();
+    }
+
+    /**
+     * Run the sampler.
+     *
+     * @param sources Source of randomness.
+     * @return the sample value
+     */
+    @Benchmark
+    public int singleSampleKnown(KnownDistributionSources sources) {
+        return sources.createSampler().sample();
+    }
+
+    /**
+     * Run the sampler.
+     *
+     * @param sources Source of randomness.
+     * @return the sample value
+     */
+    @Benchmark
+    public int sampleRandom(RandomDistributionSources sources) {
+        return sources.getSampler().sample();
+    }
+
+    /**
+     * Run the sampler.
+     *
+     * @param sources Source of randomness.
+     * @return the sample value
+     */
+    @Benchmark
+    public int singleSampleRandom(RandomDistributionSources sources) {
+        return sources.createSampler().sample();
+    }
+}


[commons-rng] 03/03: RNG-109: Track changes.

Posted by ah...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit f6d59a14140e907d2c09f309598610f12be9ecdc
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Wed Aug 7 13:03:23 2019 +0100

    RNG-109: Track changes.
---
 src/changes/changes.xml | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 905f6a6..5e848c5 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -75,6 +75,10 @@ re-run tests that fail, and pass the build if they succeed
 within the allotted number of reruns (the test will be marked
 as 'flaky' in the report).
 ">
+      <action dev="aherbert" type="update" issue="RNG-109">
+        "DiscreteProbabilityCollectionSampler": Use a faster enumerated probability
+        distribution sampler to replace the binary search algorithm.
+      </action>
       <action dev="aherbert" type="add" issue="RNG-85">
         New "MiddleSquareWeylSequence" generator.
       </action>


[commons-rng] 02/03: RNG-109: Delegate sampling in probability collection sampler.

Posted by ah...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 341af54ecbef6988b28be8ca4ec1db246362f4d5
Author: aherbert <ah...@apache.org>
AuthorDate: Tue Aug 6 14:54:16 2019 +0100

    RNG-109: Delegate sampling in probability collection sampler.
---
 .../DiscreteProbabilityCollectionSampler.java      | 119 ++++++---------------
 .../DiscreteProbabilityCollectionSamplerTest.java  |  52 ++++++---
 2 files changed, 70 insertions(+), 101 deletions(-)

diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
index 4bc50b4..69b45bc 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java
@@ -19,11 +19,11 @@ package org.apache.commons.rng.sampling;
 
 import java.util.List;
 import java.util.Map;
-import java.util.HashMap;
 import java.util.ArrayList;
-import java.util.Arrays;
 
 import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
+import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
 
 /**
  * Sampling from a collection of items with user-defined
@@ -40,12 +40,12 @@ import org.apache.commons.rng.UniformRandomProvider;
  */
 public class DiscreteProbabilityCollectionSampler<T>
     implements SharedStateSampler<DiscreteProbabilityCollectionSampler<T>> {
+    /** The error message for an empty collection. */
+    private static final String EMPTY_COLLECTION = "Empty collection";
     /** Collection to be sampled from. */
     private final List<T> items;
-    /** RNG. */
-    private final UniformRandomProvider rng;
-    /** Cumulative probabilities. */
-    private final double[] cumulativeProbabilities;
+    /** Sampler for the probabilities. */
+    private final SharedStateDiscreteSampler sampler;
 
     /**
      * Creates a sampler.
@@ -64,43 +64,22 @@ public class DiscreteProbabilityCollectionSampler<T>
     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
                                                 Map<T, Double> collection) {
         if (collection.isEmpty()) {
-            throw new IllegalArgumentException("Empty collection");
+            throw new IllegalArgumentException(EMPTY_COLLECTION);
         }
 
-        this.rng = rng;
+        // Extract the items and probabilities
         final int size = collection.size();
         items = new ArrayList<T>(size);
-        cumulativeProbabilities = new double[size];
+        final double[] probabilities = new double[size];
 
-        double sumProb = 0;
         int count = 0;
         for (final Map.Entry<T, Double> e : collection.entrySet()) {
             items.add(e.getKey());
-
-            final double prob = e.getValue();
-            if (prob < 0 ||
-                Double.isInfinite(prob) ||
-                Double.isNaN(prob)) {
-                throw new IllegalArgumentException("Invalid probability: " +
-                                                   prob);
-            }
-
-            // Temporarily store probability.
-            cumulativeProbabilities[count++] = prob;
-            sumProb += prob;
+            probabilities[count++] = e.getValue();
         }
 
-        if (sumProb <= 0) {
-            throw new IllegalArgumentException("Invalid sum of probabilities");
-        }
-
-        // Compute and store cumulative probability.
-        for (int i = 0; i < size; i++) {
-            cumulativeProbabilities[i] /= sumProb;
-            if (i > 0) {
-                cumulativeProbabilities[i] += cumulativeProbabilities[i - 1];
-            }
-        }
+        // Delegate sampling
+        sampler = createSampler(rng, probabilities);
     }
 
     /**
@@ -122,7 +101,19 @@ public class DiscreteProbabilityCollectionSampler<T>
     public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
                                                 List<T> collection,
                                                 double[] probabilities) {
-        this(rng, consolidate(collection, probabilities));
+        if (collection.isEmpty()) {
+            throw new IllegalArgumentException(EMPTY_COLLECTION);
+        }
+        final int len = probabilities.length;
+        if (len != collection.size()) {
+            throw new IllegalArgumentException("Size mismatch: " +
+                                               len + " != " +
+                                               collection.size());
+        }
+        // Shallow copy the list
+        items = new ArrayList<T>(collection);
+        // Delegate sampling
+        sampler = createSampler(rng, probabilities);
     }
 
     /**
@@ -131,9 +122,8 @@ public class DiscreteProbabilityCollectionSampler<T>
      */
     private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng,
                                                  DiscreteProbabilityCollectionSampler<T> source) {
-        this.rng = rng;
         this.items = source.items;
-        this.cumulativeProbabilities = source.cumulativeProbabilities;
+        this.sampler = source.sampler.withUniformRandomProvider(rng);
     }
 
     /**
@@ -142,22 +132,7 @@ public class DiscreteProbabilityCollectionSampler<T>
      * @return a random sample.
      */
     public T sample() {
-        final double rand = rng.nextDouble();
-
-        int index = Arrays.binarySearch(cumulativeProbabilities, rand);
-        if (index < 0) {
-            index = -index - 1;
-        }
-
-        if (index < cumulativeProbabilities.length &&
-            rand < cumulativeProbabilities[index]) {
-            return items.get(index);
-        }
-
-        // This should never happen, but it ensures we will return a correct
-        // object in case there is some floating point inequality problem
-        // wrt the cumulative probabilities.
-        return items.get(items.size() - 1);
+        return items.get(sampler.sample());
     }
 
     /** {@inheritDoc} */
@@ -167,38 +142,14 @@ public class DiscreteProbabilityCollectionSampler<T>
     }
 
     /**
-     * @param collection Collection to be sampled.
-     * @param probabilities Probability associated to each item of the
-     * {@code collection}.
-     * @return a consolidated map (where probabilities of equal items
-     * have been summed).
-     * @throws IllegalArgumentException if the number of items in the
-     * {@code collection} is not equal to the number of provided
-     * {@code probabilities}.
-     * @param <T> Type of items in the collection.
+     * Creates the sampler of the enumerated probability distribution.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param probabilities Probability associated to each item.
+     * @return the sampler
      */
-    private static <T> Map<T, Double> consolidate(List<T> collection,
-                                                  double[] probabilities) {
-        final int len = probabilities.length;
-        if (len != collection.size()) {
-            throw new IllegalArgumentException("Size mismatch: " +
-                                               len + " != " +
-                                               collection.size());
-        }
-
-        final Map<T, Double> map = new HashMap<T, Double>();
-        for (int i = 0; i < len; i++) {
-            final T item = collection.get(i);
-            final Double prob = probabilities[i];
-
-            Double currentProb = map.get(item);
-            if (currentProb == null) {
-                currentProb = 0d;
-            }
-
-            map.put(item, currentProb + prob);
-        }
-
-        return map;
+    private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng,
+                                                            double[] probabilities) {
+        return GuideTableDiscreteSampler.of(rng, probabilities);
     }
 }
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
index 78a4391..e5ffc6a 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java
@@ -18,8 +18,11 @@
 package org.apache.commons.rng.sampling;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -74,6 +77,13 @@ public class DiscreteProbabilityCollectionSamplerTest {
         new DiscreteProbabilityCollectionSampler<Double>(rng,
                                                          new HashMap<Double, Double>());
     }
+    @Test(expected = IllegalArgumentException.class)
+    public void testPrecondition7() {
+        // Empty List<T> not allowed
+        new DiscreteProbabilityCollectionSampler<Double>(rng,
+                                                         Collections.<Double>emptyList(),
+                                                         new double[0]);
+    }
 
     @Test
     public void testSample() {
@@ -99,30 +109,38 @@ public class DiscreteProbabilityCollectionSamplerTest {
         Assert.assertEquals(expectedVariance, variance, 2e-3);
     }
 
-    /**
-     * Edge-case test:
-     * Create a sampler that will return 1 for nextDouble() forcing the binary search to
-     * identify the end item of the cumulative probability array.
-     */
+
     @Test
-    public void testSampleWithProbabilityAtLastItem() {
-        sampleWithProbabilityForLastItem(false);
+    public void testSampleUsingMap() {
+        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+        final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9);
+        final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5};
+        final DiscreteProbabilityCollectionSampler<Integer> sampler1 =
+            new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities);
+
+        // Create a map version. The map iterator must be ordered so use a TreeMap.
+        final Map<Integer, Double> map = new TreeMap<Integer, Double>();
+        for (int i = 0; i < probabilities.length; i++) {
+            map.put(items.get(i), probabilities[i]);
+        }
+        final DiscreteProbabilityCollectionSampler<Integer> sampler2 =
+            new DiscreteProbabilityCollectionSampler<Integer>(rng2, map);
+
+        for (int i = 0; i < 50; i++) {
+            Assert.assertEquals(sampler1.sample(), sampler2.sample());
+        }
     }
 
     /**
      * Edge-case test:
-     * Create a sampler that will return over 1 for nextDouble() forcing the binary search to
-     * identify insertion at the end of the cumulative probability array.
+     * Create a sampler that will return 1 for nextDouble() forcing the search to
+     * identify the end item of the cumulative probability array.
      */
     @Test
-    public void testSampleWithProbabilityPastLastItem() {
-        sampleWithProbabilityForLastItem(true);
-    }
-
-    private static void sampleWithProbabilityForLastItem(boolean pastLast) {
+    public void testSampleWithProbabilityAtLastItem() {
         // Ensure the samples pick probability 0 (the first item) and then
         // a probability (for the second item) that hits an edge case.
-        final double probability = pastLast ? 1.1 : 1;
         final UniformRandomProvider dummyRng = new UniformRandomProvider() {
             private int count;
             // CHECKSTYLE: stop all
@@ -132,7 +150,7 @@ public class DiscreteProbabilityCollectionSamplerTest {
             public int nextInt() { return 0; }
             public float nextFloat() { return 0; }
             // Return 0 then the given probability
-            public double nextDouble() { return (count++ == 0) ? 0 : probability; }
+            public double nextDouble() { return (count++ == 0) ? 0 : 1.0; }
             public void nextBytes(byte[] bytes, int start, int len) {}
             public void nextBytes(byte[] bytes) {}
             public boolean nextBoolean() { return false; }
@@ -164,7 +182,7 @@ public class DiscreteProbabilityCollectionSamplerTest {
         final DiscreteProbabilityCollectionSampler<Double> sampler1 =
             new DiscreteProbabilityCollectionSampler<Double>(rng1,
                                                              items,
-                                                             new double[] {0.1, 0.2, 0.3, 04});
+                                                             new double[] {0.1, 0.2, 0.3, 0.4});
         final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2);
         RandomAssert.assertProduceSameSequence(
             new RandomAssert.Sampler<Double>() {