You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2021/11/22 13:11:12 UTC

[commons-statistics] 01/03: Add probability range implementation for discrete uniform distribution.

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git

commit 005a31719a607a850fc263a34cc696152efcb3c2
Author: aherbert <ah...@apache.org>
AuthorDate: Mon Nov 22 12:04:32 2021 +0000

    Add probability range implementation for discrete uniform distribution.
---
 .../distribution/UniformDiscreteDistribution.java  | 24 ++++++
 .../UniformDiscreteDistributionTest.java           | 86 ++++++++++++++++++++++
 .../distribution/test.uniformdiscrete.3.properties | 10 +++
 3 files changed, 120 insertions(+)

diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java
index 4841dd9..b57f54c 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java
@@ -81,6 +81,30 @@ public final class UniformDiscreteDistribution extends AbstractDiscreteDistribut
 
     /** {@inheritDoc} */
     @Override
+    public double probability(int x0,
+                              int x1) {
+        if (x0 > x1) {
+            throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
+        }
+        if (x0 >= upper || x1 < lower) {
+            // (x0, x1] does not overlap [lower, upper]
+            return 0;
+        }
+
+        // x0 < upper
+        // x1 >= lower
+
+        // Find the range between x0 (exclusive) and x1 (inclusive) within [lower, upper].
+        // In the case of x0 < lower set l so that u - l == (u - lower) + 1
+        // long arithmetic prevents overflow
+        final long l = Math.max(lower - 1L, x0);
+        final long u = Math.min(upper, x1);
+
+        return (u - l) / upperMinusLowerPlus1;
+    }
+
+    /** {@inheritDoc} */
+    @Override
     public double logProbability(int x) {
         if (x < lower || x > upper) {
             return Double.NEGATIVE_INFINITY;
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java
index 3f976c7..c8f6912 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java
@@ -169,4 +169,90 @@ class UniformDiscreteDistributionTest extends BaseDiscreteDistributionTest {
         final int[] x = MathArrays.sequence(upper - lower, lower, 1);
         testSurvivalProbabilityInverseMapping(dist, x);
     }
+
+    /**
+     * Test the probability in a range uses the exact computation of
+     * {@code (x1 - x0) / (upper - lower + 1)} assuming x0 and x1 are within [lower, upper].
+     * This test will fail if the distribution uses the default implementation in
+     * {@link AbstractDiscreteDistribution}.
+     */
+    @ParameterizedTest
+    @CsvSource(value = {
+        // Extreme bounds
+        "-2147483648, -2147483648",
+        "-2147483648, -2147483647",
+        "-2147483648, -2147483646",
+        "-2147483648, -2147483638",
+        "2147483647, 2147483647",
+        "2147483646, 2147483647",
+        "2147483645, 2147483647",
+        "2147483637, 2147483647",
+        // Range is a prime number
+        "-10, 2", // 13
+        "10, 16",  // 7
+        "-20, -10", // 11
+        // Range is even
+        "-10, 3", // 14
+        "10, 17",  // 8
+        "-20, -9", // 12
+        // Large range
+        "-2147483648, 2147483647",
+        "-2147483648, 1263781682",
+        "-2147483648, 1781682",
+        "-2147483648, -231781682",
+        "-1324234584, 2147483647",
+        "-324234584, 2147483647",
+        "6234584, 2147483647",
+        "-1256362376, 125637",
+        "-62378468, 1325657374",
+    })
+    void testProbabilityRange(int lower, int upper) {
+        final UniformDiscreteDistribution dist = UniformDiscreteDistribution.of(lower, upper);
+        final double r = (double) upper - lower + 1;
+        final long stride = r < 20 ? 1 : (long) (r / 20);
+        for (long x0 = lower; x0 <= upper; x0 += stride) {
+            for (long x1 = x0; x1 <= upper; x1 += stride) {
+                final double p = (x1 - x0) / r;
+                Assertions.assertEquals(p, dist.probability((int) x0, (int) x1));
+            }
+        }
+    }
+
+    @Test
+    void testProbabilityRangeEdgeCases() {
+        final UniformDiscreteDistribution dist = UniformDiscreteDistribution.of(3, 5);
+
+        Assertions.assertThrows(DistributionException.class, () -> dist.probability(4, 3));
+
+        // x0 >= upper
+        Assertions.assertEquals(0, dist.probability(5, 6));
+        Assertions.assertEquals(0, dist.probability(15, 16));
+        // x1 < lower
+        Assertions.assertEquals(0, dist.probability(-3, 1));
+
+        // x0 == x1
+        Assertions.assertEquals(0, dist.probability(3, 3));
+        Assertions.assertEquals(0, dist.probability(4, 4));
+        Assertions.assertEquals(0, dist.probability(5, 5));
+        Assertions.assertEquals(0, dist.probability(6, 6));
+
+        // x0+1 == x1
+        Assertions.assertEquals(1.0 / 3, dist.probability(3, 4));
+        Assertions.assertEquals(1.0 / 3, dist.probability(4, 5));
+
+        // x1 > upper
+        Assertions.assertEquals(1, dist.probability(2, 6));
+        Assertions.assertEquals(2.0 / 3, dist.probability(3, 6));
+        Assertions.assertEquals(1.0 / 3, dist.probability(4, 6));
+        Assertions.assertEquals(0, dist.probability(5, 6));
+
+        // x0 < lower
+        Assertions.assertEquals(0, dist.probability(-2, 2));
+        Assertions.assertEquals(1.0 / 3, dist.probability(-2, 3));
+        Assertions.assertEquals(2.0 / 3, dist.probability(-2, 4));
+        Assertions.assertEquals(1.0, dist.probability(-2, 5));
+
+        // x1 > upper && x0 < lower
+        Assertions.assertEquals(1, dist.probability(-2, 6));
+    }
 }
diff --git a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties
index 3e7a72d..73a2499 100644
--- a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties
+++ b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties
@@ -15,6 +15,16 @@
 
 # Big range that will overflow an integer
 parameters = -1234682638, 1825371824
+
+# Limited by probability(-123, 13).
+# The computation is:
+# (13 - -123) / (1825371824L + 1234682638 + 1) = 136 / 3060054463 = 4.4443653420033914E-8
+# The test expects this to be cdf(13) - cdf(-123):
+# 0.40348388139129665 - 0.4034838369476432 = 4.444365342415324E-8
+# The difference is: rel.error: <9.268649280418735E-11>, abs.error: <4.119326363289576E-18>
+# Configure the absolute error to allow this.
+tolerance.absolute = 5e-18
+
 # Computed using Wolfram Mathematica
 mean = 295344593
 variance = 780327776377184864