You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ps...@apache.org on 2008/07/27 20:52:39 UTC
svn commit: r680162 - in /commons/proper/math/branches/MATH_2_0/src:
java/org/apache/commons/math/stat/regression/ site/xdoc/
test/org/apache/commons/math/stat/regression/
Author: psteitz
Date: Sun Jul 27 11:52:38 2008
New Revision: 680162
URL: http://svn.apache.org/viewvc?rev=680162&view=rev
Log:
Changed OLSMultipleLinearRegression implementation to use QR decomposition to
solve the normal equations.
JIRA: MATH-217
Modified:
commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml
commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java
commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
Modified: commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java (original)
+++ commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java Sun Jul 27 11:52:38 2008
@@ -16,30 +16,52 @@
*/
package org.apache.commons.math.stat.regression;
+import org.apache.commons.math.linear.QRDecomposition;
+import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
-
+import org.apache.commons.math.linear.RealMatrixImpl;
/**
- * The OLS implementation of the multiple linear regression.
+ * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
+ * multiple linear regression model.</p>
*
- * OLS assumes the covariance matrix of the error to be diagonal and with equal variance.
+ * <p>OLS assumes the covariance matrix of the error to be diagonal and with
+ * equal variance.
* <pre>
* u ~ N(0, sigma^2*I)
- * </pre>
+ * </pre></p>
*
- * Estimated by OLS,
+ * <p>The regression coefficients, b, satisfy the normal equations:
* <pre>
- * b=(X'X)^-1X'y
- * </pre>
- * whose variance is
+ * X^T X b = X^T y
+ * </pre></p>
+ *
+ * <p>To solve the normal equations, this implementation uses QR decomposition
+ * of the X matrix. (See {@link QRDecompositionImpl} for details on the
+ * decomposition algorithm.)
* <pre>
- * Var(b)=MSE*(X'X)^-1, MSE=u'u/(n-k)
+ * X^T X b = X^T y
+ * (QR)^T (QR) b = (QR)^T y
+ * R^T (Q^T Q) R b = R^T Q^T y
+ * R^T R b = R^T Q^T y
+ * (R^T)^{-1} R^T R b = (R^T)^{-1} R^T Q^T y
+ * R b = Q^T y
* </pre>
+ * Given Q and R, the last equation is solved by back-subsitution.</p>
+ *
* @version $Revision$ $Date$
* @since 2.0
*/
public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
+
+ /** Cached QR decomposition of X matrix */
+ private QRDecomposition qr = null;
+ /*
+ * {@inheritDoc}
+ *
+ * Computes and caches QR decomposition of the X matrix.
+ */
public void newSampleData(double[] y, double[][] x) {
validateSampleData(x, y);
newYSampleData(y);
@@ -47,15 +69,33 @@
}
/**
- * Calculates beta by OLS.
- * <pre>
- * b=(X'X)^-1X'y
- * </pre>
+ * {@inheritDoc}
+ *
+ * Computes and caches QR decomposition of the X matrix
+ */
+ public void newSampleData(double[] data, int nobs, int nvars) {
+ super.newSampleData(data, nobs, nvars);
+ qr = new QRDecompositionImpl(X);
+ }
+
+ /**
+ * Loads new x sample data, overriding any previous sample
+ *
+ * @param x the [n,k] array representing the x sample
+ */
+ protected void newXSampleData(double[][] x) {
+ this.X = new RealMatrixImpl(x);
+ qr = new QRDecompositionImpl(X);
+ }
+
+ /**
+ * Calculates regression coefficients using OLS.
+ *
* @return beta
*/
protected RealMatrix calculateBeta() {
- RealMatrix XTX = X.transpose().multiply(X);
- return XTX.inverse().multiply(X.transpose()).multiply(Y);
+ return solveUpperTriangular((RealMatrixImpl) qr.getR(),
+ (RealMatrixImpl) qr.getQ().transpose().multiply(Y));
}
/**
@@ -83,5 +123,76 @@
RealMatrix sse = u.transpose().multiply(u);
return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
}
-
+
+ /** TODO: Find a home for the following methods in the linear package */
+
+ /**
+ * <p>Uses back substitution to solve the system</p>
+ *
+ * <p>coefficients X = constants</p>
+ *
+ * <p>coefficients must upper-triangular and constants must be a column
+ * matrix. The solution is returned as a column matrix.</p>
+ *
+ * <p>The number of columns in coefficients determines the length
+ * of the returned solution vector (column matrix). If constants
+ * has more rows than coefficients has columns, excess rows are ignored.
+ * Similarly, extra (zero) rows in coefficients are ignored</p>
+ *
+ * @param coefficients upper-triangular coefficients matrix
+ * @param constants column RHS constants matrix
+ * @return solution matrix as a column matrix
+ *
+ */
+ private static RealMatrix solveUpperTriangular(RealMatrixImpl coefficients,
+ RealMatrixImpl constants) {
+ if (!isUpperTriangular(coefficients, 1E-12)) {
+ throw new IllegalArgumentException(
+ "Coefficients is not upper-triangular");
+ }
+ if (constants.getColumnDimension() != 1) {
+ throw new IllegalArgumentException(
+ "Constants not a column matrix.");
+ }
+ int length = coefficients.getColumnDimension();
+ double[][] cons = constants.getDataRef();
+ double[][] coef = coefficients.getDataRef();
+ double x[] = new double[length];
+ for (int i = 0; i < length; i++) {
+ int index = length - 1 - i;
+ double sum = 0;
+ for (int j = index + 1; j < length; j++) {
+ sum += coef[index][j] * x[j];
+ }
+ x[index] = (cons[index][0] - sum) / coef[index][index];
+ }
+ return new RealMatrixImpl(x);
+ }
+
+ /**
+ * <p>Returns true iff m is an upper-triangular matrix.</p>
+ *
+ * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p>
+ *
+ * @param m matrix to check
+ * @param epsilon maximum allowable absolute value for elements below
+ * the main diagonal
+ *
+ * @return true if m is upper-triangular; false otherwise
+ * @throws NullPointerException if m is null
+ */
+ private static boolean isUpperTriangular(RealMatrixImpl m, double epsilon) {
+ double[][] data = m.getDataRef();
+ int nCols = m.getColumnDimension();
+ int nRows = m.getRowDimension();
+ for (int r = 0; r < nRows; r++) {
+ int bound = Math.min(r, nCols);
+ for (int c = 0; c < bound; c++) {
+ if (Math.abs(data[r][c]) > epsilon) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
}
Modified: commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml (original)
+++ commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml Sun Jul 27 11:52:38 2008
@@ -39,6 +39,10 @@
</properties>
<body>
<release version="2.0" date="TBD" description="TBD">
+ <action dev="psteitz" type="update" issue="MATH-217">
+ Changed OLS regression implementation added in MATH-203 to use
+ QR decomposition to solve the normal equations.
+ </action>
<action dev="luc" type="add">
New ODE integrators have been added: the explicit Adams-Bashforth and implicit
Adams-Moulton multistep methods. These methods support customizable starter
Modified: commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java (original)
+++ commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java Sun Jul 27 11:52:38 2008
@@ -58,8 +58,10 @@
@Test
public void canEstimateRegressandVariance(){
- double variance = regression.estimateRegressandVariance();
- assertTrue(variance > 0.0);
+ if (getSampleSize() > getNumberOfRegressors()) {
+ double variance = regression.estimateRegressandVariance();
+ assertTrue(variance > 0.0);
+ }
}
}
Modified: commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java (original)
+++ commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java Sun Jul 27 11:52:38 2008
@@ -18,7 +18,10 @@
import org.junit.Before;
import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
import org.apache.commons.math.TestUtils;
+import org.apache.commons.math.linear.RealMatrixImpl;
public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegressionTest {
@@ -131,7 +134,7 @@
new double[]{-3482258.63459582, 15.0618722713733,
-0.358191792925910E-01,-2.02022980381683,
-1.03322686717359,-0.511041056535807E-01,
- 1829.15146461355}, 1E-1); // <- UGH! need better accuracy!
+ 1829.15146461355}, 1E-8); //
// Check expected residuals from R
double[] residuals = model.estimateResiduals();
@@ -142,7 +145,7 @@
455.394094551857,-17.26892711483297,-39.0550425226967,
-155.5499735953195,-85.6713080421283,341.9315139607727,
-206.7578251937366},
- 1E-2); // <- UGH again! need better accuracy!
+ 1E-8);
// Check standard errors from NIST
double[][] errors = model.estimateRegressionParametersVariance();
[math] Re: svn commit: r680162 - in /commons/proper/math/branches/MATH_2_0/src:
java/org/apache/commons/math/stat/regression/ site/xdoc/ test/org/apache/commons/math/stat/regression/
Posted by Luc Maisonobe <Lu...@free.fr>.
Phil Steitz a écrit :
> Any ideas on where this stuff can go???
I am beginning to think we should have a few different matrices shapes:
upper triangular, lower triangular, diagonal, tri-diagonal and perhaps
sparse.
In this case, this could be put as an specialized implementation of
solve in a triangular matrix class.
Luc
>> -
>> + + /** TODO: Find a home for the following methods in the
>> linear package */ + + /**
>> + * <p>Uses back substitution to solve the system</p>
>> + * + * <p>coefficients X = constants</p>
>> + * + * <p>coefficients must upper-triangular and constants
>> must be a column + * matrix. The solution is returned as a column
>> matrix.</p>
>> + * + * <p>The number of columns in coefficients determines
>> the length
>> + * of the returned solution vector (column matrix). If constants
>> + * has more rows than coefficients has columns, excess rows are
>> ignored.
>> + * Similarly, extra (zero) rows in coefficients are ignored</p>
>> + * + * @param coefficients upper-triangular coefficients matrix
>> + * @param constants column RHS constants matrix
>> + * @return solution matrix as a column matrix
>> + * + */
>> + private static RealMatrix solveUpperTriangular(RealMatrixImpl
>> coefficients,
>> + RealMatrixImpl constants) {
>> + if (!isUpperTriangular(coefficients, 1E-12)) {
>> + throw new IllegalArgumentException(
>> + "Coefficients is not upper-triangular");
>> + }
>> + if (constants.getColumnDimension() != 1) {
>> + throw new IllegalArgumentException(
>> + "Constants not a column matrix.");
>> + }
>> + int length = coefficients.getColumnDimension();
>> + double[][] cons = constants.getDataRef();
>> + double[][] coef = coefficients.getDataRef();
>> + double x[] = new double[length];
>> + for (int i = 0; i < length; i++) {
>> + int index = length - 1 - i;
>> + double sum = 0;
>> + for (int j = index + 1; j < length; j++) {
>> + sum += coef[index][j] * x[j];
>> + }
>> + x[index] = (cons[index][0] - sum) / coef[index][index];
>> + } + return new RealMatrixImpl(x);
>> + }
>> + + /**
>> + * <p>Returns true iff m is an upper-triangular matrix.</p>
>> + * + * <p>Makes sure all below-diagonal elements are within
>> epsilon of 0.</p>
>> + * + * @param m matrix to check
>> + * @param epsilon maximum allowable absolute value for elements
>> below
>> + * the main diagonal
>> + * + * @return true if m is upper-triangular; false otherwise
>> + * @throws NullPointerException if m is null
>> + */
>> + private static boolean isUpperTriangular(RealMatrixImpl m, double
>> epsilon) {
>> + double[][] data = m.getDataRef();
>> + int nCols = m.getColumnDimension();
>> + int nRows = m.getRowDimension();
>> + for (int r = 0; r < nRows; r++) {
>> + int bound = Math.min(r, nCols);
>> + for (int c = 0; c < bound; c++) {
>> + if (Math.abs(data[r][c]) > epsilon) {
>> + return false;
>> + }
>> + }
>> + }
>> + return true;
>> + }
>> }
>>
>>
>
>
> ---------------------------------------------------------------------
> To unsubscribe, e-mail: dev-unsubscribe@commons.apache.org
> For additional commands, e-mail: dev-help@commons.apache.org
>
>
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@commons.apache.org
For additional commands, e-mail: dev-help@commons.apache.org
Re: svn commit: r680162 - in /commons/proper/math/branches/MATH_2_0/src:
java/org/apache/commons/math/stat/regression/ site/xdoc/ test/org/apache/commons/math/stat/regression/
Posted by Phil Steitz <ph...@steitz.com>.
Any ideas on where this stuff can go???
> -
> +
> + /** TODO: Find a home for the following methods in the linear package */
> +
> + /**
> + * <p>Uses back substitution to solve the system</p>
> + *
> + * <p>coefficients X = constants</p>
> + *
> + * <p>coefficients must upper-triangular and constants must be a column
> + * matrix. The solution is returned as a column matrix.</p>
> + *
> + * <p>The number of columns in coefficients determines the length
> + * of the returned solution vector (column matrix). If constants
> + * has more rows than coefficients has columns, excess rows are ignored.
> + * Similarly, extra (zero) rows in coefficients are ignored</p>
> + *
> + * @param coefficients upper-triangular coefficients matrix
> + * @param constants column RHS constants matrix
> + * @return solution matrix as a column matrix
> + *
> + */
> + private static RealMatrix solveUpperTriangular(RealMatrixImpl coefficients,
> + RealMatrixImpl constants) {
> + if (!isUpperTriangular(coefficients, 1E-12)) {
> + throw new IllegalArgumentException(
> + "Coefficients is not upper-triangular");
> + }
> + if (constants.getColumnDimension() != 1) {
> + throw new IllegalArgumentException(
> + "Constants not a column matrix.");
> + }
> + int length = coefficients.getColumnDimension();
> + double[][] cons = constants.getDataRef();
> + double[][] coef = coefficients.getDataRef();
> + double x[] = new double[length];
> + for (int i = 0; i < length; i++) {
> + int index = length - 1 - i;
> + double sum = 0;
> + for (int j = index + 1; j < length; j++) {
> + sum += coef[index][j] * x[j];
> + }
> + x[index] = (cons[index][0] - sum) / coef[index][index];
> + }
> + return new RealMatrixImpl(x);
> + }
> +
> + /**
> + * <p>Returns true iff m is an upper-triangular matrix.</p>
> + *
> + * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p>
> + *
> + * @param m matrix to check
> + * @param epsilon maximum allowable absolute value for elements below
> + * the main diagonal
> + *
> + * @return true if m is upper-triangular; false otherwise
> + * @throws NullPointerException if m is null
> + */
> + private static boolean isUpperTriangular(RealMatrixImpl m, double epsilon) {
> + double[][] data = m.getDataRef();
> + int nCols = m.getColumnDimension();
> + int nRows = m.getRowDimension();
> + for (int r = 0; r < nRows; r++) {
> + int bound = Math.min(r, nCols);
> + for (int c = 0; c < bound; c++) {
> + if (Math.abs(data[r][c]) > epsilon) {
> + return false;
> + }
> + }
> + }
> + return true;
> + }
> }
>
>
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@commons.apache.org
For additional commands, e-mail: dev-help@commons.apache.org