You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/08/03 11:52:08 UTC

[commons-rng] 01/02: RNG-95: Update the DiscreteUniformSampler using faster algorithms.

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 5976d7d9d34e457e8d2af689c3a64c7937db84c8
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Wed Apr 24 21:57:27 2019 +0100

    RNG-95: Update the DiscreteUniformSampler using faster algorithms.
    
    Algorithms are added for ranges that are a power of 2 and non-power of
    2.
    
    Now specifically handles a lower bound of 0.
---
 ...iscreteUniformSamplerGenerationPerformance.java | 184 +++++++++++++
 .../distribution/DiscreteUniformSampler.java       | 294 ++++++++++++++++++---
 .../distribution/DiscreteSamplersList.java         |   8 +-
 .../distribution/DiscreteUniformSamplerTest.java   | 274 ++++++++++++++++++-
 4 files changed, 724 insertions(+), 36 deletions(-)

diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteUniformSamplerGenerationPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteUniformSamplerGenerationPerformance.java
new file mode 100644
index 0000000..d1ab8a7
--- /dev/null
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteUniformSamplerGenerationPerformance.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.examples.jmh.distribution;
+
+import org.apache.commons.rng.RestorableUniformRandomProvider;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;
+import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
+import org.apache.commons.rng.simple.RandomSource;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Executes benchmark to compare the speed of generation of integer numbers in a positive range
+ * using the {@link DiscreteUniformSampler} or {@link UniformRandomProvider#nextInt(int)}.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = { "-server", "-Xms128M", "-Xmx128M" })
+public class DiscreteUniformSamplerGenerationPerformance {
+    /** The number of samples. */
+    @Param({
+        "1",
+        "2",
+        "4",
+        "8",
+        "16",
+        "1000000",
+        })
+    private int samples;
+
+    /**
+     * The benchmark state (retrieve the various "RandomSource"s).
+     */
+    @State(Scope.Benchmark)
+    public static class Sources {
+        /**
+         * RNG providers.
+         *
+         * <p>Use different speeds.</p>
+         *
+         * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
+         *      Commons RNG user guide</a>
+         */
+        @Param({"SPLIT_MIX_64",
+                // Comment in for slower generators
+                //"MWC_256", "KISS", "WELL_1024_A",
+                //"WELL_44497_B"
+                })
+        private String randomSourceName;
+
+        /** RNG. */
+        private RestorableUniformRandomProvider generator;
+
+        /**
+         * @return the RNG.
+         */
+        public UniformRandomProvider getGenerator() {
+            return generator;
+        }
+
+        /** Instantiates generator. */
+        @Setup
+        public void setup() {
+            final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
+            generator = RandomSource.create(randomSource);
+        }
+    }
+
+    /**
+     * The upper range for the {@code int} generation.
+     */
+    @State(Scope.Benchmark)
+    public static class IntRange {
+        /**
+         * The upper range for the {@code int} generation.
+         *
+         * <p>Note that the while loop uses a rejection algorithm. From the Javadoc for java.util.Random:</p>
+         *
+         * <pre>
+         * "The probability of a value being rejected depends on n. The
+         * worst case is n=2^30+1, for which the probability of a reject is 1/2,
+         * and the expected number of iterations before the loop terminates is 2."
+         * </pre>
+         */
+        @Param({
+            "256", // Even: 1 << 8
+            "257", // Prime number
+            "1073741825", // Worst case: (1 << 30) + 1
+            })
+        private int upperBound;
+
+        /**
+         * Gets the upper bound.
+         *
+         * @return the upper bound
+         */
+        public int getUpperBound() {
+            return upperBound;
+        }
+    }
+
+    // Benchmark methods.
+    // Avoid consuming the generated values inside the loop. Use a sum and
+    // consume at the end. This reduces the run-time as the BlackHole has
+    // a relatively high overhead compared with number generation.
+    // Subtracting the baseline from the other timings provides a measure
+    // of the extra work done by the algorithm to produce unbiased samples in a range.
+
+    /**
+     * @param bh the data sink
+     * @param source the source
+     */
+    @Benchmark
+    public void nextIntBaseline(Blackhole bh, Sources source) {
+        int sum = 0;
+        for (int i = samples; i-- != 0;) {
+            sum += source.getGenerator().nextInt();
+        }
+        bh.consume(sum);
+    }
+
+    /**
+     * @param bh the data sink
+     * @param source the source
+     * @param range the range
+     */
+    @Benchmark
+    public void nextIntRange(Blackhole bh, Sources source, IntRange range) {
+        final int n = range.getUpperBound();
+        int sum = 0;
+        for (int i = samples; i-- != 0;) {
+            sum += source.getGenerator().nextInt(n);
+        }
+        bh.consume(sum);
+    }
+
+    /**
+     * @param bh the data sink
+     * @param source the source
+     * @param range the range
+     */
+    @Benchmark
+    public void nextDiscreteUniformSampler(Blackhole bh, Sources source, IntRange range) {
+        // Note: The sampler upper bound is inclusive.
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(
+                source.getGenerator(), 0, range.getUpperBound() - 1);
+        int sum = 0;
+        for (int i = samples; i-- != 0;) {
+            sum += sampler.sample();
+        }
+        bh.consume(sum);
+    }
+}
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java
index 2814799..9eebc2c 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java
@@ -22,9 +22,30 @@ import org.apache.commons.rng.UniformRandomProvider;
 /**
  * Discrete uniform distribution sampler.
  *
- * <p>Sampling uses {@link UniformRandomProvider#nextInt(int)} when
- * the range {@code (upper - lower) <} {@link Integer#MAX_VALUE}, otherwise
- * {@link UniformRandomProvider#nextInt()}.</p>
+ * <p>Sampling uses {@link UniformRandomProvider#nextInt}.</p>
+ *
+ * <p>When the range is a power of two the number of calls is 1 per sample.
+ * Otherwise a rejection algorithm is used to ensure uniformity. In the worst
+ * case scenario where the range spans half the range of an {@code int}
+ * (2<sup>31</sup> + 1) the expected number of calls is 2 per sample.</p>
+ *
+ * <p>This sampler can be used as a replacement for {@link UniformRandomProvider#nextInt}
+ * with appropriate adjustment of the upper bound to be inclusive and will outperform that
+ * method when the range is not a power of two. The advantage is gained by pre-computation
+ * of the rejection threshold.</p>
+ *
+ * <p>The sampling algorithm is described in:</p>
+ *
+ * <blockquote>
+ *  Lemire, D (2019). <i>Fast Random Integer Generation in an Interval.</i>
+ *  <strong>ACM Transactions on Modeling and Computer Simulation</strong> 29 (1).
+ * </blockquote>
+ *
+ * <p>The number of {@code int} values required per sample follows a geometric distribution with
+ * a probability of success p of {@code 1 - ((2^32 % n) / 2^32)}. This requires on average 1/p random
+ * {@code int} values per sample.</p>
+ *
+ * @see <a href="https://arxiv.org/abs/1805.10941">Fast Random Integer Generation in an Interval</a>
  *
  * @since 1.0
  */
@@ -36,24 +57,19 @@ public class DiscreteUniformSampler
     private final SharedStateDiscreteSampler delegate;
 
     /**
-     * Base class for a sampler from a discrete uniform distribution.
+     * Base class for a sampler from a discrete uniform distribution. This contains the
+     * source of randomness.
      */
     private abstract static class AbstractDiscreteUniformSampler
-        implements SharedStateDiscreteSampler {
-
+            implements SharedStateDiscreteSampler {
         /** Underlying source of randomness. */
         protected final UniformRandomProvider rng;
-        /** Lower bound. */
-        protected final int lower;
 
         /**
          * @param rng Generator of uniformly distributed random numbers.
-         * @param lower Lower bound (inclusive) of the distribution.
          */
-        AbstractDiscreteUniformSampler(UniformRandomProvider rng,
-                                       int lower) {
+        AbstractDiscreteUniformSampler(UniformRandomProvider rng) {
             this.rng = rng;
-            this.lower = lower;
         }
 
         @Override
@@ -63,35 +79,158 @@ public class DiscreteUniformSampler
     }
 
     /**
-     * Discrete uniform distribution sampler when the range between lower and upper is small
+     * Discrete uniform distribution sampler when the sample value is fixed.
+     */
+    private static class FixedDiscreteUniformSampler
+            extends AbstractDiscreteUniformSampler {
+        /** The value. */
+        private final int value;
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         * @param value The value.
+         */
+        FixedDiscreteUniformSampler(UniformRandomProvider rng,
+                                    int value) {
+            super(rng);
+            this.value = value;
+        }
+
+        @Override
+        public int sample() {
+            return value;
+        }
+
+        @Override
+        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+            // No requirement for the RNG
+            return this;
+        }
+    }
+
+    /**
+     * Discrete uniform distribution sampler when the range is a power of 2 and greater than 1.
+     * This sampler assumes the lower bound of the range is 0.
+     *
+     * <p>Note: This cannot be used when the range is 1 (2^0) as the shift would be 32-bits
+     * which is ignored by the shift operator.</p>
+     */
+    private static class PowerOf2RangeDiscreteUniformSampler
+            extends AbstractDiscreteUniformSampler {
+        /** Bit shift to apply to the integer sample. */
+        private final int shift;
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         * @param range Maximum range of the sample (exclusive).
+         * Must be a power of 2 greater than 2^0.
+         */
+        PowerOf2RangeDiscreteUniformSampler(UniformRandomProvider rng,
+                                            int range) {
+            super(rng);
+            this.shift = Integer.numberOfLeadingZeros(range) + 1;
+        }
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         * @param source Source to copy.
+         */
+        PowerOf2RangeDiscreteUniformSampler(UniformRandomProvider rng,
+                                            PowerOf2RangeDiscreteUniformSampler source) {
+            super(rng);
+            this.shift = source.shift;
+        }
+
+        @Override
+        public int sample() {
+            // Use a bit shift to favour the most significant bits.
+            // Note: The result would be the same as the rejection method used in the
+            // small range sampler when there is no rejection threshold.
+            return rng.nextInt() >>> shift;
+        }
+
+        @Override
+        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+            return new PowerOf2RangeDiscreteUniformSampler(rng, this);
+        }
+    }
+
+    /**
+     * Discrete uniform distribution sampler when the range is small
      * enough to fit in a positive integer.
+     * This sampler assumes the lower bound of the range is 0.
+     *
+     * <p>Implements the algorithm of Lemire (2019).</p>
+     *
+     * @see <a href="https://arxiv.org/abs/1805.10941">Fast Random Integer Generation in an Interval</a>
      */
     private static class SmallRangeDiscreteUniformSampler
-        extends AbstractDiscreteUniformSampler {
+            extends AbstractDiscreteUniformSampler {
+        /** Maximum range of the sample (exclusive). */
+        private final long n;
 
-        /** Maximum range of the sample from the lower bound (exclusive). */
-        private final int range;
+        /**
+         * The level below which samples are rejected based on the fraction remainder.
+         *
+         * <p>Any remainder below this denotes that there are still floor(2^32 / n) more
+         * observations of this sample from the interval [0, 2^32), where n is the range.</p>
+         */
+        private final long threshold;
 
         /**
          * @param rng Generator of uniformly distributed random numbers.
-         * @param lower Lower bound (inclusive) of the distribution.
-         * @param range Maximum range of the sample from the lower bound (exclusive).
+         * @param range Maximum range of the sample (exclusive).
          */
         SmallRangeDiscreteUniformSampler(UniformRandomProvider rng,
-                                         int lower,
                                          int range) {
-            super(rng, lower);
-            this.range = range;
+            super(rng);
+            // Handle range as an unsigned 32-bit integer
+            this.n = range & 0xffffffffL;
+            // Compute 2^32 % n
+            threshold = (1L << 32) % n;
+        }
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         * @param source Source to copy.
+         */
+        SmallRangeDiscreteUniformSampler(UniformRandomProvider rng,
+                                         SmallRangeDiscreteUniformSampler source) {
+            super(rng);
+            this.n = source.n;
+            this.threshold = source.threshold;
         }
 
         @Override
         public int sample() {
-            return lower + rng.nextInt(range);
+            // Rejection method using multiply by a fraction:
+            // n * [0, 2^32 - 1)
+            //     -------------
+            //         2^32
+            // The result is mapped back to an integer and will be in the range [0, n).
+            // Note this is comparable to range * rng.nextDouble() but with compensation for
+            // non-uniformity due to round-off.
+            long result;
+            do {
+                // Compute 64-bit unsigned product of n * [0, 2^32 - 1).
+                // The upper 32-bits contains the sample value in the range [0, n), i.e. result / 2^32.
+                // The lower 32-bits contains the remainder, i.e. result % 2^32.
+                result = n * (rng.nextInt() & 0xffffffffL);
+
+                // Test the sample uniformity.
+                // Samples are observed on average (2^32 / n) times at a frequency of either
+                // floor(2^32 / n) or ceil(2^32 / n).
+                // To ensure all samples have a frequency of floor(2^32 / n) reject any results with
+                // a remainder < (2^32 % n), i.e. the level below which denotes that there are still
+                // floor(2^32 / n) more observations of this sample.
+            } while ((result & 0xffffffffL) < threshold);
+            // Divide by 2^32 to get the sample.
+            return (int)(result >>> 32);
         }
 
         @Override
         public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
-            return new SmallRangeDiscreteUniformSampler(rng, lower, range);
+            return new SmallRangeDiscreteUniformSampler(rng, this);
         }
     }
 
@@ -100,8 +239,9 @@ public class DiscreteUniformSampler
      * to fit in a positive integer.
      */
     private static class LargeRangeDiscreteUniformSampler
-        extends AbstractDiscreteUniformSampler {
-
+            extends AbstractDiscreteUniformSampler {
+        /** Lower bound. */
+        private final int lower;
         /** Upper bound. */
         private final int upper;
 
@@ -113,7 +253,8 @@ public class DiscreteUniformSampler
         LargeRangeDiscreteUniformSampler(UniformRandomProvider rng,
                                          int lower,
                                          int upper) {
-            super(rng, lower);
+            super(rng);
+            this.lower = lower;
             this.upper = upper;
         }
 
@@ -139,6 +280,44 @@ public class DiscreteUniformSampler
     }
 
     /**
+     * Adds an offset to an underlying discrete sampler.
+     */
+    private static class OffsetDiscreteUniformSampler
+            extends AbstractDiscreteUniformSampler {
+        /** The offset. */
+        private final int offset;
+        /** The discrete sampler. */
+        private final SharedStateDiscreteSampler sampler;
+
+        /**
+         * @param offset The offset for the sample.
+         * @param sampler The discrete sampler.
+         */
+        OffsetDiscreteUniformSampler(int offset,
+                                     SharedStateDiscreteSampler sampler) {
+            super(null);
+            this.offset = offset;
+            this.sampler = sampler;
+        }
+
+        @Override
+        public int sample() {
+            return offset + sampler.sample();
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        public String toString() {
+            return sampler.toString();
+        }
+
+        @Override
+        public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+            return new OffsetDiscreteUniformSampler(offset, sampler.withUniformRandomProvider(rng));
+        }
+    }
+
+    /**
      * This instance delegates sampling. Use the factory method
      * {@link #of(UniformRandomProvider, int, int)} to create an optimal sampler.
      *
@@ -188,13 +367,68 @@ public class DiscreteUniformSampler
         if (lower > upper) {
             throw new IllegalArgumentException(lower  + " > " + upper);
         }
+
         // Choose the algorithm depending on the range
+
+        // Edge case for no range.
+        // This must be done first as the methods to handle lower == 0
+        // do not handle upper == 0.
+        if (upper == lower) {
+            return new FixedDiscreteUniformSampler(rng, lower);
+        }
+
+        // Algorithms to ignore the lower bound if it is zero.
+        if (lower == 0) {
+            return createZeroBoundedSampler(rng, upper);
+        }
+
         final int range = (upper - lower) + 1;
-        return range <= 0 ?
+        // Check power of 2 first to handle range == 2^31.
+        if (isPowerOf2(range)) {
+            return new OffsetDiscreteUniformSampler(lower,
+                                                    new PowerOf2RangeDiscreteUniformSampler(rng, range));
+        }
+        if (range <= 0) {
             // The range is too wide to fit in a positive int (larger
             // than 2^31); use a simple rejection method.
-            new LargeRangeDiscreteUniformSampler(rng, lower, upper) :
-            // Use a sample from the range added to the lower bound.
-            new SmallRangeDiscreteUniformSampler(rng, lower, range);
+            // Note: if range == 0 then the input is [Integer.MIN_VALUE, Integer.MAX_VALUE].
+            // No specialisation exists for this and it is handled as a large range.
+            return new LargeRangeDiscreteUniformSampler(rng, lower, upper);
+        }
+        // Use a sample from the range added to the lower bound.
+        return new OffsetDiscreteUniformSampler(lower,
+                                                new SmallRangeDiscreteUniformSampler(rng, range));
+    }
+
+    /**
+     * Create a new sampler for the range {@code 0} inclusive to {@code upper} inclusive.
+     *
+     * <p>This can handle any positive {@code upper}.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param upper Upper bound (inclusive) of the distribution. Must be positive.
+     * @return the sampler
+     */
+    private static AbstractDiscreteUniformSampler createZeroBoundedSampler(UniformRandomProvider rng,
+                                                                           int upper) {
+        // Note: Handle any range up to 2^31 (which is negative as a signed
+        // 32-bit integer but handled as a power of 2)
+        final int range = upper + 1;
+        return isPowerOf2(range) ?
+            new PowerOf2RangeDiscreteUniformSampler(rng, range) :
+            new SmallRangeDiscreteUniformSampler(rng, range);
+    }
+
+    /**
+     * Checks if the value is a power of 2.
+     *
+     * <p>This returns {@code true} for the value {@code Integer.MIN_VALUE} which can be
+     * handled as an unsigned integer of 2^31.</p>
+     *
+     * @param value Value.
+     * @return {@code true} if a power of 2
+     */
+    private static boolean isPowerOf2(final int value) {
+        return value != 0 && (value & (value - 1)) == 0;
     }
 }
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
index dd57177..7df26a0 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -91,7 +91,7 @@ public final class DiscreteSamplersList {
             add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
                 MathArrays.sequence(8, -3, 1),
                 RandomSource.create(RandomSource.SPLIT_MIX_64));
-            // Uniform.
+            // Uniform (power of 2 range).
             add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
                 MathArrays.sequence(8, -3, 1),
                 DiscreteUniformSampler.of(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform));
@@ -102,6 +102,12 @@ public final class DiscreteSamplersList {
             add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
                 MathArrays.sequence(20, -halfMax, halfMax / 10),
                 DiscreteUniformSampler.of(RandomSource.create(RandomSource.WELL_1024_A), loLargeUniform, hiLargeUniform));
+            // Uniform (non-power of 2 range).
+            final int rangeNonPowerOf2Uniform = 11;
+            final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
+            add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
+                MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
+                DiscreteUniformSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_256_SS), loUniform, hiNonPowerOf2Uniform));
 
             // Zipf ("inverse method").
             final int numElementsZipf = 5;
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java
index 0fd662a..747bfc5 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java
@@ -17,13 +17,17 @@
 package org.apache.commons.rng.sampling.distribution;
 
 import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.core.source32.IntProvider;
 import org.apache.commons.rng.sampling.RandomAssert;
 import org.apache.commons.rng.simple.RandomSource;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.Locale;
+
 /**
- * Test for the {@link DiscreteUniformSampler}. The tests hit edge cases for the sampler.
+ * Test for the {@link DiscreteUniformSampler}. The tests hit edge cases for the sampler
+ * and demonstrates uniformity of output when the underlying RNG output is uniform.
  */
 public class DiscreteUniformSamplerTest {
     /**
@@ -37,6 +41,224 @@ public class DiscreteUniformSamplerTest {
         DiscreteUniformSampler.of(rng, lower, upper);
     }
 
+    @Test
+    public void testSamplesWithRangeOf1() {
+        final int upper = 99;
+        final int lower = upper;
+        final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+        for (int i = 0; i < 5; i++) {
+            Assert.assertEquals(lower, sampler.sample());
+        }
+    }
+
+    /**
+     * Test samples with a full integer range.
+     * The output should be the same as the int values produced from a RNG.
+     */
+    @Test
+    public void testSamplesWithFullRange() {
+        final int upper = Integer.MAX_VALUE;
+        final int lower = Integer.MIN_VALUE;
+        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng2, lower, upper);
+        for (int i = 0; i < 5; i++) {
+            Assert.assertEquals(rng1.nextInt(), sampler.sample());
+        }
+    }
+
+    @Test
+    public void testSamplesWithPowerOf2Range() {
+        final UniformRandomProvider rngZeroBits = new IntProvider() {
+            @Override
+            public int next() {
+                return 0;
+            }
+        };
+        final UniformRandomProvider rngAllBits = new IntProvider() {
+            @Override
+            public int next() {
+                return 0xffffffff;
+            }
+        };
+
+        final int lower = -3;
+        DiscreteUniformSampler sampler;
+        // The upper range for a positive integer is 2^31-1. So the max positive power of
+        // 2 is 2^30. However the sampler should handle a bit shift of 31 to create a range
+        // of Integer.MIN_VALUE (0x80000000) as this is a power of 2 as an unsigned int (2^31).
+        for (int i = 0; i < 32; i++) {
+            final int range = 1 << i;
+            final int upper = lower + range - 1;
+            sampler = new DiscreteUniformSampler(rngZeroBits, lower, upper);
+            Assert.assertEquals("Zero bits sample", lower, sampler.sample());
+            sampler = new DiscreteUniformSampler(rngAllBits, lower, upper);
+            Assert.assertEquals("All bits sample", upper, sampler.sample());
+        }
+    }
+
+    @Test
+    public void testOffsetSamplesWithNonPowerOf2Range() {
+        assertOffsetSamples(257);
+    }
+
+    @Test
+    public void testOffsetSamplesWithPowerOf2Range() {
+        assertOffsetSamples(256);
+    }
+
+    @Test
+    public void testOffsetSamplesWithRangeOf1() {
+        assertOffsetSamples(1);
+    }
+
+    private static void assertOffsetSamples(int range) {
+        final Long seed = RandomSource.createLong();
+        final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+        final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+        final UniformRandomProvider rng3 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+
+        // Since the upper limit is inclusive
+        range = range - 1;
+        final int offsetLo = -13;
+        final int offsetHi = 42;
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng1, 0, range);
+        final SharedStateDiscreteSampler samplerLo = DiscreteUniformSampler.of(rng2, offsetLo, offsetLo + range);
+        final SharedStateDiscreteSampler samplerHi = DiscreteUniformSampler.of(rng3, offsetHi, offsetHi + range);
+        for (int i = 0; i < 10; i++) {
+            final int sample1 = sampler.sample();
+            final int sample2 = samplerLo.sample();
+            final int sample3 = samplerHi.sample();
+            Assert.assertEquals("Incorrect negative offset sample", sample1 + offsetLo, sample2);
+            Assert.assertEquals("Incorrect positive offset sample", sample1 + offsetHi, sample3);
+        }
+    }
+
+    /**
+     * Test the sample uniformity when using a small range that is not a power of 2.
+     */
+    @Test
+    public void testSampleUniformityWithNonPowerOf2Range() {
+        // Test using a RNG that outputs an evenly spaced set of integers.
+        // Create a Weyl sequence using George Marsaglia’s increment from:
+        // Marsaglia, G (July 2003). "Xorshift RNGs". Journal of Statistical Software. 8 (14).
+        // https://en.wikipedia.org/wiki/Weyl_sequence
+        final UniformRandomProvider rng = new IntProvider() {
+            private final int increment = 362437;
+            // Start at the highest positive number
+            private final int start = Integer.MIN_VALUE - increment;
+
+            private int bits = start;
+
+            @Override
+            public int next() {
+                // Output until the first wrap. The entire sequence will be uniform.
+                // Note this is not the full period of the sequence.
+                // Expect ((1L << 32) / increment) numbers = 11850
+                int result = bits += increment;
+                if (result < start) {
+                    return result;
+                }
+                throw new IllegalStateException("end of sequence");
+            }
+        };
+
+        // n = upper range exclusive
+        final int n = 37; // prime
+        final int[] histogram = new int[n];
+
+        final int lower = 0;
+        final int upper = n - 1;
+
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+
+        try {
+            while (true) {
+                histogram[sampler.sample()]++;
+            }
+        } catch (IllegalStateException ex) {
+            // Expected end of sequence
+        }
+
+        // The sequence will result in either x or (x+1) samples in each bin (i.e. uniform).
+        int min = histogram[0];
+        int max = histogram[0];
+        for (int value : histogram) {
+            min = Math.min(min, value);
+            max = Math.max(max, value);
+        }
+        Assert.assertTrue("Not uniform, max = " + max + ", min=" + min, max - min <= 1);
+    }
+
+    /**
+     * Test the sample uniformity when using a small range that is a power of 2.
+     */
+    @Test
+    public void testSampleUniformityWithPowerOf2Range() {
+        // Test using a RNG that outputs a counter of integers.
+        final UniformRandomProvider rng = new IntProvider() {
+            private int bits = 0;
+
+            @Override
+            public int next() {
+                // We reverse the bits because the most significant bits are used
+                return Integer.reverse(bits++);
+            }
+        };
+
+        // n = upper range exclusive
+        final int n = 32; // power of 2
+        final int[] histogram = new int[n];
+
+        final int lower = 0;
+        final int upper = n - 1;
+
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+
+        final int expected = 2;
+        for (int i = expected * n; i-- > 0;) {
+            histogram[sampler.sample()]++;
+        }
+
+        // This should be even across the entire range
+        for (int value : histogram) {
+            Assert.assertEquals(expected, value);
+        }
+    }
+
+    /**
+     * Test the sample rejection when using a range that is not a power of 2. The rejection
+     * algorithm of Lemire (2019) splits the entire 32-bit range into intervals of size 2^32/n. It
+     * will reject the lowest value in each interval that is over sampled. This test uses 0
+     * as the first value from the RNG and tests it is rejected.
+     */
+    @Test
+    public void testSampleRejectionWithNonPowerOf2Range() {
+        // Test using a RNG that returns a sequence.
+        // The first value of zero should produce a sample that is rejected.
+        final int[] value = new int[1];
+        final UniformRandomProvider rng = new IntProvider() {
+            @Override
+            public int next() {
+                return value[0]++;
+            }
+        };
+
+        // n = upper range exclusive.
+        // Use a prime number to select the rejection algorithm.
+        final int n = 37;
+        final int lower = 0;
+        final int upper = n - 1;
+
+        final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+
+        final int sample = sampler.sample();
+
+        Assert.assertEquals("Sample is incorrect", 0, sample);
+        Assert.assertEquals("Sample should be produced from 2nd value", 2, value[0]);
+    }
+
     /**
      * Test the SharedStateSampler implementation.
      */
@@ -50,7 +272,24 @@ public class DiscreteUniformSamplerTest {
      */
     @Test
     public void testSharedStateSamplerWithLargeRange() {
-        testSharedStateSampler(-99999999, Integer.MAX_VALUE);
+        // Set the range so rejection below or above the threshold occurs with approximately p=0.25
+        testSharedStateSampler(Integer.MIN_VALUE / 2 - 1, Integer.MAX_VALUE / 2 + 1);
+    }
+
+    /**
+     * Test the SharedStateSampler implementation.
+     */
+    @Test
+    public void testSharedStateSamplerWithPowerOf2Range() {
+        testSharedStateSampler(0, 31);
+    }
+
+    /**
+     * Test the SharedStateSampler implementation.
+     */
+    @Test
+    public void testSharedStateSamplerWithRangeOf1() {
+        testSharedStateSampler(9, 9);
     }
 
     /**
@@ -69,13 +308,38 @@ public class DiscreteUniformSamplerTest {
         RandomAssert.assertProduceSameSequence(sampler1, sampler2);
     }
 
+    @Test
+    public void testToStringWithSmallRange() {
+        assertToString(5, 67);
+    }
+
+    @Test
+    public void testToStringWithLargeRange() {
+        assertToString(-99999999, Integer.MAX_VALUE);
+    }
+
+    @Test
+    public void testToStringWithPowerOf2Range() {
+        // Note the range is upper - lower + 1
+        assertToString(0, 31);
+    }
+
+    @Test
+    public void testToStringWithRangeOf1() {
+        assertToString(9, 9);
+    }
+
     /**
      * Test the toString method. This is added to ensure coverage as the factory constructor
      * used in other tests does not create an instance of the wrapper class.
+     *
+     * @param lower Lower.
+     * @param upper Upper.
      */
-    @Test
-    public void testToString() {
+    private static void assertToString(int lower, int upper) {
         final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
-        Assert.assertTrue(new DiscreteUniformSampler(rng, 1, 2).toString().toLowerCase().contains("uniform"));
+        final DiscreteUniformSampler sampler =
+            new DiscreteUniformSampler(rng, lower, upper);
+        Assert.assertTrue(sampler.toString().toLowerCase(Locale.US).contains("uniform"));
     }
 }