You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/11/22 17:13:10 UTC

svn commit: r883094 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/matrix/ test/java/org/apache/mahout/matrix/

Author: gsingers
Date: Sun Nov 22 16:13:10 2009
New Revision: 883094

URL: http://svn.apache.org/viewvc?rev=883094&view=rev
Log:
MAHOUT-182: added helper methods on Matrix for times, numRows, etc.

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java   (with props)
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java Sun Nov 22 16:13:10 2009
@@ -125,6 +125,14 @@
 
   // index into int[2] for row value
   public static final int ROW = 0;
+  
+  public int numRows() {
+    return size()[ROW];
+  }
+  
+  public int numCols() {
+    return size()[COL];
+  }
 
   public static Matrix decodeMatrix(String formatString) {
     Type vectorType = new TypeToken<Vector>() {
@@ -399,6 +407,31 @@
     return result;
   }
 
+  public Vector times(Vector v) {
+    int[] c = size();
+    if(c[COL] != v.size()) {
+      throw new CardinalityException();
+    }
+    Vector w = new DenseVector(c[ROW]);
+    for(int i=0; i<c[ROW]; i++) {
+      w.setQuick(i, v.dot(getRow(i)));
+    }
+    return w;
+  }
+  
+  public Vector timesSquared(Vector v) {
+    int[] c = size();
+    if(c[COL] != v.size()) {
+      throw new CardinalityException();
+    }
+    Vector w = new DenseVector(c[COL]);
+    for(int i=0; i<c[ROW]; i++) {
+      Vector xi = getRow(i);
+      w.assign(xi, new PlusWithScaleFunction(xi.dot(v)));
+    }
+    return w;
+  }
+  
   @Override
   public Matrix transpose() {
     int[] card = size();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java Sun Nov 22 16:13:10 2009
@@ -99,6 +99,20 @@
   int[] size();
 
   /**
+   * Helper method to return the cardinality of the row dimension
+   * 
+   * @return
+   */
+  int numRows();
+  
+  /**
+   * Helper method to return the cardinality of the column dimension
+   * 
+   * @return
+   */
+  int numCols();
+  
+  /**
    * Return a copy of the recipient
    *
    * @return a new Matrix
@@ -235,7 +249,8 @@
   int[] getNumNondefaultElements();
 
   /**
-   * Return a new matrix containing the product of each value of the recipient and the argument
+   * Return a new matrix containing the product of each value of the recipient 
+   * and the argument
    *
    * @param x a double argument
    * @return a new Matrix
@@ -243,13 +258,36 @@
   Matrix times(double x);
 
   /**
-   * Return a new matrix containing the product of the recipient and the argument
+   * Return a new matrix containing the product of the recipient and 
+   * the argument
    *
    * @param x a Matrix argument
    * @return a new Matrix
    * @throws CardinalityException if the cardinalities are incompatible
    */
   Matrix times(Matrix x);
+  
+  /**
+   * Return a new vector with cardinality equal to getNumRows() of this 
+   * matrix which is the matrix product of the recipient and the argument
+   * 
+   * @param v a vector with cardinality equal to getNumCols() of the recipient
+   * @return a new vector (typically a DenseVector)
+   * @throws CardinalityException if this.getNumRows() != v.size()
+   */
+  Vector times(Vector v);
+  
+  /**
+   * Convenience method for producing this.transpose().times(this.times(v)), 
+   * which can be implemented with only one pass over the matrix, without 
+   * making the transpose() call (which can be expensive if the matrix is sparse)
+   * 
+   * @param v a vector with cardinality equal to getNumCols() of the recipient
+   * @return a new vector (typically a DenseVector) with cardinality equal to
+   * that of the argument.
+   * @throws CardinalityException if this.getNumCols() != v.size()
+   */
+  Vector timesSquared(Vector v);
 
   /**
    * Return a new matrix that is the transpose of the receiver

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java?rev=883094&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java Sun Nov 22 16:13:10 2009
@@ -0,0 +1,31 @@
+package org.apache.mahout.matrix;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+public class PlusWithScaleFunction implements BinaryFunction {
+  private final double scale;
+
+  public PlusWithScaleFunction(final double scale) {
+    this.scale = scale;
+  }
+
+  @Override
+  public double apply(double arg1, double arg2) {
+    return arg1 + scale * arg2;
+  }
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/PlusWithScaleFunction.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseMatrix.java Sun Nov 22 16:13:10 2009
@@ -102,11 +102,10 @@
 
   @Override
   public void setQuick(int row, int column, double value) {
-    Integer rowKey = row;
-    Vector r = rows.get(rowKey);
+    Vector r = rows.get(row);
     if (r == null) {
       r = new SparseVector(cardinality[COL]);
-      rows.put(rowKey, r);
+      rows.put(row, r);
     }
     r.setQuick(column, value);
   }
@@ -142,11 +141,10 @@
     for (int row = 0; row < cardinality[ROW]; row++) {
       double val = other.getQuick(row);
       if (val != 0.0) {
-        Integer rowKey = row;
-        Vector r = rows.get(rowKey);
+        Vector r = rows.get(row);
         if (r == null) {
           r = new SparseVector(cardinality[ROW]);
-          rows.put(rowKey, r);
+          rows.put(row, r);
         }
         r.setQuick(column, val);
       }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java?rev=883094&r1=883093&r2=883094&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java Sun Nov 22 16:13:10 2009
@@ -35,6 +35,10 @@
   protected final double[][] values = {{1.1, 2.2}, {3.3, 4.4},
       {5.5, 6.6}};
 
+  protected final double[] vectorAValues = { 1.0/1.1, 2.0/1.1 };
+  
+  protected final double[] vectorBValues = { 5.0, 10.0, 100.0 };
+  
   protected Matrix test;
 
   protected MatrixTest(String name) {
@@ -395,11 +399,48 @@
     int[] v = value.size();
     assertEquals("rows", c[ROW], v[ROW]);
     assertEquals("cols", c[ROW], v[COL]);
-    // TODO: check the math too, lazy
+    
+    Matrix expected = new DenseMatrix(new double[][] {{5.0, 11.0, 17.0}, 
+        {11.0, 25.0, 39.0}, {17.0, 39.0, 61.0}}).times(1.21);
+    
+    for(int i=0; i<expected.numCols(); i++) {
+      for(int j=0; j<expected.numRows(); j++) {
+        assertTrue("Matrix times transpose not correct: " + i + ", " + j 
+                   + "\nexpected:\n\t" + expected.asFormatString() + "\nactual:\n\t" 
+                   + value.asFormatString(), 
+                   Math.abs(expected.get(i, j) - value.get(i, j)) < 1e-12);
+      }
+    }
+    
     Matrix timestest = new DenseMatrix(10, 1);
     /* will throw ArrayIndexOutOfBoundsException exception without MAHOUT-26 */
     timestest.transpose().times(timestest);
   }
+  
+  public void testTimesVector() {
+    Vector vectorA = new DenseVector(vectorAValues);
+    Vector testTimesVectorA = test.times(vectorA);
+    Vector expected = new DenseVector(new double[] { 5.0, 11.0, 17.0 });
+    assertTrue("Matrix times vector not equals: " + vectorA.asFormatString() 
+               + " != " + testTimesVectorA.asFormatString(), 
+               expected.minus(testTimesVectorA).norm(2) < 1e-12);
+    try {
+      test.times(testTimesVectorA);
+      fail("Cardinalities do not match, should throw exception");
+    } catch (CardinalityException ce) {
+      assertTrue(true);
+    }
+  }
+  
+  public void testTimesSquaredTimesVector() {
+    Vector vectorA = new DenseVector(vectorAValues);
+    Vector ttA = test.timesSquared(vectorA);
+    Vector ttASlow = test.transpose().times(test.times(vectorA));
+    assertTrue("M'Mv != M.timesSquared(v): " + ttA.asFormatString() 
+               + " != " + ttASlow.asFormatString(),
+               ttASlow.minus(ttA).norm(2) < 1e-12);
+    
+  }
 
   public void testTimesMatrixCardinality() {
     Matrix other = test.like(5, 8);