You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by tn...@apache.org on 2015/05/01 13:25:03 UTC

[math] [MATH-1220] Improve performance of ZipfDistribution.sample(). Thanks to Otmar Ertl.

Repository: commons-math
Updated Branches:
  refs/heads/MATH_3_X ab2b01168 -> 321269ed9


[MATH-1220] Improve performance of ZipfDistribution.sample(). Thanks to Otmar Ertl.


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/321269ed
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/321269ed
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/321269ed

Branch: refs/heads/MATH_3_X
Commit: 321269ed9aa84d15b18296ee6e73d53489efb622
Parents: ab2b011
Author: Thomas Neidhart <th...@gmail.com>
Authored: Fri May 1 13:24:48 2015 +0200
Committer: Thomas Neidhart <th...@gmail.com>
Committed: Fri May 1 13:24:48 2015 +0200

----------------------------------------------------------------------
 pom.xml                                         |   3 +
 src/changes/changes.xml                         |   3 +
 .../math3/distribution/ZipfDistribution.java    | 152 ++++++++++++-
 .../distribution/ZipfDistributionTest.java      | 228 ++++++++++++++++++-
 4 files changed, 378 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index 7c55d33..6820934 100644
--- a/pom.xml
+++ b/pom.xml
@@ -207,6 +207,9 @@
       <name>Ole Ersoy</name>
     </contributor>
     <contributor>
+      <name>Otmar Ertl</name>
+    </contributor>
+    <contributor>
       <name>Ajo Fod</name>
     </contributor>
     <contributor>

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 2e818b2..274cb50 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces!
   </properties>
   <body>
     <release version="3.6" date="XXXX-XX-XX" description="">
+      <action dev="tn" type="fix" issue="MATH-1220" due-to="Otmar Ertl">
+        Improve performance of "ZipfDistribution#sample()" by using a rejection algorithm.
+      </action>
       <action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev">
         Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm.
       </action>

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java b/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java
index 18cb2f4..3755407 100644
--- a/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java
+++ b/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java
@@ -43,6 +43,8 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
     private double numericalVariance = Double.NaN;
     /** Whether or not the numerical variance has been calculated */
     private boolean numericalVarianceIsCalculated = false;
+    /** The sampler to be used for the sample() method */
+    private transient ZipfRejectionSampler sampler;
 
     /**
      * Create a new Zipf distribution with the given number of elements and
@@ -258,5 +260,153 @@ public class ZipfDistribution extends AbstractIntegerDistribution {
     public boolean isSupportConnected() {
         return true;
     }
-}
 
+    /**
+     * {@inheritDoc}
+     * <p>
+     * An instrumental distribution g(k) is used to generate random values by
+     * rejection sampling. g(k) is defined as g(1):= 1 and g(k) := I(-s,k-1/2,k+1/2)
+     * for k larger than 1, where s denotes the exponent of the Zipf distribution
+     * and I(r,a,b) is the integral of x^r for x from a to b.
+     * <p>
+     * Since 1^x^s is a convex function, Jensens's inequality gives
+     * I(-s,k-1/2,k+1/2) >= 1/k^s for all positive k and non-negative s.
+     * In order to limit the rejection rate for large exponents s,
+     * the instrumental distribution weight is differently defined for value 1.
+     */
+    @Override
+    public int sample() {
+        if (sampler == null) {
+            sampler = new ZipfRejectionSampler(numberOfElements, exponent);
+        }
+        return sampler.sample(random);
+    }
+
+    /**
+     * Utility class implementing a rejection sampling method for a discrete,
+     * bounded Zipf distribution.
+     *
+     * @since 3.6
+     */
+    static final class ZipfRejectionSampler {
+
+        /** Number of elements. */
+        private final int numberOfElements;
+        /** Exponent parameter of the distribution. */
+        private final double exponent;
+        /** Cached tail weight of instrumental distribution used for rejection sampling */
+        private double instrumentalDistributionTailWeight = Double.NaN;
+
+        ZipfRejectionSampler(final int numberOfElements, final double exponent) {
+            this.numberOfElements = numberOfElements;
+            this.exponent = exponent;
+        }
+
+        int sample(final RandomGenerator random) {
+            if (Double.isNaN(instrumentalDistributionTailWeight)) {
+                instrumentalDistributionTailWeight = integratePowerFunction(-exponent, 1.5, numberOfElements+0.5);
+            }
+
+            while(true) {
+                final double randomValue = random.nextDouble()*(instrumentalDistributionTailWeight + 1.);
+                if (randomValue < instrumentalDistributionTailWeight) {
+                    final double q = randomValue / instrumentalDistributionTailWeight;
+                    final int sample = sampleFromInstrumentalDistributionTail(q);
+                    if (random.nextDouble() < acceptanceRateForTailSample(sample)) {
+                        return sample;
+                    }
+                }
+                else {
+                    return 1;
+                }
+            }
+        }
+
+        /**
+         * Returns a sample from the instrumental distribution tail for a given
+         * uniformly distributed random value.
+         *
+         * @param q a uniformly distributed random value taken from [0,1]
+         * @return a sample in the range [2, {@link #numberOfElements}]
+         */
+        int sampleFromInstrumentalDistributionTail(double q) {
+            final double a = 1.5;
+            final double b = numberOfElements + 0.5;
+            final double logBdviA = FastMath.log(b / a);
+
+            final int result  = (int) (a * FastMath.exp(logBdviA * helper1(q, logBdviA * (1. - exponent))) + 0.5);
+            if (result < 2) {
+                return 2;
+            }
+            if (result > numberOfElements) {
+                return numberOfElements;
+            }
+            return result;
+        }
+
+        /**
+         * Helper function that calculates log((1-q)+q*exp(x))/x.
+         * <p>
+         * A Taylor series expansion is used, if x is close to 0.
+         *
+         * @param q a value in the range [0,1]
+         * @param
+         * @return log((1-q)+q*exp(x))/x
+         */
+        static double helper1(final double q, final double x) {
+            if (Math.abs(x) > 1e-8) {
+                return FastMath.log((1.-q)+q*FastMath.exp(x))/x;
+            }
+            else {
+                return q*(1.+(1./2.)*x*(1.-q)*(1+(1./3.)*x*((1.-2.*q) + (1./4.)*x*(6*q*q*(q-1)+1))));
+            }
+        }
+
+        /**
+         * Helper function to calculate (exp(x)-1)/x.
+         * <p>
+         * A Taylor series expansion is used, if x is close to 0.
+         *
+         * @return (exp(x)-1)/x if x is non-zero, 1 if x=0
+         */
+        static double helper2(final double x) {
+            if (FastMath.abs(x)>1e-8) {
+                return FastMath.expm1(x)/x;
+            }
+            else {
+                return 1.+x*(1./2.)*(1.+x*(1./3.)*(1.+x*(1./4.)));
+            }
+        }
+
+        /**
+         * Integrates the power function x^r from x=a to b.
+         *
+         * @param r the exponent
+         * @param a the integral lower bound
+         * @param b the integral upper bound
+         * @return the calculated integral value
+         */
+        static double integratePowerFunction(final double r, final double a, final double b) {
+            final double logA = FastMath.log(a);
+            final double logBdivA = FastMath.log(b/a);
+            return FastMath.exp((1.+r)*logA)*helper2((1.+r)*logBdivA)*logBdivA;
+
+        }
+
+        /**
+         * Calculates the acceptance rate for a sample taken from the tail of the instrumental distribution.
+         * <p>
+         * The acceptance rate is given by the ratio k^(-s)/I(-s,k-0.5, k+0.5)
+         * where I(r,a,b) is the integral of x^r for x from a to b.
+         *
+         * @param k the value which has been sampled using the instrumental distribution
+         * @return the acceptance rate
+         */
+        double acceptanceRateForTailSample(int k) {
+            final double a = FastMath.log1p(1./(2.*k-1.));
+            final double b = FastMath.log1p(2./(2.*k-1.));
+            return FastMath.exp((1.-exponent)*a)/(k*b*helper2((1.-exponent)*b));
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java
index 06ec3c4..3c177ef 100644
--- a/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java
+++ b/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java
@@ -17,17 +17,26 @@
 
 package org.apache.commons.math3.distribution;
 
-import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
+import org.apache.commons.math3.TestUtils;
+import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.integration.SimpsonIntegrator;
+import org.apache.commons.math3.distribution.ZipfDistribution.ZipfRejectionSampler;
+import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.random.AbstractRandomGenerator;
+import org.apache.commons.math3.random.RandomGenerator;
+import org.apache.commons.math3.random.Well1024a;
 import org.apache.commons.math3.util.FastMath;
 import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 
 /**
  * Test cases for {@link ZipfDistribution}.
- * Extends IntegerDistributionAbstractTest.  See class javadoc for
- * IntegerDistributionAbstractTest for details.
- *
+ * Extends IntegerDistributionAbstractTest.
+ * See class javadoc for IntegerDistributionAbstractTest for details.
  */
 public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
 
@@ -37,7 +46,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
     public ZipfDistributionTest() {
         setTolerance(1e-12);
     }
-    
+
     @Test(expected=NotStrictlyPositiveException.class)
     public void testPreconditions1() {
         new ZipfDistribution(0, 1);
@@ -62,7 +71,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
         return new int[] {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
     }
 
-    /** 
+    /**
      * Creates the default probability density test expected values.
      *  Reference values are from R, version 2.15.3 (VGAM package 0.9-0).
      */
@@ -72,7 +81,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
             0.0569028586912, 0.0487738788782, 0.0426771440184, 0.0379352391275, 0.0341417152147, 0};
     }
 
-    /** 
+    /**
      * Creates the default logarithmic probability density test expected values.
      * Reference values are from R, version 2.14.1.
      */
@@ -119,4 +128,209 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest {
         Assert.assertEquals(dist.getNumericalMean(), FastMath.sqrt(2), tol);
         Assert.assertEquals(dist.getNumericalVariance(), 0.24264068711928521, tol);
     }
+
+    /**
+     * Test sampling for various number of points and exponents.
+     */
+    @Test
+    public void testSamplingExtended() {
+        int sampleSize = 1000;
+
+        int[] numPointsValues = {
+            2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100
+        };
+        double[] exponentValues = {
+            1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1,
+            1. - 1e-9, 1.0, 1. + 1e-9, 1.1, 1.2, 1.3, 1.5, 1.6, 1.7, 1.8, 2.0,
+            2.5, 3.0, 4., 5., 6., 7., 8., 9., 10., 20., 30.
+        };
+
+        for (int numPoints : numPointsValues) {
+            for (double exponent : exponentValues) {
+                double weightSum = 0.;
+                double[] weights = new double[numPoints];
+                for (int i = numPoints; i>=1; i-=1) {
+                    weights[i-1] = Math.pow(i, -exponent);
+                    weightSum += weights[i-1];
+                }
+
+                ZipfDistribution distribution = new ZipfDistribution(numPoints, exponent);
+                distribution.reseedRandomGenerator(6); // use fixed seed, the test is expected to fail for more than 50% of all seeds because each test case can fail with probability 0.001, the chance that all test cases do not fail is 0.999^(32*22) = 0.49442874426
+
+                double[] expectedCounts = new double[numPoints];
+                long[] observedCounts = new long[numPoints];
+                for (int i = 0; i < numPoints; i++) {
+                    expectedCounts[i] = sampleSize * (weights[i]/weightSum);
+                }
+                int[] sample = distribution.sample(sampleSize);
+                for (int s : sample) {
+                    observedCounts[s-1]++;
+                }
+                TestUtils.assertChiSquareAccept(expectedCounts, observedCounts, 0.001);
+            }
+        }
+    }
+
+    @Test
+    public void testSamplerIntegratePowerFunction() {
+        final double tol = 1e-6;
+        final double[] exponents = {
+            -1e-5, -1e-4, -1e-3, -1e-2, -1e-1, -1e0, -1e1
+        };
+        final double[] limits = {
+            0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.0, 6.5, 7.0,
+            7.5, 8.0, 8.5, 9.0, 9.5, 10.0
+        };
+
+        for (final double exponent : exponents) {
+            for (int lowerLimitIndex = 0; lowerLimitIndex < limits.length; ++lowerLimitIndex) {
+                final double lowerLimit = limits[lowerLimitIndex];
+                for (int upperLimitIndex = lowerLimitIndex+1; upperLimitIndex < limits.length; ++upperLimitIndex) {
+                    final double upperLimit = limits[upperLimitIndex];
+                    final double result1 = new SimpsonIntegrator().integrate(10000, new UnivariateFunction() {
+                        public double value(double x) {
+                            return Math.pow(x, exponent);
+                        }
+                    }, lowerLimit, upperLimit);
+
+                    final double result2 = ZipfRejectionSampler.integratePowerFunction(exponent, lowerLimit, upperLimit);
+
+                    assertEquals(result1, result2, (result1+result2)*tol);
+
+                }
+            }
+        }
+    }
+
+    @Test
+    public void testSamplerAcceptanceRate() {
+        final double tol = 1e-12;
+        final double[] exponents = {
+            1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 5e0, 1e1, 1e2, 1e3
+        };
+        final int[] values = {
+            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
+        };
+        final int numberOfElements = 1000;
+        for (final double exponent : exponents) {
+            ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
+            for (final int value : values) {
+                double expected = FastMath.pow(value, -exponent);
+                double result = sampler.acceptanceRateForTailSample(value) *
+                                ZipfRejectionSampler.integratePowerFunction(-exponent, value - 0.5, value + 0.5);
+                TestUtils.assertRelativelyEquals(expected, result, tol);
+                assertTrue(result <=1.); // test Jensen's inequality
+            }
+        }
+    }
+
+    @Test
+    public void testSamplerInverseInstrumentalDistribution() {
+        final double tol = 1e-14;
+        final double[] exponentValues = {
+            1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2E0, 3e0, 4e0, 5e0, 6., 7., 8., 9., 10., 50.
+        };
+        final double[] qValues = {
+            0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0
+        };
+        final int[] numberOfElementsValues = {
+            2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 100
+        };
+
+        for (final double exponent : exponentValues) {
+            for (final int numberOfElements : numberOfElementsValues) {
+                final ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent);
+                for (final double q : qValues) {
+                    int result = sampler.sampleFromInstrumentalDistributionTail(q);
+                    double total =
+                            ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, numberOfElements + 0.5);
+                    double lowerBound =
+                            ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result - 0.5) / total;
+                    double upperBound =
+                            ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result + 0.5) / total;
+                    assertTrue(lowerBound <= q*(1.+tol));
+                    assertTrue(upperBound >= q*(1.-tol));
+                }
+            }
+        }
+    }
+
+    @Test
+    public void testSamplerHelper1() {
+        final double tol = 1e-14;
+        final double[] qValues = {
+            0., 1e-12, 1e-11, 1e-10, 1e-9, 9e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4,
+            1e-3, 1e-2, 1e-1, 1e0
+        };
+        final double[] xValues = {
+            -Double.MAX_VALUE, -1e10, -1e9, -1e8, -1e7, -1e6, -1e5, -1e4, -1e3,
+            -1e2, -1e1, -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7,
+            -1e-8, -1e-9, -1e-10, -Double.MIN_VALUE, 0.0, Double.MIN_VALUE,
+            1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0,
+            1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, Double.MAX_VALUE
+        };
+
+        for (final double q : qValues) {
+            for(final double x : xValues) {
+                double calculated = ZipfRejectionSampler.helper1(q, x);
+                TestUtils.assertRelativelyEquals((1.-q)+q*Math.exp(x), FastMath.exp(calculated*x), tol);
+            }
+        }
+    }
+
+    @Test
+    public void testSamplerHelper2() {
+        final double tol = 1e-12;
+        final double[] testValues = {
+            -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7, -1e-8,
+            -1e-9, -1e-10, -1e-11, 0., 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6,
+            1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0
+        };
+        for (double testValue : testValues) {
+            final double expected = FastMath.expm1(testValue);
+            TestUtils.assertRelativelyEquals(expected, ZipfRejectionSampler.helper2(testValue)*testValue, tol);
+        }
+    }
+
+    @Ignore
+    @Test
+    public void testSamplerPerformance() {
+        int[] numPointsValues = {1, 2, 5, 10, 100, 1000, 10000};
+        double[] exponentValues = {1e-3, 1e-2, 1e-1, 1., 2., 5., 10.};
+        int  numGeneratedSamples = 1000000;
+
+        long sum = 0;
+
+        for (int numPoints : numPointsValues) {
+            for (double exponent : exponentValues) {
+                long start = System.currentTimeMillis();
+                final int[] randomNumberCounter = new int[1];
+
+                RandomGenerator randomGenerator  = new AbstractRandomGenerator() {
+
+                    private final RandomGenerator r = new Well1024a(0L);
+
+                    @Override
+                    public void setSeed(long seed) {
+                    }
+
+                    @Override
+                    public double nextDouble() {
+                        randomNumberCounter[0]+=1;
+                        return r.nextDouble();
+                    }
+                };
+
+                final ZipfDistribution distribution = new ZipfDistribution(randomGenerator, numPoints, exponent);
+                for (int i = 0; i < numGeneratedSamples; ++i) {
+                    sum += distribution.sample();
+                }
+
+                long end = System.currentTimeMillis();
+                System.out.println("n = " + numPoints + ", exponent = " + exponent + ", avg number consumed random values = " + (double)(randomNumberCounter[0])/numGeneratedSamples + ", measured time = " + (end-start)/1000. + "s");
+            }
+        }
+        System.out.println(sum);
+    }
+
 }