You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/10/03 09:51:04 UTC

svn commit: r1178324 - in /mahout/trunk/core/src: main/java/org/apache/mahout/classifier/ main/java/org/apache/mahout/math/ test/java/org/apache/mahout/classifier/ test/java/org/apache/mahout/math/

Author: srowen
Date: Mon Oct  3 07:51:03 2011
New Revision: 1178324

URL: http://svn.apache.org/viewvc?rev=1178324&view=rev
Log:
MAHOUT-812 help make confusion matrix writable

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Mon Oct  3 07:51:03 2011
@@ -1,9 +1,9 @@
 /**
- * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
  * this work for additional information regarding copyright ownership.
  * The ASF licenses this file to You under the Apache License, Version 2.0
  * (the "License"); you may not use this file except in compliance with
+ * Licensed to the Apache Software Foundation (ASF) under one or more
  * the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
@@ -17,12 +17,18 @@
 
 package org.apache.mahout.classifier;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
 
+import com.google.common.collect.Maps;
 import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
 
 import com.google.common.base.Preconditions;
 
@@ -127,6 +133,46 @@ public class ConfusionMatrix {
     return this;
   }
   
+  public Matrix getMatrix() {
+	  int length = confusionMatrix.length;
+	  Matrix m = new DenseMatrix(length, length);
+	  for (int r = 0; r < length; r++) {
+		  for (int c = 0; c < length; c++) {
+			  m.set(r, c, confusionMatrix[r][c]);
+		  }
+	  }
+	  Map<String,Integer> labels = Maps.newHashMap();
+	  for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+		  labels.put(entry.getKey(), entry.getValue());
+	  }
+	  m.setRowLabelBindings(labels);
+	  m.setColumnLabelBindings(labels);
+	  return m;
+  }
+
+  public void setMatrix(Matrix m) {
+	  int length = confusionMatrix.length;
+	  if (m.numRows() != m.numCols()) {
+      throw new CardinalityException(m.numRows(), m.numCols());
+    }
+    if (m.numRows() != length) {
+      throw new CardinalityException(m.numRows(), length);
+    }
+	  for (int r = 0; r < length; r++) {
+		  for (int c = 0; c < length; c++) {
+			  confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+		  }
+	  }
+	  Map<String,Integer> labels = m.getRowLabelBindings();
+	  if (labels == null) {
+      labels = m.getColumnLabelBindings();
+    }
+    labelMap.clear();    
+	  if (labels != null) {
+      labelMap.putAll(labels);
+	  }
+  }
+  
   @Override
   public String toString() {
     StringBuilder returnString = new StringBuilder(200);

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java Mon Oct  3 07:51:03 2011
@@ -18,19 +18,23 @@
 package org.apache.mahout.math;
 
 import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
 import org.apache.hadoop.io.Writable;
 
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.HashMap;
 import java.util.Map;
 
 public class MatrixWritable implements Writable {
 
+  private static final int FLAG_DENSE = 0x01;
+  private static final int FLAG_SEQUENTIAL = 0x02;
+  private static final int FLAG_LABELS = 0x04;
+  private static final int NUM_FLAGS = 3;
+
   private Matrix matrix;
-  private static final int NUM_FLAGS = 2;
-  private static final int FLAG_DENSE = 1;
-  private static final int FLAG_SEQUENTIAL = 2;
 
   public MatrixWritable() {
   }
@@ -103,6 +107,7 @@ public class MatrixWritable implements W
     Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: %d", Integer.toString(flags, 2));
     boolean dense = (flags & FLAG_DENSE) != 0;
     boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
+    boolean hasLabels = (flags & FLAG_LABELS) != 0;
 
     int rows = in.readInt();
     int columns = in.readInt();
@@ -118,6 +123,18 @@ public class MatrixWritable implements W
       r.viewRow(row).assign(VectorWritable.readVector(in));
     }
 
+    if (hasLabels) {
+    	Map<String,Integer> columnLabelBindings = Maps.newHashMap();
+    	Map<String,Integer> rowLabelBindings = Maps.newHashMap();
+    	readLabels(in, columnLabelBindings, rowLabelBindings);
+    	if (!columnLabelBindings.isEmpty()) {
+    		r.setColumnLabelBindings(columnLabelBindings);
+    	}
+    	if (!rowLabelBindings.isEmpty()) {
+    		r.setRowLabelBindings(rowLabelBindings);
+    	}
+    }
+
     return r;
   }
 
@@ -131,6 +148,9 @@ public class MatrixWritable implements W
     if (row.isSequentialAccess()) {
       flags |= FLAG_SEQUENTIAL;
     }
+    if (matrix.getRowLabelBindings() != null || matrix.getColumnLabelBindings() != null) {
+      flags |= FLAG_LABELS;
+    }
     out.writeInt(flags);
 
     out.writeInt(matrix.rowSize());
@@ -139,5 +159,8 @@ public class MatrixWritable implements W
     for (int i = 0; i < matrix.rowSize(); i++) {
       VectorWritable.writeVector(out, matrix.viewRow(i), false);
     }
+    if ((flags & FLAG_LABELS) != 0) {
+    	writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
+    }
   }
 }

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java?rev=1178324&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java Mon Oct  3 07:51:03 2011
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public final class ConfusionMatrixTest extends MahoutTestCase {
+
+  private static final int[][] VALUES = {{2, 3}, {10, 20}};
+  private static final String[] LABELS = {"Label1", "Label2"};
+  private static final String DEFAULT_LABEL = "other";
+  
+  @Test
+  public void testBuild() {
+    ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
+    checkValues(cm);
+    checkAccuracy(cm);
+  }
+
+  @Test
+  public void testGetMatrix() {
+	    ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
+	    Matrix m = cm.getMatrix();
+	    Map<String, Integer> rowLabels = m.getRowLabelBindings();
+	    assertEquals(cm.getLabels().size(), m.numCols());
+	    assertTrue(rowLabels.keySet().contains(LABELS[0]));
+	    assertTrue(rowLabels.keySet().contains(LABELS[1]));
+	    assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+	    assertEquals(2, cm.getCorrect(LABELS[0]));
+	    assertEquals(20, cm.getCorrect(LABELS[1]));
+	    assertEquals(0, cm.getCorrect(DEFAULT_LABEL));
+  }
+
+  private static void checkValues(ConfusionMatrix cm) {
+    int[][] counts = cm.getConfusionMatrix();
+    cm.toString();
+    assertEquals(counts.length, counts[0].length);
+    assertEquals(3, counts.length);
+    assertEquals(VALUES[0][0], counts[0][0]);
+    assertEquals(VALUES[0][1], counts[0][1]);
+    assertEquals(VALUES[1][0], counts[1][0]);
+    assertEquals(VALUES[1][1], counts[1][1]);
+    assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
+    assertEquals(0, counts[0][2]);
+    assertEquals(0, counts[1][2]);
+    assertEquals(3, cm.getLabels().size());
+    assertTrue(cm.getLabels().contains(LABELS[0]));
+    assertTrue(cm.getLabels().contains(LABELS[1]));
+    assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
+
+  }
+
+  private static void checkAccuracy(ConfusionMatrix cm) {
+    Collection<String> labelstrs = cm.getLabels();
+    assertEquals(3, labelstrs.size());
+    assertEquals(40.0, cm.getAccuracy("Label1"), EPSILON);
+    assertEquals(66.666666667, cm.getAccuracy("Label2"), EPSILON);
+    assertTrue(Double.isNaN(cm.getAccuracy("other")));
+  }
+  
+  private static ConfusionMatrix fillCM(int[][] values, String[] labels, String defaultLabel) {
+    Collection<String> labelList = new ArrayList<String>();
+    labelList.add(labels[0]);
+    labelList.add(labels[1]);
+    ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
+    int[][] v = cm.getConfusionMatrix();
+    v[0][0] = values[0][0];
+    v[0][1] = values[0][1];
+    v[1][0] = values[1][0];
+    v[1][1] = values[1][1];
+    return cm;
+  }
+  
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1178324&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java Mon Oct  3 07:51:03 2011
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+public final class MatrixWritableTest extends MahoutTestCase {
+
+	@Test
+	public void testSparseMatrixWritable() throws Exception {
+		Matrix m = new SparseMatrix(5, 5);
+		m.set(1, 2, 3.0);
+		m.set(3, 4, 5.0);
+		Map<String, Integer> bindings = new HashMap<String, Integer>();
+		bindings.put("A", 0);
+		bindings.put("B", 1);
+		bindings.put("C", 2);
+		bindings.put("D", 3);
+		bindings.put("default", 4);
+		m.setRowLabelBindings(bindings);
+		doTestMatrixWritableEquals(m);
+	}
+
+	@Test
+	public void testDenseMatrixWritable() throws Exception {
+		Matrix m = new DenseMatrix(5,5);
+		m.set(1, 2, 3.0);
+		m.set(3, 4, 5.0);
+		Map<String, Integer> bindings = new HashMap<String, Integer>();
+		bindings.put("A", 0);
+		bindings.put("B", 1);
+		bindings.put("C", 2);
+		bindings.put("D", 3);
+		bindings.put("default", 4);
+		m.setColumnLabelBindings(bindings);
+		doTestMatrixWritableEquals(m);
+	}
+
+	private static void doTestMatrixWritableEquals(Matrix m) throws IOException {
+		Writable matrixWritable = new MatrixWritable(m);
+		MatrixWritable matrixWritable2 = new MatrixWritable();
+		writeAndRead(matrixWritable, matrixWritable2);
+		Matrix m2 = matrixWritable2.get();
+		compareMatrices(m, m2);  // not sure this works?
+	}
+
+	private static void compareMatrices(Matrix m, Matrix m2) {
+		assertEquals(m.numRows(), m2.numRows());
+		assertEquals(m.numCols(), m2.numCols());
+		for(int r = 0; r < m.numRows(); r++) {
+			for(int c = 0; c < m.numCols(); c++) {
+				assertEquals(m.get(r, c), m2.get(r, c), EPSILON);
+			}
+		}
+		Map<String,Integer> bindings = m.getRowLabelBindings();
+		Map<String, Integer> bindings2 = m2.getRowLabelBindings();
+		assertEquals(bindings == null, bindings2 == null);
+		if (bindings != null) {
+			assertEquals(bindings.size(), m.numRows());
+			assertEquals(bindings.size(), bindings2.size());
+			for(Map.Entry<String,Integer> entry : bindings.entrySet()) {
+				assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+			}
+		}
+		bindings = m.getColumnLabelBindings();
+		bindings2 = m2.getColumnLabelBindings();
+		assertEquals(bindings == null, bindings2 == null);
+		if (bindings != null) {
+			assertEquals(bindings.size(), bindings2.size());
+			for(Map.Entry<String,Integer> entry : bindings.entrySet()) {
+				assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+			}
+		}
+	}
+
+	private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
+		ByteArrayOutputStream baos = new ByteArrayOutputStream();
+		DataOutputStream dos = new DataOutputStream(baos);
+		try {
+			toWrite.write(dos);
+		} finally {
+			Closeables.closeQuietly(dos);
+		}
+
+		ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+		DataInputStream dis = new DataInputStream(bais);
+		try {
+			toRead.readFields(dis);
+		} finally {
+			Closeables.closeQuietly(dis);
+		}
+	}
+
+
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java Mon Oct  3 07:51:03 2011
@@ -53,11 +53,27 @@ public final class VectorWritableTest ex
     doTestVectorWritableEquals(v);
   }
 
+  @Test
+  public void testNamedVectorWritable() throws Exception {
+    Vector v = new DenseVector(5);
+    v = new NamedVector(v, "Victor");
+    v.set(1, 3.0);
+    v.set(3, 5.0);
+    doTestVectorWritableEquals(v);
+  }
+
   private static void doTestVectorWritableEquals(Vector v) throws IOException {
     Writable vectorWritable = new VectorWritable(v);
     VectorWritable vectorWritable2 = new VectorWritable();
     writeAndRead(vectorWritable, vectorWritable2);
     Vector v2 = vectorWritable2.get();
+    if (v instanceof NamedVector) {
+    	assertTrue(v2 instanceof NamedVector);
+    	NamedVector nv = (NamedVector) v;
+    	NamedVector nv2 = (NamedVector) v2;
+    	assertEquals(nv.getName(), nv2.getName());
+    	assertEquals("Victor", nv.getName());
+    }
     assertEquals(v, v2);
   }