You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2023/02/16 12:27:34 UTC
[commons-statistics] 03/07: STATISTICS-64: Add Fisher's exact test
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
commit 9851570a67e61a00fcd9718f9365d37dc750d4c9
Author: Alex Herbert <a....@sussex.ac.uk>
AuthorDate: Mon Feb 13 13:29:11 2023 +0000
STATISTICS-64: Add Fisher's exact test
---
.../statistics/inference/FisherExactTest.java | 243 +++++++++++++++++++++
.../statistics/inference/FisherExactTestTest.java | 223 +++++++++++++++++++
src/conf/pmd/pmd-ruleset.xml | 2 +-
3 files changed, 467 insertions(+), 1 deletion(-)
diff --git a/commons-statistics-inference/src/main/java/org/apache/commons/statistics/inference/FisherExactTest.java b/commons-statistics-inference/src/main/java/org/apache/commons/statistics/inference/FisherExactTest.java
new file mode 100644
index 0000000..b8954fe
--- /dev/null
+++ b/commons-statistics-inference/src/main/java/org/apache/commons/statistics/inference/FisherExactTest.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.statistics.inference;
+
+import java.util.Objects;
+import org.apache.commons.statistics.distribution.HypergeometricDistribution;
+
+/**
+ * Implements Fisher's exact test.
+ *
+ * <p>Performs an exact test for the statistical significance of the association (contingency)
+ * between two kinds of categorical classification.
+ *
+ * <p>Fisher's test applies in the case that the row sums and column sums are fixed in advance
+ * and not random.
+ *
+ * @see <a href="https://en.wikipedia.org/wiki/Fisher%27s_exact_test">Fisher's exact test (Wikipedia)</a>
+ * @since 1.1
+ */
+public final class FisherExactTest {
+ /** Two. */
+ private static final int TWO = 2;
+ /** Default instance. */
+ private static final FisherExactTest DEFAULT = new FisherExactTest(AlternativeHypothesis.TWO_SIDED);
+
+ /** Alternative hypothesis. */
+ private final AlternativeHypothesis alternative;
+
+ /**
+ * @param alternative Alternative hypothesis.
+ */
+ private FisherExactTest(AlternativeHypothesis alternative) {
+ this.alternative = alternative;
+ }
+
+ /**
+ * Return an instance using the default options.
+ *
+ * <ul>
+ * <li>{@link AlternativeHypothesis#TWO_SIDED}
+ * </ul>
+ *
+ * @return default instance
+ */
+ public static FisherExactTest withDefaults() {
+ return DEFAULT;
+ }
+
+ /**
+ * Return an instance with the configured alternative hypothesis.
+ *
+ * @param v Value.
+ * @return an instance
+ */
+ public FisherExactTest with(AlternativeHypothesis v) {
+ return new FisherExactTest(Objects.requireNonNull(v));
+ }
+
+ /**
+ * Compute the prior odds ratio for the 2-by-2 contingency table. This is the
+ * "sample" or "unconditional" maximum likelihood estimate. For a table of:
+ *
+ * <p>\[ \left[ {\begin{array}{cc}
+ * a & b \\
+ * c & d \\
+ * \end{array} } \right] \]
+ *
+ * <p>this is:
+ *
+ * <p>\[ r = \frac{a d}{b c} \]
+ *
+ * <p>Special cases:
+ * <ul>
+ * <li>If the denominator is zero, the value is {@link Double#POSITIVE_INFINITY infinity}.
+ * <li>If a row or column sum is zero, the value is {@link Double#NaN NaN}.
+ * </ul>
+ *
+ * <p>Note: This statistic is equal to the statistic computed by the SciPy function
+ * {@code scipy.stats.fisher_exact}. It is different to the conditional maximum
+ * likelihood estimate computed by R function {@code fisher.test}.
+ *
+ * @param table 2-by-2 contingency table.
+ * @return odds ratio
+ * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
+ * table entry is negative; or the sum of the table is 0 or larger than a 32-bit signed integer.
+ * @see #with(AlternativeHypothesis)
+ * @see #test(int[][])
+ */
+ public double statistic(int[][] table) {
+ checkTable(table);
+ final double a = table[0][0];
+ final double b = table[0][1];
+ final double c = table[1][0];
+ final double d = table[1][1];
+ return (a * d) / (b * c);
+ }
+
+ /**
+ * Performs Fisher's exact test on the 2-by-2 contingency table.
+ *
+ * <p>The test statistic is equal to the prior odds ratio. This is the
+ * "sample" or "unconditional" maximum likelihood estimate.
+ *
+ * <p>The test is defined by the {@link AlternativeHypothesis}.
+ *
+ * <p>For a table of [[a, b], [c, d]] the possible values of any table are conditioned
+ * with the same marginals (row and column totals). In this case the possible values {@code x}
+ * of the upper-left element {@code a} are {@code min(0, a - d) <= x <= a + min(b, c)}.
+ * <ul>
+ * <li>'two-sided': the odds ratio of the underlying population is not one; the p-value
+ * is the probability that a random table has probability equal to or less than the input table.
+ * <li>'greater': the odds ratio of the underlying population is greater than one; the p-value
+ * is the probability that a random table has {@code x >= a}.
+ * <li>'less': the odds ratio of the underlying population is less than one; the p-value
+ * is the probability that a random table has {@code x <= a}.
+ * </ul>
+ *
+ * @param table 2-by-2 contingency table.
+ * @return test result
+ * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
+ * table entry is negative; or the sum of the table is 0 or larger than a 32-bit signed integer.
+ * @see #with(AlternativeHypothesis)
+ * @see #statistic(int[][])
+ */
+ public SignificanceResult test(int[][] table) {
+ checkTable(table);
+ final int a = table[0][0];
+ final int b = table[0][1];
+ final int c = table[1][0];
+ final int d = table[1][1];
+
+ // Odd-ratio.
+ final double statistic = ((double) a * d) / ((double) b * c);
+
+ final int nn = a + b + c + d;
+ final int k = a + b;
+ final int n = a + c;
+
+ // Note: The distribution validates the population size is > 0
+ final HypergeometricDistribution distribution = HypergeometricDistribution.of(nn, k, n);
+ double p;
+ if (alternative == AlternativeHypothesis.GREATER_THAN) {
+ p = distribution.survivalProbability(a - 1);
+ } else if (alternative == AlternativeHypothesis.LESS_THAN) {
+ p = distribution.cumulativeProbability(a);
+ } else {
+ p = twoSidedTest(a, distribution);
+ }
+ return new BaseSignificanceResult(statistic, p);
+ }
+
+ /**
+ * Returns the <i>observed significance level</i>, or p-value, associated with a
+ * two-sided test about the observed value.
+ *
+ * @param k Observed value.
+ * @param distribution Hypergeometric distribution.
+ * @return p-value
+ */
+ private static double twoSidedTest(int k, HypergeometricDistribution distribution) {
+ // Find all i where Pr(X = i) <= Pr(X = k) and sum them.
+ // Exploit the known unimodal distribution to increase the
+ // search speed. Note the search depends only on magnitude differences.
+ // The current HypergeometricDistribution is faster using log probability
+ // as it omits a call to Math.exp.
+
+ // Use the mode as the point of largest probability.
+ // The lower or upper mode is important for the search below.
+ final int nn = distribution.getPopulationSize();
+ final int kk = distribution.getNumberOfSuccesses();
+ final int n = distribution.getSampleSize();
+ final double v = ((double) n + 1) * ((double) kk + 1) / (nn + 2.0);
+ final int m1 = (int) Math.ceil(v) - 1;
+ final int m2 = (int) Math.floor(v);
+ if (k < m1) {
+ final double pk = distribution.logProbability(k);
+ // Lower half = cdf(k)
+ // Find upper half. As k < lower mode i should never
+ // reach the lower mode based on the probability alone.
+ // Bracket with the upper mode.
+ final int i = Searches.searchDescending(m2, distribution.getSupportUpperBound(), pk,
+ distribution::logProbability);
+ return distribution.cumulativeProbability(k) +
+ distribution.survivalProbability(i - 1);
+ } else if (k > m2) {
+ final double pk = distribution.logProbability(k);
+ // Upper half = sf(k - 1)
+ // Find lower half. As k > upper mode i should never
+ // reach the upper mode based on the probability alone.
+ // Bracket with the lower mode.
+ final int i = Searches.searchAscending(distribution.getSupportLowerBound(), m1, pk,
+ distribution::logProbability);
+ return distribution.cumulativeProbability(i) +
+ distribution.survivalProbability(k - 1);
+ }
+ // k == mode
+ // Edge case where the sum of probabilities will be either
+ // 1 or 1 - Pr(X = mode) where mode != k
+ final double pk = distribution.probability(k);
+ final double pm = distribution.probability(k == m1 ? m2 : m1);
+ return pm > pk ? 1 - pm : 1;
+ }
+
+ /**
+ * Check the input is a 2-by-2 contingency table.
+ *
+ * @param table Table.
+ * @throws IllegalArgumentException if the {@code table} is not a 2-by-2 table; any
+ * table entry is negative; or the sum is zero or is not an integer
+ */
+ private static void checkTable(int[][] table) {
+ if (table.length != TWO || table[0].length != TWO || table[1].length != TWO) {
+ throw new InferenceException("Require a 2-by-2 contingency table");
+ }
+ // Must all be positive
+ final int a = table[0][0];
+ final int b = table[0][1];
+ final int c = table[1][0];
+ final int d = table[1][1];
+ // Bitwise OR combines the sign bit from all values
+ Arguments.checkNonNegative(a | b | c | d);
+ // Sum must be an integer
+ final long sum = (long) a + b + c + d;
+ if (sum > Integer.MAX_VALUE) {
+ throw new InferenceException(InferenceException.X_GT_Y, sum, Integer.MAX_VALUE);
+ }
+ Arguments.checkStrictlyPositive((int) sum);
+ }
+}
diff --git a/commons-statistics-inference/src/test/java/org/apache/commons/statistics/inference/FisherExactTestTest.java b/commons-statistics-inference/src/test/java/org/apache/commons/statistics/inference/FisherExactTestTest.java
new file mode 100644
index 0000000..9b157ff
--- /dev/null
+++ b/commons-statistics-inference/src/test/java/org/apache/commons/statistics/inference/FisherExactTestTest.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.statistics.inference;
+
+import java.util.Arrays;
+import java.util.function.Consumer;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+import org.apache.commons.statistics.distribution.HypergeometricDistribution;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.CsvSource;
+import org.junit.jupiter.params.provider.MethodSource;
+
+/**
+ * Test cases for {@link FisherExactTest}.
+ */
+class FisherExactTestTest {
+
+ @Test
+ void testInvalidOptionsThrows() {
+ final FisherExactTest test = FisherExactTest.withDefaults();
+ Assertions.assertThrows(NullPointerException.class, () ->
+ test.with((AlternativeHypothesis) null));
+ }
+
+ @Test
+ void testFisherExactTestInvalidTableThrows() {
+ assertFisherExactTestInvalidTableThrows(FisherExactTest.withDefaults()::statistic);
+ assertFisherExactTestInvalidTableThrows(FisherExactTest.withDefaults()::test);
+ }
+
+ private void assertFisherExactTestInvalidTableThrows(Consumer<int[][]> action) {
+ // Non 2-by-2 input
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
+ FisherExactTest.withDefaults().test(new int[3][3]));
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
+ FisherExactTest.withDefaults().test(new int[2][1]));
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
+ FisherExactTest.withDefaults().test(new int[1][2]));
+ // Non-square input
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
+ FisherExactTest.withDefaults().test(new int[][] {
+ new int[2], new int[1]
+ }));
+ Assertions.assertThrows(IllegalArgumentException.class, () ->
+ FisherExactTest.withDefaults().test(new int[][] {
+ new int[1], new int[2]
+ }));
+ }
+
+ @ParameterizedTest
+ @CsvSource({
+ "0, 0, 0, 0",
+ // Overflow
+ "2147483647, 1, 0, 0",
+ "2147483647, 0, 1, 0",
+ "2147483647, 0, 0, 1",
+ "2147483647, 2147483647, 0, 0",
+ "2147483647, 0, 2147483647, 0",
+ "2147483647, 0, 0, 2147483647",
+ "2147483647, 0, 2147483647, 2147483647",
+ "2147483647, 2147483647, 0, 2147483647",
+ "2147483647, 2147483647, 2147483647, 2147483647",
+ })
+ void testFisherExactTestThrows(int a, int b, int c, int d) {
+ final int[][] table = {{a, b}, {c, d}};
+ final FisherExactTest test = FisherExactTest.withDefaults();
+ Assertions.assertThrows(IllegalArgumentException.class, () -> test.statistic(table), "statistic");
+ Assertions.assertThrows(IllegalArgumentException.class, () -> test.test(table), "test");
+ }
+
+ /**
+ * Test the Fisher exact test for each alternative hypothesis and all possible k given
+ * the input table.
+ */
+ @ParameterizedTest
+ @CsvSource({
+ // The epsilon here is due to a difference in summation of p-values. This test
+ // sums all p-values in the same stream which uses an extended precision sum.
+ // The FisherExactTest adds the CDF and SF which use standard precision summations.
+ "0, 0, 1, 2e-16",
+ "1, 0, 1, 2e-16",
+ "0, 1, 1, 2e-16",
+ "1, 1, 1, 2e-16",
+ "8, 7, 13, 2e-16",
+ "10, 12, 24, 2e-16",
+ "20, 25, 43, 3e-16",
+ // Create a contingency table where the hypergeometric mode is 1.5
+ "4, 2, 8, 2e-16",
+ })
+ void testFisherExactTest(int n, int kk, int nn, double eps) {
+ final HypergeometricDistribution dist = HypergeometricDistribution.of(nn, kk, n);
+ final int low = dist.getSupportLowerBound();
+ final int high = dist.getSupportUpperBound();
+ final double[] pk = IntStream.rangeClosed(0, high).mapToDouble(dist::probability).toArray();
+
+ // Note: TestUtils.assertProbability expects exact equality when p is 0 or 1.
+ // We *could* set the maximum for the sum to below 1 to avoid this.
+ final double maxP = 1.0;
+
+ final FisherExactTest twoSided = FisherExactTest.withDefaults();
+ final FisherExactTest less = FisherExactTest.withDefaults().with(AlternativeHypothesis.LESS_THAN);
+ final FisherExactTest greater = FisherExactTest.withDefaults().with(AlternativeHypothesis.GREATER_THAN);
+
+ IntStream.rangeClosed(low, high).forEach(k -> {
+ final int[][] table = {
+ {k, kk - k},
+ {n - k, nn - (n + kk) + k}
+ };
+ double expected;
+
+ // One-sided
+ expected = k == high ? 1 :
+ Math.min(maxP, IntStream.rangeClosed(low, k).mapToDouble(i -> pk[i]).sum());
+ TestUtils.assertProbability(expected,
+ less.test(table).getPValue(), eps,
+ () -> "less than: k=" + k);
+
+ expected = k == low ? 1 :
+ Math.min(maxP, IntStream.rangeClosed(k, high).mapToDouble(i -> pk[i]).sum());
+ TestUtils.assertProbability(expected,
+ greater.test(table).getPValue(), eps,
+ () -> "greater than: k=" + k);
+
+ // Two-sided
+ // Find all i where Pr(X = i) <= Pr(X = k) and sum them.
+ // Create an exact sum of 1.0 when all Pr(X = i) <= Pr(X = k).
+ expected = IntStream.rangeClosed(low, high).noneMatch(i -> pk[i] > pk[k]) ? 1 :
+ Math.min(maxP, Arrays.stream(pk).filter(x -> x <= pk[k]).sum());
+ TestUtils.assertProbability(expected,
+ twoSided.test(table).getPValue(), eps,
+ () -> "two-sided: k=" + k);
+ });
+ }
+
+ /**
+ * Test the p-value at the mode.
+ * See also Math-1644 which is relevant to the same situation in the BinomialTest
+ */
+ @ParameterizedTest
+ @CsvSource({
+ // k = mode = ceil((n+1)(K+1) / (N+2)) - 1, floor((n+1)(K+1) / (N+2))
+ // Exact mode
+ "2, 2, 7, 1",
+ "4, 3, 8, 2",
+ // Rounded
+ "2, 2, 5, 1",
+ "2, 2, 5, 2",
+ "4, 3, 10, 1",
+ "4, 3, 10, 2",
+ // mode == 1.5
+ "4, 2, 8, 1",
+ "4, 2, 8, 2",
+ })
+ void testMode(int n, int kk, int nn, int k) {
+ final int[][] table = {
+ {k, kk - k},
+ {n - k, nn - (n + kk) + k}
+ };
+ final double pval = FisherExactTest.withDefaults().test(table).getPValue();
+ Assertions.assertTrue(pval <= 1, () -> "pval=" + pval);
+ }
+
+ @ParameterizedTest
+ @MethodSource
+ void testFisherExactTest(int a, int b, int c, int d, double ratio, double[] p) {
+ final int[][] table = {{a, b}, {c, d}};
+ final double statistic = FisherExactTest.withDefaults().statistic(table);
+ TestUtils.assertRelativelyEquals(ratio, statistic, 2e-16, "statistic");
+ int i = 0;
+ for (final AlternativeHypothesis h : AlternativeHypothesis.values()) {
+ final SignificanceResult r = FisherExactTest.withDefaults().with(h).test(table);
+ Assertions.assertEquals(statistic, r.getStatistic(), "statistic mismatch");
+ TestUtils.assertProbability(p[i++], r.getPValue(), 8e-15, "p-value");
+ }
+ }
+
+ static Stream<Arguments> testFisherExactTest() {
+ // p-values are in the AlternativeHypothesis enum order: two-sided, greater, less
+ // scipy.stats.fisher_exact (version 1.9.3)
+ final Stream.Builder<Arguments> builder = Stream.builder();
+ // SciPy's examples
+ builder.add(Arguments.of(6, 2, 1, 4, 12,
+ new double[] {0.10256410256410256, 0.08624708624708625, 0.9953379953379954}));
+ builder.add(Arguments.of(8, 2, 1, 5, 20,
+ new double[] {0.034965034965034975, 0.024475524475524483, 0.9991258741258742}));
+ // Wikipedia example
+ builder.add(Arguments.of(1, 9, 11, 3, 0.030303030303030304,
+ new double[] {0.0027594561852200836, 0.9999663480953022, 0.0013797280926100418}));
+ // Larger tables
+ builder.add(Arguments.of(123, 92, 424, 313, 0.986951394585726,
+ new double[] {0.9376713563018861, 0.5652358165696375, 0.4969777824620944}));
+ builder.add(Arguments.of(123, 92, 424, 113, 0.3563115258408532,
+ new double[] {4.434744613652278e-09, 0.9999999990211362, 2.826445350300923e-09}));
+ builder.add(Arguments.of(67, 42, 23, 88, 6.10351966873706,
+ new double[] {8.732787245248246e-10, 5.169903676658518e-10, 0.9999999999188225}));
+ // Edge case tables
+ builder.add(Arguments.of(0, 0, 0, 1, Double.NaN,
+ new double[] {1, 1, 1}));
+ builder.add(Arguments.of(1, 0, 0, 1, Double.POSITIVE_INFINITY,
+ new double[] {1, 0.5, 1}));
+ builder.add(Arguments.of(0, 1, 1, 0, 0,
+ new double[] {1, 1, 0.5}));
+ return builder.build();
+ }
+}
diff --git a/src/conf/pmd/pmd-ruleset.xml b/src/conf/pmd/pmd-ruleset.xml
index 2a80c30..a66da52 100644
--- a/src/conf/pmd/pmd-ruleset.xml
+++ b/src/conf/pmd/pmd-ruleset.xml
@@ -202,7 +202,7 @@
<rule ref="category/java/errorprone.xml/TestClassWithoutTestCases">
<properties>
<property name="violationSuppressRegex"
- value=".*'BinomialTest'.*|.*'ChiSquareTest'.*|.*'GTest'.*|.*'KolmogorovSmirnovTest'.*|.*'MannWhitneyUTest'.*|.*'TTest'.*|.*'WilcoxonSignedRankTest'.*"/>
+ value=".*'BinomialTest'.*|.*'ChiSquareTest'.*|.*'FisherExactTest'.*|.*'GTest'.*|.*'KolmogorovSmirnovTest'.*|.*'MannWhitneyUTest'.*|.*'TTest'.*|.*'WilcoxonSignedRankTest'.*"/>
</properties>
</rule>