You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/10/03 09:51:04 UTC
svn commit: r1178324 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/classifier/ main/java/org/apache/mahout/math/
test/java/org/apache/mahout/classifier/ test/java/org/apache/mahout/math/
Author: srowen
Date: Mon Oct 3 07:51:03 2011
New Revision: 1178324
URL: http://svn.apache.org/viewvc?rev=1178324&view=rev
Log:
MAHOUT-812 help make confusion matrix writable
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Mon Oct 3 07:51:03 2011
@@ -1,9 +1,9 @@
/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
+ * Licensed to the Apache Software Foundation (ASF) under one or more
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
@@ -17,12 +17,18 @@
package org.apache.mahout.classifier;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
+import com.google.common.collect.Maps;
import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
import com.google.common.base.Preconditions;
@@ -127,6 +133,46 @@ public class ConfusionMatrix {
return this;
}
+ public Matrix getMatrix() {
+ int length = confusionMatrix.length;
+ Matrix m = new DenseMatrix(length, length);
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ m.set(r, c, confusionMatrix[r][c]);
+ }
+ }
+ Map<String,Integer> labels = Maps.newHashMap();
+ for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+ labels.put(entry.getKey(), entry.getValue());
+ }
+ m.setRowLabelBindings(labels);
+ m.setColumnLabelBindings(labels);
+ return m;
+ }
+
+ public void setMatrix(Matrix m) {
+ int length = confusionMatrix.length;
+ if (m.numRows() != m.numCols()) {
+ throw new CardinalityException(m.numRows(), m.numCols());
+ }
+ if (m.numRows() != length) {
+ throw new CardinalityException(m.numRows(), length);
+ }
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+ }
+ }
+ Map<String,Integer> labels = m.getRowLabelBindings();
+ if (labels == null) {
+ labels = m.getColumnLabelBindings();
+ }
+ labelMap.clear();
+ if (labels != null) {
+ labelMap.putAll(labels);
+ }
+ }
+
@Override
public String toString() {
StringBuilder returnString = new StringBuilder(200);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/MatrixWritable.java Mon Oct 3 07:51:03 2011
@@ -18,19 +18,23 @@
package org.apache.mahout.math;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
import org.apache.hadoop.io.Writable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
+import java.util.HashMap;
import java.util.Map;
public class MatrixWritable implements Writable {
+ private static final int FLAG_DENSE = 0x01;
+ private static final int FLAG_SEQUENTIAL = 0x02;
+ private static final int FLAG_LABELS = 0x04;
+ private static final int NUM_FLAGS = 3;
+
private Matrix matrix;
- private static final int NUM_FLAGS = 2;
- private static final int FLAG_DENSE = 1;
- private static final int FLAG_SEQUENTIAL = 2;
public MatrixWritable() {
}
@@ -103,6 +107,7 @@ public class MatrixWritable implements W
Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: %d", Integer.toString(flags, 2));
boolean dense = (flags & FLAG_DENSE) != 0;
boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
+ boolean hasLabels = (flags & FLAG_LABELS) != 0;
int rows = in.readInt();
int columns = in.readInt();
@@ -118,6 +123,18 @@ public class MatrixWritable implements W
r.viewRow(row).assign(VectorWritable.readVector(in));
}
+ if (hasLabels) {
+ Map<String,Integer> columnLabelBindings = Maps.newHashMap();
+ Map<String,Integer> rowLabelBindings = Maps.newHashMap();
+ readLabels(in, columnLabelBindings, rowLabelBindings);
+ if (!columnLabelBindings.isEmpty()) {
+ r.setColumnLabelBindings(columnLabelBindings);
+ }
+ if (!rowLabelBindings.isEmpty()) {
+ r.setRowLabelBindings(rowLabelBindings);
+ }
+ }
+
return r;
}
@@ -131,6 +148,9 @@ public class MatrixWritable implements W
if (row.isSequentialAccess()) {
flags |= FLAG_SEQUENTIAL;
}
+ if (matrix.getRowLabelBindings() != null || matrix.getColumnLabelBindings() != null) {
+ flags |= FLAG_LABELS;
+ }
out.writeInt(flags);
out.writeInt(matrix.rowSize());
@@ -139,5 +159,8 @@ public class MatrixWritable implements W
for (int i = 0; i < matrix.rowSize(); i++) {
VectorWritable.writeVector(out, matrix.viewRow(i), false);
}
+ if ((flags & FLAG_LABELS) != 0) {
+ writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
+ }
}
}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java?rev=1178324&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java Mon Oct 3 07:51:03 2011
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public final class ConfusionMatrixTest extends MahoutTestCase {
+
+ private static final int[][] VALUES = {{2, 3}, {10, 20}};
+ private static final String[] LABELS = {"Label1", "Label2"};
+ private static final String DEFAULT_LABEL = "other";
+
+ @Test
+ public void testBuild() {
+ ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
+ checkValues(cm);
+ checkAccuracy(cm);
+ }
+
+ @Test
+ public void testGetMatrix() {
+ ConfusionMatrix cm = fillCM(VALUES, LABELS, DEFAULT_LABEL);
+ Matrix m = cm.getMatrix();
+ Map<String, Integer> rowLabels = m.getRowLabelBindings();
+ assertEquals(cm.getLabels().size(), m.numCols());
+ assertTrue(rowLabels.keySet().contains(LABELS[0]));
+ assertTrue(rowLabels.keySet().contains(LABELS[1]));
+ assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+ assertEquals(2, cm.getCorrect(LABELS[0]));
+ assertEquals(20, cm.getCorrect(LABELS[1]));
+ assertEquals(0, cm.getCorrect(DEFAULT_LABEL));
+ }
+
+ private static void checkValues(ConfusionMatrix cm) {
+ int[][] counts = cm.getConfusionMatrix();
+ cm.toString();
+ assertEquals(counts.length, counts[0].length);
+ assertEquals(3, counts.length);
+ assertEquals(VALUES[0][0], counts[0][0]);
+ assertEquals(VALUES[0][1], counts[0][1]);
+ assertEquals(VALUES[1][0], counts[1][0]);
+ assertEquals(VALUES[1][1], counts[1][1]);
+ assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
+ assertEquals(0, counts[0][2]);
+ assertEquals(0, counts[1][2]);
+ assertEquals(3, cm.getLabels().size());
+ assertTrue(cm.getLabels().contains(LABELS[0]));
+ assertTrue(cm.getLabels().contains(LABELS[1]));
+ assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
+
+ }
+
+ private static void checkAccuracy(ConfusionMatrix cm) {
+ Collection<String> labelstrs = cm.getLabels();
+ assertEquals(3, labelstrs.size());
+ assertEquals(40.0, cm.getAccuracy("Label1"), EPSILON);
+ assertEquals(66.666666667, cm.getAccuracy("Label2"), EPSILON);
+ assertTrue(Double.isNaN(cm.getAccuracy("other")));
+ }
+
+ private static ConfusionMatrix fillCM(int[][] values, String[] labels, String defaultLabel) {
+ Collection<String> labelList = new ArrayList<String>();
+ labelList.add(labels[0]);
+ labelList.add(labels[1]);
+ ConfusionMatrix cm = new ConfusionMatrix(labelList, defaultLabel);
+ int[][] v = cm.getConfusionMatrix();
+ v[0][0] = values[0][0];
+ v[0][1] = values[0][1];
+ v[1][0] = values[1][0];
+ v[1][1] = values[1][1];
+ return cm;
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java?rev=1178324&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/MatrixWritableTest.java Mon Oct 3 07:51:03 2011
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+public final class MatrixWritableTest extends MahoutTestCase {
+
+ @Test
+ public void testSparseMatrixWritable() throws Exception {
+ Matrix m = new SparseMatrix(5, 5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = new HashMap<String, Integer>();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ @Test
+ public void testDenseMatrixWritable() throws Exception {
+ Matrix m = new DenseMatrix(5,5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = new HashMap<String, Integer>();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ private static void doTestMatrixWritableEquals(Matrix m) throws IOException {
+ Writable matrixWritable = new MatrixWritable(m);
+ MatrixWritable matrixWritable2 = new MatrixWritable();
+ writeAndRead(matrixWritable, matrixWritable2);
+ Matrix m2 = matrixWritable2.get();
+ compareMatrices(m, m2); // not sure this works?
+ }
+
+ private static void compareMatrices(Matrix m, Matrix m2) {
+ assertEquals(m.numRows(), m2.numRows());
+ assertEquals(m.numCols(), m2.numCols());
+ for(int r = 0; r < m.numRows(); r++) {
+ for(int c = 0; c < m.numCols(); c++) {
+ assertEquals(m.get(r, c), m2.get(r, c), EPSILON);
+ }
+ }
+ Map<String,Integer> bindings = m.getRowLabelBindings();
+ Map<String, Integer> bindings2 = m2.getRowLabelBindings();
+ assertEquals(bindings == null, bindings2 == null);
+ if (bindings != null) {
+ assertEquals(bindings.size(), m.numRows());
+ assertEquals(bindings.size(), bindings2.size());
+ for(Map.Entry<String,Integer> entry : bindings.entrySet()) {
+ assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+ }
+ }
+ bindings = m.getColumnLabelBindings();
+ bindings2 = m2.getColumnLabelBindings();
+ assertEquals(bindings == null, bindings2 == null);
+ if (bindings != null) {
+ assertEquals(bindings.size(), bindings2.size());
+ for(Map.Entry<String,Integer> entry : bindings.entrySet()) {
+ assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+ }
+ }
+ }
+
+ private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ try {
+ toWrite.write(dos);
+ } finally {
+ Closeables.closeQuietly(dos);
+ }
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ DataInputStream dis = new DataInputStream(bais);
+ try {
+ toRead.readFields(dis);
+ } finally {
+ Closeables.closeQuietly(dis);
+ }
+ }
+
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java?rev=1178324&r1=1178323&r2=1178324&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/VectorWritableTest.java Mon Oct 3 07:51:03 2011
@@ -53,11 +53,27 @@ public final class VectorWritableTest ex
doTestVectorWritableEquals(v);
}
+ @Test
+ public void testNamedVectorWritable() throws Exception {
+ Vector v = new DenseVector(5);
+ v = new NamedVector(v, "Victor");
+ v.set(1, 3.0);
+ v.set(3, 5.0);
+ doTestVectorWritableEquals(v);
+ }
+
private static void doTestVectorWritableEquals(Vector v) throws IOException {
Writable vectorWritable = new VectorWritable(v);
VectorWritable vectorWritable2 = new VectorWritable();
writeAndRead(vectorWritable, vectorWritable2);
Vector v2 = vectorWritable2.get();
+ if (v instanceof NamedVector) {
+ assertTrue(v2 instanceof NamedVector);
+ NamedVector nv = (NamedVector) v;
+ NamedVector nv2 = (NamedVector) v2;
+ assertEquals(nv.getName(), nv2.getName());
+ assertEquals("Victor", nv.getName());
+ }
assertEquals(v, v2);
}