You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2023/12/12 11:18:41 UTC

(commons-statistics) 03/04: Update FirstMoment to use a half-representation

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git

commit 0805920f3118e034fc957715e333a0a82d7b0a4a
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Tue Dec 12 10:54:59 2023 +0000

    Update FirstMoment to use a half-representation
    
    This maintains the overflow protection of downscaling but avoids
    re-upscaling the moment and stored deviations for each input value.
    Upscaling is only required when computing the final result.
    
    This has a performance gain of 30-40%. Performance is approximately the
    same as a rolling algorithm with no downscaling. Thus this modification
    allows the overflow protection with negligible cost.
    
    All sub-class moments must update their scaling factors when using the
    deviations by appropriate powers of 2.
---
 .../statistics/descriptive/FirstMoment.java        | 129 ++++++++++++++++-----
 .../descriptive/SumOfCubedDeviations.java          |   8 +-
 .../descriptive/SumOfFourthDeviations.java         |   9 +-
 .../descriptive/SumOfSquaredDeviations.java        |   6 +-
 4 files changed, 112 insertions(+), 40 deletions(-)

diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java
index 932ebe3..18d4ed2 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java
@@ -21,7 +21,7 @@ import java.util.function.DoubleConsumer;
 /**
  * Computes the first moment (arithmetic mean) using the definitional formula:
  *
- * <p> mean = sum(x_i) / n
+ * <pre>mean = sum(x_i) / n</pre>
  *
  * <p> To limit numeric errors, the value of the statistic is computed using the
  * following recursive updating algorithm:
@@ -59,12 +59,14 @@ import java.util.function.DoubleConsumer;
  * </ul>
  */
 class FirstMoment implements DoubleConsumer {
+    /** The downscale constant. Used to avoid overflow for all finite input. */
+    private static final double DOWNSCALE = 0.5;
+    /** The rescale constant. */
+    private static final double RESCALE = 2;
+
     /** Count of values that have been added. */
     protected long n;
 
-    /** First moment of values that have been added. */
-    protected double m1;
-
     /**
      * Half the deviation of most recently added value from the previous first moment.
      * Retained to prevent repeated computation in higher order moments.
@@ -74,16 +76,25 @@ class FirstMoment implements DoubleConsumer {
      *
      * <p>This value is not used in the {@link #combine(FirstMoment)} method.
      */
-    protected double halfDev;
+    protected double dev;
 
     /**
-     * Deviation of most recently added value from the previous first moment,
+     * Half the deviation of most recently added value from the previous first moment,
      * normalized by current sample size. Retained to prevent repeated
      * computation in higher order moments.
+     *
+     * <p>Note: This is (x - m1) / 2n. It is computed as a half value to prevent overflow
+     * when computing for any finite value x and m.
+     *
      * Note: This value is not used in the {@link #combine(FirstMoment)} method.
      */
     protected double nDev;
 
+    /** First moment of values that have been added.
+     * This is stored as a half value to prevent overflow for any finite input.
+     * Benchmarks show this has negligible performance impact. */
+    private double m1;
+
     /**
      * Running sum of values seen so far.
      * This is not used in the computation of mean. Used as a return value for first moment when
@@ -122,7 +133,7 @@ class FirstMoment implements DoubleConsumer {
         // "Corrected two-pass algorithm"
 
         // First pass
-        final FirstMoment m1 = Statistics.add(new FirstMoment(), values);
+        final FirstMoment m1 = create(values);
         final double xbar = m1.getFirstMoment();
         if (!Double.isFinite(xbar)) {
             // Note: Also occurs when the input is empty
@@ -135,11 +146,56 @@ class FirstMoment implements DoubleConsumer {
         }
         // Note: Correction may be infinite
         if (Double.isFinite(correction)) {
-            m1.m1 += correction / values.length;
+            // Down scale the correction to the half representation
+            m1.m1 += DOWNSCALE * correction / values.length;
         }
         return m1;
     }
 
+    /**
+     * Creates the first moment using a rolling algorithm.
+     *
+     * <p>This duplicates the algorithm in the {@link #accept(double)} method
+     * with optimisations due to the processing of an entire array:
+     * <ul>
+     *  <li>Avoid updating (unused) class level working variables.
+     *  <li>Only computing the non-finite value if required.
+     * </ul>
+     *
+     * @param values Values.
+     * @return the first moment
+     */
+    private static FirstMoment create(double[] values) {
+        double m1 = 0;
+        int n = 0;
+        for (final double x : values) {
+            // Downscale to avoid overflow for all finite input
+            m1 += (x * DOWNSCALE - m1) / ++n;
+        }
+        final FirstMoment m = new FirstMoment();
+        m.n = n;
+        m.m1 = m1;
+        // The non-finite value is only relevant if the data contains inf/nan
+        if (!Double.isFinite(m1 * RESCALE)) {
+            m.nonFiniteValue = sum(values);
+        }
+        return m;
+    }
+
+    /**
+     * Compute the sum of the values.
+     *
+     * @param values Values.
+     * @return the sum
+     */
+    private static double sum(double[] values) {
+        double sum = 0;
+        for (final double x : values) {
+            sum += x;
+        }
+        return sum;
+    }
+
     /**
      * Updates the state of the statistic to reflect the addition of {@code value}.
      *
@@ -151,14 +207,13 @@ class FirstMoment implements DoubleConsumer {
         // See: Chan et al (1983) Equation 1.3a
         // m_{i+1} = m_i + (x - m_i) / (i + 1)
         // This is modified with scaling to avoid overflow for all finite input.
+        // Scaling the input down by a factor of two ensures that the scaling is lossless.
+        // Sub-classes must alter their scaling factors when using the computed deviations.
 
-        n++;
         nonFiniteValue += value;
-        // To prevent overflow, dev is computed by scaling down and then scaling up.
-        // We choose to scale down by a factor of two to ensure that the scaling is lossless.
-        halfDev = value * 0.5 - m1 * 0.5;
-        // nDev cannot overflow as halfDev is <= MAX_VALUE when n > 1; or <= MAX_VALUE / 2 when n = 1
-        nDev = (halfDev / n) * 2;
+        // Scale down the input
+        dev = value * DOWNSCALE - m1;
+        nDev = dev / ++n;
         m1 += nDev;
     }
 
@@ -172,8 +227,10 @@ class FirstMoment implements DoubleConsumer {
      *         {@code NaN} otherwise.
      */
     double getFirstMoment() {
-        if (Double.isFinite(m1)) {
-            return n == 0 ? Double.NaN : m1;
+        // Scale back to the original magnitude
+        final double m = m1 * RESCALE;
+        if (Double.isFinite(m)) {
+            return n == 0 ? Double.NaN : m;
         }
         // A non-finite value must have been encountered, return nonFiniteValue which represents m1.
         return nonFiniteValue;
@@ -194,22 +251,13 @@ class FirstMoment implements DoubleConsumer {
         n = n1 + n2;
         // Adjust the mean with the weighted difference:
         // m1 = m1 + (m2 - m1) * n2 / (n1 + n2)
-        // The difference between means can be 2 * MAX_VALUE so the computation optionally
-        // scales by a factor of 2. Avoiding scaling if possible preserves sub-normals.
+        // The half-representation ensures the difference of means is at most MAX_VALUE
+        // so the combine can avoid scaling.
         if (n1 == n2) {
             // Optimisation for equal sizes: m1 = (m1 + m2) / 2
-            // Use scaling for a large sum
-            final double sum = mu1 + mu2;
-            m1 = Double.isFinite(sum) ?
-                sum * 0.5 :
-                mu1 * 0.5 + mu2 * 0.5;
+            m1 = (mu1 + mu2) * 0.5;
         } else {
-            // Use scaling for a large difference
-            if (Double.isFinite(mu2 - mu1)) {
-                m1 = combine(mu1, mu2, n1, n2);
-            } else {
-                m1 = 2 * combine(mu1 * 0.5, mu2 * 0.5, n1, n2);
-            }
+            m1 = combine(mu1, mu2, n1, n2);
         }
         return this;
     }
@@ -231,4 +279,27 @@ class FirstMoment implements DoubleConsumer {
             m1 + (m2 - m1) * ((double) n2 / (n1 + n2)) :
             m2 + (m1 - m2) * ((double) n1 / (n1 + n2));
     }
+
+    /**
+     * Gets the difference of the first moment between {@code this} moment and the
+     * {@code other} moment. This is provided for sub-classes.
+     *
+     * @param other Other moment.
+     * @return the difference
+     */
+    double getFirstMomentDifference(FirstMoment other) {
+        // Scale back to the original magnitude
+        return (m1 - other.m1) * RESCALE;
+    }
+
+    /**
+     * Gets the half the difference of the first moment between {@code this} moment and
+     * the {@code other} moment. This is provided for sub-classes.
+     *
+     * @param other Other moment.
+     * @return the difference
+     */
+    double getFirstMomentHalfDifference(FirstMoment other) {
+        return m1 - other.m1;
+    }
 }
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java
index 915f52d..e58610f 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java
@@ -156,10 +156,10 @@ class SumOfCubedDeviations extends SumOfSquaredDeviations {
         // multiplication of later terms (nDev * 3 and nDev^2).
         // This handles initialisation when np in {0, 1) to zero
         // for any deviation (e.g. series MAX_VALUE, -MAX_VALUE).
-        // Note: account for the half-deviation representation.
+        // Note: account for the half-deviation representation by scaling by 6=3*2; 8=2^3
         sumCubedDev = sumCubedDev -
-            ss * nDev * 3 +
-            (np - 1.0) * np * nDev * nDev * halfDev * 2;
+            ss * nDev * 6 +
+            (np - 1.0) * np * nDev * nDev * dev * 8;
     }
 
     /**
@@ -197,7 +197,7 @@ class SumOfCubedDeviations extends SumOfSquaredDeviations {
             // Avoid overflow to compute the difference.
             // This allows any samples of size n=1 to be combined as their SS=0.
             // The result is a SC=0 for the combined n=2.
-            final double halfDiffOfMean = m1 * 0.5 - other.m1 * 0.5;
+            final double halfDiffOfMean = getFirstMomentHalfDifference(other);
             sumCubedDev += other.sumCubedDev;
             // Add additional terms that do not cancel to zero
             if (halfDiffOfMean != 0) {
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java
index 79e8e33..b0bc155 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java
@@ -141,11 +141,12 @@ class SumOfFourthDeviations extends SumOfCubedDeviations {
         // This handles initialisation when np in {0, 1) to zero
         // for any deviation (e.g. series MAX_VALUE, -MAX_VALUE).
         // Note: (np1 * np1 - 3 * np) = (np+1)^2 - 3np = np^2 - np + 1
+        // Note: account for the half-deviation representation by scaling by 8=4*2; 24=6*2^2; 16=2^4
         final double np1 = n;
         sumFourthDev = sumFourthDev -
-            sc * nDev * 4 +
-            ss * nDev * nDev * 6 +
-            np * (np1 * np1 - 3 * np) * nDev * nDev * nDev * nDev * n;
+            sc * nDev * 8 +
+            ss * nDev * nDev * 24 +
+            np * (np1 * np1 - 3 * np) * nDev * nDev * nDev * nDev * n * 16;
     }
 
     /**
@@ -180,7 +181,7 @@ class SumOfFourthDeviations extends SumOfCubedDeviations {
             sumFourthDev = other.sumFourthDev;
         } else if (other.n != 0) {
             // Avoid overflow to compute the difference.
-            final double halfDiffOfMean = m1 * 0.5 - other.m1 * 0.5;
+            final double halfDiffOfMean = getFirstMomentHalfDifference(other);
             sumFourthDev += other.sumFourthDev;
             // Add additional terms that do not cancel to zero
             if (halfDiffOfMean != 0) {
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java
index 2e574f6..93cc421 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java
@@ -133,8 +133,8 @@ class SumOfSquaredDeviations extends FirstMoment {
         // "Updating one-pass algorithm"
         // See: Chan et al (1983) Equation 1.3b
         super.accept(value);
-        // Note: account for the half-deviation representation
-        sumSquaredDev += (n - 1) * halfDev * nDev * 2;
+        // Note: account for the half-deviation representation by scaling by 4=2^2
+        sumSquaredDev += (n - 1) * dev * nDev * 4;
     }
 
     /**
@@ -159,7 +159,7 @@ class SumOfSquaredDeviations extends FirstMoment {
         } else if (m != 0) {
             // "Updating one-pass algorithm"
             // See: Chan et al (1983) Equation 1.5b (modified for the mean)
-            final double diffOfMean = other.m1 - m1;
+            final double diffOfMean = getFirstMomentDifference(other);
             final double sqDiffOfMean = diffOfMean * diffOfMean;
             // Enforce symmetry
             sumSquaredDev = (sumSquaredDev + other.sumSquaredDev) +