You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/06/17 17:08:29 UTC

svn commit: r785649 - in /lucene/mahout/trunk: ./ core/src/main/java/org/apache/mahout/matrix/ core/src/test/java/org/apache/mahout/matrix/

Author: jeastman
Date: Wed Jun 17 15:08:29 2009
New Revision: 785649

URL: http://svn.apache.org/viewvc?rev=785649&view=rev
Log:
- MAHOUT-65: implemented Matrix label binding interface and unit tests
- fixed cut/paste error in JsonMatrixAdapter

Modified:
    lucene/mahout/trunk/   (props changed)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java

Propchange: lucene/mahout/trunk/
------------------------------------------------------------------------------
--- svn:ignore (original)
+++ svn:ignore Wed Jun 17 15:08:29 2009
@@ -10,3 +10,4 @@
 .settings
 atlassian-ide-plugin.xml
 target
+input

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java Wed Jun 17 15:08:29 2009
@@ -18,6 +18,8 @@
 package org.apache.mahout.matrix;
 
 import java.lang.reflect.Type;
+import java.util.HashMap;
+import java.util.Map;
 
 import com.google.gson.Gson;
 import com.google.gson.GsonBuilder;
@@ -29,6 +31,86 @@
  */
 public abstract class AbstractMatrix implements Matrix {
 
+  private Map<String, Integer> columnLabelBindings;
+  
+  private Map<String, Integer> rowLabelBindings;
+  
+  @Override
+  public double get(String rowLabel, String columnLabel) throws IndexException,
+      UnboundLabelException {
+    if (columnLabelBindings == null || rowLabelBindings == null)
+      throw new UnboundLabelException();
+    Integer row = rowLabelBindings.get(rowLabel);
+    Integer col = columnLabelBindings.get(columnLabel);
+    if (row == null || col == null)
+      throw new UnboundLabelException();
+
+    return get(row, col);
+  }
+
+  @Override
+  public Map<String, Integer> getColumnLabelBindings() {
+    return columnLabelBindings;
+  }
+
+  @Override
+  public Map<String, Integer> getRowLabelBindings() {
+    return rowLabelBindings;
+  }
+
+  @Override
+  public void set(String rowLabel, double[] rowData) {
+    if (columnLabelBindings == null)
+      throw new UnboundLabelException();
+    Integer row = rowLabelBindings.get(rowLabel);
+    if (row == null)
+      throw new UnboundLabelException();
+    set(row, rowData);
+  }
+
+  @Override
+  public void set(String rowLabel, int row, double[] rowData) {
+    if (rowLabelBindings == null)
+      rowLabelBindings = new HashMap<String, Integer>();
+    rowLabelBindings.put(rowLabel, row);
+    set(row, rowData);
+  }
+
+  @Override
+  public void set(String rowLabel, String columnLabel, double value)
+      throws IndexException, UnboundLabelException {
+    if (columnLabelBindings == null || rowLabelBindings == null)
+      throw new UnboundLabelException();
+    Integer row = rowLabelBindings.get(rowLabel);
+    Integer col = columnLabelBindings.get(columnLabel);
+    if (row == null || col == null)
+      throw new UnboundLabelException();
+    set(row, col, value);
+  }
+
+  @Override
+  public void set(String rowLabel, String columnLabel, int row, int column,
+      double value) throws IndexException, UnboundLabelException {
+    if (rowLabelBindings == null)
+      rowLabelBindings = new HashMap<String, Integer>();
+    rowLabelBindings.put(rowLabel, row);
+    if (columnLabelBindings == null)
+      columnLabelBindings = new HashMap<String, Integer>();
+    columnLabelBindings.put(columnLabel, column);
+    
+    set(row, column, value);
+  }
+
+  @Override
+  public void setColumnLabelBindings(Map<String, Integer> bindings) {
+    columnLabelBindings = bindings;
+  }
+
+  @Override
+  public void setRowLabelBindings(Map<String, Integer> bindings) {
+    rowLabelBindings = bindings;
+  }
+
   // index into int[2] for column value
   public static final int COL = 1;
 

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java Wed Jun 17 15:08:29 2009
@@ -36,7 +36,7 @@
 public class JsonMatrixAdapter implements JsonSerializer<Matrix>,
     JsonDeserializer<Matrix> {
 
-  private static final Logger log = LoggerFactory.getLogger(JsonVectorAdapter.class);
+  private static final Logger log = LoggerFactory.getLogger(JsonMatrixAdapter.class);
   public static final String CLASS = "class";
   public static final String MATRIX = "matrix";
 

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java Wed Jun 17 15:08:29 2009
@@ -17,6 +17,8 @@
 
 package org.apache.mahout.matrix;
 
+import java.util.Map;
+
 import org.apache.hadoop.io.WritableComparable;
 
 /**
@@ -303,6 +305,87 @@
    */
   double zSum();
 
+  /**
+   * Return a map of the current column label bindings of the receiver
+   * 
+   * @return a Map<String, Integer>
+   */
+  Map<String, Integer> getColumnLabelBindings();
+
+  /**
+   * Return a map of the current row label bindings of the receiver
+   * 
+   * @return a Map<String, Integer>
+   */
+  Map<String, Integer> getRowLabelBindings();
+
+  /**
+   * Sets a map of column label bindings in the receiver
+   * 
+   * @param bindings a Map<String, Integer> of label bindings
+   */
+  void setColumnLabelBindings(Map<String, Integer> bindings);
+
+  /**
+   * Sets a map of row label bindings in the receiver
+   * 
+   * @param bindings a Map<String, Integer> of label bindings
+   */
+  void setRowLabelBindings(Map<String, Integer> bindings);
+
+  /**
+   * Return the value at the given labels
+   * 
+   * @param rowLabel a String row label
+   * @param columnLabel a String column label
+   * @return the double at the index
+   * @throws IndexException if the index is out of bounds
+   */
+  double get(String rowLabel, String columnLabel) throws IndexException,
+    UnboundLabelException;
+
+  /**
+   * Set the value at the given index
+   * 
+   * @param rowLabel a String row label
+   * @param columnLabel a String column label
+   * @param value a double value to set
+   * @throws IndexException if the index is out of bounds
+   */
+  void set(String rowLabel, String columnLabel, double value) throws IndexException,
+    UnboundLabelException;
+
+  /**
+   * Set the value at the given index, updating the row and column label bindings
+   * 
+   * @param rowLabel a String row label
+   * @param columnLabel a String column label
+   * @param row an int row index
+   * @param column an int column index
+   * @param value a double value
+   * @throws IndexException
+   * @throws UnboundLabelException
+   */
+  void set(String rowLabel, String columnLabel, int row, int column, double value) throws IndexException,
+    UnboundLabelException;
+
+  /**
+   * Sets the row values at the given row label
+   * 
+   * @param rowLabel a String row label
+   * @param rowData a double[] array of row data
+   */
+  void set(String rowLabel, double[] rowData);
+
+  /**
+   * Sets the row values at the given row index and updates the row labels
+   * 
+   * @param rowLabel the String row label
+   * @param row an int the row index
+   * @param rowData a double[] array of row data
+   */
+  void set(String rowLabel, int row, double[] rowData);
+
   /*
    * Need stories for these but keeping them here for now.
    * 

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java Wed Jun 17 15:08:29 2009
@@ -17,6 +17,9 @@
 
 package org.apache.mahout.matrix;
 
+import java.util.HashMap;
+import java.util.Map;
+
 import junit.framework.TestCase;
 
 public abstract class MatrixTest extends TestCase {
@@ -25,7 +28,8 @@
 
   protected static final int COL = AbstractMatrix.COL;
 
-  protected final double[][] values = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}};
+  protected final double[][] values = { { 1.1, 2.2 }, { 3.3, 4.4 },
+      { 5.5, 6.6 } };
 
   protected Matrix test;
 
@@ -54,7 +58,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                test.getQuick(row, col), copy.getQuick(row, col));
+            test.getQuick(row, col), copy.getQuick(row, col));
   }
 
   public void testGetQuick() {
@@ -62,7 +66,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', values[row][col], test
-                .getQuick(row, col));
+            .getQuick(row, col));
   }
 
   public void testHaveSharedCells() {
@@ -90,7 +94,7 @@
       for (int col = 0; col < c[COL]; col++) {
         test.setQuick(row, col, 1.23);
         assertEquals("value[" + row + "][" + col + ']', 1.23, test.getQuick(
-                row, col));
+            row, col));
       }
   }
 
@@ -106,23 +110,23 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', values[row][col],
-                array[row][col]);
+            array[row][col]);
   }
 
   public void testViewPart() {
-    int[] offset = {1, 1};
-    int[] size = {2, 1};
+    int[] offset = { 1, 1 };
+    int[] size = { 2, 1 };
     Matrix view = test.viewPart(offset, size);
     int[] c = view.cardinality();
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                values[row + 1][col + 1], view.getQuick(row, col));
+            values[row + 1][col + 1], view.getQuick(row, col));
   }
 
   public void testViewPartCardinality() {
-    int[] offset = {1, 1};
-    int[] size = {3, 3};
+    int[] offset = { 1, 1 };
+    int[] size = { 3, 3 };
     try {
       test.viewPart(offset, size);
       fail("exception expected");
@@ -134,8 +138,8 @@
   }
 
   public void testViewPartIndexOver() {
-    int[] offset = {1, 1};
-    int[] size = {2, 2};
+    int[] offset = { 1, 1 };
+    int[] size = { 2, 2 };
     try {
       test.viewPart(offset, size);
       fail("exception expected");
@@ -147,8 +151,8 @@
   }
 
   public void testViewPartIndexUnder() {
-    int[] offset = {-1, -1};
-    int[] size = {2, 2};
+    int[] offset = { -1, -1 };
+    int[] size = { 2, 2 };
     try {
       test.viewPart(offset, size);
       fail("exception expected");
@@ -165,7 +169,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', 4.53, test.getQuick(
-                row, col));
+            row, col));
   }
 
   public void testAssignDoubleArrayArray() {
@@ -174,7 +178,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', 0.0, test.getQuick(row,
-                col));
+            col));
   }
 
   public void testAssignDoubleArrayArrayCardinality() {
@@ -193,7 +197,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', 2 * values[row][col],
-                test.getQuick(row, col));
+            test.getQuick(row, col));
   }
 
   public void testAssignMatrixBinaryFunctionCardinality() {
@@ -212,7 +216,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                test.getQuick(row, col), value.getQuick(row, col));
+            test.getQuick(row, col), value.getQuick(row, col));
   }
 
   public void testAssignMatrixCardinality() {
@@ -230,7 +234,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', -values[row][col], test
-                .getQuick(row, col));
+            .getQuick(row, col));
   }
 
   public void testDivide() {
@@ -239,7 +243,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                values[row][col] / 4.53, value.getQuick(row, col));
+            values[row][col] / 4.53, value.getQuick(row, col));
   }
 
   public void testGet() {
@@ -247,7 +251,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', values[row][col], test
-                .get(row, col));
+            .get(row, col));
   }
 
   public void testGetIndexUnder() {
@@ -280,7 +284,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', 0.0, value.getQuick(
-                row, col));
+            row, col));
   }
 
   public void testMinusCardinality() {
@@ -298,7 +302,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                values[row][col] + 4.53, value.getQuick(row, col));
+            values[row][col] + 4.53, value.getQuick(row, col));
   }
 
   public void testPlusMatrix() {
@@ -307,7 +311,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']', values[row][col] * 2,
-                value.getQuick(row, col));
+            value.getQuick(row, col));
   }
 
   public void testPlusMatrixCardinality() {
@@ -351,7 +355,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                values[row][col] * 4.53, value.getQuick(row, col));
+            values[row][col] * 4.53, value.getQuick(row, col));
   }
 
   public void testTimesMatrix() {
@@ -362,7 +366,7 @@
     assertEquals("rows", c[ROW], v[ROW]);
     assertEquals("cols", c[ROW], v[COL]);
     // TODO: check the math too, lazy
-    Matrix timestest = new DenseMatrix(10,1);
+    Matrix timestest = new DenseMatrix(10, 1);
     /* will throw ArrayIndexOutOfBoundsException exception without MAHOUT-26 */
     timestest.transpose().times(timestest);
   }
@@ -386,7 +390,7 @@
     for (int row = 0; row < c[ROW]; row++)
       for (int col = 0; col < c[COL]; col++)
         assertEquals("value[" + row + "][" + col + ']',
-                test.getQuick(row, col), transpose.getQuick(col, row));
+            test.getQuick(row, col), transpose.getQuick(col, row));
   }
 
   public void testZSum() {
@@ -395,14 +399,14 @@
   }
 
   public void testAssignRow() {
-    double[] data = {2.1, 3.2};
+    double[] data = { 2.1, 3.2 };
     test.assignRow(1, new DenseVector(data));
     assertEquals("test[1][0]", 2.1, test.getQuick(1, 0));
     assertEquals("test[1][1]", 3.2, test.getQuick(1, 1));
   }
 
   public void testAssignRowCardinality() {
-    double[] data = {2.1, 3.2, 4.3};
+    double[] data = { 2.1, 3.2, 4.3 };
     try {
       test.assignRow(1, new DenseVector(data));
       fail("expecting cardinality exception");
@@ -412,7 +416,7 @@
   }
 
   public void testAssignColumn() {
-    double[] data = {2.1, 3.2, 4.3};
+    double[] data = { 2.1, 3.2, 4.3 };
     test.assignColumn(1, new DenseVector(data));
     assertEquals("test[0][1]", 2.1, test.getQuick(0, 1));
     assertEquals("test[1][1]", 3.2, test.getQuick(1, 1));
@@ -420,7 +424,7 @@
   }
 
   public void testAssignColumnCardinality() {
-    double[] data = {2.1, 3.2};
+    double[] data = { 2.1, 3.2 };
     try {
       test.assignColumn(1, new DenseVector(data));
       fail("expecting cardinality exception");
@@ -476,7 +480,8 @@
   }
 
   public void testDetermitant() {
-    Matrix m = matrixFactory(new double[][] { {1,3,4},{5,2,3},{1,4,2} });
+    Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+        { 1, 4, 2 } });
     assertEquals("determinant", 43.0, m.determinant());
   }
 
@@ -490,4 +495,69 @@
             row, col));
   }
 
+  public void testLabelBindings() {
+    Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+        { 1, 4, 2 } });
+    assertNull("row bindings", m.getRowLabelBindings());
+    assertNull("col bindings", m.getColumnLabelBindings());
+    Map<String, Integer> rowBindings = new HashMap<String, Integer>();
+    rowBindings.put("Fee", 0);
+    rowBindings.put("Fie", 1);
+    rowBindings.put("Foe", 2);
+    m.setRowLabelBindings(rowBindings);
+    assertEquals("row", rowBindings, m.getRowLabelBindings());
+    Map<String, Integer> colBindings = new HashMap<String, Integer>();
+    colBindings.put("Foo", 0);
+    colBindings.put("Bar", 1);
+    colBindings.put("Baz", 2);
+    m.setColumnLabelBindings(colBindings);
+    assertEquals("row", rowBindings, m.getRowLabelBindings());
+    assertEquals("Fee", m.get(0, 1), m.get("Fee", "Bar"));
+
+    double[] newrow = { 9, 8, 7 };
+    m.set("Foe", newrow);
+    assertEquals("FeeBaz", m.get(0, 2), m.get("Fee", "Baz"));
+  }
+
+  public void testSettingLabelBindings() {
+    Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+        { 1, 4, 2 } });
+    assertNull("row bindings", m.getRowLabelBindings());
+    assertNull("col bindings", m.getColumnLabelBindings());
+    m.set("Fee", "Foo", 1, 2, 9);
+    assertNotNull("row", m.getRowLabelBindings());
+    assertNotNull("row", m.getRowLabelBindings());
+    assertEquals("Fee", 1, m.getRowLabelBindings().get("Fee").intValue());
+    assertEquals("Fee", 2, m.getColumnLabelBindings().get("Foo").intValue());
+    assertEquals("FeeFoo", m.get(1, 2), m.get("Fee", "Foo"));
+    try {
+      m.get("Fie", "Foe");
+      fail("Expected UnboundLabelException");
+    } catch (IndexException e) {
+      fail("Expected UnboundLabelException");
+    } catch (UnboundLabelException e) {
+      assertTrue(true);
+    }
+  }
+  
+  public void testLabelBindingSerialization(){
+    Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+        { 1, 4, 2 } });
+    assertNull("row bindings", m.getRowLabelBindings());
+    assertNull("col bindings", m.getColumnLabelBindings());
+    Map<String, Integer> rowBindings = new HashMap<String, Integer>();
+    rowBindings.put("Fee", 0);
+    rowBindings.put("Fie", 1);
+    rowBindings.put("Foe", 2);
+    m.setRowLabelBindings(rowBindings);
+    assertEquals("row", rowBindings, m.getRowLabelBindings());
+    Map<String, Integer> colBindings = new HashMap<String, Integer>();
+    colBindings.put("Foo", 0);
+    colBindings.put("Bar", 1);
+    colBindings.put("Baz", 2);
+    m.setColumnLabelBindings(colBindings);
+    String json = m.asFormatString();
+    Matrix mm = AbstractMatrix.decodeMatrix(json);
+    assertEquals("Fee", m.get(0, 1), mm.get("Fee", "Bar"));
+  }
 }