You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/06/18 21:15:33 UTC

[commons-rng] 01/02: RNG-100: Add a GuideTableDiscreteSampler.

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 2070821b022cee2e79954953c1400166dc2c6b0b
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Sun Jun 16 21:55:18 2019 +0100

    RNG-100: Add a GuideTableDiscreteSampler.
    
    This can sample from any distribution defined by an array of
    probabilities.
---
 .../distribution/DiscreteSamplersPerformance.java  |  18 +-
 .../distribution/GuideTableDiscreteSampler.java    | 201 +++++++++++++++++
 .../rng/sampling/distribution/InternalUtils.java   |  17 +-
 .../MarsagliaTsangWangDiscreteSampler.java         |   7 +-
 .../distribution/DiscreteSamplersList.java         |   2 +
 .../GuideTableDiscreteSamplerTest.java             | 237 +++++++++++++++++++++
 6 files changed, 473 insertions(+), 9 deletions(-)

diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java
index 0a72b23..641b652 100644
--- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java
@@ -22,6 +22,7 @@ import org.apache.commons.rng.examples.jmh.RandomSources;
 import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
 import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;
 import org.apache.commons.rng.sampling.distribution.GeometricSampler;
+import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler;
 import org.apache.commons.rng.sampling.distribution.LargeMeanPoissonSampler;
 import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler;
 import org.apache.commons.rng.sampling.distribution.RejectionInversionZipfSampler;
@@ -59,6 +60,17 @@ public class DiscreteSamplersPerformance {
      */
     @State(Scope.Benchmark)
     public static class Sources extends RandomSources {
+        /** The probabilities for the discrete distribution. */
+        private static final double[] DISCRETE_PROBABILITIES;
+
+        static {
+            // This is not normalised to sum to 1. The samplers should handle this.
+            DISCRETE_PROBABILITIES = new double[25];
+            for (int i = 0; i < DISCRETE_PROBABILITIES.length; i++) {
+                DISCRETE_PROBABILITIES[i] = (i + 1.0) / DISCRETE_PROBABILITIES.length;
+            }
+        }
+
         /**
          * The sampler type.
          */
@@ -70,6 +82,7 @@ public class DiscreteSamplersPerformance {
                 "MarsagliaTsangWangDiscreteSampler",
                 "MarsagliaTsangWangPoissonSampler",
                 "MarsagliaTsangWangBinomialSampler",
+                "GuideTableDiscreteSampler",
                 })
         private String samplerType;
 
@@ -101,12 +114,13 @@ public class DiscreteSamplersPerformance {
             } else if ("GeometricSampler".equals(samplerType)) {
                 sampler = new GeometricSampler(rng, 0.21);
             } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) {
-                sampler = MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(rng,
-                        new double[] {0.1, 0.2, 0.3, 0.4});
+                sampler = MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(rng, DISCRETE_PROBABILITIES);
             } else if ("MarsagliaTsangWangPoissonSampler".equals(samplerType)) {
                 sampler = MarsagliaTsangWangDiscreteSampler.createPoissonDistribution(rng, 8.9);
             } else if ("MarsagliaTsangWangBinomialSampler".equals(samplerType)) {
                 sampler = MarsagliaTsangWangDiscreteSampler.createBinomialDistribution(rng, 20, 0.33);
+            } else if ("GuideTableDiscreteSampler".equals(samplerType)) {
+                sampler = new GuideTableDiscreteSampler(rng, DISCRETE_PROBABILITIES);
             }
         }
     }
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java
new file mode 100644
index 0000000..beb7a5f
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Compute a sample from a discrete probability distribution. The cumulative probability
+ * distribution is searched using a guide table to set an initial start point. This implementation
+ * is based on:
+ *
+ * <ul>
+ *  <li>
+ *   <blockquote>
+ *    Devroye, Luc (1986). Non-Uniform Random Variate Generation.
+ *    New York: Springer-Verlag. Chapter 3.2.4 "The method of guide tables" p. 96.
+ *   </blockquote>
+ *  </li>
+ * </ul>
+ *
+ * <p>The size of the guide table can be controlled using a parameter. A larger guide table
+ * will improve performance at the cost of storage space.</p>
+ *
+ * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p>
+ *
+ * @since 1.3
+ */
+public class GuideTableDiscreteSampler
+    implements DiscreteSampler {
+    /** The default value for {@code alpha}. */
+    private static final double DEFAULT_ALPHA = 1.0;
+    /** Underlying source of randomness. */
+    private final UniformRandomProvider rng;
+    /**
+     * The cumulative probability table ({@code f(x)}).
+     */
+    private final double[] cumulativeProbabilities;
+    /**
+     * The inverse cumulative probability guide table. This is a guide map between the cumulative
+     * probability (f(x)) and the value x. It is used to set the initial point for search
+     * of the cumulative probability table.
+     *
+     * <p>The index in the map is obtained using {@code p * map.length} where {@code p} is the
+     * known cumulative probability {@code f(x)} or a uniform random deviate {@code u}. The value
+     * stored at the index is value {@code x+1} when {@code p = f(x)} such that it is the
+     * exclusive upper bound on the sample value {@code x} for searching the cumulative probability
+     * table {@code f(x)}. The search of the cumulative probability is towards zero.</p>
+     */
+    private final int[] guideTable;
+
+    /**
+     * Create a new instance using the default guide table size.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param probabilities The probabilities.
+     * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
+     * probability is negative, infinite or {@code NaN}, or the sum of all
+     * probabilities is not strictly positive.
+     */
+    public GuideTableDiscreteSampler(UniformRandomProvider rng,
+                                     double[] probabilities) {
+        this(rng, probabilities, DEFAULT_ALPHA);
+    }
+
+    /**
+     * Create a new instance.
+     *
+     * <p>The size of the guide table is {@code alpha * probabilities.length}.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param probabilities The probabilities.
+     * @param alpha The alpha factor used to set the guide table size.
+     * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
+     * probability is negative, infinite or {@code NaN}, the sum of all
+     * probabilities is not strictly positive, or {@code alpha} is not strictly positive.
+     */
+    public GuideTableDiscreteSampler(UniformRandomProvider rng,
+                                     double[] probabilities,
+                                     double alpha) {
+        validateParameters(probabilities, alpha);
+
+        final int size = probabilities.length;
+        cumulativeProbabilities = new double[size];
+
+        double sumProb = 0;
+        int count = 0;
+        for (final double prob : probabilities) {
+            InternalUtils.validateProbability(prob);
+
+            // Compute and store cumulative probability.
+            sumProb += prob;
+            cumulativeProbabilities[count++] = sumProb;
+        }
+
+        if (Double.isInfinite(sumProb) || sumProb <= 0) {
+            throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
+        }
+
+        this.rng = rng;
+
+        // Note: The guide table is at least length 1. Compute the size avoiding overflow
+        // in case (alpha * size) is too large.
+        final int guideTableSize = (int) Math.ceil(alpha * size);
+        guideTable = new int[Math.max(guideTableSize, guideTableSize + 1)];
+
+        // Compute and store cumulative probability.
+        for (int x = 0; x < size; x++) {
+            final double norm = cumulativeProbabilities[x] / sumProb;
+            cumulativeProbabilities[x] = (norm < 1) ? norm : 1.0;
+
+            // Set the guide table value as an exclusive upper bound (x + 1)
+            guideTable[getGuideTableIndex(cumulativeProbabilities[x])] = x + 1;
+        }
+
+        // Edge case for round-off
+        cumulativeProbabilities[size - 1] = 1.0;
+        // The final guide table entry is (maximum value of x + 1)
+        guideTable[guideTable.length - 1] = size;
+
+        // The first non-zero value in the guide table is from f(x=0).
+        // Any probabilities mapped below this must be sample x=0 so the
+        // table may initially be filled with zeros.
+
+        // Fill missing values in the guide table.
+        for (int i = 1; i < guideTable.length; i++) {
+            guideTable[i] = Math.max(guideTable[i - 1], guideTable[i]);
+        }
+    }
+
+    /**
+     * Validate the parameters.
+     *
+     * @param probabilities The probabilities.
+     * @param alpha The alpha factor used to set the guide table size.
+     * @throws IllegalArgumentException if {@code probabilities} is null or empty, or
+     * {@code alpha} is not strictly positive.
+     */
+    private static void validateParameters(double[] probabilities, double alpha) {
+        if (probabilities == null || probabilities.length == 0) {
+            throw new IllegalArgumentException("Probabilities must not be empty.");
+        }
+        if (alpha <= 0) {
+            throw new IllegalArgumentException("Alpha must be strictly positive.");
+        }
+    }
+
+    /**
+     * Gets the guide table index for the probability. This is obtained using
+     * {@code p * (guideTable.length - 1)} so is inside the length of the table.
+     *
+     * @param p Cumulative probability.
+     * @return the guide table index.
+     */
+    private int getGuideTableIndex(double p) {
+        // Note: This is only ever called when p is in the range of the cumulative
+        // probability table. So assume 0 <= p <= 1.
+        return (int) (p * (guideTable.length - 1));
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public int sample() {
+        // Compute a probability
+        final double u = rng.nextDouble();
+
+        // Initialise the search using the guide table to find an initial guess.
+        // The table provides an upper bound on the sample (x+1) for a known
+        // cumulative probability (f(x)).
+        int x = guideTable[getGuideTableIndex(u)];
+        // Search down.
+        // In the edge case where u is 1.0 then 'x' will be 1 outside the range of the
+        // cumulative probability table and this will decrement to a valid range.
+        // In the case where 'u' is mapped to the same guide table index as a lower
+        // cumulative probability f(x) (due to rounding down) then this will not decrement
+        // and return the exclusive upper bound (x+1).
+        while (x != 0 && u <= cumulativeProbabilities[x - 1]) {
+            x--;
+        }
+        return x;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        return "Guide table deviate [" + rng.toString() + "]";
+    }
+}
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
index 8d8e010..73d6f16 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java
@@ -20,7 +20,7 @@ package org.apache.commons.rng.sampling.distribution;
 /**
  * Functions used by some of the samplers.
  * This class is not part of the public API, as it would be
- * better to group these utilities in a dedicated components.
+ * better to group these utilities in a dedicated component.
  */
 final class InternalUtils { // Class is package-private on purpose; do not make it public.
     /** All long-representable factorials. */
@@ -50,6 +50,21 @@ final class InternalUtils { // Class is package-private on purpose; do not make
     }
 
     /**
+     * Validate the probability is a finite positive number.
+     *
+     * @param probability Probability.
+     * @throws IllegalArgumentException if {@code probability} is negative, infinite or {@code NaN}.
+     */
+    public static void validateProbability(double probability) {
+        if (probability < 0 ||
+            Double.isInfinite(probability) ||
+            Double.isNaN(probability)) {
+            throw new IllegalArgumentException("Invalid probability: " +
+                                               probability);
+        }
+    }
+
+    /**
      * Class for computing the natural logarithm of the factorial of {@code n}.
      * It allows to allocate a cache of precomputed values.
      * In case of cache miss, computation is performed by a call to
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java
index 59e5618..e8e4685 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java
@@ -669,12 +669,7 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl
 
         double sumProb = 0;
         for (final double prob : probabilities) {
-            if (prob < 0 ||
-                Double.isInfinite(prob) ||
-                Double.isNaN(prob)) {
-                throw new IllegalArgumentException("Invalid probability: " +
-                                                   prob);
-            }
+            InternalUtils.validateProbability(prob);
             sumProb += prob;
         }
 
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
index b021b50..bf8e2fd 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -169,6 +169,8 @@ public final class DiscreteSamplersList {
             final double[] discreteProbabilities = new double[] {0.1, 0.2, 0.3, 0.4, 0.5};
             add(LIST, discreteProbabilities,
                 MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(RandomSource.create(RandomSource.XO_SHI_RO_512_PLUS), discreteProbabilities));
+            add(LIST, discreteProbabilities,
+                new GuideTableDiscreteSampler(RandomSource.create(RandomSource.XO_SHI_RO_512_SS), discreteProbabilities));
         } catch (Exception e) {
             // CHECKSTYLE: stop Regexp
             System.err.println("Unexpected exception while creating the list of samplers: " + e);
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java
new file mode 100644
index 0000000..f312c93
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.math3.distribution.BinomialDistribution;
+import org.apache.commons.math3.distribution.PoissonDistribution;
+import org.apache.commons.math3.stat.inference.ChiSquareTest;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.simple.RandomSource;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test for the {@link GuideTableDiscreteSampler}.
+ */
+public class GuideTableDiscreteSamplerTest {
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithNullProbabilites() {
+        createSampler(null, 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithZeroLengthProbabilites() {
+        createSampler(new double[0], 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithNegativeProbabilites() {
+        createSampler(new double[] {-1, 0.1, 0.2}, 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithNaNProbabilites() {
+        createSampler(new double[] {0.1, Double.NaN, 0.2}, 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithInfiniteProbabilites() {
+        createSampler(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2}, 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithInfiniteSumProbabilites() {
+        createSampler(new double[] {Double.MAX_VALUE, Double.MAX_VALUE}, 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithZeroSumProbabilites() {
+        createSampler(new double[4], 1.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithZeroAlpha() {
+        createSampler(new double[] {0.5, 0.5}, 0.0);
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithNegativeAlpha() {
+        createSampler(new double[] {0.5, 0.5}, -1.0);
+    }
+
+    @Test
+    public void testToString() {
+        final GuideTableDiscreteSampler sampler = createSampler(new double[] {0.5, 0.5}, 1.0);
+        Assert.assertTrue(sampler.toString().toLowerCase().contains("guide table"));
+    }
+
+    /**
+     * Creates the sampler.
+     *
+     * @param probabilities the probabilities
+     * @return the alias method discrete sampler
+     */
+    private static GuideTableDiscreteSampler createSampler(double[] probabilities, double alpha) {
+        final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
+        return new GuideTableDiscreteSampler(rng, probabilities, alpha);
+    }
+
+    /**
+     * Test sampling from a binomial distribution.
+     */
+    @Test
+    public void testBinomialSamples() {
+        final int trials = 67;
+        final double probabilityOfSuccess = 0.345;
+        final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess);
+        final double[] expected = new double[trials + 1];
+        for (int i = 0; i < expected.length; i++) {
+            expected[i] = dist.probability(i);
+        }
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Test sampling from a Poisson distribution.
+     */
+    @Test
+    public void testPoissonSamples() {
+        final double mean = 3.14;
+        final PoissonDistribution dist = new PoissonDistribution(null, mean,
+            PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
+        final int maxN = dist.inverseCumulativeProbability(1 - 1e-6);
+        final double[] expected = new double[maxN];
+        for (int i = 0; i < expected.length; i++) {
+            expected[i] = dist.probability(i);
+        }
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
+     */
+    @Test
+    public void testNonUniformSamplesWithProbabilities() {
+        final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Test sampling from a non-uniform distribution of probabilities with an alpha smaller than
+     * the default.
+     */
+    @Test
+    public void testNonUniformSamplesWithProbabilitiesWithSmallAlpha() {
+        final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
+        checkSamples(expected, 0.1);
+    }
+
+    /**
+     * Test sampling from a non-uniform distribution of probabilities with an alpha larger than
+     * the default.
+     */
+    @Test
+    public void testNonUniformSamplesWithProbabilitiesWithLargeAlpha() {
+        final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3};
+        checkSamples(expected, 10.0);
+    }
+
+    /**
+     * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
+     * probabilities).
+     */
+    @Test
+    public void testNonUniformSamplesWithObservations() {
+        final double[] expected = {1, 2, 3, 1, 3};
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Test sampling from a non-uniform distribution of probabilities (these sum to 1).
+     * Extra zero-values are added.
+     */
+    @Test
+    public void testNonUniformSamplesWithZeroProbabilities() {
+        final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0};
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per
+     * probabilities). Extra zero-values are added.
+     */
+    @Test
+    public void testNonUniformSamplesWithZeroObservations() {
+        final double[] expected = {1, 2, 3, 0, 1, 3, 0};
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Test sampling from a uniform distribution. This is an edge case where there
+     * are no probabilities less than the mean.
+     */
+    @Test
+    public void testUniformSamplesWithNoObservationLessThanTheMean() {
+        final double[] expected = {2, 2, 2, 2, 2, 2};
+        checkSamples(expected, 1.0);
+    }
+
+    /**
+     * Check the distribution of samples match the expected probabilities.
+     *
+     * <p>If the expected probability is zero then this should never be sampled. The non-zero
+     * probabilities are compared to the sample distribution using a Chi-square test.</p>
+     *
+     * @param probabilies the probabilities
+     * @param alpha the alpha
+     */
+    private static void checkSamples(double[] probabilies, double alpha) {
+        final GuideTableDiscreteSampler sampler = createSampler(probabilies, alpha);
+
+        final int numberOfSamples = 10000;
+        final long[] samples = new long[probabilies.length];
+        for (int i = 0; i < numberOfSamples; i++) {
+            samples[sampler.sample()]++;
+        }
+
+        // Handle a test with some zero-probability observations by mapping them out.
+        // The results is the Chi-square test is performed using only the non-zero probabilities.
+        int mapSize = 0;
+        for (int i = 0; i < probabilies.length; i++) {
+            if (probabilies[i] != 0) {
+                mapSize++;
+            }
+        }
+
+        final double[] expected = new double[mapSize];
+        final long[] observed = new long[mapSize];
+        for (int i = 0; i < probabilies.length; i++) {
+            if (probabilies[i] == 0) {
+                Assert.assertEquals("No samples expected from zero probability", 0, samples[i]);
+            } else {
+                // This can be added for the Chi-square test
+                --mapSize;
+                expected[mapSize] = probabilies[i];
+                observed[mapSize] = samples[i];
+            }
+        }
+
+        final ChiSquareTest chiSquareTest = new ChiSquareTest();
+        // Pass if we cannot reject null hypothesis that the distributions are the same.
+        Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001));
+    }
+}