You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by to...@apache.org on 2015/05/12 19:05:41 UTC

svn commit: r1679006 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/utils/ test/org/apache/lucene/classification/ test/org/apache/lucene/classification/utils/

Author: tommaso
Date: Tue May 12 17:05:41 2015
New Revision: 1679006

URL: http://svn.apache.org/r1679006
Log:
LUCENE-6479 - added ConfusionMatrixGenerator

Added:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java   (with props)
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java   (with props)
Modified:
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java

Added: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java?rev=1679006&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java (added)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java Tue May 12 17:05:41 2015
@@ -0,0 +1,111 @@
+package org.apache.lucene.classification.utils;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.lucene.classification.ClassificationResult;
+import org.apache.lucene.classification.Classifier;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.StoredDocument;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * Utility class to generate the confusion matrix of a {@link Classifier}
+ */
+public class ConfusionMatrixGenerator {
+
+  private ConfusionMatrixGenerator() {
+
+  }
+
+  /**
+   * get the {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix} of a given {@link Classifier},
+   * generated on the given {@link LeafReader}, class and text fields.
+   *
+   * @param reader         the {@link LeafReader} containing the index used for creating the {@link Classifier}
+   * @param classifier     the {@link Classifier} whose confusion matrix has to be generated
+   * @param classFieldName the name of the Lucene field used as the classifier's output
+   * @param textFieldName  the nome the Lucene field used as the classifier's input
+   * @param <T>            the return type of the {@link ClassificationResult} returned by the given {@link Classifier}
+   * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
+   * @throws IOException if problems occurr while reading the index or using the classifier
+   */
+  public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T> classifier, String classFieldName, String textFieldName) throws IOException {
+
+    Map<String, Map<String, Long>> counts = new HashMap<>();
+
+    for (int i = 0; i < reader.maxDoc(); i++) {
+      StoredDocument doc = reader.document(i);
+      String correctAnswer = doc.get(classFieldName);
+
+      if (correctAnswer != null && correctAnswer.length() > 0) {
+
+        ClassificationResult<T> result = classifier.assignClass(doc.get(textFieldName));
+        T assignedClass = result.getAssignedClass();
+        String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString() : assignedClass.toString();
+
+        Map<String, Long> stringLongMap = counts.get(correctAnswer);
+        if (stringLongMap != null) {
+          Long aLong = stringLongMap.get(classified);
+          if (aLong != null) {
+            stringLongMap.put(classified, aLong + 1);
+          } else {
+            stringLongMap.put(classified, 1l);
+          }
+        } else {
+          stringLongMap = new HashMap<>();
+          stringLongMap.put(classified, 1l);
+          counts.put(correctAnswer, stringLongMap);
+        }
+
+      }
+    }
+    return new ConfusionMatrix(counts);
+  }
+
+  /**
+   * a confusion matrix, backed by a {@link Map} representing the linearized matrix
+   */
+  public static class ConfusionMatrix {
+
+    private final Map<String, Map<String, Long>> linearizedMatrix;
+
+    private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix) {
+      this.linearizedMatrix = linearizedMatrix;
+    }
+
+    /**
+     * get the linearized confusion matrix as a {@link Map}
+     * @return a {@link Map} whose keys are the correct answers and whose values are the actual answers' counts
+     */
+    public Map<String, Map<String, Long>> getLinearizedMatrix() {
+      return Collections.unmodifiableMap(linearizedMatrix);
+    }
+
+    @Override
+    public String toString() {
+      return "ConfusionMatrix{" +
+              "linearizedMatrix=" + linearizedMatrix +
+              '}';
+    }
+  }
+}

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1679006&r1=1679005&r2=1679006&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Tue May 12 17:05:41 2015
@@ -52,9 +52,9 @@ public abstract class ClassificationTest
   private Directory dir;
   private FieldType ft;
 
-  String textFieldName;
-  String categoryFieldName;
-  String booleanFieldName;
+  protected String textFieldName;
+  protected String categoryFieldName;
+  protected String booleanFieldName;
 
   @Override
   @Before

Added: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java?rev=1679006&view=auto
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java (added)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java Tue May 12 17:05:41 2015
@@ -0,0 +1,103 @@
+package org.apache.lucene.classification.utils;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.classification.BooleanPerceptronClassifier;
+import org.apache.lucene.classification.CachingNaiveBayesClassifier;
+import org.apache.lucene.classification.ClassificationTestBase;
+import org.apache.lucene.classification.Classifier;
+import org.apache.lucene.classification.KNearestNeighborClassifier;
+import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+/**
+ * Tests for {@link ConfusionMatrixGenerator}
+ */
+public class ConfusionMatrixGeneratorTest extends ClassificationTestBase {
+
+  @Test
+  public void testGetConfusionMatrixWithSNB() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = populateSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      assertNotNull(confusionMatrix.getLinearizedMatrix());
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testGetConfusionMatrixWithCNB() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = populateSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer, null, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      assertNotNull(confusionMatrix.getLinearizedMatrix());
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testGetConfusionMatrixWithKNN() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = populateSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      assertNotNull(confusionMatrix.getLinearizedMatrix());
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testGetConfusionMatrixWithBP() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = populateSampleIndex(analyzer);
+      Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer, null, 1, null, booleanFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader, classifier, booleanFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      assertNotNull(confusionMatrix.getLinearizedMatrix());
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+}
\ No newline at end of file