You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2008/12/21 22:04:48 UTC
svn commit: r728500 - in /commons/proper/math/trunk/src:
java/org/apache/commons/math/stat/regression/
test/org/apache/commons/math/stat/regression/
Author: luc
Date: Sun Dec 21 13:04:47 2008
New Revision: 728500
URL: http://svn.apache.org/viewvc?rev=728500&view=rev
Log:
reverted some changes introduced yesterday, as they lead to unexpected test failures
Modified:
commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java
commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java
Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java (original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java Sun Dec 21 13:04:47 2008
@@ -16,10 +16,8 @@
*/
package org.apache.commons.math.stat.regression;
-import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
-import org.apache.commons.math.linear.RealVectorImpl;
+import org.apache.commons.math.linear.RealMatrixImpl;
/**
* Abstract base class for implementations of MultipleLinearRegression.
@@ -33,7 +31,7 @@
protected RealMatrix X;
/** Y sample data. */
- protected RealVector Y;
+ protected RealMatrix Y;
/**
* Loads model x and y sample data from a flat array of data, overriding any previous sample.
@@ -54,8 +52,8 @@
x[i][j] = data[pointer++];
}
}
- this.X = MatrixUtils.createRealMatrix(x);
- this.Y = new RealVectorImpl(y);
+ this.X = new RealMatrixImpl(x);
+ this.Y = new RealMatrixImpl(y);
}
/**
@@ -64,7 +62,7 @@
* @param y the [n,1] array representing the y sample
*/
protected void newYSampleData(double[] y) {
- this.Y = new RealVectorImpl(y);
+ this.Y = new RealMatrixImpl(y);
}
/**
@@ -73,7 +71,7 @@
* @param x the [n,k] array representing the x sample
*/
protected void newXSampleData(double[][] x) {
- this.X = MatrixUtils.createRealMatrix(x);
+ this.X = new RealMatrixImpl(x);
}
/**
@@ -122,14 +120,17 @@
* {@inheritDoc}
*/
public double[] estimateRegressionParameters() {
- return calculateBeta().getData();
+ RealMatrix b = calculateBeta();
+ return b.getColumn(0);
}
/**
* {@inheritDoc}
*/
public double[] estimateResiduals() {
- return Y.subtract(X.operate(calculateBeta())).getData();
+ RealMatrix b = calculateBeta();
+ RealMatrix e = Y.subtract(X.multiply(b));
+ return e.getColumn(0);
}
/**
@@ -151,7 +152,7 @@
*
* @return beta
*/
- protected abstract RealVector calculateBeta();
+ protected abstract RealMatrix calculateBeta();
/**
* Calculates the beta variance of multiple linear regression in matrix
@@ -178,8 +179,9 @@
*
* @return The residuals [n,1] matrix
*/
- protected RealVector calculateResiduals() {
- return Y.subtract(X.operate(calculateBeta()));
+ protected RealMatrix calculateResiduals() {
+ RealMatrix b = calculateBeta();
+ return Y.subtract(X.multiply(b));
}
}
Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java (original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java Sun Dec 21 13:04:47 2008
@@ -18,9 +18,8 @@
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.LUSolver;
-import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
+import org.apache.commons.math.linear.RealMatrixImpl;
/**
@@ -69,7 +68,7 @@
* @param omega the [n,n] array representing the covariance
*/
protected void newCovarianceData(double[][] omega){
- this.Omega = MatrixUtils.createRealMatrix(omega);
+ this.Omega = new RealMatrixImpl(omega);
this.OmegaInverse = null;
}
@@ -92,12 +91,12 @@
* </pre>
* @return beta
*/
- protected RealVector calculateBeta() {
+ protected RealMatrix calculateBeta() {
RealMatrix OI = getOmegaInverse();
RealMatrix XT = X.transpose();
RealMatrix XTOIX = XT.multiply(OI).multiply(X);
RealMatrix inverse = new LUSolver(new LUDecompositionImpl(XTOIX)).getInverse();
- return inverse.multiply(XT).multiply(OI).operate(Y);
+ return inverse.multiply(XT).multiply(OI).multiply(Y);
}
/**
@@ -121,9 +120,9 @@
* @return The Y variance
*/
protected double calculateYVariance() {
- final RealVector u = calculateResiduals();
- final double sse = u.dotProduct(getOmegaInverse().operate(u));
- return sse / (X.getRowDimension() - X.getColumnDimension());
+ RealMatrix u = calculateResiduals();
+ RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u);
+ return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
}
}
Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java (original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java Sun Dec 21 13:04:47 2008
@@ -16,14 +16,12 @@
*/
package org.apache.commons.math.stat.regression;
-import org.apache.commons.math.linear.DenseRealMatrix;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.LUSolver;
import org.apache.commons.math.linear.QRDecomposition;
import org.apache.commons.math.linear.QRDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
-import org.apache.commons.math.linear.RealVectorImpl;
+import org.apache.commons.math.linear.RealMatrixImpl;
/**
* <p>Implements ordinary least squares (OLS) to estimate the parameters of a
@@ -88,7 +86,7 @@
* @param x the [n,k] array representing the x sample
*/
protected void newXSampleData(double[][] x) {
- this.X = new DenseRealMatrix(x);
+ this.X = new RealMatrixImpl(x);
qr = new QRDecompositionImpl(X);
}
@@ -97,8 +95,8 @@
*
* @return beta
*/
- protected RealVector calculateBeta() {
- return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
+ protected RealMatrix calculateBeta() {
+ return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y));
}
/**
@@ -122,9 +120,9 @@
* @return The Y variance
*/
protected double calculateYVariance() {
- final RealVector u = calculateResiduals();
- final double sse = u.dotProduct(u);
- return sse / (X.getRowDimension() - X.getColumnDimension());
+ RealMatrix u = calculateResiduals();
+ RealMatrix sse = u.transpose().multiply(u);
+ return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
}
/** TODO: Find a home for the following methods in the linear package */
@@ -144,14 +142,19 @@
*
* @param coefficients upper-triangular coefficients matrix
* @param constants column RHS constants matrix
- * @return solution matrix as a vector
+ * @return solution matrix as a column matrix
*
*/
- private static RealVector solveUpperTriangular(RealMatrix coefficients, RealVector constants) {
+ private static RealMatrix solveUpperTriangular(RealMatrix coefficients,
+ RealMatrix constants) {
if (!isUpperTriangular(coefficients, 1E-12)) {
throw new IllegalArgumentException(
"Coefficients is not upper-triangular");
}
+ if (constants.getColumnDimension() != 1) {
+ throw new IllegalArgumentException(
+ "Constants not a column matrix.");
+ }
int length = coefficients.getColumnDimension();
double x[] = new double[length];
for (int i = 0; i < length; i++) {
@@ -160,9 +163,9 @@
for (int j = index + 1; j < length; j++) {
sum += coefficients.getEntry(index, j) * x[j];
}
- x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
+ x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index);
}
- return new RealVectorImpl(x);
+ return new RealMatrixImpl(x);
}
/**
Modified: commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java (original)
+++ commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java Sun Dec 21 13:04:47 2008
@@ -63,7 +63,7 @@
createRegression().newSampleData(y, x, null);
}
- @Test(expected=ArrayIndexOutOfBoundsException.class)
+ @Test(expected=IllegalArgumentException.class)
public void cannotAddNullCovarianceData() {
createRegression().newSampleData(new double[]{}, new double[][]{}, null);
}