You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/08/03 11:52:08 UTC
[commons-rng] 01/02: RNG-95: Update the DiscreteUniformSampler
using faster algorithms.
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 5976d7d9d34e457e8d2af689c3a64c7937db84c8
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Wed Apr 24 21:57:27 2019 +0100
RNG-95: Update the DiscreteUniformSampler using faster algorithms.
Algorithms are added for ranges that are a power of 2 and non-power of
2.
Now specifically handles a lower bound of 0.
---
...iscreteUniformSamplerGenerationPerformance.java | 184 +++++++++++++
.../distribution/DiscreteUniformSampler.java | 294 ++++++++++++++++++---
.../distribution/DiscreteSamplersList.java | 8 +-
.../distribution/DiscreteUniformSamplerTest.java | 274 ++++++++++++++++++-
4 files changed, 724 insertions(+), 36 deletions(-)
diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteUniformSamplerGenerationPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteUniformSamplerGenerationPerformance.java
new file mode 100644
index 0000000..d1ab8a7
--- /dev/null
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteUniformSamplerGenerationPerformance.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.examples.jmh.distribution;
+
+import org.apache.commons.rng.RestorableUniformRandomProvider;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler;
+import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler;
+import org.apache.commons.rng.simple.RandomSource;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Executes benchmark to compare the speed of generation of integer numbers in a positive range
+ * using the {@link DiscreteUniformSampler} or {@link UniformRandomProvider#nextInt(int)}.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = { "-server", "-Xms128M", "-Xmx128M" })
+public class DiscreteUniformSamplerGenerationPerformance {
+ /** The number of samples. */
+ @Param({
+ "1",
+ "2",
+ "4",
+ "8",
+ "16",
+ "1000000",
+ })
+ private int samples;
+
+ /**
+ * The benchmark state (retrieve the various "RandomSource"s).
+ */
+ @State(Scope.Benchmark)
+ public static class Sources {
+ /**
+ * RNG providers.
+ *
+ * <p>Use different speeds.</p>
+ *
+ * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
+ * Commons RNG user guide</a>
+ */
+ @Param({"SPLIT_MIX_64",
+ // Comment in for slower generators
+ //"MWC_256", "KISS", "WELL_1024_A",
+ //"WELL_44497_B"
+ })
+ private String randomSourceName;
+
+ /** RNG. */
+ private RestorableUniformRandomProvider generator;
+
+ /**
+ * @return the RNG.
+ */
+ public UniformRandomProvider getGenerator() {
+ return generator;
+ }
+
+ /** Instantiates generator. */
+ @Setup
+ public void setup() {
+ final RandomSource randomSource = RandomSource.valueOf(randomSourceName);
+ generator = RandomSource.create(randomSource);
+ }
+ }
+
+ /**
+ * The upper range for the {@code int} generation.
+ */
+ @State(Scope.Benchmark)
+ public static class IntRange {
+ /**
+ * The upper range for the {@code int} generation.
+ *
+ * <p>Note that the while loop uses a rejection algorithm. From the Javadoc for java.util.Random:</p>
+ *
+ * <pre>
+ * "The probability of a value being rejected depends on n. The
+ * worst case is n=2^30+1, for which the probability of a reject is 1/2,
+ * and the expected number of iterations before the loop terminates is 2."
+ * </pre>
+ */
+ @Param({
+ "256", // Even: 1 << 8
+ "257", // Prime number
+ "1073741825", // Worst case: (1 << 30) + 1
+ })
+ private int upperBound;
+
+ /**
+ * Gets the upper bound.
+ *
+ * @return the upper bound
+ */
+ public int getUpperBound() {
+ return upperBound;
+ }
+ }
+
+ // Benchmark methods.
+ // Avoid consuming the generated values inside the loop. Use a sum and
+ // consume at the end. This reduces the run-time as the BlackHole has
+ // a relatively high overhead compared with number generation.
+ // Subtracting the baseline from the other timings provides a measure
+ // of the extra work done by the algorithm to produce unbiased samples in a range.
+
+ /**
+ * @param bh the data sink
+ * @param source the source
+ */
+ @Benchmark
+ public void nextIntBaseline(Blackhole bh, Sources source) {
+ int sum = 0;
+ for (int i = samples; i-- != 0;) {
+ sum += source.getGenerator().nextInt();
+ }
+ bh.consume(sum);
+ }
+
+ /**
+ * @param bh the data sink
+ * @param source the source
+ * @param range the range
+ */
+ @Benchmark
+ public void nextIntRange(Blackhole bh, Sources source, IntRange range) {
+ final int n = range.getUpperBound();
+ int sum = 0;
+ for (int i = samples; i-- != 0;) {
+ sum += source.getGenerator().nextInt(n);
+ }
+ bh.consume(sum);
+ }
+
+ /**
+ * @param bh the data sink
+ * @param source the source
+ * @param range the range
+ */
+ @Benchmark
+ public void nextDiscreteUniformSampler(Blackhole bh, Sources source, IntRange range) {
+ // Note: The sampler upper bound is inclusive.
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(
+ source.getGenerator(), 0, range.getUpperBound() - 1);
+ int sum = 0;
+ for (int i = samples; i-- != 0;) {
+ sum += sampler.sample();
+ }
+ bh.consume(sum);
+ }
+}
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java
index 2814799..9eebc2c 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java
@@ -22,9 +22,30 @@ import org.apache.commons.rng.UniformRandomProvider;
/**
* Discrete uniform distribution sampler.
*
- * <p>Sampling uses {@link UniformRandomProvider#nextInt(int)} when
- * the range {@code (upper - lower) <} {@link Integer#MAX_VALUE}, otherwise
- * {@link UniformRandomProvider#nextInt()}.</p>
+ * <p>Sampling uses {@link UniformRandomProvider#nextInt}.</p>
+ *
+ * <p>When the range is a power of two the number of calls is 1 per sample.
+ * Otherwise a rejection algorithm is used to ensure uniformity. In the worst
+ * case scenario where the range spans half the range of an {@code int}
+ * (2<sup>31</sup> + 1) the expected number of calls is 2 per sample.</p>
+ *
+ * <p>This sampler can be used as a replacement for {@link UniformRandomProvider#nextInt}
+ * with appropriate adjustment of the upper bound to be inclusive and will outperform that
+ * method when the range is not a power of two. The advantage is gained by pre-computation
+ * of the rejection threshold.</p>
+ *
+ * <p>The sampling algorithm is described in:</p>
+ *
+ * <blockquote>
+ * Lemire, D (2019). <i>Fast Random Integer Generation in an Interval.</i>
+ * <strong>ACM Transactions on Modeling and Computer Simulation</strong> 29 (1).
+ * </blockquote>
+ *
+ * <p>The number of {@code int} values required per sample follows a geometric distribution with
+ * a probability of success p of {@code 1 - ((2^32 % n) / 2^32)}. This requires on average 1/p random
+ * {@code int} values per sample.</p>
+ *
+ * @see <a href="https://arxiv.org/abs/1805.10941">Fast Random Integer Generation in an Interval</a>
*
* @since 1.0
*/
@@ -36,24 +57,19 @@ public class DiscreteUniformSampler
private final SharedStateDiscreteSampler delegate;
/**
- * Base class for a sampler from a discrete uniform distribution.
+ * Base class for a sampler from a discrete uniform distribution. This contains the
+ * source of randomness.
*/
private abstract static class AbstractDiscreteUniformSampler
- implements SharedStateDiscreteSampler {
-
+ implements SharedStateDiscreteSampler {
/** Underlying source of randomness. */
protected final UniformRandomProvider rng;
- /** Lower bound. */
- protected final int lower;
/**
* @param rng Generator of uniformly distributed random numbers.
- * @param lower Lower bound (inclusive) of the distribution.
*/
- AbstractDiscreteUniformSampler(UniformRandomProvider rng,
- int lower) {
+ AbstractDiscreteUniformSampler(UniformRandomProvider rng) {
this.rng = rng;
- this.lower = lower;
}
@Override
@@ -63,35 +79,158 @@ public class DiscreteUniformSampler
}
/**
- * Discrete uniform distribution sampler when the range between lower and upper is small
+ * Discrete uniform distribution sampler when the sample value is fixed.
+ */
+ private static class FixedDiscreteUniformSampler
+ extends AbstractDiscreteUniformSampler {
+ /** The value. */
+ private final int value;
+
+ /**
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param value The value.
+ */
+ FixedDiscreteUniformSampler(UniformRandomProvider rng,
+ int value) {
+ super(rng);
+ this.value = value;
+ }
+
+ @Override
+ public int sample() {
+ return value;
+ }
+
+ @Override
+ public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+ // No requirement for the RNG
+ return this;
+ }
+ }
+
+ /**
+ * Discrete uniform distribution sampler when the range is a power of 2 and greater than 1.
+ * This sampler assumes the lower bound of the range is 0.
+ *
+ * <p>Note: This cannot be used when the range is 1 (2^0) as the shift would be 32-bits
+ * which is ignored by the shift operator.</p>
+ */
+ private static class PowerOf2RangeDiscreteUniformSampler
+ extends AbstractDiscreteUniformSampler {
+ /** Bit shift to apply to the integer sample. */
+ private final int shift;
+
+ /**
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param range Maximum range of the sample (exclusive).
+ * Must be a power of 2 greater than 2^0.
+ */
+ PowerOf2RangeDiscreteUniformSampler(UniformRandomProvider rng,
+ int range) {
+ super(rng);
+ this.shift = Integer.numberOfLeadingZeros(range) + 1;
+ }
+
+ /**
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param source Source to copy.
+ */
+ PowerOf2RangeDiscreteUniformSampler(UniformRandomProvider rng,
+ PowerOf2RangeDiscreteUniformSampler source) {
+ super(rng);
+ this.shift = source.shift;
+ }
+
+ @Override
+ public int sample() {
+ // Use a bit shift to favour the most significant bits.
+ // Note: The result would be the same as the rejection method used in the
+ // small range sampler when there is no rejection threshold.
+ return rng.nextInt() >>> shift;
+ }
+
+ @Override
+ public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+ return new PowerOf2RangeDiscreteUniformSampler(rng, this);
+ }
+ }
+
+ /**
+ * Discrete uniform distribution sampler when the range is small
* enough to fit in a positive integer.
+ * This sampler assumes the lower bound of the range is 0.
+ *
+ * <p>Implements the algorithm of Lemire (2019).</p>
+ *
+ * @see <a href="https://arxiv.org/abs/1805.10941">Fast Random Integer Generation in an Interval</a>
*/
private static class SmallRangeDiscreteUniformSampler
- extends AbstractDiscreteUniformSampler {
+ extends AbstractDiscreteUniformSampler {
+ /** Maximum range of the sample (exclusive). */
+ private final long n;
- /** Maximum range of the sample from the lower bound (exclusive). */
- private final int range;
+ /**
+ * The level below which samples are rejected based on the fraction remainder.
+ *
+ * <p>Any remainder below this denotes that there are still floor(2^32 / n) more
+ * observations of this sample from the interval [0, 2^32), where n is the range.</p>
+ */
+ private final long threshold;
/**
* @param rng Generator of uniformly distributed random numbers.
- * @param lower Lower bound (inclusive) of the distribution.
- * @param range Maximum range of the sample from the lower bound (exclusive).
+ * @param range Maximum range of the sample (exclusive).
*/
SmallRangeDiscreteUniformSampler(UniformRandomProvider rng,
- int lower,
int range) {
- super(rng, lower);
- this.range = range;
+ super(rng);
+ // Handle range as an unsigned 32-bit integer
+ this.n = range & 0xffffffffL;
+ // Compute 2^32 % n
+ threshold = (1L << 32) % n;
+ }
+
+ /**
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param source Source to copy.
+ */
+ SmallRangeDiscreteUniformSampler(UniformRandomProvider rng,
+ SmallRangeDiscreteUniformSampler source) {
+ super(rng);
+ this.n = source.n;
+ this.threshold = source.threshold;
}
@Override
public int sample() {
- return lower + rng.nextInt(range);
+ // Rejection method using multiply by a fraction:
+ // n * [0, 2^32 - 1)
+ // -------------
+ // 2^32
+ // The result is mapped back to an integer and will be in the range [0, n).
+ // Note this is comparable to range * rng.nextDouble() but with compensation for
+ // non-uniformity due to round-off.
+ long result;
+ do {
+ // Compute 64-bit unsigned product of n * [0, 2^32 - 1).
+ // The upper 32-bits contains the sample value in the range [0, n), i.e. result / 2^32.
+ // The lower 32-bits contains the remainder, i.e. result % 2^32.
+ result = n * (rng.nextInt() & 0xffffffffL);
+
+ // Test the sample uniformity.
+ // Samples are observed on average (2^32 / n) times at a frequency of either
+ // floor(2^32 / n) or ceil(2^32 / n).
+ // To ensure all samples have a frequency of floor(2^32 / n) reject any results with
+ // a remainder < (2^32 % n), i.e. the level below which denotes that there are still
+ // floor(2^32 / n) more observations of this sample.
+ } while ((result & 0xffffffffL) < threshold);
+ // Divide by 2^32 to get the sample.
+ return (int)(result >>> 32);
}
@Override
public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
- return new SmallRangeDiscreteUniformSampler(rng, lower, range);
+ return new SmallRangeDiscreteUniformSampler(rng, this);
}
}
@@ -100,8 +239,9 @@ public class DiscreteUniformSampler
* to fit in a positive integer.
*/
private static class LargeRangeDiscreteUniformSampler
- extends AbstractDiscreteUniformSampler {
-
+ extends AbstractDiscreteUniformSampler {
+ /** Lower bound. */
+ private final int lower;
/** Upper bound. */
private final int upper;
@@ -113,7 +253,8 @@ public class DiscreteUniformSampler
LargeRangeDiscreteUniformSampler(UniformRandomProvider rng,
int lower,
int upper) {
- super(rng, lower);
+ super(rng);
+ this.lower = lower;
this.upper = upper;
}
@@ -139,6 +280,44 @@ public class DiscreteUniformSampler
}
/**
+ * Adds an offset to an underlying discrete sampler.
+ */
+ private static class OffsetDiscreteUniformSampler
+ extends AbstractDiscreteUniformSampler {
+ /** The offset. */
+ private final int offset;
+ /** The discrete sampler. */
+ private final SharedStateDiscreteSampler sampler;
+
+ /**
+ * @param offset The offset for the sample.
+ * @param sampler The discrete sampler.
+ */
+ OffsetDiscreteUniformSampler(int offset,
+ SharedStateDiscreteSampler sampler) {
+ super(null);
+ this.offset = offset;
+ this.sampler = sampler;
+ }
+
+ @Override
+ public int sample() {
+ return offset + sampler.sample();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return sampler.toString();
+ }
+
+ @Override
+ public SharedStateDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) {
+ return new OffsetDiscreteUniformSampler(offset, sampler.withUniformRandomProvider(rng));
+ }
+ }
+
+ /**
* This instance delegates sampling. Use the factory method
* {@link #of(UniformRandomProvider, int, int)} to create an optimal sampler.
*
@@ -188,13 +367,68 @@ public class DiscreteUniformSampler
if (lower > upper) {
throw new IllegalArgumentException(lower + " > " + upper);
}
+
// Choose the algorithm depending on the range
+
+ // Edge case for no range.
+ // This must be done first as the methods to handle lower == 0
+ // do not handle upper == 0.
+ if (upper == lower) {
+ return new FixedDiscreteUniformSampler(rng, lower);
+ }
+
+ // Algorithms to ignore the lower bound if it is zero.
+ if (lower == 0) {
+ return createZeroBoundedSampler(rng, upper);
+ }
+
final int range = (upper - lower) + 1;
- return range <= 0 ?
+ // Check power of 2 first to handle range == 2^31.
+ if (isPowerOf2(range)) {
+ return new OffsetDiscreteUniformSampler(lower,
+ new PowerOf2RangeDiscreteUniformSampler(rng, range));
+ }
+ if (range <= 0) {
// The range is too wide to fit in a positive int (larger
// than 2^31); use a simple rejection method.
- new LargeRangeDiscreteUniformSampler(rng, lower, upper) :
- // Use a sample from the range added to the lower bound.
- new SmallRangeDiscreteUniformSampler(rng, lower, range);
+ // Note: if range == 0 then the input is [Integer.MIN_VALUE, Integer.MAX_VALUE].
+ // No specialisation exists for this and it is handled as a large range.
+ return new LargeRangeDiscreteUniformSampler(rng, lower, upper);
+ }
+ // Use a sample from the range added to the lower bound.
+ return new OffsetDiscreteUniformSampler(lower,
+ new SmallRangeDiscreteUniformSampler(rng, range));
+ }
+
+ /**
+ * Create a new sampler for the range {@code 0} inclusive to {@code upper} inclusive.
+ *
+ * <p>This can handle any positive {@code upper}.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param upper Upper bound (inclusive) of the distribution. Must be positive.
+ * @return the sampler
+ */
+ private static AbstractDiscreteUniformSampler createZeroBoundedSampler(UniformRandomProvider rng,
+ int upper) {
+ // Note: Handle any range up to 2^31 (which is negative as a signed
+ // 32-bit integer but handled as a power of 2)
+ final int range = upper + 1;
+ return isPowerOf2(range) ?
+ new PowerOf2RangeDiscreteUniformSampler(rng, range) :
+ new SmallRangeDiscreteUniformSampler(rng, range);
+ }
+
+ /**
+ * Checks if the value is a power of 2.
+ *
+ * <p>This returns {@code true} for the value {@code Integer.MIN_VALUE} which can be
+ * handled as an unsigned integer of 2^31.</p>
+ *
+ * @param value Value.
+ * @return {@code true} if a power of 2
+ */
+ private static boolean isPowerOf2(final int value) {
+ return value != 0 && (value & (value - 1)) == 0;
}
}
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
index dd57177..7df26a0 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -91,7 +91,7 @@ public final class DiscreteSamplersList {
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
MathArrays.sequence(8, -3, 1),
RandomSource.create(RandomSource.SPLIT_MIX_64));
- // Uniform.
+ // Uniform (power of 2 range).
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiUniform),
MathArrays.sequence(8, -3, 1),
DiscreteUniformSampler.of(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform));
@@ -102,6 +102,12 @@ public final class DiscreteSamplersList {
add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loLargeUniform, hiLargeUniform),
MathArrays.sequence(20, -halfMax, halfMax / 10),
DiscreteUniformSampler.of(RandomSource.create(RandomSource.WELL_1024_A), loLargeUniform, hiLargeUniform));
+ // Uniform (non-power of 2 range).
+ final int rangeNonPowerOf2Uniform = 11;
+ final int hiNonPowerOf2Uniform = loUniform + rangeNonPowerOf2Uniform;
+ add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(unusedRng, loUniform, hiNonPowerOf2Uniform),
+ MathArrays.sequence(rangeNonPowerOf2Uniform, -3, 1),
+ DiscreteUniformSampler.of(RandomSource.create(RandomSource.XO_SHI_RO_256_SS), loUniform, hiNonPowerOf2Uniform));
// Zipf ("inverse method").
final int numElementsZipf = 5;
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java
index 0fd662a..747bfc5 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSamplerTest.java
@@ -17,13 +17,17 @@
package org.apache.commons.rng.sampling.distribution;
import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.core.source32.IntProvider;
import org.apache.commons.rng.sampling.RandomAssert;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.Assert;
import org.junit.Test;
+import java.util.Locale;
+
/**
- * Test for the {@link DiscreteUniformSampler}. The tests hit edge cases for the sampler.
+ * Test for the {@link DiscreteUniformSampler}. The tests hit edge cases for the sampler
+ * and demonstrates uniformity of output when the underlying RNG output is uniform.
*/
public class DiscreteUniformSamplerTest {
/**
@@ -37,6 +41,224 @@ public class DiscreteUniformSamplerTest {
DiscreteUniformSampler.of(rng, lower, upper);
}
+ @Test
+ public void testSamplesWithRangeOf1() {
+ final int upper = 99;
+ final int lower = upper;
+ final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(lower, sampler.sample());
+ }
+ }
+
+ /**
+ * Test samples with a full integer range.
+ * The output should be the same as the int values produced from a RNG.
+ */
+ @Test
+ public void testSamplesWithFullRange() {
+ final int upper = Integer.MAX_VALUE;
+ final int lower = Integer.MIN_VALUE;
+ final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+ final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng2, lower, upper);
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(rng1.nextInt(), sampler.sample());
+ }
+ }
+
+ @Test
+ public void testSamplesWithPowerOf2Range() {
+ final UniformRandomProvider rngZeroBits = new IntProvider() {
+ @Override
+ public int next() {
+ return 0;
+ }
+ };
+ final UniformRandomProvider rngAllBits = new IntProvider() {
+ @Override
+ public int next() {
+ return 0xffffffff;
+ }
+ };
+
+ final int lower = -3;
+ DiscreteUniformSampler sampler;
+ // The upper range for a positive integer is 2^31-1. So the max positive power of
+ // 2 is 2^30. However the sampler should handle a bit shift of 31 to create a range
+ // of Integer.MIN_VALUE (0x80000000) as this is a power of 2 as an unsigned int (2^31).
+ for (int i = 0; i < 32; i++) {
+ final int range = 1 << i;
+ final int upper = lower + range - 1;
+ sampler = new DiscreteUniformSampler(rngZeroBits, lower, upper);
+ Assert.assertEquals("Zero bits sample", lower, sampler.sample());
+ sampler = new DiscreteUniformSampler(rngAllBits, lower, upper);
+ Assert.assertEquals("All bits sample", upper, sampler.sample());
+ }
+ }
+
+ @Test
+ public void testOffsetSamplesWithNonPowerOf2Range() {
+ assertOffsetSamples(257);
+ }
+
+ @Test
+ public void testOffsetSamplesWithPowerOf2Range() {
+ assertOffsetSamples(256);
+ }
+
+ @Test
+ public void testOffsetSamplesWithRangeOf1() {
+ assertOffsetSamples(1);
+ }
+
+ private static void assertOffsetSamples(int range) {
+ final Long seed = RandomSource.createLong();
+ final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+ final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+ final UniformRandomProvider rng3 = RandomSource.create(RandomSource.SPLIT_MIX_64, seed);
+
+ // Since the upper limit is inclusive
+ range = range - 1;
+ final int offsetLo = -13;
+ final int offsetHi = 42;
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng1, 0, range);
+ final SharedStateDiscreteSampler samplerLo = DiscreteUniformSampler.of(rng2, offsetLo, offsetLo + range);
+ final SharedStateDiscreteSampler samplerHi = DiscreteUniformSampler.of(rng3, offsetHi, offsetHi + range);
+ for (int i = 0; i < 10; i++) {
+ final int sample1 = sampler.sample();
+ final int sample2 = samplerLo.sample();
+ final int sample3 = samplerHi.sample();
+ Assert.assertEquals("Incorrect negative offset sample", sample1 + offsetLo, sample2);
+ Assert.assertEquals("Incorrect positive offset sample", sample1 + offsetHi, sample3);
+ }
+ }
+
+ /**
+ * Test the sample uniformity when using a small range that is not a power of 2.
+ */
+ @Test
+ public void testSampleUniformityWithNonPowerOf2Range() {
+ // Test using a RNG that outputs an evenly spaced set of integers.
+ // Create a Weyl sequence using George Marsaglia’s increment from:
+ // Marsaglia, G (July 2003). "Xorshift RNGs". Journal of Statistical Software. 8 (14).
+ // https://en.wikipedia.org/wiki/Weyl_sequence
+ final UniformRandomProvider rng = new IntProvider() {
+ private final int increment = 362437;
+ // Start at the highest positive number
+ private final int start = Integer.MIN_VALUE - increment;
+
+ private int bits = start;
+
+ @Override
+ public int next() {
+ // Output until the first wrap. The entire sequence will be uniform.
+ // Note this is not the full period of the sequence.
+ // Expect ((1L << 32) / increment) numbers = 11850
+ int result = bits += increment;
+ if (result < start) {
+ return result;
+ }
+ throw new IllegalStateException("end of sequence");
+ }
+ };
+
+ // n = upper range exclusive
+ final int n = 37; // prime
+ final int[] histogram = new int[n];
+
+ final int lower = 0;
+ final int upper = n - 1;
+
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+
+ try {
+ while (true) {
+ histogram[sampler.sample()]++;
+ }
+ } catch (IllegalStateException ex) {
+ // Expected end of sequence
+ }
+
+ // The sequence will result in either x or (x+1) samples in each bin (i.e. uniform).
+ int min = histogram[0];
+ int max = histogram[0];
+ for (int value : histogram) {
+ min = Math.min(min, value);
+ max = Math.max(max, value);
+ }
+ Assert.assertTrue("Not uniform, max = " + max + ", min=" + min, max - min <= 1);
+ }
+
+ /**
+ * Test the sample uniformity when using a small range that is a power of 2.
+ */
+ @Test
+ public void testSampleUniformityWithPowerOf2Range() {
+ // Test using a RNG that outputs a counter of integers.
+ final UniformRandomProvider rng = new IntProvider() {
+ private int bits = 0;
+
+ @Override
+ public int next() {
+ // We reverse the bits because the most significant bits are used
+ return Integer.reverse(bits++);
+ }
+ };
+
+ // n = upper range exclusive
+ final int n = 32; // power of 2
+ final int[] histogram = new int[n];
+
+ final int lower = 0;
+ final int upper = n - 1;
+
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+
+ final int expected = 2;
+ for (int i = expected * n; i-- > 0;) {
+ histogram[sampler.sample()]++;
+ }
+
+ // This should be even across the entire range
+ for (int value : histogram) {
+ Assert.assertEquals(expected, value);
+ }
+ }
+
+ /**
+ * Test the sample rejection when using a range that is not a power of 2. The rejection
+ * algorithm of Lemire (2019) splits the entire 32-bit range into intervals of size 2^32/n. It
+ * will reject the lowest value in each interval that is over sampled. This test uses 0
+ * as the first value from the RNG and tests it is rejected.
+ */
+ @Test
+ public void testSampleRejectionWithNonPowerOf2Range() {
+ // Test using a RNG that returns a sequence.
+ // The first value of zero should produce a sample that is rejected.
+ final int[] value = new int[1];
+ final UniformRandomProvider rng = new IntProvider() {
+ @Override
+ public int next() {
+ return value[0]++;
+ }
+ };
+
+ // n = upper range exclusive.
+ // Use a prime number to select the rejection algorithm.
+ final int n = 37;
+ final int lower = 0;
+ final int upper = n - 1;
+
+ final SharedStateDiscreteSampler sampler = DiscreteUniformSampler.of(rng, lower, upper);
+
+ final int sample = sampler.sample();
+
+ Assert.assertEquals("Sample is incorrect", 0, sample);
+ Assert.assertEquals("Sample should be produced from 2nd value", 2, value[0]);
+ }
+
/**
* Test the SharedStateSampler implementation.
*/
@@ -50,7 +272,24 @@ public class DiscreteUniformSamplerTest {
*/
@Test
public void testSharedStateSamplerWithLargeRange() {
- testSharedStateSampler(-99999999, Integer.MAX_VALUE);
+ // Set the range so rejection below or above the threshold occurs with approximately p=0.25
+ testSharedStateSampler(Integer.MIN_VALUE / 2 - 1, Integer.MAX_VALUE / 2 + 1);
+ }
+
+ /**
+ * Test the SharedStateSampler implementation.
+ */
+ @Test
+ public void testSharedStateSamplerWithPowerOf2Range() {
+ testSharedStateSampler(0, 31);
+ }
+
+ /**
+ * Test the SharedStateSampler implementation.
+ */
+ @Test
+ public void testSharedStateSamplerWithRangeOf1() {
+ testSharedStateSampler(9, 9);
}
/**
@@ -69,13 +308,38 @@ public class DiscreteUniformSamplerTest {
RandomAssert.assertProduceSameSequence(sampler1, sampler2);
}
+ @Test
+ public void testToStringWithSmallRange() {
+ assertToString(5, 67);
+ }
+
+ @Test
+ public void testToStringWithLargeRange() {
+ assertToString(-99999999, Integer.MAX_VALUE);
+ }
+
+ @Test
+ public void testToStringWithPowerOf2Range() {
+ // Note the range is upper - lower + 1
+ assertToString(0, 31);
+ }
+
+ @Test
+ public void testToStringWithRangeOf1() {
+ assertToString(9, 9);
+ }
+
/**
* Test the toString method. This is added to ensure coverage as the factory constructor
* used in other tests does not create an instance of the wrapper class.
+ *
+ * @param lower Lower.
+ * @param upper Upper.
*/
- @Test
- public void testToString() {
+ private static void assertToString(int lower, int upper) {
final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L);
- Assert.assertTrue(new DiscreteUniformSampler(rng, 1, 2).toString().toLowerCase().contains("uniform"));
+ final DiscreteUniformSampler sampler =
+ new DiscreteUniformSampler(rng, lower, upper);
+ Assert.assertTrue(sampler.toString().toLowerCase(Locale.US).contains("uniform"));
}
}