You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/05/30 19:47:08 UTC
[commons-rng] 01/07: RNG-101: Add MarsagliaTsangWang discrete
probability sampler.
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 00b7cc08f706e9d4fd26c3de6ecdd8c926984538
Author: aherbert <ah...@apache.org>
AuthorDate: Tue May 7 17:24:02 2019 +0100
RNG-101: Add MarsagliaTsangWang discrete probability sampler.
This adds support for a generic distribution defined by an array of
probabilities and also a Poisson and Binomial distribution.
---
.../MarsagliaTsangWangBinomialSampler.java | 257 ++++++++++
.../MarsagliaTsangWangDiscreteSampler.java | 540 +++++++++++++++++++++
.../MarsagliaTsangWangSmallMeanPoissonSampler.java | 218 +++++++++
.../distribution/DiscreteSamplersList.java | 27 ++
.../MarsagliaTsangWangBinomialSamplerTest.java | 242 +++++++++
.../MarsagliaTsangWangDiscreteSamplerTest.java | 332 +++++++++++++
...sagliaTsangWangSmallMeanPoissonSamplerTest.java | 119 +++++
7 files changed, 1735 insertions(+)
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSampler.java
new file mode 100644
index 0000000..5b13155
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSampler.java
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Sampler for the <a href="https://en.wikipedia.org/wiki/Binomial_distribution">Binomial
+ * distribution</a> using an optimised look-up table.
+ *
+ * <ul>
+ * <li>
+ * A Binomial process is simulated using pre-tabulated probabilities, as
+ * described in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of
+ * Discrete Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
+ * </li>
+ * </ul>
+ *
+ * <p>The sampler will fail on construction if the distribution cannot be computed. This
+ * occurs when {@code trials} is large and probability of success is close to {@code 0.5}.
+ * The exact failure condition is:</p>
+ *
+ * <pre>
+ * {@code Math.exp(trials * Math.log(Math.min(p, 1 - p))) < Double.MIN_VALUE}
+ * </pre>
+ *
+ * <p>In this case the distribution can be approximated using a limiting distributions
+ * of either a Poisson or a Normal distribution as appropriate.</p>
+ *
+ * <p>Note: The algorithm ignores any observation where for a sample size of
+ * 2<sup>31</sup> the expected number of occurrences is {@code < 0.5}.</p>
+ *
+ * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}. Storage
+ * requirements depend on the probabilities and are capped at 2<sup>17</sup> bytes, or 131
+ * kB.</p>
+ *
+ * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
+ * 11, Issue 3</a>
+ * @since 1.3
+ */
+public class MarsagliaTsangWangBinomialSampler implements DiscreteSampler {
+ /**
+ * The value 2<sup>30</sup> as an {@code int}.</p>
+ */
+ private static final int INT_30 = 1 << 30;
+ /**
+ * The value 2<sup>16</sup> as an {@code int}.</p>
+ */
+ private static final int INT_16 = 1 << 16;
+ /**
+ * The value 2<sup>31</sup> as an {@code double}.</p>
+ */
+ private static final double DOUBLE_31 = 1L << 31;
+
+ /** The delegate. */
+ private final DiscreteSampler delegate;
+
+ /**
+ * Return a fixed result for the Binomial distribution.
+ */
+ private static class FixedResultDiscreteSampler implements DiscreteSampler {
+ /** The result. */
+ private final int result;
+
+ /**
+ * @param result Result.
+ */
+ FixedResultDiscreteSampler(int result) {
+ this.result = result;
+ }
+
+ @Override
+ public int sample() {
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ return "Binomial deviate";
+ }
+ }
+
+ /**
+ * Return an inversion result for the Binomial distribution. This assumes the
+ * following:
+ *
+ * <pre>
+ * Binomial(n, p) = 1 - Binomial(n, 1 - p)
+ * </pre>
+ */
+ private static class InversionBinomialDiscreteSampler implements DiscreteSampler {
+ /** The number of trials. */
+ private final int trials;
+ /** The Binomial distribution sampler. */
+ private final DiscreteSampler sampler;
+
+ /**
+ * @param trials Number of trials.
+ * @param sampler Binomial distribution sampler.
+ */
+ InversionBinomialDiscreteSampler(int trials, DiscreteSampler sampler) {
+ this.trials = trials;
+ this.sampler = sampler;
+ }
+
+ @Override
+ public int sample() {
+ return trials - sampler.sample();
+ }
+
+ @Override
+ public String toString() {
+ return sampler.toString();
+ }
+ }
+
+ /**
+ * Create a new instance.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param trials Number of trials.
+ * @param p Probability of success.
+ * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16},
+ * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot
+ * be computed.
+ */
+ public MarsagliaTsangWangBinomialSampler(UniformRandomProvider rng, int trials, double p) {
+ if (trials < 0) {
+ throw new IllegalArgumentException("Trials is not positive: " + trials);
+ }
+ if (p < 0 || p > 1) {
+ throw new IllegalArgumentException("Probability is not in range [0,1]: " + p);
+ }
+
+ // Handle edge cases
+ if (p == 0) {
+ delegate = new FixedResultDiscreteSampler(0);
+ return;
+ }
+ if (p == 1) {
+ delegate = new FixedResultDiscreteSampler(trials);
+ return;
+ }
+
+ // A simple check using the supported index size.
+ if (trials >= INT_16) {
+ throw new IllegalArgumentException("Unsupported number of trials: " + trials);
+ }
+
+ // The maximum supported value for Math.exp is approximately -744.
+ // This occurs when trials is large and p is close to 1.
+ // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j
+ final boolean inversion = p > 0.5;
+ if (inversion) {
+ p = 1 - p;
+ }
+
+ // Check if the distribution can be computed
+ final double p0 = Math.exp(trials * Math.log(1 - p));
+ if (p0 < Double.MIN_VALUE) {
+ throw new IllegalArgumentException("Unable to compute distribution");
+ }
+
+ // First find size of probability array
+ double t = p0;
+ final double h = p / (1 - p);
+ // Find first probability
+ int begin = 0;
+ if (t * DOUBLE_31 < 1) {
+ // Somewhere after p(0)
+ // Note:
+ // If this loop is entered p(0) is < 2^-31.
+ // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either
+ // p=0.5 or trials=2^16-1 and does not fail to find the beginning.
+ for (int i = 1; i <= trials; i++) {
+ t *= (trials + 1 - i) * h / i;
+ if (t * DOUBLE_31 >= 1) {
+ begin = i;
+ break;
+ }
+ }
+ }
+ // Find last probability
+ int end = trials;
+ for (int i = begin + 1; i <= trials; i++) {
+ t *= (trials + 1 - i) * h / i;
+ if (t * DOUBLE_31 < 1) {
+ end = i - 1;
+ break;
+ }
+ }
+ final int size = end - begin + 1;
+ final int offset = begin;
+
+ // Then assign probability values as 30-bit integers
+ final int[] prob = new int[size];
+ t = p0;
+ for (int i = 1; i <= begin; i++) {
+ t *= (trials + 1 - i) * h / i;
+ }
+ int sum = toUnsignedInt30(t);
+ prob[0] = sum;
+ for (int i = begin + 1; i <= end; i++) {
+ t *= (trials + 1 - i) * h / i;
+ prob[i - begin] = toUnsignedInt30(t);
+ sum += prob[i - begin];
+ }
+
+ // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))).
+ final int mode = (int) ((trials + 1) * p) - offset;
+ prob[mode] += Math.max(0, INT_30 - sum);
+
+ final MarsagliaTsangWangDiscreteSampler sampler = new MarsagliaTsangWangDiscreteSampler(rng, prob, offset);
+
+ if (inversion) {
+ delegate = new InversionBinomialDiscreteSampler(trials, sampler);
+ } else {
+ delegate = sampler;
+ }
+ }
+
+ /**
+ * Convert the probability to an unsigned integer in the range [0,2^30].
+ *
+ * @param p the probability
+ * @return the integer
+ */
+ private static int toUnsignedInt30(double p) {
+ return (int) (p * INT_30 + 0.5);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int sample() {
+ return delegate.sample();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return "Binomial " + delegate.toString();
+ }
+}
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java
new file mode 100644
index 0000000..a1fc5a7
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java
@@ -0,0 +1,540 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Sampler for a discrete distribution using an optimised look-up table.
+ *
+ * <ul>
+ * <li>
+ * The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described
+ * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
+ * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
+ * </li>
+ * </ul>
+ *
+ * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p>
+ *
+ * <p>Memory requirements depend on the maximum number of possible sample values, {@code n},
+ * and the values for the probabilities. Storage is optimised for {@code n}. The worst case
+ * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for
+ * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for
+ * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p>
+ *
+ * @since 1.3
+ * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
+ * 11, Issue 3</a>
+ */
+public class MarsagliaTsangWangDiscreteSampler implements DiscreteSampler {
+ /** The exclusive upper bound for an unsigned 8-bit integer. */
+ private static final int UNSIGNED_INT_8 = 1 << 8;
+ /** The exclusive upper bound for an unsigned 16-bit integer. */
+ private static final int UNSIGNED_INT_16 = 1 << 16;
+
+ /** Limit for look-up table 1. */
+ private final int t1;
+ /** Limit for look-up table 2. */
+ private final int t2;
+ /** Limit for look-up table 3. */
+ private final int t3;
+ /** Limit for look-up table 4. */
+ private final int t4;
+
+ /** Index look-up table. */
+ private final IndexTable indexTable;
+
+ /** Underlying source of randomness. */
+ private final UniformRandomProvider rng;
+
+ /**
+ * An index table contains the sample values. This is efficiently accessed for any index in the
+ * range {@code [0,2^30)} by using an algorithm based on the decomposition of the index into
+ * 5 base-64 digits.
+ *
+ * <p>This interface defines the methods for the filling and accessing values from 5 tables.
+ * It allows a concrete implementation to allocate appropriate tables to optimise memory
+ * requirements.</p>
+ */
+ private interface IndexTable {
+ /**
+ * @param from Lower bound index (inclusive).
+ * @param to Upper bound index (exclusive).
+ * @param value Value.
+ */
+ void fillTable1(int from, int to, int value);
+ /**
+ * @param from Lower bound index (inclusive).
+ * @param to Upper bound index (exclusive).
+ * @param value Value.
+ */
+ void fillTable2(int from, int to, int value);
+ /**
+ * @param from Lower bound index (inclusive).
+ * @param to Upper bound index (exclusive).
+ * @param value Value.
+ */
+ void fillTable3(int from, int to, int value);
+ /**
+ * @param from Lower bound index (inclusive).
+ * @param to Upper bound index (exclusive).
+ * @param value Value.
+ */
+ void fillTable4(int from, int to, int value);
+ /**
+ * @param from Lower bound index (inclusive).
+ * @param to Upper bound index (exclusive).
+ * @param value Value.
+ */
+ void fillTable5(int from, int to, int value);
+
+ /**
+ * @param index Index.
+ * @return Value.
+ */
+ int getTable1(int index);
+ /**
+ * @param index Index.
+ * @return Value.
+ */
+ int getTable2(int index);
+ /**
+ * @param index Index.
+ * @return Value.
+ */
+ int getTable3(int index);
+ /**
+ * @param index Index.
+ * @return Value.
+ */
+ int getTable4(int index);
+ /**
+ * @param index Index.
+ * @return Value.
+ */
+ int getTable5(int index);
+ }
+
+ /**
+ * Index table for an 8-bit index.
+ */
+ private static class IndexTable8 implements IndexTable {
+ /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */
+ private static final int MASK = 0xff;
+
+ /** Look-up table table1. */
+ private final byte[] table1;
+ /** Look-up table table2. */
+ private final byte[] table2;
+ /** Look-up table table3. */
+ private final byte[] table3;
+ /** Look-up table table4. */
+ private final byte[] table4;
+ /** Look-up table table5. */
+ private final byte[] table5;
+
+ /**
+ * @param n1 Size of table 1.
+ * @param n2 Size of table 2.
+ * @param n3 Size of table 3.
+ * @param n4 Size of table 4.
+ * @param n5 Size of table 5.
+ */
+ IndexTable8(int n1, int n2, int n3, int n4, int n5) {
+ table1 = new byte[n1];
+ table2 = new byte[n2];
+ table3 = new byte[n3];
+ table4 = new byte[n4];
+ table5 = new byte[n5];
+ }
+
+ @Override
+ public void fillTable1(int from, int to, int value) { fill(table1, from, to, value); }
+ @Override
+ public void fillTable2(int from, int to, int value) { fill(table2, from, to, value); }
+ @Override
+ public void fillTable3(int from, int to, int value) { fill(table3, from, to, value); }
+ @Override
+ public void fillTable4(int from, int to, int value) { fill(table4, from, to, value); }
+ @Override
+ public void fillTable5(int from, int to, int value) { fill(table5, from, to, value); }
+
+ /**
+ * Fill the table with the value.
+ *
+ * @param table Table.
+ * @param from Lower bound index (inclusive)
+ * @param to Upper bound index (exclusive)
+ * @param value Value.
+ */
+ private static void fill(byte[] table, int from, int to, int value) {
+ while (from < to) {
+ // Primitive type conversion will extract lower 8 bits
+ table[from++] = (byte) value;
+ }
+ }
+
+ @Override
+ public int getTable1(int index) { return table1[index] & MASK; }
+ @Override
+ public int getTable2(int index) { return table2[index] & MASK; }
+ @Override
+ public int getTable3(int index) { return table3[index] & MASK; }
+ @Override
+ public int getTable4(int index) { return table4[index] & MASK; }
+ @Override
+ public int getTable5(int index) { return table5[index] & MASK; }
+ }
+
+ /**
+ * Index table for a 16-bit index.
+ */
+ private static class IndexTable16 implements IndexTable {
+ /** The mask to convert a {@code short} to an unsigned 16-bit integer. */
+ private static final int MASK = 0xffff;
+
+ /** Look-up table table1. */
+ private final short[] table1;
+ /** Look-up table table2. */
+ private final short[] table2;
+ /** Look-up table table3. */
+ private final short[] table3;
+ /** Look-up table table4. */
+ private final short[] table4;
+ /** Look-up table table5. */
+ private final short[] table5;
+
+ /**
+ * @param n1 Size of table 1.
+ * @param n2 Size of table 2.
+ * @param n3 Size of table 3.
+ * @param n4 Size of table 4.
+ * @param n5 Size of table 5.
+ */
+ IndexTable16(int n1, int n2, int n3, int n4, int n5) {
+ table1 = new short[n1];
+ table2 = new short[n2];
+ table3 = new short[n3];
+ table4 = new short[n4];
+ table5 = new short[n5];
+ }
+
+ @Override
+ public void fillTable1(int from, int to, int value) { fill(table1, from, to, value); }
+ @Override
+ public void fillTable2(int from, int to, int value) { fill(table2, from, to, value); }
+ @Override
+ public void fillTable3(int from, int to, int value) { fill(table3, from, to, value); }
+ @Override
+ public void fillTable4(int from, int to, int value) { fill(table4, from, to, value); }
+ @Override
+ public void fillTable5(int from, int to, int value) { fill(table5, from, to, value); }
+
+ /**
+ * Fill the table with the value.
+ *
+ * @param table Table.
+ * @param from Lower bound index (inclusive)
+ * @param to Upper bound index (exclusive)
+ * @param value Value.
+ */
+ private static void fill(short[] table, int from, int to, int value) {
+ while (from < to) {
+ // Primitive type conversion will extract lower 16 bits
+ table[from++] = (short) value;
+ }
+ }
+
+ @Override
+ public int getTable1(int index) { return table1[index] & MASK; }
+ @Override
+ public int getTable2(int index) { return table2[index] & MASK; }
+ @Override
+ public int getTable3(int index) { return table3[index] & MASK; }
+ @Override
+ public int getTable4(int index) { return table4[index] & MASK; }
+ @Override
+ public int getTable5(int index) { return table5[index] & MASK; }
+ }
+
+ /**
+ * Index table for a 32-bit index.
+ */
+ private static class IndexTable32 implements IndexTable {
+ /** Look-up table table1. */
+ private final int[] table1;
+ /** Look-up table table2. */
+ private final int[] table2;
+ /** Look-up table table3. */
+ private final int[] table3;
+ /** Look-up table table4. */
+ private final int[] table4;
+ /** Look-up table table5. */
+ private final int[] table5;
+
+ /**
+ * @param n1 Size of table 1.
+ * @param n2 Size of table 2.
+ * @param n3 Size of table 3.
+ * @param n4 Size of table 4.
+ * @param n5 Size of table 5.
+ */
+ IndexTable32(int n1, int n2, int n3, int n4, int n5) {
+ table1 = new int[n1];
+ table2 = new int[n2];
+ table3 = new int[n3];
+ table4 = new int[n4];
+ table5 = new int[n5];
+ }
+
+ @Override
+ public void fillTable1(int from, int to, int value) { fill(table1, from, to, value); }
+ @Override
+ public void fillTable2(int from, int to, int value) { fill(table2, from, to, value); }
+ @Override
+ public void fillTable3(int from, int to, int value) { fill(table3, from, to, value); }
+ @Override
+ public void fillTable4(int from, int to, int value) { fill(table4, from, to, value); }
+ @Override
+ public void fillTable5(int from, int to, int value) { fill(table5, from, to, value); }
+
+ /**
+ * Fill the table with the value.
+ *
+ * @param table Table.
+ * @param from Lower bound index (inclusive)
+ * @param to Upper bound index (exclusive)
+ * @param value Value.
+ */
+ private static void fill(int[] table, int from, int to, int value) {
+ while (from < to) {
+ table[from++] = value;
+ }
+ }
+
+ @Override
+ public int getTable1(int index) { return table1[index]; }
+ @Override
+ public int getTable2(int index) { return table2[index]; }
+ @Override
+ public int getTable3(int index) { return table3[index]; }
+ @Override
+ public int getTable4(int index) { return table4[index]; }
+ @Override
+ public int getTable5(int index) { return table5[index]; }
+ }
+
+ /**
+ * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is
+ * {@code i + offset}.
+ *
+ * <p>The sum of the probabilities must be >= 2<sup>30</sup>. Only the
+ * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p>
+ *
+ * <p>Note: This is package-private for use by discrete distribution samplers that can
+ * compute their probability distribution.</p>
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param prob The probabilities.
+ * @param offset The offset (must be positive).
+ * @throws IllegalArgumentException if the offset is negative or the maximum sample index
+ * exceeds the maximum positive {@code int} value (2<sup>31</sup> - 1).
+ */
+ MarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
+ int[] prob,
+ int offset) {
+ if (offset < 0) {
+ throw new IllegalArgumentException("Unsupported offset: " + offset);
+ }
+ if ((long) prob.length + offset > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("Unsupported sample index: " + (prob.length + offset));
+ }
+
+ this.rng = rng;
+
+ // Get table sizes for each base-64 digit
+ int n1 = 0;
+ int n2 = 0;
+ int n3 = 0;
+ int n4 = 0;
+ int n5 = 0;
+ for (final int m : prob) {
+ n1 += getBase64Digit(m, 1);
+ n2 += getBase64Digit(m, 2);
+ n3 += getBase64Digit(m, 3);
+ n4 += getBase64Digit(m, 4);
+ n5 += getBase64Digit(m, 5);
+ }
+
+ // Allocate tables based on the maximum index
+ final int maxIndex = prob.length + offset - 1;
+ if (maxIndex < UNSIGNED_INT_8) {
+ indexTable = new IndexTable8(n1, n2, n3, n4, n5);
+ } else if (maxIndex < UNSIGNED_INT_16) {
+ indexTable = new IndexTable16(n1, n2, n3, n4, n5);
+ } else {
+ indexTable = new IndexTable32(n1, n2, n3, n4, n5);
+ }
+
+ // Compute offsets
+ t1 = n1 << 24;
+ t2 = t1 + (n2 << 18);
+ t3 = t2 + (n3 << 12);
+ t4 = t3 + (n4 << 6);
+ n1 = n2 = n3 = n4 = n5 = 0;
+
+ // Fill tables
+ for (int i = 0; i < prob.length; i++) {
+ final int m = prob[i];
+ final int k = i + offset;
+ indexTable.fillTable1(n1, n1 += getBase64Digit(m, 1), k);
+ indexTable.fillTable2(n2, n2 += getBase64Digit(m, 2), k);
+ indexTable.fillTable3(n3, n3 += getBase64Digit(m, 3), k);
+ indexTable.fillTable4(n4, n4 += getBase64Digit(m, 4), k);
+ indexTable.fillTable5(n5, n5 += getBase64Digit(m, 5), k);
+ }
+ }
+
+ /**
+ * Creates a sampler.
+ *
+ * <p>The probabilities will be normalised using their sum. The only requirement is the sum
+ * is positive.</p>
+ *
+ * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Any probability less
+ * than 2<sup>-30</sup> will not be observed in samples. An adjustment is made to the maximum
+ * probability to compensate for round-off during conversion.</p>
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param probabilities The list of probabilities.
+ * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
+ * probability is negative, infinite or {@code NaN}, or the sum of all
+ * probabilities is not strictly positive.
+ */
+ public MarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng,
+ double[] probabilities) {
+ this(rng, normaliseProbabilities(probabilities), 0);
+ }
+
+ /**
+ * Normalise the probabilities to integers that sum to 2<sup>30</sup>.
+ *
+ * @param probabilities The list of probabilities.
+ * @return the normalised probabilities.
+ * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
+ * probability is negative, infinite or {@code NaN}, or the sum of all
+ * probabilities is not strictly positive.
+ */
+ private static int[] normaliseProbabilities(double[] probabilities) {
+ final double sumProb = validateProbabilities(probabilities);
+
+ // Compute the normalisation: 2^30 / sum
+ final double normalisation = (1 << 30) / sumProb;
+ final int[] prob = new int[probabilities.length];
+ int sum = 0;
+ int max = 0;
+ int mode = 0;
+ for (int i = 0; i < prob.length; i++) {
+ // Add 0.5 for rounding
+ final int p = (int) (probabilities[i] * normalisation + 0.5);
+ sum += p;
+ // Find the mode (maximum probability)
+ if (max < p) {
+ max = p;
+ mode = i;
+ }
+ prob[i] = p;
+ }
+
+ // The sum must be >= 2^30.
+ // Here just compensate the difference onto the highest probability.
+ prob[mode] += (1 << 30) - sum;
+
+ return prob;
+ }
+
+ /**
+ * Validate the probabilities sum to a finite positive number.
+ *
+ * @param probabilities the probabilities
+ * @return the sum
+ * @throws IllegalArgumentException if {@code probabilities} is null or empty, a
+ * probability is negative, infinite or {@code NaN}, or the sum of all
+ * probabilities is not strictly positive.
+ */
+ private static double validateProbabilities(double[] probabilities) {
+ if (probabilities == null || probabilities.length == 0) {
+ throw new IllegalArgumentException("Probabilities must not be empty.");
+ }
+
+ double sumProb = 0;
+ for (final double prob : probabilities) {
+ if (prob < 0 ||
+ Double.isInfinite(prob) ||
+ Double.isNaN(prob)) {
+ throw new IllegalArgumentException("Invalid probability: " +
+ prob);
+ }
+ sumProb += prob;
+ }
+
+ if (Double.isInfinite(sumProb) || sumProb <= 0) {
+ throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb);
+ }
+ return sumProb;
+ }
+
+ /**
+ * Gets the k<sup>th</sup> base 64 digit of {@code m}.
+ *
+ * @param m the value m.
+ * @param k the digit.
+ * @return the base 64 digit
+ */
+ private static int getBase64Digit(int m, int k) {
+ return (m >>> (30 - 6 * k)) & 63;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int sample() {
+ final int j = rng.nextInt() >>> 2;
+ if (j < t1) {
+ return indexTable.getTable1(j >>> 24);
+ }
+ if (j < t2) {
+ return indexTable.getTable2((j - t1) >>> 18);
+ }
+ if (j < t3) {
+ return indexTable.getTable3((j - t2) >>> 12);
+ }
+ if (j < t4) {
+ return indexTable.getTable4((j - t3) >>> 6);
+ }
+ // Note the tables are filled on the assumption that the sum of the probabilities.
+ // is >=2^30. If this is not true then the final table table5 will be smaller by the
+ // difference. So the tables *must* be constructed correctly.
+ return indexTable.getTable5(j - t4);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return "Marsaglia Tsang Wang discrete deviate [" + rng.toString() + "]";
+ }
+}
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSampler.java
new file mode 100644
index 0000000..4ac66e8
--- /dev/null
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSampler.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson
+ * distribution</a> using an optimised look-up table.
+ *
+ * <ul>
+ * <li>
+ * A Poisson process is simulated using pre-tabulated probabilities, as described
+ * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete
+ * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11.
+ * </li>
+ * </ul>
+ *
+ * <p>This sampler is suitable for {@code mean <= 1024}. Larger means accumulate errors
+ * when tabulating the Poisson probability. For large means, {@link LargeMeanPoissonSampler}
+ * should be used instead.</p>
+ *
+ * <p>Note: The algorithm ignores any observation where for a sample size of
+ * 2<sup>31</sup> the expected number of occurrences is {@code < 0.5}.</p>
+ *
+ * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}. Storage requirements
+ * depend on the tabulated probability values. Example storage requirements are listed below.</p>
+ *
+ * <pre>
+ * mean table size kB
+ * 0.25 882 0.88
+ * 0.5 1135 1.14
+ * 1 1200 1.20
+ * 2 1451 1.45
+ * 4 1955 1.96
+ * 8 2961 2.96
+ * 16 4410 4.41
+ * 32 6115 6.11
+ * 64 8499 8.50
+ * 128 11528 11.53
+ * 256 15935 31.87
+ * 512 20912 41.82
+ * 1024 30614 61.23
+ * </pre>
+ *
+ * <p>Note: Storage changes to 2 bytes per index when {@code mean=256}.</p>
+ *
+ * @since 1.3
+ * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol.
+ * 11, Issue 3</a>
+ */
+public class MarsagliaTsangWangSmallMeanPoissonSampler implements DiscreteSampler {
+ /**
+ * The value 2<sup>30</sup> as an {@code int}.</p>
+ */
+ private static final int INT_30 = 1 << 30;
+ /**
+ * The value 2<sup>31</sup> as an {@code double}.</p>
+ */
+ private static final double DOUBLE_31 = 1L << 31;
+ /**
+ * Upper bound to avoid exceeding the table sizes.
+ *
+ * <p>The number of possible values of the distribution should not exceed 2^16.</p>
+ *
+ * <p>The original source code provided in Marsaglia, et al (2004) has no explicit
+ * limit but the code fails at mean >= 1941 as the transform to compute p(x=mode)
+ * produces infinity. Use a conservative limit of 1024.</p>
+ */
+ private static final double MAX_MEAN = 1024;
+
+ /** The delegate. */
+ private final DiscreteSampler delegate;
+
+ /**
+ * Create a new instance.
+ *
+ * @param rng Generator of uniformly distributed random numbers.
+ * @param mean Mean.
+ * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}.
+ */
+ public MarsagliaTsangWangSmallMeanPoissonSampler(UniformRandomProvider rng, double mean) {
+ if (mean <= 0) {
+ throw new IllegalArgumentException("mean is not strictly positive: " + mean);
+ }
+ // The algorithm is not valid if Math.floor(mean) is not an integer.
+ if (mean > MAX_MEAN) {
+ throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN);
+ }
+
+ // Probabilities are 30-bit integers, assumed denominator 2^30
+ int[] prob;
+ // This is the minimum sample value: prob[x - offset] = p(x)
+ int offset;
+
+ // Generate P's from 0 if mean < 21.4
+ if (mean < 21.4) {
+ final double p0 = Math.exp(-mean);
+
+ // Recursive update of Poisson probability until the value is too small
+ // p(x + 1) = p(x) * mean / (x + 1)
+ double p = p0;
+ int i;
+ for (i = 1; p * DOUBLE_31 >= 1; i++) {
+ p *= mean / i;
+ }
+
+ // Fill P as (30-bit integers)
+ offset = 0;
+ final int size = i - 1;
+ prob = new int[size];
+
+ p = p0;
+ prob[0] = toUnsignedInt30(p);
+ // The sum must exceed 2^30. In edges cases this is false due to round-off.
+ int sum = prob[0];
+ for (i = 1; i < prob.length; i++) {
+ p *= mean / i;
+ prob[i] = toUnsignedInt30(p);
+ sum += prob[i];
+ }
+
+ // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)).
+ prob[(int) mean] += Math.max(0, INT_30 - sum);
+ } else {
+ // If mean >= 21.4, generate from largest p-value up, then largest down.
+ // The largest p-value will be at the mode (floor(mean)).
+
+ // Find p(x=mode)
+ final int mode = (int) mean;
+ // This transform is stable until mean >= 1941 where p will result in Infinity
+ // before the divisor i is large enough to start reducing the product (i.e. i > c).
+ final double c = mean * Math.exp(-mean / mode);
+ double p = 1.0;
+ int i;
+ for (i = 1; i <= mode; i++) {
+ p *= c / i;
+ }
+ final double pX = p;
+ // Note this will exit when i overflows to negative so no check on the range
+ for (i = mode + 1; p * DOUBLE_31 >= 1; i++) {
+ p *= mean / i;
+ }
+ final int last = i - 2;
+ p = pX;
+ int j = -1;
+ for (i = mode - 1; i >= 0; i--) {
+ p *= (i + 1) / mean;
+ if (p * DOUBLE_31 < 1) {
+ j = i;
+ break;
+ }
+ }
+
+ // Fill P as (30-bit integers)
+ offset = j + 1;
+ final int size = last - offset + 1;
+ prob = new int[size];
+
+ p = pX;
+ prob[mode - offset] = toUnsignedInt30(p);
+ // The sum must exceed 2^30. In edges cases this is false due to round-off.
+ int sum = prob[mode - offset];
+ for (i = mode + 1; i <= last; i++) {
+ p *= mean / i;
+ prob[i - offset] = toUnsignedInt30(p);
+ sum += prob[i - offset];
+ }
+ p = pX;
+ for (i = mode - 1; i >= offset; i--) {
+ p *= (i + 1) / mean;
+ prob[i - offset] = toUnsignedInt30(p);
+ sum += prob[i - offset];
+ }
+
+ // If the sum is < 2^30 add the remaining sum to the mode
+ prob[mode - offset] += Math.max(0, INT_30 - sum);
+ }
+
+ delegate = new MarsagliaTsangWangDiscreteSampler(rng, prob, offset);
+ }
+
+ /**
+ * Convert the probability to an unsigned integer in the range [0,2^30].
+ *
+ * @param p the probability
+ * @return the integer
+ */
+ private static int toUnsignedInt30(double p) {
+ return (int) (p * INT_30 + 0.5);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public int sample() {
+ return delegate.sample();
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public String toString() {
+ return "Small Mean Poisson " + delegate.toString();
+ }
+}
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
index 5dab832..6158a2d 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -50,6 +50,15 @@ public class DiscreteSamplersList {
add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
MathArrays.sequence(8, 9, 1),
RandomSource.create(RandomSource.KISS));
+ add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial),
+ // range [9,16]
+ MathArrays.sequence(8, 9, 1),
+ new MarsagliaTsangWangBinomialSampler(RandomSource.create(RandomSource.WELL_19937_A), trialsBinomial, probSuccessBinomial));
+ // Inverted
+ add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial),
+ // range [4,11] = [20-16, 20-9]
+ MathArrays.sequence(8, 4, 1),
+ new MarsagliaTsangWangBinomialSampler(RandomSource.create(RandomSource.WELL_19937_C), trialsBinomial, 1 - probSuccessBinomial));
// Geometric ("inverse method").
final double probSuccessGeometric = 0.21;
@@ -146,6 +155,11 @@ public class DiscreteSamplersList {
add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson),
MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1),
new LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson));
+
+ // Any discrete distribution
+ double[] discreteProbabilities = new double[] { 0.1, 0.2, 0.3, 0.4 };
+ add(LIST, discreteProbabilities,
+ new MarsagliaTsangWangDiscreteSampler(RandomSource.create(RandomSource.XO_SHI_RO_512_PLUS), discreteProbabilities));
} catch (Exception e) {
System.err.println("Unexpected exception while creating the list of samplers: " + e);
e.printStackTrace(System.err);
@@ -201,6 +215,19 @@ public class DiscreteSamplersList {
}
/**
+ * @param list List of data (one the "parameters" tested by the Junit parametric test).
+ * @param probabilities Probability distribution to which the samples are supposed to conform.
+ * @param sampler Sampler.
+ */
+ private static void add(List<DiscreteSamplerTestData[]> list,
+ final double[] probabilities,
+ final DiscreteSampler sampler) {
+ list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(sampler,
+ MathArrays.natural(probabilities.length),
+ probabilities) });
+ }
+
+ /**
* Subclasses that are "parametric" tests can forward the call to
* the "@Parameters"-annotated method to this method.
*
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSamplerTest.java
new file mode 100644
index 0000000..bfe5052
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSamplerTest.java
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+import org.junit.Test;
+
+import org.junit.Assert;
+
+/**
+ * Test for the {@link MarsagliaTsangWangBinomialSampler}. The tests hit edge cases for
+ * the sampler.
+ */
+public class MarsagliaTsangWangBinomialSamplerTest {
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithTrialsBelow0() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final int trials = -1;
+ final double p = 0.5;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithTrialsAboveMax() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final int trials = 1 << 16; // 2^16
+ final double p = 0.5;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithProbabilityBelow0() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final int trials = 1;
+ final double p = -0.5;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithProbabilityAbove1() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final int trials = 1;
+ final double p = 1.5;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ }
+
+ /**
+ * Test the constructor with distribution parameters that create a very small p(0)
+ * with a high probability of success.
+ */
+ @Test
+ public void testSamplerWithSmallestP0ValueAndHighestProbabilityOfSuccess() {
+ final UniformRandomProvider rng = new FixedRNG(0xffffffff);
+ // p(0) = Math.exp(trials * Math.log(1-p))
+ // p(0) will be smaller as Math.log(1-p) is more negative, which occurs when p is
+ // larger.
+ // Since the sampler uses inversion the largest value for p is 0.5.
+ // At the extreme for p = 0.5:
+ // trials = Math.log(p(0)) / Math.log(1-p)
+ // = Math.log(Double.MIN_VALUE) / Math.log(0.5)
+ // = 1074
+ final int trials = (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5));
+ final double p = 0.5;
+ // Validate set-up
+ Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getP0(trials, p), 0);
+ Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials + 1, p), 0);
+
+ // This will throw if the table does not sum to 2^30
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ sampler.sample();
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWhenP0IsZero() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ // As above but increase the trials so p(0) should be zero
+ final int trials = 1 + (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5));
+ final double p = 0.5;
+ // Validate set-up
+ Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials, p), 0);
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ }
+
+ /**
+ * Test the constructor with distribution parameters that create a very small p(0)
+ * with a high number of trials.
+ */
+ @Test
+ public void testSamplerWithLargestTrialsAndSmallestProbabilityOfSuccess() {
+ final UniformRandomProvider rng = new FixedRNG(0xffffffff);
+ // p(0) = Math.exp(trials * Math.log(1-p))
+ // p(0) will be smaller as Math.log(1-p) is more negative, which occurs when p is
+ // larger.
+ // Since the sampler uses inversion the largest value for p is 0.5.
+ // At the extreme for trials = 2^16-1:
+ // p = 1 - Math.exp(Math.log(p(0)) / trials)
+ // = 1 - Math.exp(Math.log(Double.MIN_VALUE) / trials)
+ // = 0.011295152668039599
+ final int trials = (1 << 16) - 1;
+ double p = 1 - Math.exp(Math.log(Double.MIN_VALUE) / trials);
+
+ // Validate set-up
+ Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getP0(trials, p), 0);
+
+ // Search for larger p until Math.nextAfter(p, 1) produces 0
+ double upper = p * 2;
+ Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials, upper), 0);
+
+ double lower = p;
+ while (Double.doubleToRawLongBits(lower) + 1 < Double.doubleToRawLongBits(upper)) {
+ final double mid = (upper + lower) / 2;
+ if (getP0(trials, mid) == 0) {
+ upper = mid;
+ } else {
+ lower = mid;
+ }
+ }
+ p = lower;
+
+ // Re-validate
+ Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getP0(trials, p), 0);
+ Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials, Math.nextAfter(p, 1)), 0);
+
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ // This will throw if the table does not sum to 2^30
+ sampler.sample();
+ }
+
+ /**
+ * Gets the p(0) value.
+ *
+ * @param trials the trials
+ * @param probabilityOfSuccess the probability of success
+ * @return the p(0) value
+ */
+ private static double getP0(int trials, double probabilityOfSuccess) {
+ return Math.exp(trials * Math.log(1 - probabilityOfSuccess));
+ }
+
+ @Test
+ public void testSamplerWithProbability0() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final int trials = 1000000;
+ final double p = 0;
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(0, sampler.sample());
+ }
+ // Hit the toString() method
+ Assert.assertTrue(sampler.toString().contains("Binomial"));
+ }
+
+ @Test
+ public void testSamplerWithProbability1() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final int trials = 1000000;
+ final double p = 1;
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ for (int i = 0; i < 5; i++) {
+ Assert.assertEquals(trials, sampler.sample());
+ }
+ // Hit the toString() method
+ Assert.assertTrue(sampler.toString().contains("Binomial"));
+ }
+
+ /**
+ * Test the sampler with a large number of trials. This tests the sampler can create the
+ * Binomial distribution for a large size when a limiting distribution (e.g. the Normal distribution)
+ * could be used instead.
+ */
+ @Test
+ public void testSamplerWithLargeNumberOfTrials() {
+ final UniformRandomProvider rng = new FixedRNG(0xffffffff);
+ final int trials = 65000;
+ final double p = 0.01;
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ // This will throw if the table does not sum to 2^30
+ sampler.sample();
+ }
+
+ /**
+ * Test the sampler with a probability of 0.5. This should hit the edge case in the loop to
+ * search for the last probability of the Binomial distribution.
+ */
+ @Test
+ public void testSamplerWithProbability0_5() {
+ final UniformRandomProvider rng = new FixedRNG(0xffffffff);
+ final int trials = 10;
+ final double p = 0.5;
+ final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p);
+ // This will throw if the table does not sum to 2^30
+ sampler.sample();
+ }
+
+ /**
+ * A RNG returning a fixed value.
+ */
+ private static class FixedRNG implements UniformRandomProvider {
+ /** The value. */
+ private final int value;
+
+ /**
+ * @param value the value
+ */
+ FixedRNG(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public int nextInt() {
+ return value;
+ }
+
+ public void nextBytes(byte[] bytes) {}
+ public void nextBytes(byte[] bytes, int start, int len) {}
+ public int nextInt(int n) { return 0; }
+ public long nextLong() { return 0; }
+ public long nextLong(long n) { return 0; }
+ public boolean nextBoolean() { return false; }
+ public float nextFloat() { return 0; }
+ public double nextDouble() { return 0; }
+ }
+}
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java
new file mode 100644
index 0000000..d1fce42
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.math3.stat.inference.ChiSquareTest;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.core.source32.IntProvider;
+import org.apache.commons.rng.core.source64.SplitMix64;
+import org.apache.commons.rng.simple.RandomSource;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test for the {@link MarsagliaTsangWangDiscreteSampler}. The tests hit edge cases for
+ * the sampler.
+ */
+public class MarsagliaTsangWangDiscreteSamplerTest {
+ // Tests for the package-private constructor using int[] + offset
+
+ /**
+ * Test constructor throws with max index above integer max.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithMaxIndexAboveIntegerMax() {
+ final int[] prob = new int[1];
+ final int offset = Integer.MAX_VALUE;
+ createSampler(prob, offset);
+ }
+
+ /**
+ * Test constructor throws with negative offset.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithNegativeOffset() {
+ final int[] prob = new int[1];
+ final int offset = -1;
+ createSampler(prob, offset);
+ }
+
+ /**
+ * Test construction is allowed or when max index equals integer max.
+ */
+ @Test
+ public void testConstructorWhenMaxIndexEqualsIntegerMax() {
+ final int[] prob = new int[1];
+ prob[0] = 1 << 30; // So the total probability is 2^30
+ final int offset = Integer.MAX_VALUE - 1;
+ createSampler(prob, offset);
+ }
+
+ /**
+ * Creates the sampler.
+ *
+ * @param prob the probabilities
+ * @param offset the offset
+ * @return the sampler
+ */
+ private static MarsagliaTsangWangDiscreteSampler createSampler(final int[] probabilities, int offset) {
+ final UniformRandomProvider rng = new SplitMix64(0L);
+ return new MarsagliaTsangWangDiscreteSampler(rng, probabilities, offset);
+ }
+
+ // Tests for the public constructor using double[]
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithNullProbabilites() {
+ createSampler(null);
+ }
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithZeroLengthProbabilites() {
+ createSampler(new double[0]);
+ }
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithNegativeProbabilites() {
+ createSampler(new double[] { -1, 0.1, 0.2 });
+ }
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithNaNProbabilites() {
+ createSampler(new double[] { 0.1, Double.NaN, 0.2 });
+ }
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithInfiniteProbabilites() {
+ createSampler(new double[] { 0.1, Double.POSITIVE_INFINITY, 0.2 });
+ }
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithInfiniteSumProbabilites() {
+ createSampler(new double[] { Double.MAX_VALUE, Double.MAX_VALUE });
+ }
+
+ @Test(expected=IllegalArgumentException.class)
+ public void testConstructorThrowsWithZeroSumProbabilites() {
+ createSampler(new double[4]);
+ }
+
+ /**
+ * Creates the sampler.
+ *
+ * @param probabilities the probabilities
+ * @return the sampler
+ */
+ private static MarsagliaTsangWangDiscreteSampler createSampler(double[] probabilities) {
+ final UniformRandomProvider rng = new SplitMix64(0L);
+ return new MarsagliaTsangWangDiscreteSampler(rng, probabilities);
+ }
+
+ // Sampling tests
+
+ /**
+ * Test offset samples. This test hits all code paths in the sampler for 8, 16, and 32-bit
+ * storage using different offsets to control the maximum sample value.
+ */
+ @Test
+ public void testOffsetSamples() {
+ // This is filled with probabilities to hit all edge cases in the fill procedure.
+ // The probabilities must have a digit from each of the 5 possible.
+ final int[] prob = new int[6];
+ prob[0] = 1;
+ prob[1] = 1 + 1 << 6;
+ prob[2] = 1 + 1 << 12;
+ prob[3] = 1 + 1 << 18;
+ prob[4] = 1 + 1 << 24;
+ // Ensure probabilities sum to 2^30
+ prob[5] = (1 << 30) - (prob[0] + prob[1] + prob[2] + prob[3] + prob[4]);
+
+ // To hit all samples requires integers that are under the look-up table limits.
+ // So compute the limits here.
+ int n1 = 0;
+ int n2 = 0;
+ int n3 = 0;
+ int n4 = 0;
+ for (final int m : prob) {
+ n1 += getBase64Digit(m, 1);
+ n2 += getBase64Digit(m, 2);
+ n3 += getBase64Digit(m, 3);
+ n4 += getBase64Digit(m, 4);
+ }
+
+ final int t1 = n1 << 24;
+ final int t2 = t1 + (n2 << 18);
+ final int t3 = t2 + (n3 << 12);
+ final int t4 = t3 + (n4 << 6);
+
+ // Create values under the limits and bit shift by 2 to reverse what the sampler does.
+ final int[] values = new int[] { 0, t1, t2, t3, t4, 0xffffffff };
+ for (int i = 0; i < values.length; i++) {
+ values[i] <<= 2;
+ }
+
+ final UniformRandomProvider rng1 = new FixedSequenceIntProvider(values);
+ final UniformRandomProvider rng2 = new FixedSequenceIntProvider(values);
+ final UniformRandomProvider rng3 = new FixedSequenceIntProvider(values);
+
+ // Create offsets to force storage as 8, 16, or 32-bit
+ final int offset1 = 1;
+ final int offset2 = 1 << 8;
+ final int offset3 = 1 << 16;
+
+ final MarsagliaTsangWangDiscreteSampler sampler1 = new MarsagliaTsangWangDiscreteSampler(rng1, prob, offset1);
+ final MarsagliaTsangWangDiscreteSampler sampler2 = new MarsagliaTsangWangDiscreteSampler(rng2, prob, offset2);
+ final MarsagliaTsangWangDiscreteSampler sampler3 = new MarsagliaTsangWangDiscreteSampler(rng3, prob, offset3);
+
+ for (int i = 0; i < values.length; i++) {
+ // Remove offsets
+ final int s1 = sampler1.sample() - offset1;
+ final int s2 = sampler2.sample() - offset2;
+ final int s3 = sampler3.sample() - offset3;
+ Assert.assertEquals("Offset sample 1 and 2 do not match", s1, s2);
+ Assert.assertEquals("Offset Sample 1 and 3 do not match", s1, s3);
+ }
+ }
+
+ /**
+ * Test samples from a distribution expressed using {@code double} probabilities.
+ */
+ @Test
+ public void testRealProbabilityDistributionSamples() {
+ // These do not have to sum to 1
+ final double[] probabilities = new double[11];
+ final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64);
+ for (int i = 0; i < probabilities.length; i++) {
+ probabilities[i] = rng.nextDouble();
+ }
+
+ // First test the table is completely filled to 2^30
+ final UniformRandomProvider dummyRng = new FixedSequenceIntProvider(new int[] { 0xffffffff});
+ final MarsagliaTsangWangDiscreteSampler dummySampler = new MarsagliaTsangWangDiscreteSampler(dummyRng, probabilities);
+ // This will throw if the table is incomplete as it hits the upper limit
+ dummySampler.sample();
+
+ // Do a test of the actual sampler
+ final MarsagliaTsangWangDiscreteSampler sampler = new MarsagliaTsangWangDiscreteSampler(rng, probabilities);
+
+ final int numberOfSamples = 10000;
+ final long[] samples = new long[probabilities.length];
+ for (int i = 0; i < numberOfSamples; i++) {
+ samples[sampler.sample()]++;
+ }
+
+ final ChiSquareTest chiSquareTest = new ChiSquareTest();
+ // Pass if we cannot reject null hypothesis that the distributions are the same.
+ Assert.assertFalse(chiSquareTest.chiSquareTest(probabilities, samples, 0.001));
+ }
+
+ /**
+ * Test the storage requirements for a worst case set of 2^8 probabilities. This tests the
+ * limits described in the class Javadoc is correct.
+ */
+ @Test
+ public void testStorageRequirements8() {
+ // Max digits from 2^22:
+ // (2^4 + 2^6 + 2^6 + 2^6)
+ // Storage in bytes
+ // = (15 + 3 * 63) * 2^8
+ // = 52224 B
+ // = 0.0522 MB
+ checkStorageRequirements(8, 0.06);
+ }
+
+ /**
+ * Test the storage requirements for a worst case set of 2^16 probabilities. This tests the
+ * limits described in the class Javadoc is correct.
+ */
+ @Test
+ public void testStorageRequirements16() {
+ // Max digits from 2^14:
+ // (2^2 + 2^6 + 2^6)
+ // Storage in bytes
+ // = 2 * (3 + 2 * 63) * 2^16
+ // = 16908288 B
+ // = 16.91 MB
+ checkStorageRequirements(16, 17.0);
+ }
+
+ /**
+ * Test the storage requirements for a worst case set of 2^k probabilities. This
+ * tests the limits described in the class Javadoc is correct.
+ *
+ * @param k Base is 2^k.
+ * @param expectedLimitMB the expected limit in MB
+ */
+ private static void checkStorageRequirements(int k, double expectedLimitMB) {
+ // Worst case scenario is a uniform distribution of 2^k samples each with the highest
+ // mask set for base 64 digits.
+ // The max number of samples: 2^k
+ final int maxSamples = (1 << k);
+
+ // The highest value for each sample:
+ // 2^30 / 2^k = 2^(30-k)
+ // The highest mask is all bits set
+ final int m = (1 << (30 - k)) - 1;
+
+ // Check the sum is less than 2^30
+ final long sum = (long) maxSamples * m;
+ final int total = 1 << 30;
+ Assert.assertTrue("Worst case uniform distribution is above 2^30", sum < total);
+
+ // Get the digits as per the sampler and compute storage
+ final int d1 = getBase64Digit(m, 1);
+ final int d2 = getBase64Digit(m, 2);
+ final int d3 = getBase64Digit(m, 3);
+ final int d4 = getBase64Digit(m, 4);
+ final int d5 = getBase64Digit(m, 5);
+ // Compute storage in MB assuming 2 byte storage
+ int bytes;
+ if (k <= 8) {
+ bytes = 1;
+ } else if (k <= 16) {
+ bytes = 2;
+ } else {
+ bytes = 4;
+ }
+ final double storageMB = bytes * 1e-6 * (d1 + d2 + d3 + d4 + d5) * maxSamples;
+ Assert.assertTrue(
+ "Worst case uniform distribution storage " + storageMB + "MB is above expected limit: " + expectedLimitMB,
+ storageMB < expectedLimitMB);
+ }
+
+ /**
+ * Gets the k<sup>th</sup> base 64 digit of {@code m}.
+ *
+ * @param m the value m.
+ * @param k the digit.
+ * @return the base 64 digit
+ */
+ private static int getBase64Digit(int m, int k) {
+ return (m >>> (30 - 6 * k)) & 63;
+ }
+
+ /**
+ * Return a fixed sequence of {@code int} output.
+ */
+ private class FixedSequenceIntProvider extends IntProvider {
+ /** The count of values output. */
+ private int count;
+ /** The values. */
+ private final int[] values;
+
+ /**
+ * Instantiates a new fixed sequence int provider.
+ *
+ * @param values Values.
+ */
+ FixedSequenceIntProvider(int[] values) {
+ this.values = values;
+ }
+
+ @Override
+ public int next() {
+ // This should not be called enough to overflow count
+ return values[count++ % values.length];
+ }
+ }
+}
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSamplerTest.java
new file mode 100644
index 0000000..207840f
--- /dev/null
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSamplerTest.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+import org.junit.Test;
+
+/**
+ * Test for the {@link MarsagliaTsangWangSmallMeanPoissonSampler}. The tests hit edge
+ * cases for the sampler.
+ */
+public class MarsagliaTsangWangSmallMeanPoissonSamplerTest {
+ /**
+ * Test the constructor with a bad mean.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithMeanLargerThanUpperBound() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final double mean = 1025;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng,
+ mean);
+ }
+
+ /**
+ * Test the constructor with a bad mean.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithZeroMean() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final double mean = 0;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng,
+ mean);
+ }
+
+ /**
+ * Test the constructor with the maximum mean.
+ */
+ @Test
+ public void testConstructorWithMaximumMean() {
+ final UniformRandomProvider rng = new FixedRNG(0);
+ final double mean = 1024;
+ @SuppressWarnings("unused")
+ final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng,
+ mean);
+ }
+
+ /**
+ * Test the constructor with a small mean that hits the edge case where the
+ * probability sum is not 2^30.
+ */
+ @Test
+ public void testConstructorWithSmallMean() {
+ final UniformRandomProvider rng = new FixedRNG(0xffffffff);
+ final double mean = 0.25;
+ final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng,
+ mean);
+ // This will throw if the table does not sum to 2^30
+ sampler.sample();
+ }
+
+ /**
+ * Test the constructor with a medium mean that is at the switch point for how the probability
+ * distribution is computed.
+ */
+ @Test
+ public void testConstructorWithMediumMean() {
+ final UniformRandomProvider rng = new FixedRNG(0xffffffff);
+ final double mean = 21.4;
+ final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng,
+ mean);
+ // This will throw if the table does not sum to 2^30
+ sampler.sample();
+ }
+
+ /**
+ * A RNG returning a fixed value.
+ */
+ private static class FixedRNG implements UniformRandomProvider {
+ /** The value. */
+ private final int value;
+
+ /**
+ * @param value the value
+ */
+ FixedRNG(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public int nextInt() {
+ return value;
+ }
+
+ public void nextBytes(byte[] bytes) {}
+ public void nextBytes(byte[] bytes, int start, int len) {}
+ public int nextInt(int n) { return 0; }
+ public long nextLong() { return 0; }
+ public long nextLong(long n) { return 0; }
+ public boolean nextBoolean() { return false; }
+ public float nextFloat() { return 0; }
+ public double nextDouble() { return 0; }
+ }
+}