You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2022/11/09 12:33:56 UTC

[commons-numbers] 01/04: Numbers-191: Compute Stirling number of the first kind

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-numbers.git

commit 2ec97e42be39f410a06a3ba9c60f89ddea65614c
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Mon Nov 7 17:25:35 2022 +0000

    Numbers-191: Compute Stirling number of the first kind
---
 .../commons/numbers/combinatorics/Stirling.java    | 174 ++++++++++++++---
 .../numbers/combinatorics/StirlingTest.java        | 211 +++++++++++++++++++--
 2 files changed, 343 insertions(+), 42 deletions(-)

diff --git a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java
index d5250fc4..2d301eae 100644
--- a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java
+++ b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java
@@ -23,11 +23,47 @@ package org.apache.commons.numbers.combinatorics;
  * @since 1.2
  */
 public final class Stirling {
+    /** Stirling S1 error message. */
+    private static final String S1_ERROR_FORMAT = "s(n=%d, k=%d)";
     /** Stirling S2 error message. */
     private static final String S2_ERROR_FORMAT = "S(n=%d, k=%d)";
+    /** Overflow threshold for n when computing s(n, 1). */
+    private static final int S1_OVERFLOW_K_EQUALS_1 = 21;
+    /** Overflow threshold for n when computing s(n, n-2). */
+    private static final int S1_OVERFLOW_K_EQUALS_NM2 = 92682;
+    /** Overflow threshold for n when computing s(n, n-3). */
+    private static final int S1_OVERFLOW_K_EQUALS_NM3 = 2761;
     /** Overflow threshold for n when computing S(n, n-2). */
     private static final int S2_OVERFLOW_K_EQUALS_NM2 = 92683;
 
+    /**
+     * Precomputed Stirling numbers of the first kind.
+     * Provides a thread-safe lazy initialization of the cache.
+     */
+    private static class StirlingS1Cache {
+        /** Maximum n to compute (exclusive).
+         * As s(21,3) = 13803759753640704000 is larger than Long.MAX_VALUE
+         * we must stop computation at row 21. */
+        static final int MAX_N = 21;
+        /** Stirling numbers of the first kind. */
+        static final long[][] S1;
+
+        static {
+            S1 = new long[MAX_N][];
+            // Initialise first two rows to allow s(2, 1) to use s(1, 1)
+            S1[0] = new long[] {1};
+            S1[1] = new long[] {0, 1};
+            for (int n = 2; n < S1.length; n++) {
+                S1[n] = new long[n + 1];
+                S1[n][0] = 0;
+                S1[n][n] = 1;
+                for (int k = 1; k < n; k++) {
+                    S1[n][k] = S1[n - 1][k - 1] - (n - 1) * S1[n - 1][k];
+                }
+            }
+        }
+    }
+
     /**
      * Precomputed Stirling numbers of the second kind.
      * Provides a thread-safe lazy initialization of the cache.
@@ -38,18 +74,18 @@ public final class Stirling {
          * we must stop computation at row 26. */
         static final int MAX_N = 26;
         /** Stirling numbers of the second kind. */
-        static final long[][] STIRLING_S2;
+        static final long[][] S2;
 
         static {
-            STIRLING_S2 = new long[MAX_N][];
-            STIRLING_S2[0] = new long[] {1};
-            for (int n = 1; n < STIRLING_S2.length; n++) {
-                STIRLING_S2[n] = new long[n + 1];
-                STIRLING_S2[n][0] = 0;
-                STIRLING_S2[n][1] = 1;
-                STIRLING_S2[n][n] = 1;
+            S2 = new long[MAX_N][];
+            S2[0] = new long[] {1};
+            for (int n = 1; n < S2.length; n++) {
+                S2[n] = new long[n + 1];
+                S2[n][0] = 0;
+                S2[n][1] = 1;
+                S2[n][n] = 1;
                 for (int k = 2; k < n; k++) {
-                    STIRLING_S2[n][k] = k * STIRLING_S2[n - 1][k] + STIRLING_S2[n - 1][k - 1];
+                    S2[n][k] = k * S2[n - 1][k] + S2[n - 1][k - 1];
                 }
             }
         }
@@ -60,6 +96,81 @@ public final class Stirling {
         // intentionally empty.
     }
 
+    /**
+     * Returns the <em>signed</em> <a
+     * href="https://mathworld.wolfram.com/StirlingNumberoftheFirstKind.html">
+     * Stirling number of the first kind</a>, "{@code s(n,k)}". The number of permutations of
+     * {@code n} elements which contain exactly {@code k} permutation cycles is the
+     * nonnegative number: {@code |s(n,k)| = (-1)^(n-k) s(n,k)}
+     *
+     * @param n Size of the set
+     * @param k Number of permutation cycles ({@code 0 <= k <= n})
+     * @return {@code s(n,k)}
+     * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
+     * @throws ArithmeticException if some overflow happens, typically for n exceeding 20
+     * (s(n,n-1) is handled specifically and does not overflow)
+     */
+    public static long stirlingS1(int n, int k) {
+        checkArguments(n, k);
+
+        if (n < StirlingS1Cache.MAX_N) {
+            // The number is in the small cache
+            return StirlingS1Cache.S1[n][k];
+        }
+
+        // Simple cases
+        // https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind#Simple_identities
+        if (k == 0) {
+            return 0;
+        } else if (k == n) {
+            return 1;
+        } else if (k == 1) {
+            checkN(n, k, S1_OVERFLOW_K_EQUALS_1, S1_ERROR_FORMAT);
+            // Note: Only occurs for n=21 so avoid computing the sign with pow(-1, n-1) * (n-1)!
+            return Factorial.value(n - 1);
+        } else if (k == n - 1) {
+            return -BinomialCoefficient.value(n, 2);
+        } else if (k == n - 2) {
+            checkN(n, k, S1_OVERFLOW_K_EQUALS_NM2, S1_ERROR_FORMAT);
+            // (3n-1) * binom(n, 3) / 4
+            final long a = 3L * n - 1;
+            final long b = BinomialCoefficient.value(n, 3);
+            // Compute (a*b/4) without intermediate overflow.
+            // The product (a*b) must be an exact multiple of 4.
+            // Conditional branch on b which is typically large and even (a is 50% even)
+            // If b is even: ((b/2) * a) / 2
+            // If b is odd then a must be even to make a*b even: ((a/2) * b) / 2
+            return (b & 1) == 0 ? ((b >>> 1) * a) >>> 1 : ((a >>> 1) * b) >>> 1;
+        } else if (k == n - 3) {
+            checkN(n, k, S1_OVERFLOW_K_EQUALS_NM3, S1_ERROR_FORMAT);
+            return -BinomialCoefficient.value(n, 2) * BinomialCoefficient.value(n, 4);
+        }
+
+        // Compute using:
+        // s(n + 1, k) = s(n, k - 1)     - n       * s(n, k)
+        // s(n, k)     = s(n - 1, k - 1) - (n - 1) * s(n - 1, k)
+
+        // n >= 21 (MAX_N)
+        // 2 <= k <= n-4
+
+        // Start at the largest easily computed value: n < MAX_N or k < 2
+        final int reduction = Math.min(n - StirlingS1Cache.MAX_N, k - 2) + 1;
+        int n0 = n - reduction;
+        int k0 = k - reduction;
+
+        long sum = stirlingS1(n0, k0);
+        while (n0 < n) {
+            k0++;
+            sum = Math.subtractExact(
+                sum,
+                Math.multiplyExact(n0, stirlingS1(n0, k0))
+            );
+            n0++;
+        }
+
+        return sum;
+    }
+
     /**
      * Returns the <a
      * href="https://mathworld.wolfram.com/StirlingNumberoftheSecondKind.html">
@@ -70,21 +181,16 @@ public final class Stirling {
      * @param n Size of the set
      * @param k Number of non-empty subsets ({@code 0 <= k <= n})
      * @return {@code S(n,k)}
-     * @throws IllegalArgumentException if {@code k < 0} or {@code k > n}.
+     * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
      * @throws ArithmeticException if some overflow happens, typically for n exceeding 25 and
      * k between 20 and n-2 (S(n,n-1) is handled specifically and does not overflow)
      */
     public static long stirlingS2(int n, int k) {
-        if (k < 0) {
-            throw new CombinatoricsException(CombinatoricsException.NEGATIVE, k);
-        }
-        if (k > n) {
-            throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n);
-        }
+        checkArguments(n, k);
 
         if (n < StirlingS2Cache.MAX_N) {
             // The number is in the small cache
-            return StirlingS2Cache.STIRLING_S2[n][k];
+            return StirlingS2Cache.S2[n][k];
         }
 
         // Simple cases
@@ -93,7 +199,7 @@ public final class Stirling {
         } else if (k == 1 || k == n) {
             return 1;
         } else if (k == 2) {
-            checkN(n, k, 64);
+            checkN(n, k, 64, S2_ERROR_FORMAT);
             return (1L << (n - 1)) - 1L;
         } else if (k == n - 1) {
             return BinomialCoefficient.value(n, 2);
@@ -108,7 +214,7 @@ public final class Stirling {
             //   for i in [1, k]:
             //     sum (i * binom(i+1, 2))
             // Avoid overflow checks using the known limit for n when k=n-2
-            checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2);
+            checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2, S2_ERROR_FORMAT);
             long binom = BinomialCoefficient.value(k + 1, 2);
             long sum = 0;
             for (int i = k; i > 0; i--) {
@@ -130,28 +236,50 @@ public final class Stirling {
 
         long sum = stirlingS2(n0, k0);
         while (n0 < n) {
-            n0++;
             k0++;
             sum = Math.addExact(
-                Math.multiplyExact(k0, stirlingS2(n0 - 1, k0)),
+                Math.multiplyExact(k0, stirlingS2(n0, k0)),
                 sum
             );
+            n0++;
         }
 
         return sum;
     }
 
+    /**
+     * Check {@code 0 <= k <= n}.
+     *
+     * @param n N
+     * @param k K
+     * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}.
+     */
+    private static void checkArguments(int n, int k) {
+        // Combine all checks with a single branch:
+        // 0 <= n; 0 <= k <= n
+        // Note: If n >= 0 && k >= 0 && n - k < 0 then k > n.
+        // Bitwise or will detect a negative sign bit in any of the numbers
+        if ((n | k | (n - k)) < 0) {
+            // Raise the correct exception
+            if (n < 0) {
+                throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n);
+            }
+            throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n);
+        }
+    }
+
     /**
      * Check {@code n <= threshold}, or else throw an {@link ArithmeticException}.
      *
      * @param n N
      * @param k K
      * @param threshold Threshold for {@code n}
+     * @param msgFormat Error message format
      * @throws ArithmeticException if overflow is expected to happen
      */
-    private static void checkN(int n, int k, int threshold) {
+    private static void checkN(int n, int k, int threshold, String msgFormat) {
         if (n > threshold) {
-            throw new ArithmeticException(String.format(S2_ERROR_FORMAT, n, k));
+            throw new ArithmeticException(String.format(msgFormat, n, k));
         }
     }
 }
diff --git a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java
index 5d797db8..26587619 100644
--- a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java
+++ b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java
@@ -16,23 +16,209 @@
  */
 package org.apache.commons.numbers.combinatorics;
 
+import java.util.stream.Stream;
+import org.apache.commons.numbers.core.ArithmeticUtils;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
 import org.junit.jupiter.params.provider.CsvSource;
+import org.junit.jupiter.params.provider.MethodSource;
 
 /**
  * Test cases for the {@link Stirling} class.
  */
 class StirlingTest {
 
+    /**
+     * Arguments that are illegal for the Stirling number computations.
+     *
+     * @return the arguments
+     */
+    static Stream<Arguments> stirlingIllegalArguments() {
+        return Stream.of(
+            Arguments.of(1, -1),
+            Arguments.of(-1, -1),
+            Arguments.of(-1, 1),
+            Arguments.of(10, 15),
+            Arguments.of(Integer.MIN_VALUE, 1),
+            Arguments.of(1, Integer.MIN_VALUE),
+            Arguments.of(Integer.MIN_VALUE, Integer.MIN_VALUE),
+            Arguments.of(Integer.MAX_VALUE - 1, Integer.MAX_VALUE)
+        );
+    }
+
+    /**
+     * Arguments that should easily overflow the Stirling number computations.
+     * Used to verify the exception is correct
+     * (e.g. no StackOverflowError occurs due to recursion).
+     *
+     * @return the arguments
+     */
+    static Stream<Arguments> stirlingOverflowArguments() {
+        return Stream.of(
+            Arguments.of(123, 32),
+            Arguments.of(612534, 56123),
+            Arguments.of(261388631, 213),
+            Arguments.of(678688997, 213879),
+            Arguments.of(1000000002, 1000000000),
+            Arguments.of(1000000003, 1000000000),
+            Arguments.of(1000000004, 1000000000),
+            Arguments.of(1000000005, 1000000000),
+            Arguments.of(1000000010, 1000000000),
+            Arguments.of(1000000100, 1000000000)
+        );
+    }
+
+    @ParameterizedTest
+    @MethodSource(value = {"stirlingIllegalArguments"})
+    void testStirlingS1IllegalArgument(int n, int k) {
+        Assertions.assertThrows(IllegalArgumentException.class, () -> Stirling.stirlingS1(n, k));
+    }
+
+    @Test
+    void testStirlingS1StandardCases() {
+        Assertions.assertEquals(1, Stirling.stirlingS1(0, 0));
+
+        for (int n = 1; n < 64; ++n) {
+            Assertions.assertEquals(0, Stirling.stirlingS1(n, 0));
+            if (n < 21) {
+                Assertions.assertEquals(ArithmeticUtils.pow(-1, n - 1) * Factorial.value(n - 1),
+                                        Stirling.stirlingS1(n, 1));
+                if (n > 2) {
+                    Assertions.assertEquals(-BinomialCoefficient.value(n, 2),
+                                            Stirling.stirlingS1(n, n - 1));
+                }
+            }
+            Assertions.assertEquals(1, Stirling.stirlingS1(n, n));
+        }
+    }
+
     @ParameterizedTest
     @CsvSource({
-        "1, -1",
-        "-1, -1",
-        "-1, 1",
-        "10, 15",
+        // Data verified using Mathematica StirlingS1[n, k]
+        "5, 3, 35",
+        "6, 3, -225",
+        "6, 4, 85",
+        "7, 3, 1624",
+        "7, 4, -735",
+        "7, 5, 175",
+        "8, 3, -13132",
+        "8, 4, 6769",
+        "8, 5, -1960",
+        "8, 6, 322",
+        "9, 3, 118124",
+        "9, 4, -67284",
+        "9, 5, 22449",
+        "9, 6, -4536",
+        "9, 7, 546",
+        "10, 3, -1172700",
+        "10, 4, 723680",
+        "10, 5, -269325",
+        "10, 6, 63273",
+        "10, 7, -9450",
+        "10, 8, 870",
+        // n >= 21 is not cached
+        // ... k in [1, 7] require n <= 21
+        "21, 8, -311333643161390640",
+        "21, 9, 63030812099294896",
+        "22, 10, 276019109275035346",
+        "22, 11, -37600535086859745",
+        "23, 12, -129006659818331295",
+        "23, 13, 12363045847086207",
+        "24, 14, 34701806448704206",
+        "25, 15, 92446911376173550",
+        "25, 16, -5700586321864500",
+        "26, 17, -12972753318542875",
+        "27, 18, -28460103232088385",
+        "28, 19, -60383004803151030",
+        "29, 20, -124243455209483610",
+        // k in [n-8, n-2]
+        "33, 25, 42669229615802790",
+        "40, 33, -16386027912368400",
+        "66, 60, 98715435586436240",
+        "155, 150, -1849441185054164625",
+        "404, 400, 1793805203416799170",
+        "1003, 1000, -21063481189500750",
+        "10002, 10000, 1250583420837500",
+        // Limits for k in [n-1, n] use n = Integer.MAX_VALUE
+        "2147483647, 2147483646, -2305843005992468481",
+        "2147483647, 2147483647, 1",
+        // Data for s(n, n-2)
+        "21, 19, 20615",
+        "22, 20, 25025",
+        "23, 21, 30107",
+        "24, 22, 35926",
+        "25, 23, 42550",
+        "26, 24, 50050",
+        "27, 25, 58500",
+        "92679, 92677, 9221886003909976111",
+        "92680, 92678, 9222284027979459010",
+        "92681, 92679, 9222682064933083810",
+        // Data for s(n, n-3)
+        "21, 18, -1256850",
+        "22, 19, -1689765",
+        "23, 20, -2240315",
+        "24, 21, -2932776",
+        "25, 22, -3795000",
+        "26, 23, -4858750",
+        "27, 24, -6160050",
+        "2758, 2755, -9145798629595485585",
+        "2759, 2756, -9165721700732052911",
+        "2760, 2757, -9185680925511388200",
     })
+    void testStirlingS1(int n, int k, long expected) {
+        Assertions.assertEquals(expected, Stirling.stirlingS1(n, k));
+    }
+
+    @ParameterizedTest
+    @CsvSource({
+        // Upper limits for n with k in [1, 20]
+        "21, 1, 2432902008176640000",
+        "21, 2, -8752948036761600000",
+        "20, 3, -668609730341153280",
+        "20, 4, 610116075740491776",
+        "21, 5, 8037811822645051776",
+        "21, 6, -3599979517947607200",
+        "21, 7, 1206647803780373360",
+        "22, 8, 7744654310169576800",
+        "22, 9, -1634980697246583456",
+        "23, 10, -7707401101297361068",
+        "23, 11, 1103230881185949736",
+        "24, 12, 4070384057007569521",
+        "24, 13, -413356714301314056",
+        "25, 14, -1246200069070215000",
+        "26, 15, -3557372853474553750",
+        "26, 16, 234961569422786050",
+        "27, 17, 572253155704900800",
+        "28, 18, 1340675942971287195",
+        "29, 19, 3031400077459516035",
+        "30, 20, 6634460278534540725",
+        // Upper limits for n with k in [n-9, n-2]
+        "35, 26, -5576855646887454930",
+        "44, 36, 6364808704290634598",
+        "61, 54, -8424028440309413250",
+        "95, 89, 8864929183170733205",
+        "181, 176, -8872439767850041020",
+        "495, 491, 9161199664152744351",
+        "2761, 2758, -9205676356399769400",
+        "92682, 92680, 9223080114771128550",
+    })
+    void testStirlingS1LimitsN(int n, int k, long expected) {
+        Assertions.assertEquals(expected, Stirling.stirlingS1(n, k));
+        Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n + 1, k));
+        Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n + 100, k));
+        Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n + 10000, k));
+    }
+
+    @ParameterizedTest
+    @MethodSource(value = {"stirlingOverflowArguments"})
+    void testStirlingS1Overflow(int n, int k) {
+        Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n, k));
+    }
+
+    @ParameterizedTest
+    @MethodSource(value = {"stirlingIllegalArguments"})
     void testStirlingS2IllegalArgument(int n, int k) {
         Assertions.assertThrows(IllegalArgumentException.class, () -> Stirling.stirlingS2(n, k));
     }
@@ -129,7 +315,7 @@ class StirlingTest {
         "30, 20, 581535955088511150",
         "31, 21, 1359760239259935240",
         "32, 22, 3069483578649883980",
-        // Upper limits for n with with k in [n-10, n-2]
+        // Upper limits for n with k in [n-10, n-2]
         "33, 23, 6708404338089491700",
         "38, 29, 6766081393022256030",
         "47, 39, 8248929419122431611",
@@ -148,20 +334,7 @@ class StirlingTest {
     }
 
     @ParameterizedTest
-    @CsvSource({
-        // Large numbers that should easily overflow. Verifies the exception is correct
-        // (e.g. no StackOverflowError occurs due to recursion)
-        "123, 32",
-        "612534, 56123",
-        "261388631, 213",
-        "678688997, 213879",
-        "1000000002, 1000000000",
-        "1000000003, 1000000000",
-        "1000000004, 1000000000",
-        "1000000005, 1000000000",
-        "1000000010, 1000000000",
-        "1000000100, 1000000000",
-    })
+    @MethodSource(value = {"stirlingOverflowArguments"})
     void testStirlingS2Overflow(int n, int k) {
         Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS2(n, k));
     }