You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by tn...@apache.org on 2015/05/01 13:50:22 UTC
[math] [MATH-1220] Improve performance of ZipfDistribution.sample.
Thanks to Otmar Ertl.
Repository: commons-math
Updated Branches:
refs/heads/master 5597ed7ea -> 002276ea3
[MATH-1220] Improve performance of ZipfDistribution.sample. Thanks to Otmar Ertl.
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/002276ea
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/002276ea
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/002276ea
Branch: refs/heads/master
Commit: 002276ea313fd880122502e9840b43f996acd537
Parents: 5597ed7
Author: Thomas Neidhart <th...@gmail.com>
Authored: Fri May 1 13:50:10 2015 +0200
Committer: Thomas Neidhart <th...@gmail.com>
Committed: Fri May 1 13:50:10 2015 +0200
----------------------------------------------------------------------
pom.xml | 3 +
src/changes/changes.xml | 3 +
.../math4/distribution/ZipfDistribution.java | 151 +++++++++++-
.../distribution/ZipfDistributionTest.java | 230 ++++++++++++++++++-
4 files changed, 379 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-math/blob/002276ea/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index ae022e8..b88cf28 100644
--- a/pom.xml
+++ b/pom.xml
@@ -207,6 +207,9 @@
<name>Ole Ersoy</name>
</contributor>
<contributor>
+ <name>Otmar Ertl</name>
+ </contributor>
+ <contributor>
<name>Ajo Fod</name>
</contributor>
<contributor>
http://git-wip-us.apache.org/repos/asf/commons-math/blob/002276ea/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 27705e2..1fb151f 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</release>
<release version="4.0" date="XXXX-XX-XX" description="">
+ <action dev="tn" type="fix" issue="MATH-1220" due-to="Otmar Ertl"> <!-- backported to 3.6 -->
+ Improve performance of "ZipfDistribution#sample()" by using a rejection algorithm.
+ </action>
<action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> <!-- backported to 3.6 -->
Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm.
</action>
http://git-wip-us.apache.org/repos/asf/commons-math/blob/002276ea/src/main/java/org/apache/commons/math4/distribution/ZipfDistribution.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/distribution/ZipfDistribution.java b/src/main/java/org/apache/commons/math4/distribution/ZipfDistribution.java
index 04bb522..366af69 100644
--- a/src/main/java/org/apache/commons/math4/distribution/ZipfDistribution.java
+++ b/src/main/java/org/apache/commons/math4/distribution/ZipfDistribution.java
@@ -43,6 +43,8 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
private double numericalVariance = Double.NaN;
/** Whether or not the numerical variance has been calculated */
private boolean numericalVarianceIsCalculated = false;
+ /** The sampler to be used for the sample() method */
+ private transient ZipfRejectionSampler sampler;
/**
* Create a new Zipf distribution with the given number of elements and
@@ -265,5 +267,152 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
public boolean isSupportConnected() {
return true;
}
-}
+ /**
+ * {@inheritDoc}
+ * <p>
+ * An instrumental distribution g(k) is used to generate random values by
+ * rejection sampling. g(k) is defined as g(1):= 1 and g(k) := I(-s,k-1/2,k+1/2)
+ * for k larger than 1, where s denotes the exponent of the Zipf distribution
+ * and I(r,a,b) is the integral of x^r for x from a to b.
+ * <p>
+ * Since 1^x^s is a convex function, Jensens's inequality gives
+ * I(-s,k-1/2,k+1/2) >= 1/k^s for all positive k and non-negative s.
+ * In order to limit the rejection rate for large exponents s,
+ * the instrumental distribution weight is differently defined for value 1.
+ */
+ @Override
+ public int sample() {
+ if (sampler == null) {
+ sampler = new ZipfRejectionSampler(numberOfElements, exponent);
+ }
+ return sampler.sample(random);
+ }
+
+ /**
+ * Utility class implementing a rejection sampling method for a discrete,
+ * bounded Zipf distribution.
+ *
+ * @since 3.6
+ */
+ static final class ZipfRejectionSampler {
+
+ /** Number of elements. */
+ private final int numberOfElements;
+ /** Exponent parameter of the distribution. */
+ private final double exponent;
+ /** Cached tail weight of instrumental distribution used for rejection sampling */
+ private double instrumentalDistributionTailWeight = Double.NaN;
+
+ ZipfRejectionSampler(final int numberOfElements, final double exponent) {
+ this.numberOfElements = numberOfElements;
+ this.exponent = exponent;
+ }
+
+ int sample(final RandomGenerator random) {
+ if (Double.isNaN(instrumentalDistributionTailWeight)) {
+ instrumentalDistributionTailWeight = integratePowerFunction(-exponent, 1.5, numberOfElements+0.5);
+ }
+
+ while(true) {
+ final double randomValue = random.nextDouble()*(instrumentalDistributionTailWeight + 1.);
+ if (randomValue < instrumentalDistributionTailWeight) {
+ final double q = randomValue / instrumentalDistributionTailWeight;
+ final int sample = sampleFromInstrumentalDistributionTail(q);
+ if (random.nextDouble() < acceptanceRateForTailSample(sample)) {
+ return sample;
+ }
+ }
+ else {
+ return 1;
+ }
+ }
+ }
+
+ /**
+ * Returns a sample from the instrumental distribution tail for a given
+ * uniformly distributed random value.
+ *
+ * @param q a uniformly distributed random value taken from [0,1]
+ * @return a sample in the range [2, {@link #numberOfElements}]
+ */
+ int sampleFromInstrumentalDistributionTail(double q) {
+ final double a = 1.5;
+ final double b = numberOfElements + 0.5;
+ final double logBdviA = FastMath.log(b / a);
+
+ final int result = (int) (a * FastMath.exp(logBdviA * helper1(q, logBdviA * (1. - exponent))) + 0.5);
+ if (result < 2) {
+ return 2;
+ }
+ if (result > numberOfElements) {
+ return numberOfElements;
+ }
+ return result;
+ }
+
+ /**
+ * Helper function that calculates log((1-q)+q*exp(x))/x.
+ * <p>
+ * A Taylor series expansion is used, if x is close to 0.
+ *
+ * @param q a value in the range [0,1]
+ * @param
+ * @return log((1-q)+q*exp(x))/x
+ */
+ static double helper1(final double q, final double x) {
+ if (Math.abs(x) > 1e-8) {
+ return FastMath.log((1.-q)+q*FastMath.exp(x))/x;
+ }
+ else {
+ return q*(1.+(1./2.)*x*(1.-q)*(1+(1./3.)*x*((1.-2.*q) + (1./4.)*x*(6*q*q*(q-1)+1))));
+ }
+ }
+
+ /**
+ * Helper function to calculate (exp(x)-1)/x.
+ * <p>
+ * A Taylor series expansion is used, if x is close to 0.
+ *
+ * @return (exp(x)-1)/x if x is non-zero, 1 if x=0
+ */
+ static double helper2(final double x) {
+ if (FastMath.abs(x)>1e-8) {
+ return FastMath.expm1(x)/x;
+ }
+ else {
+ return 1.+x*(1./2.)*(1.+x*(1./3.)*(1.+x*(1./4.)));
+ }
+ }
+
+ /**
+ * Integrates the power function x^r from x=a to b.
+ *
+ * @param r the exponent
+ * @param a the integral lower bound
+ * @param b the integral upper bound
+ * @return the calculated integral value
+ */
+ static double integratePowerFunction(final double r, final double a, final double b) {
+ final double logA = FastMath.log(a);
+ final double logBdivA = FastMath.log(b/a);
+ return FastMath.exp((1.+r)*logA)*helper2((1.+r)*logBdivA)*logBdivA;
+
+ }
+
+ /**
+ * Calculates the acceptance rate for a sample taken from the tail of the instrumental distribution.
+ * <p>
+ * The acceptance rate is given by the ratio k^(-s)/I(-s,k-0.5, k+0.5)
+ * where I(r,a,b) is the integral of x^r for x from a to b.
+ *
+ * @param k the value which has been sampled using the instrumental distribution
+ * @return the acceptance rate
+ */
+ double acceptanceRateForTailSample(int k) {
+ final double a = FastMath.log1p(1./(2.*k-1.));
+ final double b = FastMath.log1p(2./(2.*k-1.));
+ return FastMath.exp((1.-exponent)*a)/(k*b*helper2((1.-exponent)*b));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/002276ea/src/test/java/org/apache/commons/math4/distribution/ZipfDistributionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math4/distribution/ZipfDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/ZipfDistributionTest.java
index 661eb33..44bad94 100644
--- a/src/test/java/org/apache/commons/math4/distribution/ZipfDistributionTest.java
+++ b/src/test/java/org/apache/commons/math4/distribution/ZipfDistributionTest.java
@@ -17,18 +17,28 @@
package org.apache.commons.math4.distribution;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.commons.math4.TestUtils;
+import org.apache.commons.math4.analysis.UnivariateFunction;
+import org.apache.commons.math4.analysis.integration.SimpsonIntegrator;
import org.apache.commons.math4.distribution.IntegerDistribution;
import org.apache.commons.math4.distribution.ZipfDistribution;
+import org.apache.commons.math4.distribution.ZipfDistribution.ZipfRejectionSampler;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
+import org.apache.commons.math4.random.AbstractRandomGenerator;
+import org.apache.commons.math4.random.RandomGenerator;
+import org.apache.commons.math4.random.Well1024a;
import org.apache.commons.math4.util.FastMath;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
/**
* Test cases for {@link ZipfDistribution}.
- * Extends IntegerDistributionAbstractTest. See class javadoc for
- * IntegerDistributionAbstractTest for details.
- *
+ * Extends IntegerDistributionAbstractTest.
+ * See class javadoc for IntegerDistributionAbstractTest for details.
*/
public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
@@ -38,7 +48,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
public ZipfDistributionTest() {
setTolerance(1e-12);
}
-
+
@Test(expected=NotStrictlyPositiveException.class)
public void testPreconditions1() {
new ZipfDistribution(0, 1);
@@ -63,9 +73,9 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
return new int[] {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
}
- /**
+ /**
* Creates the default probability density test expected values.
- * Reference values are from R, version 2.15.3 (VGAM package 0.9-0).
+ * Reference values are from R, version 2.15.3 (VGAM package 0.9-0).
*/
@Override
public double[] makeDensityTestValues() {
@@ -73,7 +83,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
0.0569028586912, 0.0487738788782, 0.0426771440184, 0.0379352391275, 0.0341417152147, 0};
}
- /**
+ /**
* Creates the default logarithmic probability density test expected values.
* Reference values are from R, version 2.14.1.
*/
@@ -120,4 +130,210 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
Assert.assertEquals(dist.getNumericalMean(), FastMath.sqrt(2), tol);
Assert.assertEquals(dist.getNumericalVariance(), 0.24264068711928521, tol);
}
+
+
+ /**
+ * Test sampling for various number of points and exponents.
+ */
+ @Test
+ public void testSamplingExtended() {
+ int sampleSize = 1000;
+
+ int[] numPointsValues = {
+ 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100
+ };
+ double[] exponentValues = {
+ 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1,
+ 1. - 1e-9, 1.0, 1. + 1e-9, 1.1, 1.2, 1.3, 1.5, 1.6, 1.7, 1.8, 2.0,
+ 2.5, 3.0, 4., 5., 6., 7., 8., 9., 10., 20., 30.
+ };
+
+ for (int numPoints : numPointsValues) {
+ for (double exponent : exponentValues) {
+ double weightSum = 0.;
+ double[] weights = new double[numPoints];
+ for (int i = numPoints; i>=1; i-=1) {
+ weights[i-1] = Math.pow(i, -exponent);
+ weightSum += weights[i-1];
+ }
+
+ ZipfDistribution distribution = new ZipfDistribution(numPoints, exponent);
+ distribution.reseedRandomGenerator(6); // use fixed seed, the test is expected to fail for more than 50% of all seeds because each test case can fail with probability 0.001, the chance that all test cases do not fail is 0.999^(32*22) = 0.49442874426
+
+ double[] expectedCounts = new double[numPoints];
+ long[] observedCounts = new long[numPoints];
+ for (int i = 0; i < numPoints; i++) {
+ expectedCounts[i] = sampleSize * (weights[i]/weightSum);
+ }
+ int[] sample = distribution.sample(sampleSize);
+ for (int s : sample) {
+ observedCounts[s-1]++;
+ }
+ TestUtils.assertChiSquareAccept(expectedCounts, observedCounts, 0.001);
+ }
+ }
+ }
+
+ @Test
+ public void testSamplerIntegratePowerFunction() {
+ final double tol = 1e-6;
+ final double[] exponents = {
+ -1e-5, -1e-4, -1e-3, -1e-2, -1e-1, -1e0, -1e1
+ };
+ final double[] limits = {
+ 0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.0, 6.5, 7.0,
+ 7.5, 8.0, 8.5, 9.0, 9.5, 10.0
+ };
+
+ for (final double exponent : exponents) {
+ for (int lowerLimitIndex = 0; lowerLimitIndex < limits.length; ++lowerLimitIndex) {
+ final double lowerLimit = limits[lowerLimitIndex];
+ for (int upperLimitIndex = lowerLimitIndex+1; upperLimitIndex < limits.length; ++upperLimitIndex) {
+ final double upperLimit = limits[upperLimitIndex];
+ final double result1 = new SimpsonIntegrator().integrate(10000, new UnivariateFunction() {
+ @Override
+ public double value(double x) {
+ return Math.pow(x, exponent);
+ }
+ }, lowerLimit, upperLimit);
+
+ final double result2 =
+ ZipfRejectionSampler.integratePowerFunction(exponent, lowerLimit, upperLimit);
+ assertEquals(result1, result2, (result1+result2)*tol);
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testSamplerAcceptanceRate() {
+ final double tol = 1e-12;
+ final double[] exponents = {
+ 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 5e0, 1e1, 1e2, 1e3
+ };
+ final int[] values = {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
+ };
+ final int numberOfElements = 1000;
+ for (final double exponent : exponents) {
+ ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
+ for (final int value : values) {
+ double expected = FastMath.pow(value, -exponent);
+ double result = sampler.acceptanceRateForTailSample(value) *
+ ZipfRejectionSampler.integratePowerFunction(-exponent, value - 0.5, value + 0.5);
+ TestUtils.assertRelativelyEquals(expected, result, tol);
+ assertTrue(result <= 1.); // test Jensen's inequality
+ }
+ }
+ }
+
+ @Test
+ public void testSamplerInverseInstrumentalDistribution() {
+ final double tol = 1e-14;
+ final double[] exponentValues = {
+ 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2E0, 3e0, 4e0, 5e0, 6., 7., 8., 9., 10., 50.
+ };
+ final double[] qValues = {
+ 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0
+ };
+ final int[] numberOfElementsValues = {
+ 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 100
+ };
+
+ for (final double exponent : exponentValues) {
+ for (final int numberOfElements : numberOfElementsValues) {
+ final ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
+ for (final double q : qValues) {
+ int result = sampler.sampleFromInstrumentalDistributionTail(q);
+ double total =
+ ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, numberOfElements + 0.5);
+ double lowerBound =
+ ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result - 0.5) / total;
+ double upperBound =
+ ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result + 0.5) / total;
+ assertTrue(lowerBound <= q*(1.+tol));
+ assertTrue(upperBound >= q*(1.-tol));
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testSamplerHelper1() {
+ final double tol = 1e-14;
+ final double[] qValues = {
+ 0., 1e-12, 1e-11, 1e-10, 1e-9, 9e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4,
+ 1e-3, 1e-2, 1e-1, 1e0
+ };
+ final double[] xValues = {
+ -Double.MAX_VALUE, -1e10, -1e9, -1e8, -1e7, -1e6, -1e5, -1e4, -1e3,
+ -1e2, -1e1, -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7,
+ -1e-8, -1e-9, -1e-10, -Double.MIN_VALUE, 0.0, Double.MIN_VALUE,
+ 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0,
+ 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, Double.MAX_VALUE
+ };
+
+ for (final double q : qValues) {
+ for(final double x : xValues) {
+ double calculated = ZipfRejectionSampler.helper1(q, x);
+ TestUtils.assertRelativelyEquals((1.-q)+q*Math.exp(x), FastMath.exp(calculated*x), tol);
+ }
+ }
+ }
+
+ @Test
+ public void testSamplerHelper2() {
+ final double tol = 1e-12;
+ final double[] testValues = {
+ -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7, -1e-8,
+ -1e-9, -1e-10, -1e-11, 0., 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6,
+ 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0
+ };
+ for (double testValue : testValues) {
+ final double expected = FastMath.expm1(testValue);
+ TestUtils.assertRelativelyEquals(expected, ZipfRejectionSampler.helper2(testValue)*testValue, tol);
+ }
+ }
+
+ @Ignore
+ @Test
+ public void testSamplerPerformance() {
+ int[] numPointsValues = {1, 2, 5, 10, 100, 1000, 10000};
+ double[] exponentValues = {1e-3, 1e-2, 1e-1, 1., 2., 5., 10.};
+ int numGeneratedSamples = 1000000;
+
+ long sum = 0;
+
+ for (int numPoints : numPointsValues) {
+ for (double exponent : exponentValues) {
+ long start = System.currentTimeMillis();
+ final int[] randomNumberCounter = new int[1];
+
+ RandomGenerator randomGenerator = new AbstractRandomGenerator() {
+
+ private final RandomGenerator r = new Well1024a(0L);
+
+ @Override
+ public void setSeed(long seed) {
+ }
+
+ @Override
+ public double nextDouble() {
+ randomNumberCounter[0]+=1;
+ return r.nextDouble();
+ }
+ };
+
+ final ZipfDistribution distribution = new ZipfDistribution(randomGenerator, numPoints, exponent);
+ for (int i = 0; i < numGeneratedSamples; ++i) {
+ sum += distribution.sample();
+ }
+
+ long end = System.currentTimeMillis();
+ System.out.println("n = " + numPoints + ", exponent = " + exponent + ", avg number consumed random values = " + (double)(randomNumberCounter[0])/numGeneratedSamples + ", measured time = " + (end-start)/1000. + "s");
+ }
+ }
+ System.out.println(sum);
+ }
+
}