You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2023/02/17 18:48:28 UTC

[commons-statistics] branch master updated: STATISTICS-70: Improve Hypergeometric distribution probability sums

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git


The following commit(s) were added to refs/heads/master by this push:
     new 584bf89  STATISTICS-70: Improve Hypergeometric distribution probability sums
584bf89 is described below

commit 584bf8966b999e542d389cbe7f8f76516d5dbacf
Author: aherbert <ah...@apache.org>
AuthorDate: Fri Feb 17 18:47:49 2023 +0000

    STATISTICS-70: Improve Hypergeometric distribution probability sums
    
    Cache the midpoint of the CDF closest to 0.5. This midpoint is used to
    compute the CDF or SF using the appropriate domain which uses the
    summation of smaller terms.
    
    The probability(x0, x1) function is implemented to sum the smallest
    range of the PDF, avoiding the duplicate summation performed in the
    default implementation using CDF(x1) - CDF(x0).
    
    The inverse CDF or SF is computed using a single summation to find the
    quantile.
---
 .../distribution/HypergeometricDistribution.java   | 205 ++++++++++++++++++++-
 .../HypergeometricDistributionTest.java            |  33 ++++
 src/changes/changes.xml                            |   4 +
 src/conf/pmd/pmd-ruleset.xml                       |   3 +-
 src/conf/spotbugs/spotbugs-exclude-filter.xml      |   5 +
 5 files changed, 240 insertions(+), 10 deletions(-)

diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java
index 58b2ac3..e8ed472 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java
@@ -17,6 +17,8 @@
 
 package org.apache.commons.statistics.distribution;
 
+import java.util.function.DoublePredicate;
+
 /**
  * Implementation of the hypergeometric distribution.
  *
@@ -37,6 +39,8 @@ package org.apache.commons.statistics.distribution;
  * @see <a href="https://mathworld.wolfram.com/HypergeometricDistribution.html">Hypergeometric distribution (MathWorld)</a>
  */
 public final class HypergeometricDistribution extends AbstractDiscreteDistribution {
+    /** 1/2. */
+    private static final double HALF = 0.5;
     /** The number of successes in the population. */
     private final int numberOfSuccesses;
     /** The population size. */
@@ -48,9 +52,12 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
     /** The upper bound of the support (inclusive). */
     private final int upperBound;
     /** Binomial probability of success (sampleSize / populationSize). */
-    private final double p;
+    private final double bp;
     /** Binomial probability of failure ((populationSize - sampleSize) / populationSize). */
-    private final double q;
+    private final double bq;
+    /** Cached midpoint of the CDF/SF. The array holds [x, cdf(x)] for the midpoint x.
+     * Used for the cumulative probability functions. */
+    private double[] midpoint;
 
     /**
      * @param populationSize Population size.
@@ -65,8 +72,8 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
         this.sampleSize = sampleSize;
         lowerBound = getLowerDomain(populationSize, numberOfSuccesses, sampleSize);
         upperBound = getUpperDomain(numberOfSuccesses, sampleSize);
-        p = (double) sampleSize / (double) populationSize;
-        q = (double) (populationSize - sampleSize) / (double) populationSize;
+        bp = (double) sampleSize / (double) populationSize;
+        bq = (double) (populationSize - sampleSize) / (double) populationSize;
     }
 
     /**
@@ -167,6 +174,33 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
         return Math.exp(logProbability(x));
     }
 
+    /** {@inheritDoc} */
+    @Override
+    public double probability(int x0, int x1) {
+        if (x0 > x1) {
+            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
+        }
+        if (x0 == x1 || x1 < lowerBound) {
+            return 0;
+        }
+        // If the range is outside the bounds use the appropriate cumulative probability
+        if (x0 < lowerBound) {
+            return cumulativeProbability(x1);
+        }
+        if (x1 >= upperBound) {
+            // 1 - cdf(x0)
+            return survivalProbability(x0);
+        }
+        // Here: lower <= x0 < x1 < upper:
+        // sum(pdf(x)) for x in (x0, x1]
+        final int lo = x0 + 1;
+        // Sum small values first by starting at the point the greatest distance from the mode.
+        final int mode = (int) Math.floor((sampleSize + 1.0) * (numberOfSuccesses + 1.0) / (populationSize + 2.0));
+        return Math.abs(mode - lo) > Math.abs(mode - x1) ?
+            innerCumulativeProbability(lo, x1) :
+            innerCumulativeProbability(x1, lo);
+    }
+
     /** {@inheritDoc} */
     @Override
     public double logProbability(int x) {
@@ -184,12 +218,12 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
      */
     private double computeLogProbability(int x) {
         final double p1 =
-                SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, p, q);
+                SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, bp, bq);
         final double p2 =
                 SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x,
-                        populationSize - numberOfSuccesses, p, q);
+                        populationSize - numberOfSuccesses, bp, bq);
         final double p3 =
-                SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, p, q);
+                SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, bp, bq);
         return p1 + p2 - p3;
     }
 
@@ -201,7 +235,15 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
         } else if (x >= upperBound) {
             return 1.0;
         }
-        return innerCumulativeProbability(lowerBound, x);
+        final double[] mid = getMidPoint();
+        final int m = (int) mid[0];
+        if (x < m) {
+            return innerCumulativeProbability(lowerBound, x);
+        } else if (x > m) {
+            return 1 - innerCumulativeProbability(upperBound, x + 1);
+        }
+        // cdf(x)
+        return mid[1];
     }
 
     /** {@inheritDoc} */
@@ -212,7 +254,15 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
         } else if (x >= upperBound) {
             return 0.0;
         }
-        return innerCumulativeProbability(upperBound, x + 1);
+        final double[] mid = getMidPoint();
+        final int m = (int) mid[0];
+        if (x < m) {
+            return 1 - innerCumulativeProbability(lowerBound, x);
+        } else if (x > m) {
+            return innerCumulativeProbability(upperBound, x + 1);
+        }
+        // 1 - cdf(x)
+        return 1 - mid[1];
     }
 
     /**
@@ -247,6 +297,110 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
         return ret;
     }
 
+    @Override
+    public int inverseCumulativeProbability(double p) {
+        ArgumentUtils.checkProbability(p);
+        return computeInverseProbability(p, 1 - p, false);
+    }
+
+    @Override
+    public int inverseSurvivalProbability(double p) {
+        ArgumentUtils.checkProbability(p);
+        return computeInverseProbability(1 - p, p, true);
+    }
+
+    /**
+     * Implementation for the inverse cumulative or survival probability.
+     *
+     * @param p Cumulative probability.
+     * @param q Survival probability.
+     * @param complement Set to true to compute the inverse survival probability.
+     * @return the value
+     */
+    private int computeInverseProbability(double p, double q, boolean complement) {
+        if (p == 0) {
+            return lowerBound;
+        }
+        if (q == 0) {
+            return upperBound;
+        }
+
+        // Sum the PDF(x) until the appropriate p-value is obtained
+        // CDF: require smallest x where P(X<=x) >= p
+        // SF:  require smallest x where P(X>x) <= q
+        // The choice of summation uses the mid-point.
+        // The test on the CDF or SF is based on the appropriate input p-value.
+
+        final double[] mid = getMidPoint();
+        final int m = (int) mid[0];
+        final double mp = mid[1];
+
+        final int midPointComparison = complement ?
+            Double.compare(1 - mp, q) :
+            Double.compare(p, mp);
+
+        if (midPointComparison < 0) {
+            return inverseLower(p, q, complement);
+        } else if (midPointComparison > 0) {
+            // Avoid floating-point summation error when the mid-point computed using the
+            // lower sum is different to the midpoint computed using the upper sum.
+            // Here we know the result must be above the midpoint so we can clip the result.
+            return Math.max(m + 1, inverseUpper(p, q, complement));
+        }
+        // Exact mid-point
+        return m;
+    }
+
+    /**
+     * Compute the inverse cumulative or survival probability using the lower sum.
+     *
+     * @param p Cumulative probability.
+     * @param q Survival probability.
+     * @param complement Set to true to compute the inverse survival probability.
+     * @return the value
+     */
+    private int inverseLower(double p, double q, boolean complement) {
+        // Sum from the lower bound (computing the cdf)
+        int x = lowerBound;
+        final DoublePredicate test = complement ?
+            i -> 1 - i > q :
+            i -> i < p;
+        double cdf = Math.exp(computeLogProbability(x));
+        while (test.test(cdf)) {
+            x++;
+            cdf += Math.exp(computeLogProbability(x));
+        }
+        return x;
+    }
+
+    /**
+     * Compute the inverse cumulative or survival probability using the upper sum.
+     *
+     * @param p Cumulative probability.
+     * @param q Survival probability.
+     * @param complement Set to true to compute the inverse survival probability.
+     * @return the value
+     */
+    private int inverseUpper(double p, double q, boolean complement) {
+        // Sum from the upper bound (computing the sf)
+        int x = upperBound;
+        final DoublePredicate test = complement ?
+            i -> i < q :
+            i -> 1 - i > p;
+        double sf = 0;
+        while (test.test(sf)) {
+            sf += Math.exp(computeLogProbability(x));
+            x--;
+        }
+        // Here either sf(x) >= q, or cdf(x) <= p
+        // Ensure sf(x) <= q, or cdf(x) >= p
+        if (complement && sf > q ||
+            !complement && 1 - sf < p) {
+            x++;
+        }
+        return x;
+    }
+
     /**
      * {@inheritDoc}
      *
@@ -301,4 +455,37 @@ public final class HypergeometricDistribution extends AbstractDiscreteDistributi
     public int getSupportUpperBound() {
         return upperBound;
     }
+
+    /**
+     * Return the mid-point {@code x} of the distribution, and the cdf(x).
+     *
+     * <p>This is not the true median. It is the value where the CDF(x) is closest to 0.5;
+     * as such the CDF may be below 0.5 if the next value of x is further from 0.5.
+     *
+     * @return the mid-point ([x, cdf(x)])
+     */
+    private double[] getMidPoint() {
+        double[] v = midpoint;
+        if (v == null) {
+            // Find the closest sum(PDF) to 0.5
+            int x = lowerBound;
+            double p0 = 0;
+            double p1 = Math.exp(computeLogProbability(x));
+            // No check of the upper bound required here as the CDF should sum to 1 and 0.5
+            // is exceeded before a bounds error.
+            while (p1 < HALF) {
+                x++;
+                p0 = p1;
+                p1 += Math.exp(computeLogProbability(x));
+            }
+            // p1 >= 0.5 > p0
+            // Pick closet
+            if (p1 - HALF >= HALF - p0) {
+                x--;
+                p1 = p0;
+            }
+            midpoint = v = new double[] {x, p1};
+        }
+        return v;
+    }
 }
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java
index e6016aa..6fadd7d 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.commons.statistics.distribution;
 
+import java.util.stream.IntStream;
 import java.util.stream.Stream;
 import org.apache.commons.numbers.core.Precision;
 import org.apache.commons.rng.simple.RandomSource;
@@ -24,6 +25,7 @@ import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.CsvSource;
 import org.junit.jupiter.params.provider.MethodSource;
 
 /**
@@ -244,4 +246,35 @@ class HypergeometricDistributionTest extends BaseDiscreteDistributionTest {
             new double[] {4.570379934029859e-16, 7.4187180434325268e-18},
             DoubleTolerances.relative(5e-14));
     }
+
+    @ParameterizedTest
+    @CsvSource({
+        "1, 0, 0",
+        "1, 1, 0",
+        "1, 0, 1",
+        "1, 1, 1",
+        "2, 1, 1",
+        "2, 1, 2",
+        "2, 2, 1",
+        "2, 2, 2",
+        "3, 1, 1",
+        "3, 1, 2",
+        "3, 1, 3",
+        "3, 2, 1",
+        "3, 2, 2",
+        "3, 2, 3",
+        "3, 3, 1",
+        "3, 3, 2",
+        "3, 3, 3",
+        // Mean = n * K / N
+        "15, 9, 7", // 4.2
+        "23, 13, 11", // 6.22
+        "200, 130, 70", // 45.5
+    })
+    void testAdditionalInverseMapping(int populationSize, int numberOfSuccesses, int sampleSize) {
+        final HypergeometricDistribution dist = HypergeometricDistribution.of(populationSize, numberOfSuccesses, sampleSize);
+        final int[] points = IntStream.rangeClosed(dist.getSupportLowerBound(), dist.getSupportUpperBound()).toArray();
+        testCumulativeProbabilityInverseMapping(dist, points);
+        testSurvivalProbabilityInverseMapping(dist, points);
+    }
 }
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 137c6cd..eddd61f 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -56,6 +56,10 @@ If the output is not quite correct, check for invisible trailing spaces!
     <release version="1.1" date="TBD" description="
 Adds ranking, inference and bom modules. (requires Java 8).
 ">
+      <action dev="aherbert" type="add" issue="STATISTICS-70">
+        "HypergeometricDistribution": Improve the summation used for the cumulative
+        probability functions.
+      </action>
       <action dev="aherbert" type="add" issue="STATISTICS-66">
         Add a Bill of Materials (BOM) to aid in dependency management when referencing multiple
         Apache Commons Statistics artifacts. The BOM should be used to ensure all imported
diff --git a/src/conf/pmd/pmd-ruleset.xml b/src/conf/pmd/pmd-ruleset.xml
index a66da52..b7e816f 100644
--- a/src/conf/pmd/pmd-ruleset.xml
+++ b/src/conf/pmd/pmd-ruleset.xml
@@ -157,7 +157,8 @@
       <property name="violationSuppressXPath"
         value="./ancestor-or-self::ClassOrInterfaceDeclaration[@SimpleName='NaturalRanking'
           or @SimpleName='KolmogorovSmirnovTest' or @SimpleName='DD' or @SimpleName='Arguments'
-          or @SimpleName='MannWhitneyUTest' or @SimpleName='WilcoxonSignedRankTest']"/>
+          or @SimpleName='MannWhitneyUTest' or @SimpleName='WilcoxonSignedRankTest'
+          or @SimpleName='HypergeometricDistribution']"/>
     </properties>
   </rule>
   <rule ref="category/java/design.xml/LogicInversion">
diff --git a/src/conf/spotbugs/spotbugs-exclude-filter.xml b/src/conf/spotbugs/spotbugs-exclude-filter.xml
index 629481e..084b52d 100644
--- a/src/conf/spotbugs/spotbugs-exclude-filter.xml
+++ b/src/conf/spotbugs/spotbugs-exclude-filter.xml
@@ -61,6 +61,11 @@
     <Method name="lambda$createSampler$2" />
     <Bug pattern="FL_FLOATS_AS_LOOP_COUNTERS" />
   </Match>
+  <Match>
+    <Class name="org.apache.commons.statistics.distribution.HypergeometricDistribution" />
+    <Method name="getMidPoint" />
+    <Bug pattern="FL_FLOATS_AS_LOOP_COUNTERS" />
+  </Match>
 
   <Match>
     <Class name="org.apache.commons.statistics.ranking.NaturalRanking$DataPosition" />