You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2008/12/21 22:04:48 UTC

svn commit: r728500 - in /commons/proper/math/trunk/src: java/org/apache/commons/math/stat/regression/ test/org/apache/commons/math/stat/regression/

Author: luc
Date: Sun Dec 21 13:04:47 2008
New Revision: 728500

URL: http://svn.apache.org/viewvc?rev=728500&view=rev
Log:
reverted some changes introduced yesterday, as they lead to unexpected test failures

Modified:
    commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java
    commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
    commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
    commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java

Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java (original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegression.java Sun Dec 21 13:04:47 2008
@@ -16,10 +16,8 @@
  */
 package org.apache.commons.math.stat.regression;
 
-import org.apache.commons.math.linear.MatrixUtils;
 import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
-import org.apache.commons.math.linear.RealVectorImpl;
+import org.apache.commons.math.linear.RealMatrixImpl;
 
 /**
  * Abstract base class for implementations of MultipleLinearRegression.
@@ -33,7 +31,7 @@
     protected RealMatrix X;
 
     /** Y sample data. */
-    protected RealVector Y;
+    protected RealMatrix Y;
 
     /**
      * Loads model x and y sample data from a flat array of data, overriding any previous sample.
@@ -54,8 +52,8 @@
                 x[i][j] = data[pointer++];
             }
         }
-        this.X = MatrixUtils.createRealMatrix(x);
-        this.Y = new RealVectorImpl(y);
+        this.X = new RealMatrixImpl(x);
+        this.Y = new RealMatrixImpl(y);
     }
     
     /**
@@ -64,7 +62,7 @@
      * @param y the [n,1] array representing the y sample
      */
     protected void newYSampleData(double[] y) {
-        this.Y = new RealVectorImpl(y);
+        this.Y = new RealMatrixImpl(y);
     }
 
     /**
@@ -73,7 +71,7 @@
      * @param x the [n,k] array representing the x sample
      */
     protected void newXSampleData(double[][] x) {
-        this.X = MatrixUtils.createRealMatrix(x);
+        this.X = new RealMatrixImpl(x);
     }
 
     /**
@@ -122,14 +120,17 @@
      * {@inheritDoc}
      */
     public double[] estimateRegressionParameters() {
-        return calculateBeta().getData();
+        RealMatrix b = calculateBeta();
+        return b.getColumn(0);
     }
 
     /**
      * {@inheritDoc}
      */
     public double[] estimateResiduals() {
-        return Y.subtract(X.operate(calculateBeta())).getData();
+        RealMatrix b = calculateBeta();
+        RealMatrix e = Y.subtract(X.multiply(b));
+        return e.getColumn(0);
     }
 
     /**
@@ -151,7 +152,7 @@
      * 
      * @return beta
      */
-    protected abstract RealVector calculateBeta();
+    protected abstract RealMatrix calculateBeta();
 
     /**
      * Calculates the beta variance of multiple linear regression in matrix
@@ -178,8 +179,9 @@
      * 
      * @return The residuals [n,1] matrix
      */
-    protected RealVector calculateResiduals() {
-        return Y.subtract(X.operate(calculateBeta()));
+    protected RealMatrix calculateResiduals() {
+        RealMatrix b = calculateBeta();
+        return Y.subtract(X.multiply(b));
     }
 
 }

Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java (original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java Sun Dec 21 13:04:47 2008
@@ -18,9 +18,8 @@
 
 import org.apache.commons.math.linear.LUDecompositionImpl;
 import org.apache.commons.math.linear.LUSolver;
-import org.apache.commons.math.linear.MatrixUtils;
 import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
+import org.apache.commons.math.linear.RealMatrixImpl;
 
 
 /**
@@ -69,7 +68,7 @@
      * @param omega the [n,n] array representing the covariance
      */
     protected void newCovarianceData(double[][] omega){
-        this.Omega = MatrixUtils.createRealMatrix(omega);
+        this.Omega = new RealMatrixImpl(omega);
         this.OmegaInverse = null;
     }
 
@@ -92,12 +91,12 @@
      * </pre>
      * @return beta
      */
-    protected RealVector calculateBeta() {
+    protected RealMatrix calculateBeta() {
         RealMatrix OI = getOmegaInverse();
         RealMatrix XT = X.transpose();
         RealMatrix XTOIX = XT.multiply(OI).multiply(X);
         RealMatrix inverse = new LUSolver(new LUDecompositionImpl(XTOIX)).getInverse();
-        return inverse.multiply(XT).multiply(OI).operate(Y);
+        return inverse.multiply(XT).multiply(OI).multiply(Y);
     }
 
     /**
@@ -121,9 +120,9 @@
      * @return The Y variance
      */
     protected double calculateYVariance() {
-        final RealVector u = calculateResiduals();
-        final double sse =  u.dotProduct(getOmegaInverse().operate(u));
-        return sse / (X.getRowDimension() - X.getColumnDimension());
+        RealMatrix u = calculateResiduals();
+        RealMatrix sse =  u.transpose().multiply(getOmegaInverse()).multiply(u);
+        return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
     }
     
 }

Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java (original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java Sun Dec 21 13:04:47 2008
@@ -16,14 +16,12 @@
  */
 package org.apache.commons.math.stat.regression;
 
-import org.apache.commons.math.linear.DenseRealMatrix;
 import org.apache.commons.math.linear.LUDecompositionImpl;
 import org.apache.commons.math.linear.LUSolver;
 import org.apache.commons.math.linear.QRDecomposition;
 import org.apache.commons.math.linear.QRDecompositionImpl;
 import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
-import org.apache.commons.math.linear.RealVectorImpl;
+import org.apache.commons.math.linear.RealMatrixImpl;
 
 /**
  * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 
@@ -88,7 +86,7 @@
      * @param x the [n,k] array representing the x sample
      */
     protected void newXSampleData(double[][] x) {
-        this.X = new DenseRealMatrix(x);
+        this.X = new RealMatrixImpl(x);
         qr = new QRDecompositionImpl(X);
     }
     
@@ -97,8 +95,8 @@
      * 
      * @return beta
      */
-    protected RealVector calculateBeta() {
-        return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
+    protected RealMatrix calculateBeta() {
+        return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y));
     }
 
     /**
@@ -122,9 +120,9 @@
      * @return The Y variance
      */
     protected double calculateYVariance() {
-        final RealVector u = calculateResiduals();
-        final double sse = u.dotProduct(u);
-        return sse / (X.getRowDimension() - X.getColumnDimension());
+        RealMatrix u = calculateResiduals();
+        RealMatrix sse = u.transpose().multiply(u);
+        return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
     }
     
     /** TODO:  Find a home for the following methods in the linear package */   
@@ -144,14 +142,19 @@
      * 
      * @param coefficients upper-triangular coefficients matrix
      * @param constants column RHS constants matrix
-     * @return solution matrix as a vector
+     * @return solution matrix as a column matrix
      * 
      */
-    private static RealVector solveUpperTriangular(RealMatrix coefficients, RealVector constants) {
+    private static RealMatrix solveUpperTriangular(RealMatrix coefficients,
+            RealMatrix constants) {
         if (!isUpperTriangular(coefficients, 1E-12)) {
             throw new IllegalArgumentException(
                    "Coefficients is not upper-triangular");
         }
+        if (constants.getColumnDimension() != 1) {
+            throw new IllegalArgumentException(
+                    "Constants not a column matrix.");
+        }
         int length = coefficients.getColumnDimension();
         double x[] = new double[length];
         for (int i = 0; i < length; i++) {
@@ -160,9 +163,9 @@
             for (int j = index + 1; j < length; j++) {
                 sum += coefficients.getEntry(index, j) * x[j];
             }
-            x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
+            x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index);
         } 
-        return new RealVectorImpl(x);
+        return new RealMatrixImpl(x);
     }
     
     /**

Modified: commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java?rev=728500&r1=728499&r2=728500&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java (original)
+++ commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/GLSMultipleLinearRegressionTest.java Sun Dec 21 13:04:47 2008
@@ -63,7 +63,7 @@
         createRegression().newSampleData(y, x, null);
     }
     
-    @Test(expected=ArrayIndexOutOfBoundsException.class)
+    @Test(expected=IllegalArgumentException.class)
     public void cannotAddNullCovarianceData() {
         createRegression().newSampleData(new double[]{}, new double[][]{}, null);
     }