You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2021/11/22 13:11:12 UTC
[commons-statistics] 01/03: Add probability range implementation for discrete uniform distribution.
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
commit 005a31719a607a850fc263a34cc696152efcb3c2
Author: aherbert <ah...@apache.org>
AuthorDate: Mon Nov 22 12:04:32 2021 +0000
Add probability range implementation for discrete uniform distribution.
---
.../distribution/UniformDiscreteDistribution.java | 24 ++++++
.../UniformDiscreteDistributionTest.java | 86 ++++++++++++++++++++++
.../distribution/test.uniformdiscrete.3.properties | 10 +++
3 files changed, 120 insertions(+)
diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java
index 4841dd9..b57f54c 100644
--- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java
+++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/UniformDiscreteDistribution.java
@@ -81,6 +81,30 @@ public final class UniformDiscreteDistribution extends AbstractDiscreteDistribut
/** {@inheritDoc} */
@Override
+ public double probability(int x0,
+ int x1) {
+ if (x0 > x1) {
+ throw new DistributionException(DistributionException.INVALID_RANGE_LOW_GT_HIGH, x0, x1);
+ }
+ if (x0 >= upper || x1 < lower) {
+ // (x0, x1] does not overlap [lower, upper]
+ return 0;
+ }
+
+ // x0 < upper
+ // x1 >= lower
+
+ // Find the range between x0 (exclusive) and x1 (inclusive) within [lower, upper].
+ // In the case of x0 < lower set l so that u - l == (u - lower) + 1
+ // long arithmetic prevents overflow
+ final long l = Math.max(lower - 1L, x0);
+ final long u = Math.min(upper, x1);
+
+ return (u - l) / upperMinusLowerPlus1;
+ }
+
+ /** {@inheritDoc} */
+ @Override
public double logProbability(int x) {
if (x < lower || x > upper) {
return Double.NEGATIVE_INFINITY;
diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java
index 3f976c7..c8f6912 100644
--- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java
+++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/UniformDiscreteDistributionTest.java
@@ -169,4 +169,90 @@ class UniformDiscreteDistributionTest extends BaseDiscreteDistributionTest {
final int[] x = MathArrays.sequence(upper - lower, lower, 1);
testSurvivalProbabilityInverseMapping(dist, x);
}
+
+ /**
+ * Test the probability in a range uses the exact computation of
+ * {@code (x1 - x0) / (upper - lower + 1)} assuming x0 and x1 are within [lower, upper].
+ * This test will fail if the distribution uses the default implementation in
+ * {@link AbstractDiscreteDistribution}.
+ */
+ @ParameterizedTest
+ @CsvSource(value = {
+ // Extreme bounds
+ "-2147483648, -2147483648",
+ "-2147483648, -2147483647",
+ "-2147483648, -2147483646",
+ "-2147483648, -2147483638",
+ "2147483647, 2147483647",
+ "2147483646, 2147483647",
+ "2147483645, 2147483647",
+ "2147483637, 2147483647",
+ // Range is a prime number
+ "-10, 2", // 13
+ "10, 16", // 7
+ "-20, -10", // 11
+ // Range is even
+ "-10, 3", // 14
+ "10, 17", // 8
+ "-20, -9", // 12
+ // Large range
+ "-2147483648, 2147483647",
+ "-2147483648, 1263781682",
+ "-2147483648, 1781682",
+ "-2147483648, -231781682",
+ "-1324234584, 2147483647",
+ "-324234584, 2147483647",
+ "6234584, 2147483647",
+ "-1256362376, 125637",
+ "-62378468, 1325657374",
+ })
+ void testProbabilityRange(int lower, int upper) {
+ final UniformDiscreteDistribution dist = UniformDiscreteDistribution.of(lower, upper);
+ final double r = (double) upper - lower + 1;
+ final long stride = r < 20 ? 1 : (long) (r / 20);
+ for (long x0 = lower; x0 <= upper; x0 += stride) {
+ for (long x1 = x0; x1 <= upper; x1 += stride) {
+ final double p = (x1 - x0) / r;
+ Assertions.assertEquals(p, dist.probability((int) x0, (int) x1));
+ }
+ }
+ }
+
+ @Test
+ void testProbabilityRangeEdgeCases() {
+ final UniformDiscreteDistribution dist = UniformDiscreteDistribution.of(3, 5);
+
+ Assertions.assertThrows(DistributionException.class, () -> dist.probability(4, 3));
+
+ // x0 >= upper
+ Assertions.assertEquals(0, dist.probability(5, 6));
+ Assertions.assertEquals(0, dist.probability(15, 16));
+ // x1 < lower
+ Assertions.assertEquals(0, dist.probability(-3, 1));
+
+ // x0 == x1
+ Assertions.assertEquals(0, dist.probability(3, 3));
+ Assertions.assertEquals(0, dist.probability(4, 4));
+ Assertions.assertEquals(0, dist.probability(5, 5));
+ Assertions.assertEquals(0, dist.probability(6, 6));
+
+ // x0+1 == x1
+ Assertions.assertEquals(1.0 / 3, dist.probability(3, 4));
+ Assertions.assertEquals(1.0 / 3, dist.probability(4, 5));
+
+ // x1 > upper
+ Assertions.assertEquals(1, dist.probability(2, 6));
+ Assertions.assertEquals(2.0 / 3, dist.probability(3, 6));
+ Assertions.assertEquals(1.0 / 3, dist.probability(4, 6));
+ Assertions.assertEquals(0, dist.probability(5, 6));
+
+ // x0 < lower
+ Assertions.assertEquals(0, dist.probability(-2, 2));
+ Assertions.assertEquals(1.0 / 3, dist.probability(-2, 3));
+ Assertions.assertEquals(2.0 / 3, dist.probability(-2, 4));
+ Assertions.assertEquals(1.0, dist.probability(-2, 5));
+
+ // x1 > upper && x0 < lower
+ Assertions.assertEquals(1, dist.probability(-2, 6));
+ }
}
diff --git a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties
index 3e7a72d..73a2499 100644
--- a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties
+++ b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.uniformdiscrete.3.properties
@@ -15,6 +15,16 @@
# Big range that will overflow an integer
parameters = -1234682638, 1825371824
+
+# Limited by probability(-123, 13).
+# The computation is:
+# (13 - -123) / (1825371824L + 1234682638 + 1) = 136 / 3060054463 = 4.4443653420033914E-8
+# The test expects this to be cdf(13) - cdf(-123):
+# 0.40348388139129665 - 0.4034838369476432 = 4.444365342415324E-8
+# The difference is: rel.error: <9.268649280418735E-11>, abs.error: <4.119326363289576E-18>
+# Configure the absolute error to allow this.
+tolerance.absolute = 5e-18
+
# Computed using Wolfram Mathematica
mean = 295344593
variance = 780327776377184864