You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/08/03 02:12:24 UTC
svn commit: r981711 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/discriminative/
math/src/main/java/org/apache/mahout/math/
math/src/test/java/org/apache/mahout/math/
Author: tdunning
Date: Tue Aug 3 00:12:23 2010
New Revision: 981711
URL: http://svn.apache.org/viewvc?rev=981711&view=rev
Log:
MAHOUT-452 - more aggregates, row and column views
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/discriminative/LinearTrainer.java Tue Aug 3 00:12:23 2010
@@ -100,6 +100,7 @@ public abstract class LinearTrainer {
}
}
}
+ iteration++;
}
}
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractMatrix.java Tue Aug 3 00:12:23 2010
@@ -19,9 +19,7 @@ package org.apache.mahout.math;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
-import org.apache.mahout.math.function.BinaryFunction;
-import org.apache.mahout.math.function.PlusMult;
-import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.function.*;
import java.util.HashMap;
import java.util.Iterator;
@@ -253,6 +251,69 @@ public abstract class AbstractMatrix imp
return this;
}
+ /**
+ * Collects the results of a function applied to each row of a matrix.
+ *
+ * @param f The function to be applied to each row.
+ * @return The vector of results.
+ */
+ public Vector aggregateRows(VectorFunction f) {
+ Vector r = new DenseVector(numRows());
+ int n = numRows();
+ for (int row = 0; row < n; row++) {
+ r.set(row, f.apply(viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a view of a row. Changes to the view will affect the original.
+ * @param row Which row to return.
+ * @return A vector that references the desired row.
+ */
+ public Vector viewRow(int row) {
+ return new MatrixVectorView(this, row, 0, 0, 1);
+ }
+
+
+ /**
+ * Returns a view of a row. Changes to the view will affect the original.
+ * @param column Which column to return.
+ * @return A vector that references the desired column.
+ */
+ public Vector viewColumn(int column) {
+ return new MatrixVectorView(this, 0, column, 1, 0);
+ }
+
+ /**
+ * Collects the results of a function applied to each column of a matrix.
+ *
+ * @param f The function to be applied to each column.
+ * @return The vector of results.
+ */
+ public Vector aggregateColumns(VectorFunction f) {
+ Vector r = new DenseVector(numCols());
+ for (int col = 0; col < numCols(); col++) {
+ r.set(col, f.apply(viewColumn(col)));
+ }
+ return r;
+ }
+
+ /**
+ * Collects the results of a function applied to each element of a matrix and then aggregated.
+ *
+ * @param combiner A function that combines the results of the mapper.
+ * @param mapper A function to apply to each element.
+ * @return The result.
+ */
+ public double aggregate(final BinaryFunction combiner, final UnaryFunction mapper) {
+ return aggregateRows(new VectorFunction() {
+ public double apply(Vector v) {
+ return v.aggregate(combiner, mapper);
+ }
+ }).aggregate(combiner, Functions.identity);
+ }
+
public double determinant() {
int[] card = size();
int rowSize = card[ROW];
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java Tue Aug 3 00:12:23 2010
@@ -71,13 +71,22 @@ public abstract class AbstractVector imp
}
@Override
- public abstract Vector clone();
+ public Vector clone() {
+ try {
+ AbstractVector r = (AbstractVector) super.clone();
+ r.size = size;
+ r.lengthSquared = lengthSquared;
+ return r;
+ } catch (CloneNotSupportedException e) {
+ throw new IllegalStateException("Can't happen");
+ }
+ }
public Vector divide(double x) {
if (x == 1.0) {
- return clone();
+ return like().assign(this);
}
- Vector result = clone();
+ Vector result = like().assign(this);
Iterator<Element> iter = result.iterateNonZero();
while (iter.hasNext()) {
Element element = iter.next();
@@ -94,9 +103,9 @@ public abstract class AbstractVector imp
return dotSelf();
}
double result = 0.0;
- Iterator<Element> iter = iterateNonZero();
+ Iterator<Vector.Element> iter = iterateNonZero();
while (iter.hasNext()) {
- Element element = iter.next();
+ Vector.Element element = iter.next();
result += element.get() * x.getQuick(element.index());
}
return result;
@@ -104,7 +113,7 @@ public abstract class AbstractVector imp
public double dotSelf() {
double result = 0.0;
- Iterator<Element> iter = iterateNonZero();
+ Iterator<Vector.Element> iter = iterateNonZero();
while (iter.hasNext()) {
double value = iter.next().get();
result += value * value;
@@ -119,26 +128,17 @@ public abstract class AbstractVector imp
return getQuick(index);
}
- public Element getElement(final int index) {
- return new Element() {
- public double get() {
- return AbstractVector.this.get(index);
- }
- public int index() {
- return index;
- }
- public void set(double value) {
- AbstractVector.this.set(index, value);
- }
- };
+ public Vector.Element getElement(final int index) {
+ return new LocalElement(index);
}
public Vector minus(Vector that) {
if (size != that.size()) {
throw new CardinalityException(size, that.size());
}
+
// TODO: check the numNonDefault elements to further optimize
- Vector result = this.clone();
+ Vector result = like().assign(this);
Iterator<Element> iter = that.iterateNonZero();
while (iter.hasNext()) {
Element thatElement = iter.next();
@@ -163,7 +163,7 @@ public abstract class AbstractVector imp
// we can special case certain powers
if (Double.isInfinite(power)) {
double val = 0.0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
val = Math.max(val, Math.abs(iter.next().get()));
}
@@ -172,7 +172,7 @@ public abstract class AbstractVector imp
return Math.sqrt(dotSelf());
} else if (power == 1.0) {
double val = 0.0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
val += Math.abs(iter.next().get());
}
@@ -180,16 +180,16 @@ public abstract class AbstractVector imp
} else if (power == 0.0) {
// this is the number of non-zero elements
double val = 0.0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
val += iter.next().get() == 0 ? 0 : 1;
}
return val;
} else {
double val = 0.0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
- Element element = iter.next();
+ Vector.Element element = iter.next();
val += Math.pow(element.get(), power);
}
return Math.pow(val, 1.0 / power);
@@ -212,7 +212,7 @@ public abstract class AbstractVector imp
return lengthSquared + v.getLengthSquared() - 2 * this.dot(v);
}
Vector randomlyAccessed;
- Iterator<Element> it;
+ Iterator<Vector.Element> it;
double d = 0.0;
if (lengthSquared >= 0.0) {
it = v.iterateNonZero();
@@ -224,7 +224,7 @@ public abstract class AbstractVector imp
d += v.getLengthSquared();
}
while(it.hasNext()) {
- Element e = it.next();
+ Vector.Element e = it.next();
double value = e.get();
d += value * (value - 2.0 * randomlyAccessed.getQuick(e.index()));
}
@@ -235,10 +235,10 @@ public abstract class AbstractVector imp
public double maxValue() {
double result = Double.NEGATIVE_INFINITY;
int nonZeroElements = 0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
nonZeroElements++;
- Element element = iter.next();
+ Vector.Element element = iter.next();
result = Math.max(result, element.get());
}
if (nonZeroElements < size) {
@@ -251,10 +251,10 @@ public abstract class AbstractVector imp
int result = -1;
double max = Double.NEGATIVE_INFINITY;
int nonZeroElements = 0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
nonZeroElements++;
- Element element = iter.next();
+ Vector.Element element = iter.next();
double tmp = element.get();
if (tmp > max) {
max = tmp;
@@ -265,7 +265,7 @@ public abstract class AbstractVector imp
// unfilled element(0.0) could be the maxValue hence we need to
// find one of those elements
if (nonZeroElements < size && max < 0.0) {
- for (Element element : this) {
+ for (Vector.Element element : this) {
if (element.get() == 0.0) {
return element.index();
}
@@ -277,10 +277,10 @@ public abstract class AbstractVector imp
public double minValue() {
double result = Double.POSITIVE_INFINITY;
int nonZeroElements = 0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
nonZeroElements++;
- Element element = iter.next();
+ Vector.Element element = iter.next();
result = Math.min(result, element.get());
}
if (nonZeroElements < size) {
@@ -293,10 +293,10 @@ public abstract class AbstractVector imp
int result = -1;
double min = Double.POSITIVE_INFINITY;
int nonZeroElements = 0;
- Iterator<Element> iter = this.iterateNonZero();
+ Iterator<Vector.Element> iter = this.iterateNonZero();
while (iter.hasNext()) {
nonZeroElements++;
- Element element = iter.next();
+ Vector.Element element = iter.next();
double tmp = element.get();
if (tmp < min) {
min = tmp;
@@ -307,7 +307,7 @@ public abstract class AbstractVector imp
// unfilled element(0.0) could be the maxValue hence we need to
// find one of those elements
if (nonZeroElements < size && min > 0.0) {
- for (Element element : this) {
+ for (Vector.Element element : this) {
if (element.get() == 0.0) {
return element.index();
}
@@ -317,10 +317,10 @@ public abstract class AbstractVector imp
}
public Vector plus(double x) {
+ Vector result = like().assign(this);
if (x == 0.0) {
- return clone();
+ return result;
}
- Vector result = clone();
int size = result.size();
for (int i = 0; i < size; i++) {
result.setQuick(i, getQuick(i) + x);
@@ -333,30 +333,25 @@ public abstract class AbstractVector imp
throw new CardinalityException(size, x.size());
}
- Vector to = this;
- Vector from = x;
- // Clone and edit to the sparse one; if both are sparse, add from the more sparse one
- if (isDense() || (!x.isDense() &&
- getNumNondefaultElements() < x.getNumNondefaultElements())) {
- to = x;
- from = this;
+ // prefer to have this be the denser than x
+ if (!isDense() && (x.isDense() || x.getNumNondefaultElements() > this.getNumNondefaultElements())) {
+ return x.plus(this);
}
- //TODO: get smarter about this, if we are adding a dense to a sparse, then we should return a dense
- Vector result = to.clone();
- Iterator<Element> iter = from.iterateNonZero();
+ Vector result = like().assign(this);
+ Iterator<Vector.Element> iter = x.iterateNonZero();
while (iter.hasNext()) {
- Element e = iter.next();
+ Vector.Element e = iter.next();
int index = e.index();
- result.setQuick(index, result.getQuick(index) + e.get());
+ result.setQuick(index, this.getQuick(index) + e.get());
}
return result;
}
public void addTo(Vector v) {
- Iterator<Element> it = iterateNonZero();
+ Iterator<Vector.Element> it = iterateNonZero();
while(it.hasNext() ) {
- Element e = it.next();
+ Vector.Element e = it.next();
int index = e.index();
v.setQuick(index, v.getQuick(index) + e.get());
}
@@ -370,13 +365,13 @@ public abstract class AbstractVector imp
}
public Vector times(double x) {
+ Vector result = like().assign(this);
if (x == 1.0) {
- return clone();
+ return result;
}
if (x == 0.0) {
return like();
}
- Vector result = clone();
Iterator<Element> iter = result.iterateNonZero();
while (iter.hasNext()) {
Element element = iter.next();
@@ -399,7 +394,7 @@ public abstract class AbstractVector imp
from = this;
}
- Vector result = to.clone();
+ Vector result = to.like().assign(to);
Iterator<Element> iter = result.iterateNonZero();
while (iter.hasNext()) {
Element element = iter.next();
@@ -411,7 +406,7 @@ public abstract class AbstractVector imp
public double zSum() {
double result = 0.0;
- Iterator<Element> iter = iterateNonZero();
+ Iterator<Vector.Element> iter = iterateNonZero();
while (iter.hasNext()) {
result += iter.next().get();
}
@@ -447,28 +442,28 @@ public abstract class AbstractVector imp
}
public Vector assign(BinaryFunction f, double y) {
- Iterator<Element> it;
+ Iterator<Vector.Element> it;
if(f.apply(0, y) == 0) {
it = iterateNonZero();
} else {
it = iterator();
}
while(it.hasNext()) {
- Element e = it.next();
+ Vector.Element e = it.next();
e.set(f.apply(e.get(), y));
}
return this;
}
public Vector assign(UnaryFunction function) {
- Iterator<Element> it;
+ Iterator<Vector.Element> it;
if(function.apply(0) == 0) {
it = iterateNonZero();
} else {
it = iterator();
}
while(it.hasNext()) {
- Element e = it.next();
+ Vector.Element e = it.next();
e.set(function.apply(e.get()));
}
return this;
@@ -520,9 +515,9 @@ public abstract class AbstractVector imp
@Override
public int hashCode() {
int result = size;
- Iterator<Element> iter = iterateNonZero();
+ Iterator<Vector.Element> iter = iterateNonZero();
while (iter.hasNext()) {
- Element ele = iter.next();
+ Vector.Element ele = iter.next();
long v = Double.doubleToLongBits(ele.get());
result += ele.index() * (int) (v ^ (v >>> 32));
}
@@ -573,4 +568,24 @@ public abstract class AbstractVector imp
return result.toString();
}
+
+ protected final class LocalElement implements Vector.Element {
+ protected int index;
+
+ LocalElement(int index) {
+ this.index = index;
+ }
+
+ public double get() {
+ return getQuick(index);
+ }
+
+ public int index() {
+ return index;
+ }
+
+ public void set(double value) {
+ setQuick(index, value);
+ }
+ }
}
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/Matrix.java Tue Aug 3 00:12:23 2010
@@ -19,6 +19,7 @@ package org.apache.mahout.math;
import org.apache.mahout.math.function.BinaryFunction;
import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.function.VectorFunction;
import java.util.Map;
@@ -93,6 +94,29 @@ public interface Matrix extends Cloneabl
Matrix assignRow(int row, Vector other);
/**
+ * Collects the results of a function applied to each row of a matrix.
+ * @param f The function to be applied to each row.
+ * @return The vector of results.
+ */
+ Vector aggregateRows(VectorFunction f);
+
+ /**
+ * Collects the results of a function applied to each column of a matrix.
+ * @param f The function to be applied to each column.
+ * @return The vector of results.
+ */
+ Vector aggregateColumns(VectorFunction f);
+
+ /**
+ * Collects the results of a function applied to each element of a matrix and then
+ * aggregated.
+ * @param combiner A function that combines the results of the mapper.
+ * @param mapper A function to apply to each element.
+ * @return The result.
+ */
+ double aggregate(BinaryFunction combiner, UnaryFunction mapper);
+
+ /**
* Return the cardinality of the recipient (the maximum number of values)
*
* @return an int[2]
@@ -359,4 +383,8 @@ public interface Matrix extends Cloneabl
// BinaryFunction map);
// NewMatrix assign(Matrix y, BinaryFunction function, IntArrayList
// nonZeroIndexes);
+
+ Vector viewRow(int row);
+
+ Vector viewColumn(int column);
}
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/VectorView.java Tue Aug 3 00:12:23 2010
@@ -46,7 +46,10 @@ public class VectorView extends Abstract
@Override
public Vector clone() {
- return new VectorView(vector.clone(), offset, size());
+ VectorView r = (VectorView) super.clone();
+ r.vector = vector.clone();
+ r.offset = offset;
+ return r;
}
public boolean isDense() {
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java Tue Aug 3 00:12:23 2010
@@ -91,7 +91,7 @@ public abstract class AbstractTestVector
}
}
- public void testIteratorSet() {
+ public void testIteratorSet() throws CloneNotSupportedException {
Vector clone = test.clone();
Iterator<Vector.Element> it = clone.iterateNonZero();
while (it.hasNext()) {
@@ -219,7 +219,7 @@ public abstract class AbstractTestVector
assertEquals("dot", expected, res, EPSILON);
}
- public void testDot2() {
+ public void testDot2() throws CloneNotSupportedException {
Vector test2 = test.clone();
test2.set(1, 0.0);
test2.set(3, 0.0);
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java?rev=981711&r1=981710&r2=981711&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/MatrixTest.java Tue Aug 3 00:12:23 2010
@@ -17,11 +17,15 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.VectorFunction;
+
import static org.apache.mahout.math.function.Functions.*;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Random;
public abstract class MatrixTest extends MahoutTestCase {
@@ -239,6 +243,85 @@ public abstract class MatrixTest extends
}
}
+ public void testRowView() {
+ int[] c = test.size();
+ for (int row = 0; row < c[ROW]; row++) {
+ assertEquals(0.0, test.getRow(row).minus(test.viewRow(row)).norm(1), 0);
+ }
+
+ assertEquals(c[COL], test.viewRow(3).size());
+ assertEquals(c[COL], test.viewRow(5).size());
+
+ Random gen = new Random(1);
+ for (int row = 0; row < c[ROW]; row++) {
+ int j = gen.nextInt(c[COL]);
+ double old = test.get(row, j);
+ double v = gen.nextGaussian();
+ test.viewRow(row).set(j, v);
+ assertEquals(v, test.get(row, j), 0);
+ assertEquals(v, test.viewRow(row).get(j), 0);
+ test.set(row, j, old);
+ assertEquals(old, test.get(row, j), 0);
+ assertEquals(old, test.viewRow(row).get(j), 0);
+ }
+ }
+
+ public void testColumnView() {
+ int[] c = test.size();
+ for (int col = 0; col < c[COL]; col++) {
+ assertEquals(0.0, test.getColumn(col).minus(test.viewColumn(col)).norm(1), 0);
+ }
+
+ assertEquals(c[ROW], test.viewColumn(3).size());
+ assertEquals(c[ROW], test.viewColumn(5).size());
+
+ Random gen = new Random(1);
+ for (int col = 0; col < c[COL]; col++) {
+ int j = gen.nextInt(c[COL]);
+ double old = test.get(col, j);
+ double v = gen.nextGaussian();
+ test.viewColumn(col).set(j, v);
+ assertEquals(v, test.get(j, col), 0);
+ assertEquals(v, test.viewColumn(col).get(j), 0);
+ test.set(j, col, old);
+ assertEquals(old, test.get(j, col), 0);
+ assertEquals(old, test.viewColumn(col).get(j), 0);
+ }
+ }
+
+ public void testAggregateRows() {
+ Vector v = test.aggregateRows(new VectorFunction() {
+ public double apply(Vector v) {
+ return v.zSum();
+ }
+ });
+
+ for (int i = 0; i < test.numRows(); i++) {
+ assertEquals(test.getRow(i).zSum(), v.get(i));
+ }
+ }
+
+ public void testAggregateCols() {
+ Vector v = test.aggregateColumns(new VectorFunction() {
+ public double apply(Vector v) {
+ return v.zSum();
+ }
+ });
+
+ for (int i = 0; i < test.numCols(); i++) {
+ assertEquals(test.getColumn(i).zSum(), v.get(i));
+ }
+ }
+
+ public void testAggregate() {
+ double total = test.aggregate(Functions.plus, Functions.identity);
+ assertEquals(test.aggregateRows(new VectorFunction() {
+ public double apply(Vector v) {
+ return v.zSum();
+ }
+ }).zSum(), total);
+ }
+
public void testDivide() {
int[] c = test.size();
Matrix value = test.divide(4.53);