You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/08/03 02:12:24 UTC

svn commit: r981711 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/discriminative/ math/src/main/java/org/apache/mahout/math/ math/src/test/java/org/apache/mahout/math/

Author: tdunning
Date: Tue Aug  3 00:12:23 2010
New Revision: 981711

URL: http://svn.apache.org/viewvc?rev=981711&view=rev
Log:
MAHOUT-452 - more aggregates, row and column views

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java Tue Aug  3 00:12:23 2010
@@ -100,6 +100,7 @@ public abstract class LinearTrainer {
           }
         }
       }
+      iteration++;
     }
   }
   

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java Tue Aug  3 00:12:23 2010
@@ -19,9 +19,7 @@ package org.apache.mahout.math;
 
 import com.google.gson.Gson;
 import com.google.gson.GsonBuilder;
-import org.apache.mahout.math.function.BinaryFunction;
-import org.apache.mahout.math.function.PlusMult;
-import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.function.*;
 
 import java.util.HashMap;
 import java.util.Iterator;
@@ -253,6 +251,69 @@ public abstract class AbstractMatrix imp
     return this;
   }
 
+  /**
+   * Collects the results of a function applied to each row of a matrix.
+   *
+   * @param f The function to be applied to each row.
+   * @return The vector of results.
+   */
+  public Vector aggregateRows(VectorFunction f) {
+    Vector r = new DenseVector(numRows());
+    int n = numRows();
+    for (int row = 0; row < n; row++) {
+      r.set(row, f.apply(viewRow(row)));
+    }
+    return r;
+  }
+
+  /**
+   * Returns a view of a row.  Changes to the view will affect the original.
+   * @param row  Which row to return.
+   * @return A vector that references the desired row.
+   */
+  public Vector viewRow(int row) {
+    return new MatrixVectorView(this, row, 0, 0, 1);
+  }
+
+
+  /**
+   * Returns a view of a row.  Changes to the view will affect the original.
+   * @param column Which column to return.
+   * @return A vector that references the desired column.
+   */
+  public Vector viewColumn(int column) {
+    return new MatrixVectorView(this, 0, column, 1, 0);
+  }
+
+  /**
+   * Collects the results of a function applied to each column of a matrix.
+   *
+   * @param f The function to be applied to each column.
+   * @return The vector of results.
+   */
+  public Vector aggregateColumns(VectorFunction f) {
+    Vector r = new DenseVector(numCols());
+    for (int col = 0; col < numCols(); col++) {
+      r.set(col, f.apply(viewColumn(col)));
+    }
+    return r;
+  }
+
+  /**
+   * Collects the results of a function applied to each element of a matrix and then aggregated.
+   *
+   * @param combiner A function that combines the results of the mapper.
+   * @param mapper   A function to apply to each element.
+   * @return The result.
+   */
+  public double aggregate(final BinaryFunction combiner, final UnaryFunction mapper) {
+    return aggregateRows(new VectorFunction() {
+      public double apply(Vector v) {
+        return v.aggregate(combiner, mapper);
+      }
+    }).aggregate(combiner, Functions.identity);
+  }
+
   public double determinant() {
     int[] card = size();
     int rowSize = card[ROW];

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java Tue Aug  3 00:12:23 2010
@@ -71,13 +71,22 @@ public abstract class AbstractVector imp
   }
 
   @Override
-  public abstract Vector clone();
+  public Vector clone() {
+    try {
+      AbstractVector r = (AbstractVector) super.clone();
+      r.size = size;
+      r.lengthSquared = lengthSquared;
+      return r;
+    } catch (CloneNotSupportedException e) {
+      throw new IllegalStateException("Can't happen");
+    }
+  }
 
   public Vector divide(double x) {
     if (x == 1.0) {
-      return clone();
+      return like().assign(this);
     }
-    Vector result = clone();
+    Vector result = like().assign(this);
     Iterator<Element> iter = result.iterateNonZero();
     while (iter.hasNext()) {
       Element element = iter.next();
@@ -94,9 +103,9 @@ public abstract class AbstractVector imp
       return dotSelf();
     }
     double result = 0.0;
-    Iterator<Element> iter = iterateNonZero();
+    Iterator<Vector.Element> iter = iterateNonZero();
     while (iter.hasNext()) {
-      Element element = iter.next();
+      Vector.Element element = iter.next();
       result += element.get() * x.getQuick(element.index());
     }
     return result;
@@ -104,7 +113,7 @@ public abstract class AbstractVector imp
   
   public double dotSelf() {
     double result = 0.0;
-    Iterator<Element> iter = iterateNonZero();
+    Iterator<Vector.Element> iter = iterateNonZero();
     while (iter.hasNext()) {
       double value = iter.next().get();
       result += value * value;
@@ -119,26 +128,17 @@ public abstract class AbstractVector imp
     return getQuick(index);
   }
 
-  public Element getElement(final int index) {
-    return new Element() {
-      public double get() {
-        return AbstractVector.this.get(index);
-      }
-      public int index() {
-        return index;
-      }
-      public void set(double value) {
-        AbstractVector.this.set(index, value);
-      }
-    };
+  public Vector.Element getElement(final int index) {
+    return new LocalElement(index);
   }
 
   public Vector minus(Vector that) {
     if (size != that.size()) {
       throw new CardinalityException(size, that.size());
     }
+
     // TODO: check the numNonDefault elements to further optimize
-    Vector result = this.clone();
+    Vector result = like().assign(this);
     Iterator<Element> iter = that.iterateNonZero();
     while (iter.hasNext()) {
       Element thatElement = iter.next();
@@ -163,7 +163,7 @@ public abstract class AbstractVector imp
     // we can special case certain powers
     if (Double.isInfinite(power)) {
       double val = 0.0;
-      Iterator<Element> iter = this.iterateNonZero();
+      Iterator<Vector.Element> iter = this.iterateNonZero();
       while (iter.hasNext()) {
         val = Math.max(val, Math.abs(iter.next().get()));
       }
@@ -172,7 +172,7 @@ public abstract class AbstractVector imp
       return Math.sqrt(dotSelf());
     } else if (power == 1.0) {
       double val = 0.0;
-      Iterator<Element> iter = this.iterateNonZero();
+      Iterator<Vector.Element> iter = this.iterateNonZero();
       while (iter.hasNext()) {
         val += Math.abs(iter.next().get());
       }
@@ -180,16 +180,16 @@ public abstract class AbstractVector imp
     } else if (power == 0.0) {
       // this is the number of non-zero elements
       double val = 0.0;
-      Iterator<Element> iter = this.iterateNonZero();
+      Iterator<Vector.Element> iter = this.iterateNonZero();
       while (iter.hasNext()) {
         val += iter.next().get() == 0 ? 0 : 1;
       }
       return val;
     } else {
       double val = 0.0;
-      Iterator<Element> iter = this.iterateNonZero();
+      Iterator<Vector.Element> iter = this.iterateNonZero();
       while (iter.hasNext()) {
-        Element element = iter.next();
+        Vector.Element element = iter.next();
         val += Math.pow(element.get(), power);
       }
       return Math.pow(val, 1.0 / power);
@@ -212,7 +212,7 @@ public abstract class AbstractVector imp
       return lengthSquared + v.getLengthSquared() - 2 * this.dot(v);
     }
     Vector randomlyAccessed;
-    Iterator<Element> it;
+    Iterator<Vector.Element> it;
     double d = 0.0;
     if (lengthSquared >= 0.0) {
       it = v.iterateNonZero();
@@ -224,7 +224,7 @@ public abstract class AbstractVector imp
       d += v.getLengthSquared();
     }
     while(it.hasNext()) {
-      Element e = it.next();
+      Vector.Element e = it.next();
       double value = e.get();
       d += value * (value - 2.0 * randomlyAccessed.getQuick(e.index()));
     }
@@ -235,10 +235,10 @@ public abstract class AbstractVector imp
   public double maxValue() {
     double result = Double.NEGATIVE_INFINITY;
     int nonZeroElements = 0;
-    Iterator<Element> iter = this.iterateNonZero();
+    Iterator<Vector.Element> iter = this.iterateNonZero();
     while (iter.hasNext()) {
       nonZeroElements++;
-      Element element = iter.next();
+      Vector.Element element = iter.next();
       result = Math.max(result, element.get());
     }
     if (nonZeroElements < size) {
@@ -251,10 +251,10 @@ public abstract class AbstractVector imp
     int result = -1;
     double max = Double.NEGATIVE_INFINITY;
     int nonZeroElements = 0;
-    Iterator<Element> iter = this.iterateNonZero();
+    Iterator<Vector.Element> iter = this.iterateNonZero();
     while (iter.hasNext()) {
       nonZeroElements++;
-      Element element = iter.next();
+      Vector.Element element = iter.next();
       double tmp = element.get();
       if (tmp > max) {
         max = tmp;
@@ -265,7 +265,7 @@ public abstract class AbstractVector imp
     // unfilled element(0.0) could be the maxValue hence we need to
     // find one of those elements
     if (nonZeroElements < size && max < 0.0) {
-      for (Element element : this) {
+      for (Vector.Element element : this) {
         if (element.get() == 0.0) {
           return element.index();
         }
@@ -277,10 +277,10 @@ public abstract class AbstractVector imp
   public double minValue() {
     double result = Double.POSITIVE_INFINITY;
     int nonZeroElements = 0;
-    Iterator<Element> iter = this.iterateNonZero();
+    Iterator<Vector.Element> iter = this.iterateNonZero();
     while (iter.hasNext()) {
       nonZeroElements++;
-      Element element = iter.next();
+      Vector.Element element = iter.next();
       result = Math.min(result, element.get());
     }
     if (nonZeroElements < size) {
@@ -293,10 +293,10 @@ public abstract class AbstractVector imp
     int result = -1;
     double min = Double.POSITIVE_INFINITY;
     int nonZeroElements = 0;
-    Iterator<Element> iter = this.iterateNonZero();
+    Iterator<Vector.Element> iter = this.iterateNonZero();
     while (iter.hasNext()) {
       nonZeroElements++;
-      Element element = iter.next();
+      Vector.Element element = iter.next();
       double tmp = element.get();
       if (tmp < min) {
         min = tmp;
@@ -307,7 +307,7 @@ public abstract class AbstractVector imp
     // unfilled element(0.0) could be the maxValue hence we need to
     // find one of those elements
     if (nonZeroElements < size && min > 0.0) {
-      for (Element element : this) {
+      for (Vector.Element element : this) {
         if (element.get() == 0.0) {
           return element.index();
         }
@@ -317,10 +317,10 @@ public abstract class AbstractVector imp
   }
 
   public Vector plus(double x) {
+    Vector result = like().assign(this);
     if (x == 0.0) {
-      return clone();
+      return result;
     }
-    Vector result = clone();
     int size = result.size();
     for (int i = 0; i < size; i++) {
       result.setQuick(i, getQuick(i) + x);
@@ -333,30 +333,25 @@ public abstract class AbstractVector imp
       throw new CardinalityException(size, x.size());
     }
 
-    Vector to = this;
-    Vector from = x;
-    // Clone and edit to the sparse one; if both are sparse, add from the more sparse one
-    if (isDense() || (!x.isDense() &&
-        getNumNondefaultElements() < x.getNumNondefaultElements())) {
-      to = x;
-      from = this;
+    // prefer to have this be the denser than x
+    if (!isDense() && (x.isDense() || x.getNumNondefaultElements() > this.getNumNondefaultElements())) {
+      return x.plus(this);
     }
 
-    //TODO: get smarter about this, if we are adding a dense to a sparse, then we should return a dense
-    Vector result = to.clone();
-    Iterator<Element> iter = from.iterateNonZero();
+    Vector result = like().assign(this);
+    Iterator<Vector.Element> iter = x.iterateNonZero();
     while (iter.hasNext()) {
-      Element e = iter.next();
+      Vector.Element e = iter.next();
       int index = e.index();
-      result.setQuick(index, result.getQuick(index) + e.get());
+      result.setQuick(index, this.getQuick(index) + e.get());
     }
     return result;
   }
 
   public void addTo(Vector v) {
-    Iterator<Element> it = iterateNonZero();
+    Iterator<Vector.Element> it = iterateNonZero();
     while(it.hasNext() ) {
-      Element e = it.next();
+      Vector.Element e = it.next();
       int index = e.index();
       v.setQuick(index, v.getQuick(index) + e.get());
     }
@@ -370,13 +365,13 @@ public abstract class AbstractVector imp
   }
 
   public Vector times(double x) {
+    Vector result = like().assign(this);
     if (x == 1.0) {
-      return clone();
+      return result;
     }
     if (x == 0.0) {
       return like();
     }
-    Vector result = clone();
     Iterator<Element> iter = result.iterateNonZero();
     while (iter.hasNext()) {
       Element element = iter.next();
@@ -399,7 +394,7 @@ public abstract class AbstractVector imp
       from = this;
     }
 
-    Vector result = to.clone();
+    Vector result = to.like().assign(to);
     Iterator<Element> iter = result.iterateNonZero();
     while (iter.hasNext()) {
       Element element = iter.next();
@@ -411,7 +406,7 @@ public abstract class AbstractVector imp
 
   public double zSum() {
     double result = 0.0;
-    Iterator<Element> iter = iterateNonZero();
+    Iterator<Vector.Element> iter = iterateNonZero();
     while (iter.hasNext()) {
       result += iter.next().get();
     }
@@ -447,28 +442,28 @@ public abstract class AbstractVector imp
   }
 
   public Vector assign(BinaryFunction f, double y) {
-    Iterator<Element> it;
+    Iterator<Vector.Element> it;
     if(f.apply(0, y) == 0) {
       it = iterateNonZero();
     } else {
       it = iterator();
     }
     while(it.hasNext()) {
-      Element e = it.next();
+      Vector.Element e = it.next();
       e.set(f.apply(e.get(), y));
     }
     return this;
   }
 
   public Vector assign(UnaryFunction function) {
-    Iterator<Element> it;
+    Iterator<Vector.Element> it;
     if(function.apply(0) == 0) {
       it = iterateNonZero();
     } else {
       it = iterator();
     }
     while(it.hasNext()) {
-      Element e = it.next();
+      Vector.Element e = it.next();
       e.set(function.apply(e.get()));
     }
     return this;
@@ -520,9 +515,9 @@ public abstract class AbstractVector imp
   @Override
   public int hashCode() {
     int result = size;
-    Iterator<Element> iter = iterateNonZero();
+    Iterator<Vector.Element> iter = iterateNonZero();
     while (iter.hasNext()) {
-      Element ele = iter.next();
+      Vector.Element ele = iter.next();
       long v = Double.doubleToLongBits(ele.get());
       result += ele.index() * (int) (v ^ (v >>> 32));
     }
@@ -573,4 +568,24 @@ public abstract class AbstractVector imp
     return result.toString();
   }
 
+
+  protected final class LocalElement implements Vector.Element {
+    protected int index;
+
+    LocalElement(int index) {
+      this.index = index;
+    }
+
+    public double get() {
+        return getQuick(index);
+      }
+
+      public int index() {
+        return index;
+      }
+
+      public void set(double value) {
+        setQuick(index, value);
+      }
+  }
 }

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java Tue Aug  3 00:12:23 2010
@@ -19,6 +19,7 @@ package org.apache.mahout.math;
 
 import org.apache.mahout.math.function.BinaryFunction;
 import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.function.VectorFunction;
 
 import java.util.Map;
 
@@ -93,6 +94,29 @@ public interface Matrix extends Cloneabl
   Matrix assignRow(int row, Vector other);
 
   /**
+   * Collects the results of a function applied to each row of a matrix.
+   * @param f  The function to be applied to each row.
+   * @return  The vector of results.
+   */
+  Vector aggregateRows(VectorFunction f);
+
+  /**
+   * Collects the results of a function applied to each column of a matrix.
+   * @param f  The function to be applied to each column.
+   * @return  The vector of results.
+   */
+  Vector aggregateColumns(VectorFunction f);
+
+  /**
+   * Collects the results of a function applied to each element of a matrix and then
+   * aggregated.
+   * @param combiner  A function that combines the results of the mapper.
+   * @param mapper  A function to apply to each element.
+   * @return  The result.
+   */
+  double aggregate(BinaryFunction combiner, UnaryFunction mapper);
+
+  /**
    * Return the cardinality of the recipient (the maximum number of values)
    *
    * @return an int[2]
@@ -359,4 +383,8 @@ public interface Matrix extends Cloneabl
   // BinaryFunction map);
   // NewMatrix assign(Matrix y, BinaryFunction function, IntArrayList
   // nonZeroIndexes);
+
+  Vector viewRow(int row);
+
+  Vector viewColumn(int column);
 }

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java Tue Aug  3 00:12:23 2010
@@ -46,7 +46,10 @@ public class VectorView extends Abstract
 
   @Override
   public Vector clone() {
-    return new VectorView(vector.clone(), offset, size());
+    VectorView r = (VectorView) super.clone();
+    r.vector = vector.clone();
+    r.offset = offset;
+    return r;
   }
 
   public boolean isDense() {

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java Tue Aug  3 00:12:23 2010
@@ -91,7 +91,7 @@ public abstract class AbstractTestVector
     }
   }
 
-  public void testIteratorSet() {
+  public void testIteratorSet() throws CloneNotSupportedException {
     Vector clone = test.clone();
     Iterator<Vector.Element> it = clone.iterateNonZero();
     while (it.hasNext()) {
@@ -219,7 +219,7 @@ public abstract class AbstractTestVector
     assertEquals("dot", expected, res, EPSILON);
   }
 
-  public void testDot2() {
+  public void testDot2() throws CloneNotSupportedException {
     Vector test2 = test.clone();
     test2.set(1, 0.0);
     test2.set(3, 0.0);

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java Tue Aug  3 00:12:23 2010
@@ -17,11 +17,15 @@
 
 package org.apache.mahout.math;
 
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.VectorFunction;
+
 import static org.apache.mahout.math.function.Functions.*;
 
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.Random;
 
 public abstract class MatrixTest extends MahoutTestCase {
 
@@ -239,6 +243,85 @@ public abstract class MatrixTest extends
     }
   }
 
+  public void testRowView() {
+    int[] c = test.size();
+    for (int row = 0; row < c[ROW]; row++) {
+      assertEquals(0.0, test.getRow(row).minus(test.viewRow(row)).norm(1), 0);
+    }
+
+    assertEquals(c[COL], test.viewRow(3).size());
+    assertEquals(c[COL], test.viewRow(5).size());
+
+    Random gen = new Random(1);
+    for (int row = 0; row < c[ROW]; row++) {
+      int j = gen.nextInt(c[COL]);
+      double old = test.get(row, j);
+      double v = gen.nextGaussian();
+      test.viewRow(row).set(j, v);
+      assertEquals(v, test.get(row, j), 0);
+      assertEquals(v, test.viewRow(row).get(j), 0);
+      test.set(row, j, old);
+      assertEquals(old, test.get(row, j), 0);
+      assertEquals(old, test.viewRow(row).get(j), 0);
+    }
+  }
+
+  public void testColumnView() {
+    int[] c = test.size();
+    for (int col = 0; col < c[COL]; col++) {
+      assertEquals(0.0, test.getColumn(col).minus(test.viewColumn(col)).norm(1), 0);
+    }
+
+    assertEquals(c[ROW], test.viewColumn(3).size());
+    assertEquals(c[ROW], test.viewColumn(5).size());
+
+    Random gen = new Random(1);
+    for (int col = 0; col < c[COL]; col++) {
+      int j = gen.nextInt(c[COL]);
+      double old = test.get(col, j);
+      double v = gen.nextGaussian();
+      test.viewColumn(col).set(j, v);
+      assertEquals(v, test.get(j, col), 0);
+      assertEquals(v, test.viewColumn(col).get(j), 0);
+      test.set(j, col, old);
+      assertEquals(old, test.get(j, col), 0);
+      assertEquals(old, test.viewColumn(col).get(j), 0);
+    }
+  }
+
+  public void testAggregateRows() {
+    Vector v = test.aggregateRows(new VectorFunction() {
+      public double apply(Vector v) {
+        return v.zSum();
+      }
+    });
+
+    for (int i = 0; i < test.numRows(); i++) {
+      assertEquals(test.getRow(i).zSum(), v.get(i));
+    }
+  }
+
+  public void testAggregateCols() {
+    Vector v = test.aggregateColumns(new VectorFunction() {
+      public double apply(Vector v) {
+        return v.zSum();
+      }
+    });
+
+    for (int i = 0; i < test.numCols(); i++) {
+      assertEquals(test.getColumn(i).zSum(), v.get(i));
+    }
+  }
+
+  public void testAggregate() {
+    double total = test.aggregate(Functions.plus, Functions.identity);
+    assertEquals(test.aggregateRows(new VectorFunction() {
+      public double apply(Vector v) {
+        return v.zSum();
+      }
+    }).zSum(), total);
+  }
+
   public void testDivide() {
     int[] c = test.size();
     Matrix value = test.divide(4.53);