You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/11/22 17:13:10 UTC
svn commit: r883094 - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/matrix/ test/java/org/apache/mahout/matrix/
Author: gsingers
Date: Sun Nov 22 16:13:10 2009
New Revision: 883094
URL: http://svn.apache.org/viewvc?rev=883094&view=rev
Log:
MAHOUT-182: added helper methods on Matrix for times, numRows, etc.
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java (with props)
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java Sun Nov 22 16:13:10 2009
@@ -125,6 +125,14 @@
// index into int[2] for row value
public static final int ROW = 0;
+
+ public int numRows() {
+ return size()[ROW];
+ }
+
+ public int numCols() {
+ return size()[COL];
+ }
public static Matrix decodeMatrix(String formatString) {
Type vectorType = new TypeToken<Vector>() {
@@ -399,6 +407,31 @@
return result;
}
+ public Vector times(Vector v) {
+ int[] c = size();
+ if(c[COL] != v.size()) {
+ throw new CardinalityException();
+ }
+ Vector w = new DenseVector(c[ROW]);
+ for(int i=0; i<c[ROW]; i++) {
+ w.setQuick(i, v.dot(getRow(i)));
+ }
+ return w;
+ }
+
+ public Vector timesSquared(Vector v) {
+ int[] c = size();
+ if(c[COL] != v.size()) {
+ throw new CardinalityException();
+ }
+ Vector w = new DenseVector(c[COL]);
+ for(int i=0; i<c[ROW]; i++) {
+ Vector xi = getRow(i);
+ w.assign(xi, new PlusWithScaleFunction(xi.dot(v)));
+ }
+ return w;
+ }
+
@Override
public Matrix transpose() {
int[] card = size();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java Sun Nov 22 16:13:10 2009
@@ -99,6 +99,20 @@
int[] size();
/**
+ * Helper method to return the cardinality of the row dimension
+ *
+ * @return
+ */
+ int numRows();
+
+ /**
+ * Helper method to return the cardinality of the column dimension
+ *
+ * @return
+ */
+ int numCols();
+
+ /**
* Return a copy of the recipient
*
* @return a new Matrix
@@ -235,7 +249,8 @@
int[] getNumNondefaultElements();
/**
- * Return a new matrix containing the product of each value of the recipient and the argument
+ * Return a new matrix containing the product of each value of the recipient
+ * and the argument
*
* @param x a double argument
* @return a new Matrix
@@ -243,13 +258,36 @@
Matrix times(double x);
/**
- * Return a new matrix containing the product of the recipient and the argument
+ * Return a new matrix containing the product of the recipient and
+ * the argument
*
* @param x a Matrix argument
* @return a new Matrix
* @throws CardinalityException if the cardinalities are incompatible
*/
Matrix times(Matrix x);
+
+ /**
+ * Return a new vector with cardinality equal to getNumRows() of this
+ * matrix which is the matrix product of the recipient and the argument
+ *
+ * @param v a vector with cardinality equal to getNumCols() of the recipient
+ * @return a new vector (typically a DenseVector)
+ * @throws CardinalityException if this.getNumRows() != v.size()
+ */
+ Vector times(Vector v);
+
+ /**
+ * Convenience method for producing this.transpose().times(this.times(v)),
+ * which can be implemented with only one pass over the matrix, without
+ * making the transpose() call (which can be expensive if the matrix is sparse)
+ *
+ * @param v a vector with cardinality equal to getNumCols() of the recipient
+ * @return a new vector (typically a DenseVector) with cardinality equal to
+ * that of the argument.
+ * @throws CardinalityException if this.getNumCols() != v.size()
+ */
+ Vector timesSquared(Vector v);
/**
* Return a new matrix that is the transpose of the receiver
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java?rev=883094&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java Sun Nov 22 16:13:10 2009
@@ -0,0 +1,31 @@
+package org.apache.mahout.matrix;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+public class PlusWithScaleFunction implements BinaryFunction {
+ private final double scale;
+
+ public PlusWithScaleFunction(final double scale) {
+ this.scale = scale;
+ }
+
+ @Override
+ public double apply(double arg1, double arg2) {
+ return arg1 + scale * arg2;
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java
------------------------------------------------------------------------------
svn:eol-style = native
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java Sun Nov 22 16:13:10 2009
@@ -102,11 +102,10 @@
@Override
public void setQuick(int row, int column, double value) {
- Integer rowKey = row;
- Vector r = rows.get(rowKey);
+ Vector r = rows.get(row);
if (r == null) {
r = new SparseVector(cardinality[COL]);
- rows.put(rowKey, r);
+ rows.put(row, r);
}
r.setQuick(column, value);
}
@@ -142,11 +141,10 @@
for (int row = 0; row < cardinality[ROW]; row++) {
double val = other.getQuick(row);
if (val != 0.0) {
- Integer rowKey = row;
- Vector r = rows.get(rowKey);
+ Vector r = rows.get(row);
if (r == null) {
r = new SparseVector(cardinality[ROW]);
- rows.put(rowKey, r);
+ rows.put(row, r);
}
r.setQuick(column, val);
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java Sun Nov 22 16:13:10 2009
@@ -35,6 +35,10 @@
protected final double[][] values = {{1.1, 2.2}, {3.3, 4.4},
{5.5, 6.6}};
+ protected final double[] vectorAValues = { 1.0/1.1, 2.0/1.1 };
+
+ protected final double[] vectorBValues = { 5.0, 10.0, 100.0 };
+
protected Matrix test;
protected MatrixTest(String name) {
@@ -395,11 +399,48 @@
int[] v = value.size();
assertEquals("rows", c[ROW], v[ROW]);
assertEquals("cols", c[ROW], v[COL]);
- // TODO: check the math too, lazy
+
+ Matrix expected = new DenseMatrix(new double[][] {{5.0, 11.0, 17.0},
+ {11.0, 25.0, 39.0}, {17.0, 39.0, 61.0}}).times(1.21);
+
+ for(int i=0; i<expected.numCols(); i++) {
+ for(int j=0; j<expected.numRows(); j++) {
+ assertTrue("Matrix times transpose not correct: " + i + ", " + j
+ + "\nexpected:\n\t" + expected.asFormatString() + "\nactual:\n\t"
+ + value.asFormatString(),
+ Math.abs(expected.get(i, j) - value.get(i, j)) < 1e-12);
+ }
+ }
+
Matrix timestest = new DenseMatrix(10, 1);
/* will throw ArrayIndexOutOfBoundsException exception without MAHOUT-26 */
timestest.transpose().times(timestest);
}
+
+ public void testTimesVector() {
+ Vector vectorA = new DenseVector(vectorAValues);
+ Vector testTimesVectorA = test.times(vectorA);
+ Vector expected = new DenseVector(new double[] { 5.0, 11.0, 17.0 });
+ assertTrue("Matrix times vector not equals: " + vectorA.asFormatString()
+ + " != " + testTimesVectorA.asFormatString(),
+ expected.minus(testTimesVectorA).norm(2) < 1e-12);
+ try {
+ test.times(testTimesVectorA);
+ fail("Cardinalities do not match, should throw exception");
+ } catch (CardinalityException ce) {
+ assertTrue(true);
+ }
+ }
+
+ public void testTimesSquaredTimesVector() {
+ Vector vectorA = new DenseVector(vectorAValues);
+ Vector ttA = test.timesSquared(vectorA);
+ Vector ttASlow = test.transpose().times(test.times(vectorA));
+ assertTrue("M'Mv != M.timesSquared(v): " + ttA.asFormatString()
+ + " != " + ttASlow.asFormatString(),
+ ttASlow.minus(ttA).norm(2) < 1e-12);
+
+ }
public void testTimesMatrixCardinality() {
Matrix other = test.like(5, 8);