You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2023/12/12 11:18:41 UTC
(commons-statistics) 03/04: Update FirstMoment to use a half-representation
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
commit 0805920f3118e034fc957715e333a0a82d7b0a4a
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Tue Dec 12 10:54:59 2023 +0000
Update FirstMoment to use a half-representation
This maintains the overflow protection of downscaling but avoids
re-upscaling the moment and stored deviations for each input value.
Upscaling is only required when computing the final result.
This has a performance gain of 30-40%. Performance is approximately the
same as a rolling algorithm with no downscaling. Thus this modification
allows the overflow protection with negligible cost.
All sub-class moments must update their scaling factors when using the
deviations by appropriate powers of 2.
---
.../statistics/descriptive/FirstMoment.java | 129 ++++++++++++++++-----
.../descriptive/SumOfCubedDeviations.java | 8 +-
.../descriptive/SumOfFourthDeviations.java | 9 +-
.../descriptive/SumOfSquaredDeviations.java | 6 +-
4 files changed, 112 insertions(+), 40 deletions(-)
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java
index 932ebe3..18d4ed2 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java
@@ -21,7 +21,7 @@ import java.util.function.DoubleConsumer;
/**
* Computes the first moment (arithmetic mean) using the definitional formula:
*
- * <p> mean = sum(x_i) / n
+ * <pre>mean = sum(x_i) / n</pre>
*
* <p> To limit numeric errors, the value of the statistic is computed using the
* following recursive updating algorithm:
@@ -59,12 +59,14 @@ import java.util.function.DoubleConsumer;
* </ul>
*/
class FirstMoment implements DoubleConsumer {
+ /** The downscale constant. Used to avoid overflow for all finite input. */
+ private static final double DOWNSCALE = 0.5;
+ /** The rescale constant. */
+ private static final double RESCALE = 2;
+
/** Count of values that have been added. */
protected long n;
- /** First moment of values that have been added. */
- protected double m1;
-
/**
* Half the deviation of most recently added value from the previous first moment.
* Retained to prevent repeated computation in higher order moments.
@@ -74,16 +76,25 @@ class FirstMoment implements DoubleConsumer {
*
* <p>This value is not used in the {@link #combine(FirstMoment)} method.
*/
- protected double halfDev;
+ protected double dev;
/**
- * Deviation of most recently added value from the previous first moment,
+ * Half the deviation of most recently added value from the previous first moment,
* normalized by current sample size. Retained to prevent repeated
* computation in higher order moments.
+ *
+ * <p>Note: This is (x - m1) / 2n. It is computed as a half value to prevent overflow
+ * when computing for any finite value x and m.
+ *
* Note: This value is not used in the {@link #combine(FirstMoment)} method.
*/
protected double nDev;
+ /** First moment of values that have been added.
+ * This is stored as a half value to prevent overflow for any finite input.
+ * Benchmarks show this has negligible performance impact. */
+ private double m1;
+
/**
* Running sum of values seen so far.
* This is not used in the computation of mean. Used as a return value for first moment when
@@ -122,7 +133,7 @@ class FirstMoment implements DoubleConsumer {
// "Corrected two-pass algorithm"
// First pass
- final FirstMoment m1 = Statistics.add(new FirstMoment(), values);
+ final FirstMoment m1 = create(values);
final double xbar = m1.getFirstMoment();
if (!Double.isFinite(xbar)) {
// Note: Also occurs when the input is empty
@@ -135,11 +146,56 @@ class FirstMoment implements DoubleConsumer {
}
// Note: Correction may be infinite
if (Double.isFinite(correction)) {
- m1.m1 += correction / values.length;
+ // Down scale the correction to the half representation
+ m1.m1 += DOWNSCALE * correction / values.length;
}
return m1;
}
+ /**
+ * Creates the first moment using a rolling algorithm.
+ *
+ * <p>This duplicates the algorithm in the {@link #accept(double)} method
+ * with optimisations due to the processing of an entire array:
+ * <ul>
+ * <li>Avoid updating (unused) class level working variables.
+ * <li>Only computing the non-finite value if required.
+ * </ul>
+ *
+ * @param values Values.
+ * @return the first moment
+ */
+ private static FirstMoment create(double[] values) {
+ double m1 = 0;
+ int n = 0;
+ for (final double x : values) {
+ // Downscale to avoid overflow for all finite input
+ m1 += (x * DOWNSCALE - m1) / ++n;
+ }
+ final FirstMoment m = new FirstMoment();
+ m.n = n;
+ m.m1 = m1;
+ // The non-finite value is only relevant if the data contains inf/nan
+ if (!Double.isFinite(m1 * RESCALE)) {
+ m.nonFiniteValue = sum(values);
+ }
+ return m;
+ }
+
+ /**
+ * Compute the sum of the values.
+ *
+ * @param values Values.
+ * @return the sum
+ */
+ private static double sum(double[] values) {
+ double sum = 0;
+ for (final double x : values) {
+ sum += x;
+ }
+ return sum;
+ }
+
/**
* Updates the state of the statistic to reflect the addition of {@code value}.
*
@@ -151,14 +207,13 @@ class FirstMoment implements DoubleConsumer {
// See: Chan et al (1983) Equation 1.3a
// m_{i+1} = m_i + (x - m_i) / (i + 1)
// This is modified with scaling to avoid overflow for all finite input.
+ // Scaling the input down by a factor of two ensures that the scaling is lossless.
+ // Sub-classes must alter their scaling factors when using the computed deviations.
- n++;
nonFiniteValue += value;
- // To prevent overflow, dev is computed by scaling down and then scaling up.
- // We choose to scale down by a factor of two to ensure that the scaling is lossless.
- halfDev = value * 0.5 - m1 * 0.5;
- // nDev cannot overflow as halfDev is <= MAX_VALUE when n > 1; or <= MAX_VALUE / 2 when n = 1
- nDev = (halfDev / n) * 2;
+ // Scale down the input
+ dev = value * DOWNSCALE - m1;
+ nDev = dev / ++n;
m1 += nDev;
}
@@ -172,8 +227,10 @@ class FirstMoment implements DoubleConsumer {
* {@code NaN} otherwise.
*/
double getFirstMoment() {
- if (Double.isFinite(m1)) {
- return n == 0 ? Double.NaN : m1;
+ // Scale back to the original magnitude
+ final double m = m1 * RESCALE;
+ if (Double.isFinite(m)) {
+ return n == 0 ? Double.NaN : m;
}
// A non-finite value must have been encountered, return nonFiniteValue which represents m1.
return nonFiniteValue;
@@ -194,22 +251,13 @@ class FirstMoment implements DoubleConsumer {
n = n1 + n2;
// Adjust the mean with the weighted difference:
// m1 = m1 + (m2 - m1) * n2 / (n1 + n2)
- // The difference between means can be 2 * MAX_VALUE so the computation optionally
- // scales by a factor of 2. Avoiding scaling if possible preserves sub-normals.
+ // The half-representation ensures the difference of means is at most MAX_VALUE
+ // so the combine can avoid scaling.
if (n1 == n2) {
// Optimisation for equal sizes: m1 = (m1 + m2) / 2
- // Use scaling for a large sum
- final double sum = mu1 + mu2;
- m1 = Double.isFinite(sum) ?
- sum * 0.5 :
- mu1 * 0.5 + mu2 * 0.5;
+ m1 = (mu1 + mu2) * 0.5;
} else {
- // Use scaling for a large difference
- if (Double.isFinite(mu2 - mu1)) {
- m1 = combine(mu1, mu2, n1, n2);
- } else {
- m1 = 2 * combine(mu1 * 0.5, mu2 * 0.5, n1, n2);
- }
+ m1 = combine(mu1, mu2, n1, n2);
}
return this;
}
@@ -231,4 +279,27 @@ class FirstMoment implements DoubleConsumer {
m1 + (m2 - m1) * ((double) n2 / (n1 + n2)) :
m2 + (m1 - m2) * ((double) n1 / (n1 + n2));
}
+
+ /**
+ * Gets the difference of the first moment between {@code this} moment and the
+ * {@code other} moment. This is provided for sub-classes.
+ *
+ * @param other Other moment.
+ * @return the difference
+ */
+ double getFirstMomentDifference(FirstMoment other) {
+ // Scale back to the original magnitude
+ return (m1 - other.m1) * RESCALE;
+ }
+
+ /**
+ * Gets the half the difference of the first moment between {@code this} moment and
+ * the {@code other} moment. This is provided for sub-classes.
+ *
+ * @param other Other moment.
+ * @return the difference
+ */
+ double getFirstMomentHalfDifference(FirstMoment other) {
+ return m1 - other.m1;
+ }
}
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java
index 915f52d..e58610f 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfCubedDeviations.java
@@ -156,10 +156,10 @@ class SumOfCubedDeviations extends SumOfSquaredDeviations {
// multiplication of later terms (nDev * 3 and nDev^2).
// This handles initialisation when np in {0, 1) to zero
// for any deviation (e.g. series MAX_VALUE, -MAX_VALUE).
- // Note: account for the half-deviation representation.
+ // Note: account for the half-deviation representation by scaling by 6=3*2; 8=2^3
sumCubedDev = sumCubedDev -
- ss * nDev * 3 +
- (np - 1.0) * np * nDev * nDev * halfDev * 2;
+ ss * nDev * 6 +
+ (np - 1.0) * np * nDev * nDev * dev * 8;
}
/**
@@ -197,7 +197,7 @@ class SumOfCubedDeviations extends SumOfSquaredDeviations {
// Avoid overflow to compute the difference.
// This allows any samples of size n=1 to be combined as their SS=0.
// The result is a SC=0 for the combined n=2.
- final double halfDiffOfMean = m1 * 0.5 - other.m1 * 0.5;
+ final double halfDiffOfMean = getFirstMomentHalfDifference(other);
sumCubedDev += other.sumCubedDev;
// Add additional terms that do not cancel to zero
if (halfDiffOfMean != 0) {
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java
index 79e8e33..b0bc155 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfFourthDeviations.java
@@ -141,11 +141,12 @@ class SumOfFourthDeviations extends SumOfCubedDeviations {
// This handles initialisation when np in {0, 1) to zero
// for any deviation (e.g. series MAX_VALUE, -MAX_VALUE).
// Note: (np1 * np1 - 3 * np) = (np+1)^2 - 3np = np^2 - np + 1
+ // Note: account for the half-deviation representation by scaling by 8=4*2; 24=6*2^2; 16=2^4
final double np1 = n;
sumFourthDev = sumFourthDev -
- sc * nDev * 4 +
- ss * nDev * nDev * 6 +
- np * (np1 * np1 - 3 * np) * nDev * nDev * nDev * nDev * n;
+ sc * nDev * 8 +
+ ss * nDev * nDev * 24 +
+ np * (np1 * np1 - 3 * np) * nDev * nDev * nDev * nDev * n * 16;
}
/**
@@ -180,7 +181,7 @@ class SumOfFourthDeviations extends SumOfCubedDeviations {
sumFourthDev = other.sumFourthDev;
} else if (other.n != 0) {
// Avoid overflow to compute the difference.
- final double halfDiffOfMean = m1 * 0.5 - other.m1 * 0.5;
+ final double halfDiffOfMean = getFirstMomentHalfDifference(other);
sumFourthDev += other.sumFourthDev;
// Add additional terms that do not cancel to zero
if (halfDiffOfMean != 0) {
diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java
index 2e574f6..93cc421 100644
--- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java
+++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java
@@ -133,8 +133,8 @@ class SumOfSquaredDeviations extends FirstMoment {
// "Updating one-pass algorithm"
// See: Chan et al (1983) Equation 1.3b
super.accept(value);
- // Note: account for the half-deviation representation
- sumSquaredDev += (n - 1) * halfDev * nDev * 2;
+ // Note: account for the half-deviation representation by scaling by 4=2^2
+ sumSquaredDev += (n - 1) * dev * nDev * 4;
}
/**
@@ -159,7 +159,7 @@ class SumOfSquaredDeviations extends FirstMoment {
} else if (m != 0) {
// "Updating one-pass algorithm"
// See: Chan et al (1983) Equation 1.5b (modified for the mean)
- final double diffOfMean = other.m1 - m1;
+ final double diffOfMean = getFirstMomentDifference(other);
final double sqDiffOfMean = diffOfMean * diffOfMean;
// Enforce symmetry
sumSquaredDev = (sumSquaredDev + other.sumSquaredDev) +