You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by tn...@apache.org on 2015/05/01 12:08:12 UTC
[math] [MATH-1153] Improve performance of BetaDistribution.sample().
Thanks to Sergei Lebedev.
Repository: commons-math
Updated Branches:
refs/heads/MATH_3_X 2011e11e5 -> f5d028ca6
[MATH-1153] Improve performance of BetaDistribution.sample(). Thanks to Sergei Lebedev.
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/f5d028ca
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/f5d028ca
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/f5d028ca
Branch: refs/heads/MATH_3_X
Commit: f5d028ca6af5591ca51785da7c15d7bd81d4215f
Parents: 2011e11
Author: Thomas Neidhart <th...@gmail.com>
Authored: Fri May 1 12:07:52 2015 +0200
Committer: Thomas Neidhart <th...@gmail.com>
Committed: Fri May 1 12:07:52 2015 +0200
----------------------------------------------------------------------
pom.xml | 3 +
src/changes/changes.xml | 3 +
.../math3/distribution/BetaDistribution.java | 134 +++++++++++++++++++
.../distribution/BetaDistributionTest.java | 73 ++++++++++
.../math3/random/RandomDataGeneratorTest.java | 94 +++++--------
5 files changed, 248 insertions(+), 59 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 223b316..7c55d33 100644
--- a/pom.xml
+++ b/pom.xml
@@ -252,6 +252,9 @@
<name>Piotr Kochanski</name>
</contributor>
<contributor>
+ <name>Sergei Lebedev</name>
+ </contributor>
+ <contributor>
<name>Bob MacCallum</name>
</contributor>
<contributor>
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 0759e8e..2e818b2 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</properties>
<body>
<release version="3.6" date="XXXX-XX-XX" description="">
+ <action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev">
+ Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm.
+ </action>
<action dev="tn" type="fix" issue="MATH-1197">
Computation of 2-sample Kolmogorov-Smirnov statistic in case of ties
was not correct.
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java b/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java
index 3f62f64..19b19e0 100644
--- a/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java
+++ b/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java
@@ -23,6 +23,7 @@ import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.special.Beta;
import org.apache.commons.math3.special.Gamma;
import org.apache.commons.math3.util.FastMath;
+import org.apache.commons.math3.util.Precision;
/**
* Implements the Beta distribution.
@@ -148,6 +149,7 @@ public class BetaDistribution extends AbstractRealDistribution {
}
/** {@inheritDoc} */
+ @Override
public double density(double x) {
final double logDensity = logDensity(x);
return logDensity == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logDensity);
@@ -177,6 +179,7 @@ public class BetaDistribution extends AbstractRealDistribution {
}
/** {@inheritDoc} */
+ @Override
public double cumulativeProbability(double x) {
if (x <= 0) {
return 0;
@@ -205,6 +208,7 @@ public class BetaDistribution extends AbstractRealDistribution {
* For first shape parameter {@code alpha} and second shape parameter
* {@code beta}, the mean is {@code alpha / (alpha + beta)}.
*/
+ @Override
public double getNumericalMean() {
final double a = getAlpha();
return a / (a + getBeta());
@@ -217,6 +221,7 @@ public class BetaDistribution extends AbstractRealDistribution {
* {@code beta}, the variance is
* {@code (alpha * beta) / [(alpha + beta)^2 * (alpha + beta + 1)]}.
*/
+ @Override
public double getNumericalVariance() {
final double a = getAlpha();
final double b = getBeta();
@@ -231,6 +236,7 @@ public class BetaDistribution extends AbstractRealDistribution {
*
* @return lower bound of the support (always 0)
*/
+ @Override
public double getSupportLowerBound() {
return 0;
}
@@ -242,16 +248,19 @@ public class BetaDistribution extends AbstractRealDistribution {
*
* @return upper bound of the support (always 1)
*/
+ @Override
public double getSupportUpperBound() {
return 1;
}
/** {@inheritDoc} */
+ @Override
public boolean isSupportLowerBoundInclusive() {
return false;
}
/** {@inheritDoc} */
+ @Override
public boolean isSupportUpperBoundInclusive() {
return false;
}
@@ -263,7 +272,132 @@ public class BetaDistribution extends AbstractRealDistribution {
*
* @return {@code true}
*/
+ @Override
public boolean isSupportConnected() {
return true;
}
+
+
+ /** {@inheritDoc}
+ * <p>
+ * Sampling is performed using Cheng algorithms:
+ * </p>
+ * <p>
+ * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.".
+ * Communications of the ACM, 21, 317–322, 1978.
+ * </p>
+ */
+ @Override
+ public double sample() {
+ return ChengBetaSampler.sample(random, alpha, beta);
+ }
+
+ /** Utility class implementing Cheng's algorithms for beta distribution sampling.
+ * <p>
+ * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.".
+ * Communications of the ACM, 21, 317–322, 1978.
+ * </p>
+ * @since 3.6
+ */
+ private static final class ChengBetaSampler {
+
+ /**
+ * Returns one sample using Cheng's sampling algorithm.
+ * @param random random generator to use
+ * @param alpha distribution first shape parameter
+ * @param beta distribution second shape parameter
+ * @return sampled value
+ */
+ static double sample(RandomGenerator random, final double alpha, final double beta) {
+ final double a = FastMath.min(alpha, beta);
+ final double b = FastMath.max(alpha, beta);
+
+ if (a > 1) {
+ return algorithmBB(random, alpha, a, b);
+ } else {
+ return algorithmBC(random, alpha, b, a);
+ }
+ }
+
+ /**
+ * Returns one sample using Cheng's BB algorithm, when both α and β are greater than 1.
+ */
+ private static double algorithmBB(RandomGenerator random,
+ final double a0,
+ final double a,
+ final double b) {
+ final double alpha = a + b;
+ final double beta = FastMath.sqrt((alpha - 2.) / (2. * a * b - alpha));
+ final double gamma = a + 1. / beta;
+
+ double r, w, t;
+ do {
+ final double u1 = random.nextDouble();
+ final double u2 = random.nextDouble();
+ final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1));
+ w = a * FastMath.exp(v);
+ final double z = u1 * u1 * u2;
+ r = gamma * v - 1.3862944;
+ final double s = a + r - w;
+ if (s + 2.609438 >= 5 * z) {
+ break;
+ }
+
+ t = FastMath.log(z);
+ if (s >= t) {
+ break;
+ }
+ } while (r + alpha * (FastMath.log(alpha) - FastMath.log(b + w)) < t);
+
+ w = FastMath.min(w, Double.MAX_VALUE);
+ return Precision.equals(a, a0) ? w / (b + w) : b / (b + w);
+ }
+
+ /**
+ * Returns one sample using Cheng's BC algorithm, when at least one of α and β is smaller than 1.
+ */
+ private static double algorithmBC(RandomGenerator random,
+ final double a0,
+ final double a,
+ final double b) {
+ final double alpha = a + b;
+ final double beta = 1. / b;
+ final double delta = 1. + a - b;
+ final double k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778);
+ final double k2 = 0.25 + (0.5 + 0.25 / delta) * b;
+
+ double w;
+ for (;;) {
+ final double u1 = random.nextDouble();
+ final double u2 = random.nextDouble();
+ final double y = u1 * u2;
+ final double z = u1 * y;
+ if (u1 < 0.5) {
+ if (0.25 * u2 + z - y >= k1) {
+ continue;
+ }
+ } else {
+ if (z <= 0.25) {
+ final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1));
+ w = a * FastMath.exp(v);
+ break;
+ }
+
+ if (z >= k2) {
+ continue;
+ }
+ }
+
+ final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1));
+ w = a * FastMath.exp(v);
+ if (alpha * (FastMath.log(alpha) - FastMath.log(b + w) + v) - 1.3862944 >= FastMath.log(z)) {
+ break;
+ }
+ }
+
+ w = FastMath.min(w, Double.MAX_VALUE);
+ return Precision.equals(a, a0) ? w / (b + w) : b / (b + w);
+ }
+
+ }
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java
index 217ae66..3778bfe 100644
--- a/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java
+++ b/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java
@@ -16,10 +16,22 @@
*/
package org.apache.commons.math3.distribution;
+import java.util.Arrays;
+
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.Well1024a;
+import org.apache.commons.math3.random.Well19937a;
+import org.apache.commons.math3.stat.StatUtils;
+import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
+import org.apache.commons.math3.stat.inference.TestUtils;
import org.junit.Assert;
import org.junit.Test;
public class BetaDistributionTest {
+
+ static final double[] alphaBetas = {0.1, 1, 10, 100, 1000};
+ static final double epsilon = StatUtils.min(alphaBetas);
+
@Test
public void testCumulative() {
double[] x = new double[]{-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1};
@@ -303,4 +315,65 @@ public class BetaDistributionTest {
Assert.assertEquals(dist.getNumericalMean(), 2.0 / 7.0, tol);
Assert.assertEquals(dist.getNumericalVariance(), 10.0 / (49.0 * 8.0), tol);
}
+
+ @Test
+ public void testMomentsSampling() {
+ RandomGenerator random = new Well1024a(0x7829862c82fec2dal);
+ final int numSamples = 1000;
+ for (final double alpha : alphaBetas) {
+ for (final double beta : alphaBetas) {
+ final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta);
+ final double[] observed = new BetaDistribution(alpha, beta).sample(numSamples);
+ Arrays.sort(observed);
+
+ final String distribution = String.format("Beta(%.2f, %.2f)", alpha, beta);
+ Assert.assertEquals(String.format("E[%s]", distribution),
+ betaDistribution.getNumericalMean(),
+ StatUtils.mean(observed), epsilon);
+ Assert.assertEquals(String.format("Var[%s]", distribution),
+ betaDistribution.getNumericalVariance(),
+ StatUtils.variance(observed), epsilon);
+ }
+ }
+ }
+
+ @Test
+ public void testGoodnessOfFit() {
+ RandomGenerator random = new Well19937a(0x237db1db907b089fl);
+ final int numSamples = 1000;
+ final double level = 0.01;
+ for (final double alpha : alphaBetas) {
+ for (final double beta : alphaBetas) {
+ final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta);
+ final double[] observed = betaDistribution.sample(numSamples);
+ Assert.assertFalse("G goodness-of-fit test rejected null at alpha = " + level,
+ gTest(betaDistribution, observed) < level);
+ Assert.assertFalse("KS goodness-of-fit test rejected null at alpha = " + level,
+ new KolmogorovSmirnovTest(random).kolmogorovSmirnovTest(betaDistribution, observed) < level);
+ }
+ }
+ }
+
+ private double gTest(final RealDistribution expectedDistribution, final double[] values) {
+ final int numBins = values.length / 30;
+ final double[] breaks = new double[numBins];
+ for (int b = 0; b < breaks.length; b++) {
+ breaks[b] = expectedDistribution.inverseCumulativeProbability((double) b / numBins);
+ }
+
+ final long[] observed = new long[numBins];
+ for (final double value : values) {
+ int b = 0;
+ do {
+ b++;
+ } while (b < numBins && value >= breaks[b]);
+
+ observed[b - 1]++;
+ }
+
+ final double[] expected = new double[numBins];
+ Arrays.fill(expected, (double) values.length / numBins);
+
+ return TestUtils.gTest(expected, observed);
+ }
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java b/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java
index a0b6e26..839b1e6 100644
--- a/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java
+++ b/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java
@@ -83,7 +83,7 @@ public class RandomDataGeneratorTest {
long y = randomData.nextLong(Long.MIN_VALUE, Long.MAX_VALUE);
Assert.assertFalse(x == y);
}
-
+
@Test
public void testNextUniformExtremeValues() {
double x = randomData.nextUniform(-Double.MAX_VALUE, Double.MAX_VALUE);
@@ -94,7 +94,7 @@ public class RandomDataGeneratorTest {
Assert.assertFalse(Double.isInfinite(x));
Assert.assertFalse(Double.isInfinite(y));
}
-
+
@Test
public void testNextIntIAE() {
try {
@@ -104,7 +104,7 @@ public class RandomDataGeneratorTest {
// ignored
}
}
-
+
@Test
public void testNextIntNegativeToPositiveRange() {
for (int i = 0; i < 5; i++) {
@@ -113,7 +113,7 @@ public class RandomDataGeneratorTest {
}
}
- @Test
+ @Test
public void testNextIntNegativeRange() {
for (int i = 0; i < 5; i++) {
checkNextIntUniform(-7, -4);
@@ -122,7 +122,7 @@ public class RandomDataGeneratorTest {
}
}
- @Test
+ @Test
public void testNextIntPositiveRange() {
for (int i = 0; i < 5; i++) {
checkNextIntUniform(0, 3);
@@ -148,7 +148,7 @@ public class RandomDataGeneratorTest {
for (int i = 0; i < len; i++) {
expected[i] = 1d / len;
}
-
+
TestUtils.assertChiSquareAccept(expected, observed, 0.001);
}
@@ -169,7 +169,7 @@ public class RandomDataGeneratorTest {
(((double) upper) - ((double) lower));
Assert.assertTrue(ratio > 0.99999);
}
-
+
@Test
public void testNextLongIAE() {
try {
@@ -188,7 +188,7 @@ public class RandomDataGeneratorTest {
}
}
- @Test
+ @Test
public void testNextLongNegativeRange() {
for (int i = 0; i < 5; i++) {
checkNextLongUniform(-7, -4);
@@ -197,7 +197,7 @@ public class RandomDataGeneratorTest {
}
}
- @Test
+ @Test
public void testNextLongPositiveRange() {
for (int i = 0; i < 5; i++) {
checkNextLongUniform(0, 3);
@@ -223,7 +223,7 @@ public class RandomDataGeneratorTest {
for (int i = 0; i < len; i++) {
expected[i] = 1d / len;
}
-
+
TestUtils.assertChiSquareAccept(expected, observed, 0.01);
}
@@ -244,7 +244,7 @@ public class RandomDataGeneratorTest {
(((double) upper) - ((double) lower));
Assert.assertTrue(ratio > 0.99999);
}
-
+
@Test
public void testNextSecureLongIAE() {
try {
@@ -254,7 +254,7 @@ public class RandomDataGeneratorTest {
// ignored
}
}
-
+
@Test
@Retry(3)
public void testNextSecureLongNegativeToPositiveRange() {
@@ -263,7 +263,7 @@ public class RandomDataGeneratorTest {
checkNextSecureLongUniform(-3, 6);
}
}
-
+
@Test
@Retry(3)
public void testNextSecureLongNegativeRange() {
@@ -272,7 +272,7 @@ public class RandomDataGeneratorTest {
checkNextSecureLongUniform(-15, -2);
}
}
-
+
@Test
@Retry(3)
public void testNextSecureLongPositiveRange() {
@@ -281,7 +281,7 @@ public class RandomDataGeneratorTest {
checkNextSecureLongUniform(2, 12);
}
}
-
+
private void checkNextSecureLongUniform(int min, int max) {
final Frequency freq = new Frequency();
for (int i = 0; i < smallSampleSize; i++) {
@@ -298,7 +298,7 @@ public class RandomDataGeneratorTest {
for (int i = 0; i < len; i++) {
expected[i] = 1d / len;
}
-
+
TestUtils.assertChiSquareAccept(expected, observed, 0.0001);
}
@@ -311,7 +311,7 @@ public class RandomDataGeneratorTest {
// ignored
}
}
-
+
@Test
@Retry(3)
public void testNextSecureIntNegativeToPositiveRange() {
@@ -320,7 +320,7 @@ public class RandomDataGeneratorTest {
checkNextSecureIntUniform(-3, 6);
}
}
-
+
@Test
@Retry(3)
public void testNextSecureIntNegativeRange() {
@@ -329,8 +329,8 @@ public class RandomDataGeneratorTest {
checkNextSecureIntUniform(-15, -2);
}
}
-
- @Test
+
+ @Test
@Retry(3)
public void testNextSecureIntPositiveRange() {
for (int i = 0; i < 5; i++) {
@@ -338,7 +338,7 @@ public class RandomDataGeneratorTest {
checkNextSecureIntUniform(2, 12);
}
}
-
+
private void checkNextSecureIntUniform(int min, int max) {
final Frequency freq = new Frequency();
for (int i = 0; i < smallSampleSize; i++) {
@@ -355,11 +355,11 @@ public class RandomDataGeneratorTest {
for (int i = 0; i < len; i++) {
expected[i] = 1d / len;
}
-
+
TestUtils.assertChiSquareAccept(expected, observed, 0.0001);
}
-
-
+
+
/**
* Make sure that empirical distribution of random Poisson(4)'s has P(X <=
@@ -386,7 +386,7 @@ public class RandomDataGeneratorTest {
} catch (MathIllegalArgumentException ex) {
// ignored
}
-
+
final double mean = 4.0d;
final int len = 5;
PoissonDistribution poissonDistribution = new PoissonDistribution(mean);
@@ -403,7 +403,7 @@ public class RandomDataGeneratorTest {
for (int i = 0; i < len; i++) {
expected[i] = poissonDistribution.probability(i + 1) * largeSampleSize;
}
-
+
TestUtils.assertChiSquareAccept(expected, observed, 0.0001);
}
@@ -683,35 +683,35 @@ public class RandomDataGeneratorTest {
// ignored
}
}
-
+
@Test
public void testNextUniformUniformPositiveBounds() {
for (int i = 0; i < 5; i++) {
checkNextUniformUniform(0, 10);
}
}
-
+
@Test
public void testNextUniformUniformNegativeToPositiveBounds() {
for (int i = 0; i < 5; i++) {
checkNextUniformUniform(-3, 5);
}
}
-
+
@Test
public void testNextUniformUniformNegaiveBounds() {
for (int i = 0; i < 5; i++) {
checkNextUniformUniform(-7, -3);
}
}
-
+
@Test
public void testNextUniformUniformMaximalInterval() {
for (int i = 0; i < 5; i++) {
checkNextUniformUniform(-Double.MAX_VALUE, Double.MAX_VALUE);
}
}
-
+
private void checkNextUniformUniform(double min, double max) {
// Set up bin bounds - min, binBound[0], ..., binBound[binCount-2], max
final int binCount = 5;
@@ -721,7 +721,7 @@ public class RandomDataGeneratorTest {
for (int i = 1; i < binCount - 1; i++) {
binBounds[i] = binBounds[i - 1] + binSize; // + instead of * to avoid overflow in extreme case
}
-
+
final Frequency freq = new Frequency();
for (int i = 0; i < smallSampleSize; i++) {
final double value = randomData.nextUniform(min, max);
@@ -733,7 +733,7 @@ public class RandomDataGeneratorTest {
}
freq.addValue(j);
}
-
+
final long[] observed = new long[binCount];
for (int i = 0; i < binCount; i++) {
observed[i] = freq.getCount(i);
@@ -742,7 +742,7 @@ public class RandomDataGeneratorTest {
for (int i = 0; i < binCount; i++) {
expected[i] = 1d / binCount;
}
-
+
TestUtils.assertChiSquareAccept(expected, observed, 0.01);
}
@@ -951,7 +951,7 @@ public class RandomDataGeneratorTest {
int[] perm = randomData.nextPermutation(3, 3);
observed[findPerm(p, perm)]++;
}
-
+
String[] labels = {"{0, 1, 2}", "{ 0, 2, 1 }", "{ 1, 0, 2 }",
"{ 1, 2, 0 }", "{ 2, 0, 1 }", "{ 2, 1, 0 }"};
TestUtils.assertChiSquareAccept(labels, expected, observed, 0.001);
@@ -1010,30 +1010,6 @@ public class RandomDataGeneratorTest {
}
@Test
- public void testNextInversionDeviate() {
- // Set the seed for the default random generator
- RandomGenerator rg = new Well19937c(100);
- RandomDataGenerator rdg = new RandomDataGenerator(rg);
- double[] quantiles = new double[10];
- for (int i = 0; i < 10; i++) {
- quantiles[i] = rdg.nextUniform(0, 1);
- }
- // Reseed again so the inversion generator gets the same sequence
- rg.setSeed(100);
- BetaDistribution betaDistribution = new BetaDistribution(rg, 2, 4,
- BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
- /*
- * Generate a sequence of deviates using inversion - the distribution function
- * evaluated at the random value from the distribution should match the uniform
- * random value used to generate it, which is stored in the quantiles[] array.
- */
- for (int i = 0; i < 10; i++) {
- double value = betaDistribution.sample();
- Assert.assertEquals(betaDistribution.cumulativeProbability(value), quantiles[i], 10E-9);
- }
- }
-
- @Test
public void testNextBeta() {
double[] quartiles = TestUtils.getDistributionQuartiles(new BetaDistribution(2,5));
long[] counts = new long[4];