You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2022/01/20 16:27:09 UTC
[commons-numbers] 02/03: NUMBERS-183: Improve binomial coefficient classes
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-numbers.git
commit d22325416794452dcf19d998374b4dd1519523bf
Author: aherbert <ah...@apache.org>
AuthorDate: Thu Jan 20 14:58:26 2022 +0000
NUMBERS-183: Improve binomial coefficient classes
Refactor the tests to use JUnit 5 parameterized test features.
Refactor the test for all binomial coefficients up to n=200 to compute
the result using n! / (n-k)! / k!. Removes the caching test
implementation which only works for results within a long datatype. This
has been updated for BinomialCoefficientDouble and
LogBinomialCoefficient to test the result is correct even when a long
overflows.
Update the implementations:
- Avoid recursive method call if k > n/2.
- Add early exit if the result cannot fit into the output datatype. This
is k >= 34 for long, and k >= 515 for a double.
- BinomialCoefficientDouble: Use the precomputed factorials to compute
n! / k! / (n-k)! if possible.
- BinomialCoefficientDouble: avoid overflow by checking the intermediate
result
- LogBinomialCoefficient: Compute the terms using the LogBeta class
---
.../numbers/combinatorics/BinomialCoefficient.java | 84 +++++---
.../combinatorics/BinomialCoefficientDouble.java | 93 +++++++--
.../commons/numbers/combinatorics/Factorial.java | 13 ++
.../combinatorics/LogBinomialCoefficient.java | 64 +++---
.../BinomialCoefficientDoubleTest.java | 189 ++++++++++-------
.../combinatorics/BinomialCoefficientTest.java | 229 +++++++--------------
.../combinatorics/LogBinomialCoefficientTest.java | 180 ++++++++++------
7 files changed, 471 insertions(+), 381 deletions(-)
diff --git a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficient.java b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficient.java
index d19f592..807b427 100644
--- a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficient.java
+++ b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficient.java
@@ -26,6 +26,15 @@ import org.apache.commons.numbers.core.ArithmeticUtils;
* can be selected from an {@code n}-element set.
*/
public final class BinomialCoefficient {
+ /** The maximum m that can be computed without overflow of a long.
+ * C(68, 34) > 2^63. */
+ private static final int MAX_M = 33;
+ /** The maximum n that can be computed without intermediate overflow for any m.
+ * C(61, 30) * 30 < 2^63. */
+ private static final int SMALL_N = 61;
+ /** The maximum n that can be computed without overflow of a long for any m.
+ * C(66, 33) < 2^63. */
+ private static final int LIMIT_N = 66;
/** Private constructor. */
private BinomialCoefficient() {
@@ -34,51 +43,49 @@ public final class BinomialCoefficient {
/**
* Computes the binomial coefficient.
- * The largest value of {@code n} for which all coefficients can
- * fit into a {@code long} is 66.
+ *
+ * <p>The largest value of {@code n} for which <em>all</em> coefficients can
+ * fit into a {@code long} is 66. Larger {@code n} may result in an
+ * {@link ArithmeticException} depending on the value of {@code k}.
+ *
+ * <p>Any {@code min(k, n - k) >= 34} cannot fit into a {@code long}
+ * and will result in an {@link ArithmeticException}.
*
* @param n Size of the set.
* @param k Size of the subsets to be counted.
* @return {@code n choose k}.
- * @throws IllegalArgumentException if {@code n < 0}.
- * @throws IllegalArgumentException if {@code k > n}.
+ * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
* @throws ArithmeticException if the result is too large to be
* represented by a {@code long}.
*/
public static long value(int n, int k) {
- checkBinomial(n, k);
+ final int m = checkBinomial(n, k);
- if (n == k ||
- k == 0) {
+ if (m == 0) {
return 1;
}
- if (k == 1 ||
- k == n - 1) {
+ if (m == 1) {
return n;
}
- // Use symmetry for large k.
- if (k > n / 2) {
- return value(n, n - k);
- }
// We use the formulae:
- // (n choose k) = n! / (n-k)! / k!
- // (n choose k) = ((n-k+1)*...*n) / (1*...*k)
+ // (n choose m) = n! / (n-m)! / m!
+ // (n choose m) = ((n-m+1)*...*n) / (1*...*m)
// which can be written
- // (n choose k) = (n-1 choose k-1) * n / k
+ // (n choose m) = (n-1 choose m-1) * n / m
long result = 1;
- if (n <= 61) {
+ if (n <= SMALL_N) {
// For n <= 61, the naive implementation cannot overflow.
- int i = n - k + 1;
- for (int j = 1; j <= k; j++) {
+ int i = n - m + 1;
+ for (int j = 1; j <= m; j++) {
result = result * i / j;
i++;
}
- } else if (n <= 66) {
+ } else if (n <= LIMIT_N) {
// For n > 61 but n <= 66, the result cannot overflow,
// but we must take care not to overflow intermediate values.
- int i = n - k + 1;
- for (int j = 1; j <= k; j++) {
+ int i = n - m + 1;
+ for (int j = 1; j <= m; j++) {
// We know that (result * i) is divisible by j,
// but (result * i) may overflow, so we split j:
// Filter out the gcd, d, so j/d and i/d are integer.
@@ -90,11 +97,15 @@ public final class BinomialCoefficient {
++i;
}
} else {
+ if (m > MAX_M) {
+ throw new ArithmeticException(n + " choose " + k);
+ }
+
// For n > 66, a result overflow might occur, so we check
// the multiplication, taking care to not overflow
// unnecessary.
- int i = n - k + 1;
- for (int j = 1; j <= k; j++) {
+ int i = n - m + 1;
+ for (int j = 1; j <= m; j++) {
final long d = ArithmeticUtils.gcd(i, j);
result = Math.multiplyExact(result / (j / d), i / d);
++i;
@@ -107,19 +118,30 @@ public final class BinomialCoefficient {
/**
* Check binomial preconditions.
*
+ * <p>For convenience in implementations this returns the smaller of
+ * {@code k} or {@code n - k} allowing symmetry to be exploited in
+ * computing the binomial coefficient.
+ *
* @param n Size of the set.
* @param k Size of the subsets to be counted.
+ * @return min(k, n - k)
* @throws IllegalArgumentException if {@code n < 0}.
* @throws IllegalArgumentException if {@code k > n} or {@code k < 0}.
*/
- static void checkBinomial(int n,
- int k) {
- if (n < 0) {
- throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n);
- }
- if (k > n ||
- k < 0) {
+ static int checkBinomial(int n,
+ int k) {
+ // Combine all checks with a single branch:
+ // 0 <= n; 0 <= k <= n
+ // Note: If n >= 0 && k >= 0 && n - k < 0 then k > n.
+ final int m = n - k;
+ // Bitwise or will detect a negative sign bit in any of the numbers
+ if ((n | k | m) < 0) {
+ // Raise the correct exception
+ if (n < 0) {
+ throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n);
+ }
throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n);
}
+ return m < k ? m : k;
}
}
diff --git a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDouble.java b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDouble.java
index bd95280..bdcc059 100644
--- a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDouble.java
+++ b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDouble.java
@@ -24,6 +24,20 @@ package org.apache.commons.numbers.combinatorics;
* can be selected from an {@code n}-element set.
*/
public final class BinomialCoefficientDouble {
+ /** The maximum factorial that can be represented as a double. */
+ private static final int MAX_FACTORIAL = 170;
+ /** The maximum n that can be computed without overflow of a long for any m.
+ * C(66, 33) < 2^63. */
+ private static final int LIMIT_N_LONG = 66;
+ /** The maximum m that can be computed without overflow of a double.
+ * C(1030, 515) ~ 2.85e308. */
+ private static final int MAX_M = 514;
+ /** The maximum n that can be computed without intermediate overflow for any m.
+ * C(1020, 510) * 510 ~ 1.43e308. */
+ private static final int SMALL_N = 1020;
+ /** The maximum m that can be computed without intermediate overflow for any n.
+ * C(2147483647, 37) * 37 ~ 5.13e303. */
+ private static final int SMALL_M = 37;
/** Private constructor. */
private BinomialCoefficientDouble() {
@@ -32,37 +46,78 @@ public final class BinomialCoefficientDouble {
/**
* Computes the binomial coefficient.
- * The largest value of {@code n} for which all coefficients can
- * fit into a {@code long} is 66.
+ *
+ * <p>The largest value of {@code n} for which <em>all</em> coefficients can
+ * fit into a {@code double} is 1029. Larger {@code n} may result in
+ * infinity depending on the value of {@code k}.
+ *
+ * <p>Any {@code min(k, n - k) >= 515} cannot fit into a {@code double}
+ * and will result in infinity.
*
* @param n Size of the set.
* @param k Size of the subsets to be counted.
* @return {@code n choose k}.
- * @throws IllegalArgumentException if {@code n < 0}.
- * @throws IllegalArgumentException if {@code k > n}.
+ * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
*/
public static double value(int n, int k) {
- BinomialCoefficient.checkBinomial(n, k);
+ if (n <= LIMIT_N_LONG) {
+ // Delegate to the exact long result
+ return BinomialCoefficient.value(n, k);
+ }
+ final int m = BinomialCoefficient.checkBinomial(n, k);
- if (n == k ||
- k == 0) {
+ if (m == 0) {
return 1;
}
- if (k == 1 ||
- k == n - 1) {
+ if (m == 1) {
return n;
}
- if (k > n / 2) {
- return value(n, n - k);
- }
- if (n < 67) {
- return BinomialCoefficient.value(n, k);
- }
- double result = 1;
- for (int i = 1; i <= k; i++) {
- result *= n - k + i;
- result /= i;
+ double result;
+ if (n <= MAX_FACTORIAL) {
+ // Small factorials are tabulated exactly
+ // n! / m! / (n-m)!
+ result = Factorial.uncheckedFactorial(n) /
+ Factorial.uncheckedFactorial(m) /
+ Factorial.uncheckedFactorial(n - m);
+ } else {
+ // Compute recursively using:
+ // (n choose m) = (n-1 choose m-1) * n / m
+
+ if (n <= SMALL_N || m <= SMALL_M) {
+ // No overflow possible
+ result = 1;
+ for (int i = 1; i <= m; i++) {
+ result *= n - m + i;
+ result /= i;
+ }
+ } else {
+ if (m > MAX_M) {
+ return Double.POSITIVE_INFINITY;
+ }
+
+ // Compute the initial part without overflow checks
+ result = 1;
+ for (int i = 1; i <= SMALL_M; i++) {
+ result *= n - m + i;
+ result /= i;
+ }
+ // Careful of overflow
+ for (int i = SMALL_M + 1; i <= m; i++) {
+ final double next = result * (n - m + i);
+ if (next > Double.MAX_VALUE) {
+ // Reverse order of terms
+ result /= i;
+ result *= n - m + i;
+ if (result > Double.MAX_VALUE) {
+ // Definite overflow
+ return Double.POSITIVE_INFINITY;
+ }
+ } else {
+ result = next / i;
+ }
+ }
+ }
}
return Math.floor(result + 0.5);
diff --git a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Factorial.java b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Factorial.java
index 7a47070..9d16bfd 100644
--- a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Factorial.java
+++ b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Factorial.java
@@ -254,4 +254,17 @@ public final class Factorial {
}
return Double.POSITIVE_INFINITY;
}
+
+ /**
+ * Return the factorial of {@code n}.
+ *
+ * <p>Note: This is an internal method that exposes the tabulated factorials that can
+ * be represented as a double. No checks are performed on the argument.
+ *
+ * @param n Argument (must be in [0, 170])
+ * @return n!
+ */
+ static double uncheckedFactorial(int n) {
+ return DOUBLE_FACTORIALS[n];
+ }
}
diff --git a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficient.java b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficient.java
index dad945c..e398f4e 100644
--- a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficient.java
+++ b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficient.java
@@ -17,6 +17,8 @@
package org.apache.commons.numbers.combinatorics;
+import org.apache.commons.numbers.gamma.LogBeta;
+
/**
* Natural logarithm of the <a href="http://mathworld.wolfram.com/BinomialCoefficient.html">
* binomial coefficient</a>.
@@ -24,6 +26,15 @@ package org.apache.commons.numbers.combinatorics;
* can be selected from an {@code n}-element set.
*/
public final class LogBinomialCoefficient {
+ /** The maximum n that can be computed without overflow of a long for any m.
+ * C(66, 33) < 2^63. */
+ private static final int LIMIT_N_LONG = 66;
+ /** The maximum n that can be computed without overflow of a double for an m.
+ * C(1029, 514) ~ 1.43e308. */
+ private static final int LIMIT_N_DOUBLE = 1029;
+ /** The maximum m that can be computed without overflow of a double for any n.
+ * C(2147483647, 37) ~ 1.39e302. */
+ private static final int LIMIT_M_DOUBLE = 37;
/** Private constructor. */
private LogBinomialCoefficient() {
@@ -32,56 +43,43 @@ public final class LogBinomialCoefficient {
/**
* Computes the logarithm of the binomial coefficient.
- * The largest value of {@code n} for which all coefficients can
- * fit into a {@code long} is 66.
+ *
+ * <p>This returns a finite result for any valid {@code n choose k}.
*
* @param n Size of the set.
* @param k Size of the subsets to be counted.
* @return {@code log(n choose k)}.
- * @throws IllegalArgumentException if {@code n < 0}.
- * @throws IllegalArgumentException if {@code k > n}.
+ * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
*/
public static double value(int n, int k) {
- BinomialCoefficient.checkBinomial(n, k);
+ final int m = BinomialCoefficient.checkBinomial(n, k);
- if (n == k ||
- k == 0) {
+ if (m == 0) {
return 0;
}
- if (k == 1 ||
- k == n - 1) {
+ if (m == 1) {
return Math.log(n);
}
- // For values small enough to do exact integer computation,
- // return the log of the exact value.
- if (n < 67) {
+ if (n <= LIMIT_N_LONG) {
+ // Delegate to the exact long result
return Math.log(BinomialCoefficient.value(n, k));
}
-
- // Logarithm of "BinomialCoefficientDouble" for values that
- // will not overflow.
- if (n < 1030) {
+ if (n <= LIMIT_N_DOUBLE || m <= LIMIT_M_DOUBLE) {
+ // Delegate to the double result
return Math.log(BinomialCoefficientDouble.value(n, k));
}
- if (k > n / 2) {
- return value(n, n - k);
- }
-
- // Sum for values that could overflow.
- double logSum = 0;
-
- // n! / (n - k)!
- for (int i = n - k + 1; i <= n; i++) {
- logSum += Math.log(i);
- }
-
- // Divide by k!
- for (int i = 2; i <= k; i++) {
- logSum -= Math.log(i);
- }
+ // n! gamma(n+1) gamma(k+1) * gamma(n-k+1)
+ // --------- = ------------------------- = 1 / -------------------------
+ // k! (n-k)! gamma(k+1) * gamma(n-k+1) gamma(n+1)
+ //
+ //
+ // = 1 / (k * beta(k, n-k+1))
+ //
+ // where: beta(a, b) = gamma(a) * gamma(b) / gamma(a+b)
- return logSum;
+ // Delegate to LogBeta
+ return -Math.log(m) - LogBeta.value(m, n - m + 1);
}
}
diff --git a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDoubleTest.java b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDoubleTest.java
index 367aa50..c9432d2 100644
--- a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDoubleTest.java
+++ b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientDoubleTest.java
@@ -16,105 +16,134 @@
*/
package org.apache.commons.numbers.combinatorics;
+import java.math.BigInteger;
+import org.apache.commons.numbers.core.Precision;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
/**
- * Test cases for the {@link BinomialCoefficient} class.
+ * Test cases for the {@link BinomialCoefficientDouble} class.
*/
class BinomialCoefficientDoubleTest {
- /** Verify that b(0,0) = 1 */
- @Test
- void test0Choose0() {
- Assertions.assertEquals(1d, BinomialCoefficientDouble.value(0, 0));
- }
-
- @Test
- void testBinomialCoefficient() {
- final long[] bcoef5 = {1, 5, 10, 10, 5, 1};
- final long[] bcoef6 = {1, 6, 15, 20, 15, 6, 1};
-
- for (int n = 1; n < 10; n++) {
- for (int k = 0; k <= n; k++) {
- Assertions.assertEquals(
- BinomialCoefficientTest.binomialCoefficient(n, k),
- BinomialCoefficientDouble.value(n, k),
- Double.MIN_VALUE,
- n + " choose " + k
- );
- }
- }
-
- final int[] n = {34, 66, 100, 1500, 1500};
- final int[] k = {17, 33, 10, 1500 - 4, 4};
- for (int i = 0; i < n.length; i++) {
- final long expected = BinomialCoefficientTest.binomialCoefficient(n[i], k[i]);
- Assertions.assertEquals(
- expected,
- BinomialCoefficientDouble.value(n[i], k[i]),
- 0.0,
- n[i] + " choose " + k[i]
- );
- }
+ @ParameterizedTest
+ @CsvSource({
+ "4, 5",
+ "-1, 1",
+ "10, -1",
+ "-1, -1",
+ "-1, -2",
+ })
+ void testBinomialCoefficientIllegalArguments(int n, int k) {
+ Assertions.assertThrows(CombinatoricsException.class, () -> BinomialCoefficientDouble.value(n, k),
+ () -> n + " choose " + k);
}
- @Test
- void testBinomialCoefficientFail1() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficientDouble.value(4, 5)
- );
+ @ParameterizedTest
+ @CsvSource({
+ // Data verified using maxima: bfloat(binomial(n, k)) using 30 digits of precision.
+ // Note: This test will correctly assert infinite expected values.
+ "0, 0, 1, 0",
+ "5, 0, 1, 0",
+ "5, 1, 5, 0",
+ "5, 2, 10, 0",
+ "6, 0, 1, 0",
+ "6, 1, 6, 0",
+ "6, 2, 15, 0",
+ "6, 3, 20, 0",
+ "34, 17, 2333606220, 0",
+ "66, 33, 7219428434016265740, 0",
+ "100, 10, 17310309456440, 0",
+ "1500, 4, 210094780875, 0",
+ "300, 3, 4455100, 0",
+ "700, 697, 56921900, 0",
+ "10000, 3, 166616670000, 0",
+ "412, 9, 863177604710689620, 0",
+ "678, 7, 12667255449994080, 0",
+ "66, 33, 7219428434016265740, 0",
+ // Overflow as a long
+ "67, 30, 9989690752182277136, 1",
+ "67, 33, 14226520737620288370, 0",
+ "68, 34, 28453041475240576740, 0",
+ // Overflow without special handling
+ // See NUMBERS-183
+ "1040, 450, 2.3101613255412135615e307, 11",
+ "1029, 514, 1.4298206864989040819e308, 5",
+ "1786388282, 38, 7.187239013254065384599502085053593e306, 0",
+ "1914878305, 38, 100.6570419073661447979173868523364e306, 1",
+ "1179067476, 39, 30.22890249420109200962786203300876e306, 2",
+ "2147483647, 37, 1.388890512412231479281222156415993e302, 4",
+ "20000, 116, 1.75293130532995289393810309132e308, 8",
+ "20000, 117, 2.97908427992998148231326853571e310, 0",
+ "1028, 514, 7.156051054877897008430135897e307, 8",
+ "1030, 496, 1.41941785031194251722295917039e308, 0",
+ "1030, 497, 1.52508879691464246317315935007e308, 32",
+ "1030, 498, 1.63227375252109323869737737668e308, 0",
+ "1030, 499, 1.74021971210665651901203359598e308, 12",
+ "1030, 500, 1.84811333425726922319077967894e308, 0",
+ "1020, 510, 2.80626776829962271039414307883e305, 8",
+ "1022, 511, 1.12140876377061244121816833013e306, 14",
+ "1024, 512, 4.48125455209897081002416485048e306, 3",
+ })
+ void testBinomialCoefficient(int n, int k, double nCk, int ulp) {
+ assertBinomial(n, k, nCk, ulp);
}
- @Test
- void testBinomialCoefficientFail2() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficientDouble.value(-1, -2)
- );
- }
-
- @Test
- void testBinomialCoefficientFail3() {
- final double x = BinomialCoefficientDouble.value(1030, 515);
- Assertions.assertTrue(Double.isInfinite(x), "expecting infinite binomial coefficient");
+ @ParameterizedTest
+ @CsvSource({
+ "1030, 515",
+ "10000000, 10000",
+ })
+ void testBinomialCoefficientOverflow(int n, int k) {
+ Assertions.assertEquals(Double.POSITIVE_INFINITY, BinomialCoefficientDouble.value(n, k),
+ () -> n + " choose " + k);
}
/**
- * Tests correctness for large n and sharpness of upper bound in API doc
+ * Tests correctness for large n and sharpness of upper bound in API doc.
* JIRA: MATH-241
*/
@Test
- void testBinomialCoefficientLarge() throws Exception {
- // This tests all legal and illegal values for n <= 200.
- for (int n = 0; n <= 200; n++) {
- for (int k = 0; k <= n; k++) {
- long exactResult = -1;
- boolean shouldThrow = false;
- boolean didThrow = false;
- try {
- BinomialCoefficient.value(n, k);
- } catch (ArithmeticException ex) {
- didThrow = true;
- }
- try {
- exactResult = BinomialCoefficientTest.binomialCoefficient(n, k);
- } catch (ArithmeticException ex) {
- shouldThrow = true;
- }
+ void testBinomialCoefficientLarge() {
+ // This tests all values for n <= 200.
+ final int size = 200;
+ final BigInteger[] factorials = new BigInteger[size + 1];
+ factorials[0] = BigInteger.ONE;
+ for (int n = 1; n <= size; n++) {
+ factorials[n] = factorials[n - 1].multiply(BigInteger.valueOf(n));
+ }
- if (!shouldThrow && exactResult > 1) {
- Assertions.assertEquals(
- 1.,
- BinomialCoefficientDouble.value(n, k) / exactResult,
- 1e-10,
- n + " choose " + k
- );
- }
+ for (int n = 0; n <= size; n++) {
+ int ulp;
+ if (n <= 66) {
+ ulp = 0;
+ } else if (n <= 100) {
+ ulp = 5;
+ } else if (n <= 150) {
+ ulp = 10;
+ } else {
+ ulp = 15;
+ }
+ for (int k = 0; k <= n / 2; k++) {
+ final BigInteger nCk = factorials[n].divide(factorials[n - k]).divide(factorials[k]);
+ final double expected = nCk.doubleValue();
+ assertBinomial(n, k, expected, ulp);
}
}
+ }
- final int n = 10000;
- final double actualOverExpected = BinomialCoefficientDouble.value(n, 3) /
- BinomialCoefficientTest.binomialCoefficient(n, 3);
- Assertions.assertEquals(1, actualOverExpected, 1e-10);
+ private static void assertBinomial(int n, int k, double expected, int ulp) {
+ final double actual = BinomialCoefficientDouble.value(n, k);
+ if (expected == Double.POSITIVE_INFINITY) {
+ Assertions.assertEquals(expected, actual, () -> n + " choose " + k);
+ } else {
+ Assertions.assertTrue(Precision.equals(expected, actual, ulp),
+ () -> String.format("C(%d, %d) = %s : actual %s : ULP error = %d", n, k,
+ expected, actual,
+ Double.doubleToRawLongBits(actual) - Double.doubleToRawLongBits(expected)));
+ }
+ // Test symmetry
+ Assertions.assertEquals(actual, BinomialCoefficientDouble.value(n, n - k), () -> n + " choose " + k);
}
}
diff --git a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientTest.java b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientTest.java
index a843fcb..e68d81f 100644
--- a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientTest.java
+++ b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/BinomialCoefficientTest.java
@@ -16,186 +16,103 @@
*/
package org.apache.commons.numbers.combinatorics;
-import java.util.List;
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.HashMap;
+import java.math.BigInteger;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
/**
* Test cases for the {@link BinomialCoefficient} class.
*/
class BinomialCoefficientTest {
- /** Cached binomial coefficients. */
- private static final List<Map<Integer, Long>> binomialCache = new ArrayList<>();
-
- /** Verify that b(0,0) = 1 */
- @Test
- void test0Choose0() {
- Assertions.assertEquals(1, BinomialCoefficient.value(0, 0));
- }
-
- @Test
- void testBinomialCoefficient() {
- final long[] bcoef5 = {1, 5, 10, 10, 5, 1};
- final long[] bcoef6 = {1, 6, 15, 20, 15, 6, 1};
-
- for (int i = 0; i < 6; i++) {
- Assertions.assertEquals(bcoef5[i], BinomialCoefficient.value(5, i), "5 choose " + i);
- }
- for (int i = 0; i < 7; i++) {
- Assertions.assertEquals(bcoef6[i], BinomialCoefficient.value(6, i), "6 choose " + i);
- }
-
- for (int n = 1; n < 10; n++) {
- for (int k = 0; k <= n; k++) {
- Assertions.assertEquals(
- binomialCoefficient(n, k),
- BinomialCoefficient.value(n, k),
- n + " choose " + k
- );
- }
- }
-
- final int[] n = {34, 66, 100, 1500, 1500};
- final int[] k = {17, 33, 10, 1500 - 4, 4};
- for (int i = 0; i < n.length; i++) {
- final long expected = binomialCoefficient(n[i], k[i]);
- Assertions.assertEquals(
- expected,
- BinomialCoefficient.value(n[i], k[i]),
- n[i] + " choose " + k[i]
- );
- }
- }
-
- @Test
- void testBinomialCoefficientKLargerThanN() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficient.value(4, 5)
- );
- }
-
- @Test
- void testBinomialCoefficientNegativeN() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficient.value(-1, 1)
- );
+ @ParameterizedTest
+ @CsvSource({
+ "4, 5",
+ "-1, 1",
+ "10, -1",
+ "-1, -1",
+ "-1, -2",
+ })
+ void testBinomialCoefficientIllegalArguments(int n, int k) {
+ Assertions.assertThrows(CombinatoricsException.class, () -> BinomialCoefficient.value(n, k),
+ () -> n + " choose " + k);
}
- @Test
- void testBinomialCoefficientNegativeK() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficient.value(10, -1)
- );
+ @ParameterizedTest
+ @CsvSource({
+ // Data verified using maxima: binomial(n, k)
+ "0, 0, 1",
+ "5, 0, 1",
+ "5, 1, 5",
+ "5, 2, 10",
+ "6, 0, 1",
+ "6, 1, 6",
+ "6, 2, 15",
+ "6, 3, 20",
+ "34, 17, 2333606220",
+ "66, 33, 7219428434016265740",
+ "100, 10, 17310309456440",
+ "1500, 4, 210094780875",
+ "300, 3, 4455100",
+ "700, 697, 56921900",
+ "10000, 3, 166616670000",
+ "412, 9, 863177604710689620",
+ "678, 7, 12667255449994080",
+ "66, 33, 7219428434016265740",
+ })
+ void testBinomialCoefficient(int n, int k, long nCk) {
+ Assertions.assertEquals(nCk, BinomialCoefficient.value(n, k), () -> n + " choose " + k);
+ final int m = n - k;
+ Assertions.assertEquals(nCk, BinomialCoefficient.value(n, m), () -> n + " choose " + m);
}
- @Test
- void testBinomialCoefficientNAbove66ResultOverflow() {
- Assertions.assertThrows(ArithmeticException.class,
- () -> BinomialCoefficient.value(67, 30)
- );
+ @ParameterizedTest
+ @CsvSource({
+ "67, 30",
+ "67, 33",
+ "68, 34",
+ })
+ void testBinomialCoefficientOverflow(int n, int k) {
+ Assertions.assertThrows(ArithmeticException.class, () -> BinomialCoefficient.value(n, k),
+ () -> n + " choose " + k);
}
/**
- * Tests correctness for large n and sharpness of upper bound in API doc
+ * Tests correctness for large n and sharpness of upper bound in API doc.
* JIRA: MATH-241
*/
@Test
- void testBinomialCoefficientLarge() throws Exception {
+ void testBinomialCoefficientLarge() {
// This tests all legal and illegal values for n <= 200.
- for (int n = 0; n <= 200; n++) {
- for (int k = 0; k <= n; k++) {
- long ourResult = -1;
- long exactResult = -1;
- boolean shouldThrow = false;
- boolean didThrow = false;
+ final int size = 200;
+ final BigInteger[] factorials = new BigInteger[size + 1];
+ factorials[0] = BigInteger.ONE;
+ for (int n = 1; n <= size; n++) {
+ factorials[n] = factorials[n - 1].multiply(BigInteger.valueOf(n));
+ }
+
+ for (int i = 0; i <= size; i++) {
+ final int n = i;
+ for (int j = 0; j <= n; j++) {
+ final int k = j;
+ final BigInteger nCk = factorials[n].divide(factorials[n - k]).divide(factorials[k]);
+ // Exceptions are ignored. If both throw then the results will match as -1.
+ long actual = -1;
+ long expected = -1;
try {
- ourResult = BinomialCoefficient.value(n, k);
- } catch (ArithmeticException ex) {
- didThrow = true;
+ actual = BinomialCoefficient.value(n, k);
+ } catch (final ArithmeticException ex) {
+ // Ignore
}
try {
- exactResult = binomialCoefficient(n, k);
- } catch (ArithmeticException ex) {
- shouldThrow = true;
+ expected = nCk.longValueExact();
+ } catch (final ArithmeticException ex) {
+ // Ignore
}
- Assertions.assertEquals(exactResult, ourResult, n + " choose " + k);
- Assertions.assertEquals(shouldThrow, didThrow, n + " choose " + k);
- Assertions.assertTrue(n > 66 || !didThrow, n + " choose " + k);
+ Assertions.assertEquals(expected, actual, () -> n + " choose " + k);
}
}
-
- long ourResult = BinomialCoefficient.value(300, 3);
- long exactResult = binomialCoefficient(300, 3);
- Assertions.assertEquals(exactResult, ourResult);
-
- ourResult = BinomialCoefficient.value(700, 697);
- exactResult = binomialCoefficient(700, 697);
- Assertions.assertEquals(exactResult, ourResult);
-
- final int n = 10000;
- ourResult = BinomialCoefficient.value(n, 3);
- exactResult = binomialCoefficient(n, 3);
- Assertions.assertEquals(exactResult, ourResult);
-
- }
-
- @Test
- void checkNLessThanOne() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficient.checkBinomial(-1, -2)
- );
- }
-
- @Test
- void checkKGreaterThanN() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> BinomialCoefficient.checkBinomial(4, 5)
- );
- }
-
- @Test
- void testCheckBinomial3() {
- // OK (no exception thrown)
- BinomialCoefficient.checkBinomial(5, 4);
- }
-
- /**
- * Exact (caching) recursive implementation to test against.
- */
- static long binomialCoefficient(int n, int k) {
- if (binomialCache.size() > n) {
- final Long cachedResult = binomialCache.get(n).get(Integer.valueOf(k));
- if (cachedResult != null) {
- return cachedResult.longValue();
- }
- }
- long result = -1;
- if ((n == k) || (k == 0)) {
- result = 1;
- } else if ((k == 1) || (k == n - 1)) {
- result = n;
- } else {
- // Reduce stack depth for larger values of n.
- if (k < n - 100) {
- binomialCoefficient(n - 100, k);
- }
- if (k > 100) {
- binomialCoefficient(n - 100, k - 100);
- }
- result = Math.addExact(binomialCoefficient(n - 1, k - 1),
- binomialCoefficient(n - 1, k));
- }
- if (result == -1) {
- throw new IllegalArgumentException();
- }
- for (int i = binomialCache.size(); i < n + 1; i++) {
- binomialCache.add(new HashMap<Integer, Long>());
- }
- binomialCache.get(n).put(Integer.valueOf(k), Long.valueOf(result));
- return result;
}
}
diff --git a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficientTest.java b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficientTest.java
index 88f05a1..e15909a 100644
--- a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficientTest.java
+++ b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/LogBinomialCoefficientTest.java
@@ -16,90 +16,146 @@
*/
package org.apache.commons.numbers.combinatorics;
+import java.math.BigInteger;
+import org.apache.commons.numbers.core.Precision;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
/**
* Test cases for the {@link LogBinomialCoefficient} class.
*/
class LogBinomialCoefficientTest {
- /** Verify that b(0,0) = 1 */
- @Test
- void test0Choose0() {
- Assertions.assertEquals(0d, LogBinomialCoefficient.value(0, 0));
+ @ParameterizedTest
+ @CsvSource({
+ "4, 5",
+ "-1, 1",
+ "10, -1",
+ "-1, -1",
+ "-1, -2",
+ })
+ void testBinomialCoefficientIllegalArguments(int n, int k) {
+ Assertions.assertThrows(CombinatoricsException.class, () -> LogBinomialCoefficient.value(n, k),
+ () -> n + " choose " + k);
}
- @Test
- void testBinomialCoefficient() {
- final long[] bcoef5 = {1, 5, 10, 10, 5, 1};
- final long[] bcoef6 = {1, 6, 15, 20, 15, 6, 1};
-
- for (int n = 1; n < 10; n++) {
- for (int k = 0; k <= n; k++) {
- Assertions.assertEquals(
- Math.log(BinomialCoefficientTest.binomialCoefficient(n, k)),
- LogBinomialCoefficient.value(n, k), 1e-12, n + " choose " + k);
- }
- }
-
- final int[] n = {34, 66, 100, 1500, 1500};
- final int[] k = {17, 33, 10, 1500 - 4, 4};
- for (int i = 0; i < n.length; i++) {
- final long expected = BinomialCoefficientTest.binomialCoefficient(n[i], k[i]);
- Assertions.assertEquals(
- Math.log(expected),
- LogBinomialCoefficient.value(n[i], k[i]),
- 0d, "log(" + n[i] + " choose " + k[i] + ")");
- }
+ @ParameterizedTest
+ @CsvSource({
+ // Data verified using maxima: bfloat(binomial(n, k)) using 30 digits of precision.
+ // Note: This test avoids infinite expected values.
+ "0, 0, 1, 0",
+ "5, 0, 1, 0",
+ "5, 1, 5, 0",
+ "5, 2, 10, 0",
+ "6, 0, 1, 0",
+ "6, 1, 6, 0",
+ "6, 2, 15, 0",
+ "6, 3, 20, 0",
+ "34, 17, 2333606220, 0",
+ "66, 33, 7219428434016265740, 0",
+ "100, 10, 17310309456440, 0",
+ "1500, 4, 210094780875, 0",
+ "300, 3, 4455100, 0",
+ "700, 697, 56921900, 0",
+ "10000, 3, 166616670000, 0",
+ "412, 9, 863177604710689620, 0",
+ "678, 7, 12667255449994080, 0",
+ "66, 33, 7219428434016265740, 0",
+ // Overflow as a long
+ "67, 30, 9989690752182277136, 0",
+ "67, 33, 14226520737620288370, 0",
+ "68, 34, 28453041475240576740, 0",
+ // Overflow a double without special handling
+ // See NUMBERS-183
+ "1040, 450, 2.3101613255412135615e307, 1",
+ "1029, 514, 1.4298206864989040819e308, 0",
+ "1786388282, 38, 7.187239013254065384599502085053593e306, 1",
+ "1914878305, 38, 100.6570419073661447979173868523364e306, 1",
+ "1179067476, 39, 30.22890249420109200962786203300876e306, 0",
+ "2147483647, 37, 1.388890512412231479281222156415993e302, 0",
+ "20000, 116, 1.75293130532995289393810309132e308, 0",
+ "1028, 514, 7.156051054877897008430135897e307, 0",
+ "1030, 496, 1.41941785031194251722295917039e308, 0",
+ "1030, 497, 1.52508879691464246317315935007e308, 1",
+ "1030, 498, 1.63227375252109323869737737668e308, 1",
+ "1030, 499, 1.74021971210665651901203359598e308, 1",
+ "1020, 510, 2.80626776829962271039414307883e305, 0",
+ "1022, 511, 1.12140876377061244121816833013e306, 0",
+ "1024, 512, 4.48125455209897081002416485048e306, 0",
+ })
+ void testBinomialCoefficient(int n, int k, double nCk, int ulp) {
+ Assertions.assertTrue(Double.isFinite(nCk));
+ assertBinomial(n, k, Math.log(nCk), ulp);
}
- @Test
- void testBinomialCoefficientFail1() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> LogBinomialCoefficient.value(4, 5)
- );
- }
+ @ParameterizedTest
+ @CsvSource({
+ // Data verified using maxima: bfloat(log(binomial(n, k))) using 30 digits of precision.
+ "20000, 117, 7.14892994792834554505294427064e2, 0",
+ "1030, 500, 7.09810373941566367297112517919e2, 0",
+ "1030, 515, 7.10246904865078629457587850298e2, 1",
+ "10000000, 10000, 7.90670275055185025423945062124e4, 0",
+ "10000000, 10000, 7.90670275055185025423945062124e4, 0",
+ "152635712, 789, 1.03890814359013076045677773736e4, 1",
+ "152635712, 4546, 5.19172217038600425710151646693e4, 1",
+ "152635712, 125636, 1.01789719965975898402939070835e6, 1",
+ "2147483647, 107, 1.90292037817495610257804518397e3, 0",
+ "2147483647, 207, 3.54746695691646741842751657371e3, 1",
+ "2147483647, 407, 6.70292742211067648162528876066e3, 0",
+ "2147483647, 807, 1.27416849603252171413378310472e4, 0",
+ "2147483647, 1607, 2.42698285839945068392543390151e4, 0",
- @Test
- void testBinomialCoefficientFail2() {
- Assertions.assertThrows(CombinatoricsException.class,
- () -> LogBinomialCoefficient.value(-1, -2)
- );
+ // Maxima cannot handle very large arguments to binomial or beta.
+ // fpprec : 64;
+ // a(n, k) := bfloat(log(gamma(n+1)) - log(gamma(k+1)) - log(gamma(n-k+1)));
+ // a(2147483647, 1073741824);
+ "2147483647, 1073741824, 1.488522224247066203747566030677421662382623181089980272736272874e9, 1",
+ })
+ void testLogBinomialCoefficient(int n, int k, double lognCk, int ulp) {
+ assertBinomial(n, k, lognCk, ulp);
}
/**
- * Tests correctness for large n and sharpness of upper bound in API doc
+ * Tests correctness for large n and sharpness of upper bound in API doc.
* JIRA: MATH-241
*/
@Test
- void testBinomialCoefficientLarge() throws Exception {
- // This tests all legal and illegal values for n <= 200.
- for (int n = 0; n <= 200; n++) {
- for (int k = 0; k <= n; k++) {
- long exactResult = -1;
- boolean shouldThrow = false;
- boolean didThrow = false;
- try {
- BinomialCoefficient.value(n, k);
- } catch (ArithmeticException ex) {
- didThrow = true;
- }
- try {
- exactResult = BinomialCoefficientTest.binomialCoefficient(n, k);
- } catch (ArithmeticException ex) {
- shouldThrow = true;
- }
+ void testBinomialCoefficientLarge() {
+ // This tests all values for n <= 200.
+ final int size = 200;
+ final BigInteger[] factorials = new BigInteger[size + 1];
+ factorials[0] = BigInteger.ONE;
+ for (int n = 1; n <= size; n++) {
+ factorials[n] = factorials[n - 1].multiply(BigInteger.valueOf(n));
+ }
- if (!shouldThrow && exactResult > 1) {
- Assertions.assertEquals(1,
- LogBinomialCoefficient.value(n, k) / Math.log(exactResult), 1e-10, n + " choose " + k);
+ for (int n = 0; n <= size; n++) {
+ int ulp;
+ if (n <= 66) {
+ ulp = 0;
+ } else {
+ ulp = 1;
+ }
+ for (int k = 0; k <= n / 2; k++) {
+ final BigInteger nCk = factorials[n].divide(factorials[n - k]).divide(factorials[k]);
+ final double expected = nCk.doubleValue();
+ // Cannot log infinite result
+ if (expected == Double.POSITIVE_INFINITY) {
+ Assertions.fail("Incorrect limit for n: " + size);
}
+ assertBinomial(n, k, Math.log(expected), ulp);
}
}
+ }
- final int n = 10000;
- final double actualOverExpected = LogBinomialCoefficient.value(n, 3) /
- Math.log(BinomialCoefficientTest.binomialCoefficient(n, 3));
- Assertions.assertEquals(1, actualOverExpected, 1e-10);
+ private static void assertBinomial(int n, int k, double expected, int ulp) {
+ final double actual = LogBinomialCoefficient.value(n, k);
+ Assertions.assertTrue(Precision.equals(expected, actual, ulp),
+ () -> String.format("Log C(%d, %d) = %s : actual %s : ULP error = %d", n, k,
+ expected, actual,
+ Double.doubleToRawLongBits(actual) - Double.doubleToRawLongBits(expected)));
+ // Test symmetry
+ Assertions.assertEquals(actual, LogBinomialCoefficient.value(n, n - k), () -> n + " choose " + k);
}
}