You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by tf...@apache.org on 2017/05/12 23:39:11 UTC

[44/58] [abbrv] lucene-solr:jira/solr-10233: LUCENE-7823 - added bm25 nb classifier

LUCENE-7823 - added bm25 nb classifier


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/89905001
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/89905001
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/89905001

Branch: refs/heads/jira/solr-10233
Commit: 89905001831de12dd8cb18647d3d54944899ccdf
Parents: fb56948
Author: Tommaso Teofili <to...@apache.org>
Authored: Thu May 11 16:26:01 2017 +0200
Committer: Tommaso Teofili <to...@apache.org>
Committed: Thu May 11 16:26:32 2017 +0200

----------------------------------------------------------------------
 .../lucene/classification/BM25NBClassifier.java | 243 +++++++++++++++++++
 .../classification/BM25NBClassifierTest.java    | 154 ++++++++++++
 2 files changed, 397 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/89905001/lucene/classification/src/java/org/apache/lucene/classification/BM25NBClassifier.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/BM25NBClassifier.java b/lucene/classification/src/java/org/apache/lucene/classification/BM25NBClassifier.java
new file mode 100644
index 0000000..1a74416
--- /dev/null
+++ b/lucene/classification/src/java/org/apache/lucene/classification/BM25NBClassifier.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.similarities.BM25Similarity;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * A classifier approximating naive bayes classifier by using pure queries on BM25.
+ *
+ * @lucene.experimental
+ */
+public class BM25NBClassifier implements Classifier<BytesRef> {
+
+  /**
+   * {@link IndexReader} used to access the {@link Classifier}'s
+   * index
+   */
+  private final IndexReader indexReader;
+
+  /**
+   * names of the fields to be used as input text
+   */
+  private final String[] textFieldNames;
+
+  /**
+   * name of the field to be used as a class / category output
+   */
+  private final String classFieldName;
+
+  /**
+   * {@link Analyzer} to be used for tokenizing unseen input text
+   */
+  private final Analyzer analyzer;
+
+  /**
+   * {@link IndexSearcher} to run searches on the index for retrieving frequencies
+   */
+  private final IndexSearcher indexSearcher;
+
+  /**
+   * {@link Query} used to eventually filter the document set to be used to classify
+   */
+  private final Query query;
+
+  /**
+   * Creates a new NaiveBayes classifier.
+   *
+   * @param indexReader    the reader on the index to be used for classification
+   * @param analyzer       an {@link Analyzer} used to analyze unseen text
+   * @param query          a {@link Query} to eventually filter the docs used for training the classifier, or {@code null}
+   *                       if all the indexed docs should be used
+   * @param classFieldName the name of the field used as the output for the classifier NOTE: must not be havely analyzed
+   *                       as the returned class will be a token indexed for this field
+   * @param textFieldNames the name of the fields used as the inputs for the classifier, NO boosting supported per field
+   */
+  public BM25NBClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
+    this.indexReader = indexReader;
+    this.indexSearcher = new IndexSearcher(this.indexReader);
+    this.indexSearcher.setSimilarity(new BM25Similarity());
+    this.textFieldNames = textFieldNames;
+    this.classFieldName = classFieldName;
+    this.analyzer = analyzer;
+    this.query = query;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
+    return assignClassNormalizedList(inputDocument).get(0);
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
+    List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
+    Collections.sort(assignedClasses);
+    return assignedClasses;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
+    List<ClassificationResult<BytesRef>> assignedClasses = assignClassNormalizedList(text);
+    Collections.sort(assignedClasses);
+    return assignedClasses.subList(0, max);
+  }
+
+  /**
+   * Calculate probabilities for all classes for a given input text
+   *
+   * @param inputDocument the input text as a {@code String}
+   * @return a {@code List} of {@code ClassificationResult}, one for each existing class
+   * @throws IOException if assigning probabilities fails
+   */
+  private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
+    List<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<>();
+
+    Terms classes = MultiFields.getTerms(indexReader, classFieldName);
+    TermsEnum classesEnum = classes.iterator();
+    BytesRef next;
+    String[] tokenizedText = tokenize(inputDocument);
+    while ((next = classesEnum.next()) != null) {
+      if (next.length > 0) {
+        Term term = new Term(this.classFieldName, next);
+        assignedClasses.add(new ClassificationResult<>(term.bytes(), calculateLogPrior(term) + calculateLogLikelihood(tokenizedText, term)));
+      }
+    }
+
+    return normClassificationResults(assignedClasses);
+  }
+
+  /**
+   * Normalize the classification results based on the max score available
+   *
+   * @param assignedClasses the list of assigned classes
+   * @return the normalized results
+   */
+  private ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
+    // normalization; the values transforms to a 0-1 range
+    ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+    if (!assignedClasses.isEmpty()) {
+      Collections.sort(assignedClasses);
+      // this is a negative number closest to 0 = a
+      double smax = assignedClasses.get(0).getScore();
+
+      double sumLog = 0;
+      // log(sum(exp(x_n-a)))
+      for (ClassificationResult<BytesRef> cr : assignedClasses) {
+        // getScore-smax <=0 (both negative, smax is the smallest abs()
+        sumLog += Math.exp(cr.getScore() - smax);
+      }
+      // loga=a+log(sum(exp(x_n-a))) = log(sum(exp(x_n)))
+      double loga = smax;
+      loga += Math.log(sumLog);
+
+      // 1/sum*x = exp(log(x))*1/sum = exp(log(x)-log(sum))
+      for (ClassificationResult<BytesRef> cr : assignedClasses) {
+        double scoreDiff = cr.getScore() - loga;
+        returnList.add(new ClassificationResult<>(cr.getAssignedClass(), Math.exp(scoreDiff)));
+      }
+    }
+    return returnList;
+  }
+
+  /**
+   * tokenize a <code>String</code> on this classifier's text fields and analyzer
+   *
+   * @param text the <code>String</code> representing an input text (to be classified)
+   * @return a <code>String</code> array of the resulting tokens
+   * @throws IOException if tokenization fails
+   */
+  private String[] tokenize(String text) throws IOException {
+    Collection<String> result = new LinkedList<>();
+    for (String textFieldName : textFieldNames) {
+      try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
+        CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
+        tokenStream.reset();
+        while (tokenStream.incrementToken()) {
+          result.add(charTermAttribute.toString());
+        }
+        tokenStream.end();
+      }
+    }
+    return result.toArray(new String[result.size()]);
+  }
+
+  private double calculateLogLikelihood(String[] tokens, Term term) throws IOException {
+    double result = 0d;
+    for (String word : tokens) {
+      result += Math.log(getTermProbForClass(term, word));
+    }
+    return result;
+  }
+
+  private double getTermProbForClass(Term classTerm, String... words) throws IOException {
+    BooleanQuery.Builder builder = new BooleanQuery.Builder();
+    builder.add(new BooleanClause(new TermQuery(classTerm), BooleanClause.Occur.MUST));
+    for (String textFieldName : textFieldNames) {
+      for (String word : words) {
+        builder.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
+      }
+    }
+    if (query != null) {
+      builder.add(query, BooleanClause.Occur.MUST);
+    }
+    TopDocs search = indexSearcher.search(builder.build(), 1);
+    return search.totalHits > 0 ? search.getMaxScore() : 1;
+  }
+
+  private double calculateLogPrior(Term term) throws IOException {
+    TermQuery termQuery = new TermQuery(term);
+    BooleanQuery.Builder bq = new BooleanQuery.Builder();
+    bq.add(termQuery, BooleanClause.Occur.MUST);
+    if (query != null) {
+      bq.add(query, BooleanClause.Occur.MUST);
+    }
+    TopDocs topDocs = indexSearcher.search(bq.build(), 1);
+    return topDocs.totalHits > 0 ? Math.log(topDocs.getMaxScore()) : 0;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/89905001/lucene/classification/src/test/org/apache/lucene/classification/BM25NBClassifierTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/BM25NBClassifierTest.java b/lucene/classification/src/test/org/apache/lucene/classification/BM25NBClassifierTest.java
new file mode 100644
index 0000000..724f2a6
--- /dev/null
+++ b/lucene/classification/src/test/org/apache/lucene/classification/BM25NBClassifierTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.analysis.Tokenizer;
+import org.apache.lucene.analysis.core.KeywordTokenizer;
+import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
+import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+/**
+ * Tests for {@link BM25NBClassifier}
+ */
+public class BM25NBClassifierTest extends ClassificationTestBase<BytesRef> {
+
+  @Test
+  public void testBasicUsage() throws Exception {
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = getSampleIndex(analyzer);
+      BM25NBClassifier classifier = new BM25NBClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName);
+      checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testBasicUsageWithQuery() throws Exception {
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = getSampleIndex(analyzer);
+      TermQuery query = new TermQuery(new Term(textFieldName, "not"));
+      BM25NBClassifier classifier = new BM25NBClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName);
+      checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testNGramUsage() throws Exception {
+    LeafReader leafReader = null;
+    try {
+      Analyzer analyzer = new NGramAnalyzer();
+      leafReader = getSampleIndex(analyzer);
+      BM25NBClassifier classifier = new BM25NBClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName);
+      checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
+  }
+
+  private class NGramAnalyzer extends Analyzer {
+    @Override
+    protected TokenStreamComponents createComponents(String fieldName) {
+      final Tokenizer tokenizer = new KeywordTokenizer();
+      return new TokenStreamComponents(tokenizer, new ReverseStringFilter(new EdgeNGramTokenFilter(new ReverseStringFilter(tokenizer), 10, 20)));
+    }
+  }
+
+  @Test
+  public void testPerformance() throws Exception {
+    MockAnalyzer analyzer = new MockAnalyzer(random());
+    LeafReader leafReader = getRandomIndex(analyzer, 100);
+    try {
+      long trainStart = System.currentTimeMillis();
+      BM25NBClassifier classifier = new BM25NBClassifier(leafReader,
+          analyzer, null, categoryFieldName, textFieldName);
+      long trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime < 10000);
+
+      long evaluationStart = System.currentTimeMillis();
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+          classifier, categoryFieldName, textFieldName, -1);
+      assertNotNull(confusionMatrix);
+      long evaluationEnd = System.currentTimeMillis();
+      long evaluationTime = evaluationEnd - evaluationStart;
+      assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime < 120000);
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
+
+      double f1 = confusionMatrix.getF1Measure();
+      assertTrue(f1 >= 0d);
+      assertTrue(f1 <= 1d);
+
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
+
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+
+      Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
+      TermsEnum iterator = terms.iterator();
+      BytesRef term;
+      while ((term = iterator.next()) != null) {
+        String s = term.utf8ToString();
+        recall = confusionMatrix.getRecall(s);
+        assertTrue(recall >= 0d);
+        assertTrue(recall <= 1d);
+        precision = confusionMatrix.getPrecision(s);
+        assertTrue(precision >= 0d);
+        assertTrue(precision <= 1d);
+        double f1Measure = confusionMatrix.getF1Measure(s);
+        assertTrue(f1Measure >= 0d);
+        assertTrue(f1Measure <= 1d);
+      }
+
+    } finally {
+      leafReader.close();
+    }
+
+  }
+
+}