You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/06/17 17:08:29 UTC
svn commit: r785649 - in /lucene/mahout/trunk: ./
core/src/main/java/org/apache/mahout/matrix/
core/src/test/java/org/apache/mahout/matrix/
Author: jeastman
Date: Wed Jun 17 15:08:29 2009
New Revision: 785649
URL: http://svn.apache.org/viewvc?rev=785649&view=rev
Log:
- MAHOUT-65: implemented Matrix label binding interface and unit tests
- fixed cut/paste error in JsonMatrixAdapter
Modified:
lucene/mahout/trunk/ (props changed)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
Propchange: lucene/mahout/trunk/
------------------------------------------------------------------------------
--- svn:ignore (original)
+++ svn:ignore Wed Jun 17 15:08:29 2009
@@ -10,3 +10,4 @@
.settings
atlassian-ide-plugin.xml
target
+input
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/AbstractMatrix.java Wed Jun 17 15:08:29 2009
@@ -18,6 +18,8 @@
package org.apache.mahout.matrix;
import java.lang.reflect.Type;
+import java.util.HashMap;
+import java.util.Map;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
@@ -29,6 +31,86 @@
*/
public abstract class AbstractMatrix implements Matrix {
+ private Map<String, Integer> columnLabelBindings;
+
+ private Map<String, Integer> rowLabelBindings;
+
+ @Override
+ public double get(String rowLabel, String columnLabel) throws IndexException,
+ UnboundLabelException {
+ if (columnLabelBindings == null || rowLabelBindings == null)
+ throw new UnboundLabelException();
+ Integer row = rowLabelBindings.get(rowLabel);
+ Integer col = columnLabelBindings.get(columnLabel);
+ if (row == null || col == null)
+ throw new UnboundLabelException();
+
+ return get(row, col);
+ }
+
+ @Override
+ public Map<String, Integer> getColumnLabelBindings() {
+ return columnLabelBindings;
+ }
+
+ @Override
+ public Map<String, Integer> getRowLabelBindings() {
+ return rowLabelBindings;
+ }
+
+ @Override
+ public void set(String rowLabel, double[] rowData) {
+ if (columnLabelBindings == null)
+ throw new UnboundLabelException();
+ Integer row = rowLabelBindings.get(rowLabel);
+ if (row == null)
+ throw new UnboundLabelException();
+ set(row, rowData);
+ }
+
+ @Override
+ public void set(String rowLabel, int row, double[] rowData) {
+ if (rowLabelBindings == null)
+ rowLabelBindings = new HashMap<String, Integer>();
+ rowLabelBindings.put(rowLabel, row);
+ set(row, rowData);
+ }
+
+ @Override
+ public void set(String rowLabel, String columnLabel, double value)
+ throws IndexException, UnboundLabelException {
+ if (columnLabelBindings == null || rowLabelBindings == null)
+ throw new UnboundLabelException();
+ Integer row = rowLabelBindings.get(rowLabel);
+ Integer col = columnLabelBindings.get(columnLabel);
+ if (row == null || col == null)
+ throw new UnboundLabelException();
+ set(row, col, value);
+ }
+
+ @Override
+ public void set(String rowLabel, String columnLabel, int row, int column,
+ double value) throws IndexException, UnboundLabelException {
+ if (rowLabelBindings == null)
+ rowLabelBindings = new HashMap<String, Integer>();
+ rowLabelBindings.put(rowLabel, row);
+ if (columnLabelBindings == null)
+ columnLabelBindings = new HashMap<String, Integer>();
+ columnLabelBindings.put(columnLabel, column);
+
+ set(row, column, value);
+ }
+
+ @Override
+ public void setColumnLabelBindings(Map<String, Integer> bindings) {
+ columnLabelBindings = bindings;
+ }
+
+ @Override
+ public void setRowLabelBindings(Map<String, Integer> bindings) {
+ rowLabelBindings = bindings;
+ }
+
// index into int[2] for column value
public static final int COL = 1;
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/JsonMatrixAdapter.java Wed Jun 17 15:08:29 2009
@@ -36,7 +36,7 @@
public class JsonMatrixAdapter implements JsonSerializer<Matrix>,
JsonDeserializer<Matrix> {
- private static final Logger log = LoggerFactory.getLogger(JsonVectorAdapter.class);
+ private static final Logger log = LoggerFactory.getLogger(JsonMatrixAdapter.class);
public static final String CLASS = "class";
public static final String MATRIX = "matrix";
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/Matrix.java Wed Jun 17 15:08:29 2009
@@ -17,6 +17,8 @@
package org.apache.mahout.matrix;
+import java.util.Map;
+
import org.apache.hadoop.io.WritableComparable;
/**
@@ -303,6 +305,87 @@
*/
double zSum();
+ /**
+ * Return a map of the current column label bindings of the receiver
+ *
+ * @return a Map<String, Integer>
+ */
+ Map<String, Integer> getColumnLabelBindings();
+
+ /**
+ * Return a map of the current row label bindings of the receiver
+ *
+ * @return a Map<String, Integer>
+ */
+ Map<String, Integer> getRowLabelBindings();
+
+ /**
+ * Sets a map of column label bindings in the receiver
+ *
+ * @param bindings a Map<String, Integer> of label bindings
+ */
+ void setColumnLabelBindings(Map<String, Integer> bindings);
+
+ /**
+ * Sets a map of row label bindings in the receiver
+ *
+ * @param bindings a Map<String, Integer> of label bindings
+ */
+ void setRowLabelBindings(Map<String, Integer> bindings);
+
+ /**
+ * Return the value at the given labels
+ *
+ * @param rowLabel a String row label
+ * @param columnLabel a String column label
+ * @return the double at the index
+ * @throws IndexException if the index is out of bounds
+ */
+ double get(String rowLabel, String columnLabel) throws IndexException,
+ UnboundLabelException;
+
+ /**
+ * Set the value at the given index
+ *
+ * @param rowLabel a String row label
+ * @param columnLabel a String column label
+ * @param value a double value to set
+ * @throws IndexException if the index is out of bounds
+ */
+ void set(String rowLabel, String columnLabel, double value) throws IndexException,
+ UnboundLabelException;
+
+ /**
+ * Set the value at the given index, updating the row and column label bindings
+ *
+ * @param rowLabel a String row label
+ * @param columnLabel a String column label
+ * @param row an int row index
+ * @param column an int column index
+ * @param value a double value
+ * @throws IndexException
+ * @throws UnboundLabelException
+ */
+ void set(String rowLabel, String columnLabel, int row, int column, double value) throws IndexException,
+ UnboundLabelException;
+
+ /**
+ * Sets the row values at the given row label
+ *
+ * @param rowLabel a String row label
+ * @param rowData a double[] array of row data
+ */
+ void set(String rowLabel, double[] rowData);
+
+ /**
+ * Sets the row values at the given row index and updates the row labels
+ *
+ * @param rowLabel the String row label
+ * @param row an int the row index
+ * @param rowData a double[] array of row data
+ */
+ void set(String rowLabel, int row, double[] rowData);
+
/*
* Need stories for these but keeping them here for now.
*
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java?rev=785649&r1=785648&r2=785649&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/MatrixTest.java Wed Jun 17 15:08:29 2009
@@ -17,6 +17,9 @@
package org.apache.mahout.matrix;
+import java.util.HashMap;
+import java.util.Map;
+
import junit.framework.TestCase;
public abstract class MatrixTest extends TestCase {
@@ -25,7 +28,8 @@
protected static final int COL = AbstractMatrix.COL;
- protected final double[][] values = {{1.1, 2.2}, {3.3, 4.4}, {5.5, 6.6}};
+ protected final double[][] values = { { 1.1, 2.2 }, { 3.3, 4.4 },
+ { 5.5, 6.6 } };
protected Matrix test;
@@ -54,7 +58,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- test.getQuick(row, col), copy.getQuick(row, col));
+ test.getQuick(row, col), copy.getQuick(row, col));
}
public void testGetQuick() {
@@ -62,7 +66,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', values[row][col], test
- .getQuick(row, col));
+ .getQuick(row, col));
}
public void testHaveSharedCells() {
@@ -90,7 +94,7 @@
for (int col = 0; col < c[COL]; col++) {
test.setQuick(row, col, 1.23);
assertEquals("value[" + row + "][" + col + ']', 1.23, test.getQuick(
- row, col));
+ row, col));
}
}
@@ -106,23 +110,23 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', values[row][col],
- array[row][col]);
+ array[row][col]);
}
public void testViewPart() {
- int[] offset = {1, 1};
- int[] size = {2, 1};
+ int[] offset = { 1, 1 };
+ int[] size = { 2, 1 };
Matrix view = test.viewPart(offset, size);
int[] c = view.cardinality();
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- values[row + 1][col + 1], view.getQuick(row, col));
+ values[row + 1][col + 1], view.getQuick(row, col));
}
public void testViewPartCardinality() {
- int[] offset = {1, 1};
- int[] size = {3, 3};
+ int[] offset = { 1, 1 };
+ int[] size = { 3, 3 };
try {
test.viewPart(offset, size);
fail("exception expected");
@@ -134,8 +138,8 @@
}
public void testViewPartIndexOver() {
- int[] offset = {1, 1};
- int[] size = {2, 2};
+ int[] offset = { 1, 1 };
+ int[] size = { 2, 2 };
try {
test.viewPart(offset, size);
fail("exception expected");
@@ -147,8 +151,8 @@
}
public void testViewPartIndexUnder() {
- int[] offset = {-1, -1};
- int[] size = {2, 2};
+ int[] offset = { -1, -1 };
+ int[] size = { 2, 2 };
try {
test.viewPart(offset, size);
fail("exception expected");
@@ -165,7 +169,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', 4.53, test.getQuick(
- row, col));
+ row, col));
}
public void testAssignDoubleArrayArray() {
@@ -174,7 +178,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', 0.0, test.getQuick(row,
- col));
+ col));
}
public void testAssignDoubleArrayArrayCardinality() {
@@ -193,7 +197,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', 2 * values[row][col],
- test.getQuick(row, col));
+ test.getQuick(row, col));
}
public void testAssignMatrixBinaryFunctionCardinality() {
@@ -212,7 +216,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- test.getQuick(row, col), value.getQuick(row, col));
+ test.getQuick(row, col), value.getQuick(row, col));
}
public void testAssignMatrixCardinality() {
@@ -230,7 +234,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', -values[row][col], test
- .getQuick(row, col));
+ .getQuick(row, col));
}
public void testDivide() {
@@ -239,7 +243,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- values[row][col] / 4.53, value.getQuick(row, col));
+ values[row][col] / 4.53, value.getQuick(row, col));
}
public void testGet() {
@@ -247,7 +251,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', values[row][col], test
- .get(row, col));
+ .get(row, col));
}
public void testGetIndexUnder() {
@@ -280,7 +284,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', 0.0, value.getQuick(
- row, col));
+ row, col));
}
public void testMinusCardinality() {
@@ -298,7 +302,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- values[row][col] + 4.53, value.getQuick(row, col));
+ values[row][col] + 4.53, value.getQuick(row, col));
}
public void testPlusMatrix() {
@@ -307,7 +311,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']', values[row][col] * 2,
- value.getQuick(row, col));
+ value.getQuick(row, col));
}
public void testPlusMatrixCardinality() {
@@ -351,7 +355,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- values[row][col] * 4.53, value.getQuick(row, col));
+ values[row][col] * 4.53, value.getQuick(row, col));
}
public void testTimesMatrix() {
@@ -362,7 +366,7 @@
assertEquals("rows", c[ROW], v[ROW]);
assertEquals("cols", c[ROW], v[COL]);
// TODO: check the math too, lazy
- Matrix timestest = new DenseMatrix(10,1);
+ Matrix timestest = new DenseMatrix(10, 1);
/* will throw ArrayIndexOutOfBoundsException exception without MAHOUT-26 */
timestest.transpose().times(timestest);
}
@@ -386,7 +390,7 @@
for (int row = 0; row < c[ROW]; row++)
for (int col = 0; col < c[COL]; col++)
assertEquals("value[" + row + "][" + col + ']',
- test.getQuick(row, col), transpose.getQuick(col, row));
+ test.getQuick(row, col), transpose.getQuick(col, row));
}
public void testZSum() {
@@ -395,14 +399,14 @@
}
public void testAssignRow() {
- double[] data = {2.1, 3.2};
+ double[] data = { 2.1, 3.2 };
test.assignRow(1, new DenseVector(data));
assertEquals("test[1][0]", 2.1, test.getQuick(1, 0));
assertEquals("test[1][1]", 3.2, test.getQuick(1, 1));
}
public void testAssignRowCardinality() {
- double[] data = {2.1, 3.2, 4.3};
+ double[] data = { 2.1, 3.2, 4.3 };
try {
test.assignRow(1, new DenseVector(data));
fail("expecting cardinality exception");
@@ -412,7 +416,7 @@
}
public void testAssignColumn() {
- double[] data = {2.1, 3.2, 4.3};
+ double[] data = { 2.1, 3.2, 4.3 };
test.assignColumn(1, new DenseVector(data));
assertEquals("test[0][1]", 2.1, test.getQuick(0, 1));
assertEquals("test[1][1]", 3.2, test.getQuick(1, 1));
@@ -420,7 +424,7 @@
}
public void testAssignColumnCardinality() {
- double[] data = {2.1, 3.2};
+ double[] data = { 2.1, 3.2 };
try {
test.assignColumn(1, new DenseVector(data));
fail("expecting cardinality exception");
@@ -476,7 +480,8 @@
}
public void testDetermitant() {
- Matrix m = matrixFactory(new double[][] { {1,3,4},{5,2,3},{1,4,2} });
+ Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+ { 1, 4, 2 } });
assertEquals("determinant", 43.0, m.determinant());
}
@@ -490,4 +495,69 @@
row, col));
}
+ public void testLabelBindings() {
+ Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+ { 1, 4, 2 } });
+ assertNull("row bindings", m.getRowLabelBindings());
+ assertNull("col bindings", m.getColumnLabelBindings());
+ Map<String, Integer> rowBindings = new HashMap<String, Integer>();
+ rowBindings.put("Fee", 0);
+ rowBindings.put("Fie", 1);
+ rowBindings.put("Foe", 2);
+ m.setRowLabelBindings(rowBindings);
+ assertEquals("row", rowBindings, m.getRowLabelBindings());
+ Map<String, Integer> colBindings = new HashMap<String, Integer>();
+ colBindings.put("Foo", 0);
+ colBindings.put("Bar", 1);
+ colBindings.put("Baz", 2);
+ m.setColumnLabelBindings(colBindings);
+ assertEquals("row", rowBindings, m.getRowLabelBindings());
+ assertEquals("Fee", m.get(0, 1), m.get("Fee", "Bar"));
+
+ double[] newrow = { 9, 8, 7 };
+ m.set("Foe", newrow);
+ assertEquals("FeeBaz", m.get(0, 2), m.get("Fee", "Baz"));
+ }
+
+ public void testSettingLabelBindings() {
+ Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+ { 1, 4, 2 } });
+ assertNull("row bindings", m.getRowLabelBindings());
+ assertNull("col bindings", m.getColumnLabelBindings());
+ m.set("Fee", "Foo", 1, 2, 9);
+ assertNotNull("row", m.getRowLabelBindings());
+ assertNotNull("row", m.getRowLabelBindings());
+ assertEquals("Fee", 1, m.getRowLabelBindings().get("Fee").intValue());
+ assertEquals("Fee", 2, m.getColumnLabelBindings().get("Foo").intValue());
+ assertEquals("FeeFoo", m.get(1, 2), m.get("Fee", "Foo"));
+ try {
+ m.get("Fie", "Foe");
+ fail("Expected UnboundLabelException");
+ } catch (IndexException e) {
+ fail("Expected UnboundLabelException");
+ } catch (UnboundLabelException e) {
+ assertTrue(true);
+ }
+ }
+
+ public void testLabelBindingSerialization(){
+ Matrix m = matrixFactory(new double[][] { { 1, 3, 4 }, { 5, 2, 3 },
+ { 1, 4, 2 } });
+ assertNull("row bindings", m.getRowLabelBindings());
+ assertNull("col bindings", m.getColumnLabelBindings());
+ Map<String, Integer> rowBindings = new HashMap<String, Integer>();
+ rowBindings.put("Fee", 0);
+ rowBindings.put("Fie", 1);
+ rowBindings.put("Foe", 2);
+ m.setRowLabelBindings(rowBindings);
+ assertEquals("row", rowBindings, m.getRowLabelBindings());
+ Map<String, Integer> colBindings = new HashMap<String, Integer>();
+ colBindings.put("Foo", 0);
+ colBindings.put("Bar", 1);
+ colBindings.put("Baz", 2);
+ m.setColumnLabelBindings(colBindings);
+ String json = m.asFormatString();
+ Matrix mm = AbstractMatrix.decodeMatrix(json);
+ assertEquals("Fee", m.get(0, 1), mm.get("Fee", "Bar"));
+ }
}