You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@commons.apache.org by md...@apache.org on 2003/06/21 04:13:41 UTC
cvs commit: jakarta-commons-sandbox/math/src/java/org/apache/commons/math/stat BivariateRegression.java
mdiggory 2003/06/20 19:13:41
Modified: math/src/test/org/apache/commons/math/stat
BivariateRegressionTest.java
math/src/java/org/apache/commons/math/stat
BivariateRegression.java
Log:
PR: http://nagoya.apache.org/bugzilla/show_bug.cgi?id=20979
Submitted by: phil@steitz.com
Revision Changes Path
1.3 +5 -5 jakarta-commons-sandbox/math/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java
Index: BivariateRegressionTest.java
===================================================================
RCS file: /home/cvs/jakarta-commons-sandbox/math/src/test/org/apache/commons/math/stat/BivariateRegressionTest.java,v
retrieving revision 1.2
retrieving revision 1.3
diff -u -r1.2 -r1.3
--- BivariateRegressionTest.java 11 Jun 2003 03:33:05 -0000 1.2
+++ BivariateRegressionTest.java 21 Jun 2003 02:13:41 -0000 1.3
@@ -130,15 +130,15 @@
assertEquals("r-square",0.999993745883712,
regression.getRSquare(),10E-12);
assertEquals("SSR",4255954.13232369,
- regression.getRegressionSumSquares(),10E-8);
+ regression.getRegressionSumSquares(),10E-9);
assertEquals("MSE",0.782864662630069,
- regression.getMeanSquareError(),10E-8);
+ regression.getMeanSquareError(),10E-10);
assertEquals("SSE",26.6173985294224,
- regression.getSumSquaredErrors(),10E-8);
+ regression.getSumSquaredErrors(),10E-9);
assertEquals("predict(0)",-0.262323073774029,
regression.predict(0),10E-12);
assertEquals("predict(1)",1.00211681802045-0.262323073774029,
- regression.predict(1),10E-11);
+ regression.predict(1),10E-12);
}
public void testCorr() {
1.3 +53 -42 jakarta-commons-sandbox/math/src/java/org/apache/commons/math/stat/BivariateRegression.java
Index: BivariateRegression.java
===================================================================
RCS file: /home/cvs/jakarta-commons-sandbox/math/src/java/org/apache/commons/math/stat/BivariateRegression.java,v
retrieving revision 1.2
retrieving revision 1.3
diff -u -r1.2 -r1.3
--- BivariateRegression.java 11 Jun 2003 03:33:05 -0000 1.2
+++ BivariateRegression.java 21 Jun 2003 02:13:41 -0000 1.3
@@ -92,14 +92,14 @@
/** sum of x values */
private double sumX = 0d;
- /** sum of squared x values */
- private double sumSqX = 0d;
+ /** total variation in x (sum of squared deviations from xbar) */
+ private double sumXX = 0d;
/** sum of y values */
private double sumY = 0d;
- /** sum of squared y values */
- private double sumSqY = 0d;
+ /** total variation in y (sum of squared deviations from ybar) */
+ private double sumYY = 0d;
/** sum of products */
private double sumXY = 0d;
@@ -107,20 +107,41 @@
/** number of observations */
private long n = 0;
+ /** mean of accumulated x values, used in updating formulas */
+ private double xbar = 0;
+
+ /** mean of accumulated y values, used in updating formulas */
+ private double ybar = 0;
+
+
// ---------------------Public methods--------------------------------------
/**
- * Adds the observation (x,y) to the regression data set
+ * Adds the observation (x,y) to the regression data set.
+ * <p>
+ * Uses updating formulas for means and sums of squares defined in
+ * "Algorithms for Computing the Sample Variance: Analysis and
+ * Recommendations", Chan, T.F., Golub, G.H., and LeVeque, R.J.
+ * 1983, American Statistician, vol. 37, pp. 242-247, referenced in
+ * Weisberg, S. "Applied Linear Regression". 2nd Ed. 1985
+ *
*
* @param x independent variable value
* @param y dependent variable value
*/
public void addData(double x, double y) {
+ if (n == 0) {
+ xbar = x;
+ ybar = y;
+ } else {
+ sumXX += ((double) n / (double) (n + 1)) * (x - xbar) * (x - xbar);
+ sumYY += ((double) n / (double) (n + 1)) * (y - ybar) * (y - ybar);
+ sumXY += ((double) n / (double) (n + 1)) * (x - xbar) * (y - ybar);
+ xbar += (1d / (double) (n + 1)) * (x - xbar);
+ ybar += (1d / (double) (n + 1)) * (y - ybar);
+ }
sumX += x;
- sumSqX += x * x;
sumY += y;
- sumSqY += y * y;
- sumXY += x * y;
n++;
}
@@ -148,9 +169,9 @@
*/
public void clear() {
sumX = 0d;
- sumSqX = 0d;
+ sumXX = 0d;
sumY = 0d;
- sumSqY = 0d;
+ sumYY = 0d;
sumXY = 0d;
n = 0;
}
@@ -215,7 +236,7 @@
* <strong>Preconditions</strong>: <ul>
* <li>At least two observations (with at least two different x values)
* must have been added before invoking this method. If this method is
- * invoked before a model can be estimated, <code>Double,NaN</code> is
+ * invoked before a model can be estimated, <code>Double.NaN</code> is
* returned.
* </li></ul>
*
@@ -225,12 +246,10 @@
if (n < 2) {
return Double.NaN; //not enough data
}
- double dn = (double) n;
- double denom = sumSqX - (sumX * sumX / dn);
- if (Math.abs(denom) < 10 * Double.MIN_VALUE) {
+ if (Math.abs(sumXX) < 10 * Double.MIN_VALUE) {
return Double.NaN; //not enough variation in x
}
- return (sumXY - (sumX * sumY / dn)) / denom;
+ return sumXY / sumXX;
}
/**
@@ -265,7 +284,7 @@
if (n < 2) {
return Double.NaN;
}
- return sumSqY - sumY * sumY / (double) n;
+ return sumYY;
}
/**
@@ -282,11 +301,10 @@
* returned.
* </li></ul>
*
- * @return sum of squared deviations of y values
+ * @return sum of squared deviations of predicted y values
*/
public double getRegressionSumSquares() {
- double b1 = getSlope();
- return b1 * (sumXY - sumX * sumY / (double) n);
+ return getRegressionSumSquares(getSlope());
}
/**
@@ -303,8 +321,7 @@
if (n < 3) {
return Double.NaN;
}
- double sse = getSumSquaredErrors();
- return sse / (double) (n - 2);
+ return getSumSquaredErrors() / (double) (n - 2);
}
/**
@@ -361,8 +378,8 @@
* @return standard error associated with intercept estimate
*/
public double getInterceptStdErr() {
- double ssx = getSumSquaresX();
- return Math.sqrt(getMeanSquareError() * sumSqX / (((double) n) * ssx));
+ return Math.sqrt(getMeanSquareError() * ((1d / (double) n) +
+ (xbar * xbar) / sumXX));
}
/**
@@ -376,8 +393,7 @@
* @return standard error associated with slope estimate
*/
public double getSlopeStdErr() {
- double ssx = getSumSquaresX();
- return Math.sqrt(getMeanSquareError() / ssx);
+ return Math.sqrt(getMeanSquareError() / sumXX);
}
/**
@@ -492,24 +508,9 @@
* @return sum of squared errors associated with the regression model
*/
private double getSumSquaredErrors(double b1) {
- double b0 = getIntercept(b1);
- return sumSqY - b0 * sumY - b1 * sumXY;
+ return sumYY - sumXY * sumXY / sumXX;
}
- /**
- * Returns the sum of squared deviations of the x values about their mean.
- * <p>
- * If n < 2, this returns NaN.
- *
- * @return sum of squared deviations of x values
- */
- private double getSumSquaresX() {
- if (n < 2) {
- return Double.NaN;
- }
- return sumSqX - sumX * sumX / (double) n;
- }
-
/**
* Computes r-square from the slope.
* <p>
@@ -521,6 +522,16 @@
private double getRSquare(double b1) {
double ssto = getTotalSumSquares();
return (ssto - getSumSquaredErrors(b1)) / ssto;
+ }
+
+ /**
+ * Computes SSR from b1.
+ *
+ * @param slope regression slope estimate
+ * @return sum of squared deviations of predicted y values
+ */
+ private double getRegressionSumSquares(double slope) {
+ return slope * slope * sumXX;
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: commons-dev-unsubscribe@jakarta.apache.org
For additional commands, e-mail: commons-dev-help@jakarta.apache.org