You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by th...@apache.org on 2016/04/14 01:21:29 UTC

[13/50] lucene-solr:jira/SOLR-8908: LUCENE-7196 - guaranteed class coverage in split indexes through grouping by class

LUCENE-7196 - guaranteed class coverage in split indexes through grouping by class


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/112078ea
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/112078ea
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/112078ea

Branch: refs/heads/jira/SOLR-8908
Commit: 112078eaf909bdda27a092d13a8431c046a82ac0
Parents: d7867b8
Author: Tommaso Teofili <to...@apache.org>
Authored: Fri Apr 8 16:59:41 2016 +0200
Committer: Tommaso Teofili <to...@apache.org>
Committed: Mon Apr 11 09:59:22 2016 +0200

----------------------------------------------------------------------
 lucene/classification/build.xml                 |   6 +-
 .../classification/utils/DatasetSplitter.java   | 126 +++++++++++++------
 .../classification/utils/DataSplitterTest.java  |  33 ++---
 3 files changed, 106 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/112078ea/lucene/classification/build.xml
----------------------------------------------------------------------
diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml
index 930b1fa..fd15239 100644
--- a/lucene/classification/build.xml
+++ b/lucene/classification/build.xml
@@ -27,6 +27,7 @@
   <path id="classpath">
     <path refid="base.classpath"/>
     <pathelement path="${queries.jar}"/>
+    <pathelement path="${grouping.jar}"/>
   </path>
 
   <path id="test.classpath">
@@ -35,15 +36,16 @@
     <path refid="test.base.classpath"/>
   </path>
 
-  <target name="compile-core" depends="jar-queries,jar-analyzers-common,common.compile-core" />
+  <target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core" />
 
   <target name="jar-core" depends="common.jar-core" />
 
-  <target name="javadocs" depends="javadocs-queries,compile-core,check-javadocs-uptodate"
+  <target name="javadocs" depends="javadocs-grouping,javadocs-misc,compile-core,check-javadocs-uptodate"
           unless="javadocs-uptodate-${name}">
     <invoke-module-javadoc>
       <links>
         <link href="../queries"/>
+        <link href="../group"/>
       </links>
     </invoke-module-javadoc>
   </target>

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/112078ea/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
index 0b03b94..fce786b 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
@@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils;
 
 
 import java.io.IOException;
+import java.util.HashMap;
 
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.document.Document;
@@ -28,11 +29,16 @@ import org.apache.lucene.index.IndexWriter;
 import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.IndexableField;
 import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.Terms;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.search.ScoreDoc;
-import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.Sort;
+import org.apache.lucene.search.grouping.GroupDocs;
+import org.apache.lucene.search.grouping.GroupingSearch;
+import org.apache.lucene.search.grouping.TopGroups;
 import org.apache.lucene.store.Directory;
+import org.apache.lucene.uninverting.UninvertingReader;
 
 /**
  * Utility class for creating training / test / cross validation indexes from the original index.
@@ -61,67 +67,78 @@ public class DatasetSplitter {
    * @param testIndex            a {@link Directory} used to write the test index
    * @param crossValidationIndex a {@link Directory} used to write the cross validation index
    * @param analyzer             {@link Analyzer} used to create the new docs
+   * @param termVectors          {@code true} if term vectors should be kept
+   * @param classFieldName       names of the field used as the label for classification
    * @param fieldNames           names of fields that need to be put in the new indexes or <code>null</code> if all should be used
    * @throws IOException if any writing operation fails on any of the indexes
    */
   public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex,
-                    Analyzer analyzer, String... fieldNames) throws IOException {
+                    Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException {
 
     // create IWs for train / test / cv IDXs
     IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
     IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
     IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
 
+    // try to get the exact no. of existing classes
+    Terms terms = originalIndex.terms(classFieldName);
+    long noOfClasses = -1;
+    if (terms != null) {
+      noOfClasses = terms.size();
+
+    }
+    if (noOfClasses == -1) {
+      noOfClasses = 10000; // fallback
+    }
+
+    HashMap<String, UninvertingReader.Type> mapping = new HashMap<>();
+    mapping.put(classFieldName, UninvertingReader.Type.SORTED);
+    UninvertingReader uninvertingReader = new UninvertingReader(originalIndex, mapping);
+
     try {
-      int size = originalIndex.maxDoc();
 
-      IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
-      TopDocs topDocs = indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE);
+      IndexSearcher indexSearcher = new IndexSearcher(uninvertingReader);
+      GroupingSearch gs = new GroupingSearch(classFieldName);
+      gs.setGroupSort(Sort.INDEXORDER);
+      gs.setSortWithinGroup(Sort.INDEXORDER);
+      gs.setAllGroups(true);
+      gs.setGroupDocsLimit(originalIndex.maxDoc());
+      TopGroups<Object> topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, (int) noOfClasses);
 
       // set the type to be indexed, stored, with term vectors
       FieldType ft = new FieldType(TextField.TYPE_STORED);
-      ft.setStoreTermVectors(true);
-      ft.setStoreTermVectorOffsets(true);
-      ft.setStoreTermVectorPositions(true);
+      if (termVectors) {
+        ft.setStoreTermVectors(true);
+        ft.setStoreTermVectorOffsets(true);
+        ft.setStoreTermVectorPositions(true);
+      }
 
       int b = 0;
 
       // iterate over existing documents
-      for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
-
-        // create a new document for indexing
-        Document doc = new Document();
-        Document document = originalIndex.document(scoreDoc.doc);
-        if (fieldNames != null && fieldNames.length > 0) {
-          for (String fieldName : fieldNames) {
-            IndexableField field = document.getField(fieldName);
-            if (field != null) {
-              doc.add(new Field(fieldName, field.stringValue(), ft));
-            }
+      for (GroupDocs group : topGroups.groups) {
+        int totalHits = group.totalHits;
+        double testSize = totalHits * testRatio;
+        int tc = 0;
+        double cvSize = totalHits * crossValidationRatio;
+        int cvc = 0;
+        for (ScoreDoc scoreDoc : group.scoreDocs) {
+
+          // create a new document for indexing
+          Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
+
+          // add it to one of the IDXs
+          if (b % 2 == 0 && tc < testSize) {
+            testWriter.addDocument(doc);
+            tc++;
+          } else if (cvc < cvSize) {
+            cvWriter.addDocument(doc);
+            cvc++;
+          } else {
+            trainingWriter.addDocument(doc);
           }
-        } else {
-          for (IndexableField field : document.getFields()) {
-            if (field.readerValue() != null) {
-              doc.add(new Field(field.name(), field.readerValue(), ft));
-            } else if (field.binaryValue() != null) {
-              doc.add(new Field(field.name(), field.binaryValue(), ft));
-            } else if (field.stringValue() != null) {
-              doc.add(new Field(field.name(), field.stringValue(), ft));
-            } else if (field.numericValue() != null) {
-              doc.add(new Field(field.name(), field.numericValue().toString(), ft));
-            }
-          }
-        }
-
-        // add it to one of the IDXs
-        if (b % 2 == 0 && testWriter.maxDoc() < size * testRatio) {
-          testWriter.addDocument(doc);
-        } else if (cvWriter.maxDoc() < size * crossValidationRatio) {
-          cvWriter.addDocument(doc);
-        } else {
-          trainingWriter.addDocument(doc);
+          b++;
         }
-        b++;
       }
       // commit
       testWriter.commit();
@@ -139,7 +156,34 @@ public class DatasetSplitter {
       testWriter.close();
       cvWriter.close();
       trainingWriter.close();
+      uninvertingReader.close();
+    }
+  }
+
+  private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
+    Document doc = new Document();
+    Document document = originalIndex.document(scoreDoc.doc);
+    if (fieldNames != null && fieldNames.length > 0) {
+      for (String fieldName : fieldNames) {
+        IndexableField field = document.getField(fieldName);
+        if (field != null) {
+          doc.add(new Field(fieldName, field.stringValue(), ft));
+        }
+      }
+    } else {
+      for (IndexableField field : document.getFields()) {
+        if (field.readerValue() != null) {
+          doc.add(new Field(field.name(), field.readerValue(), ft));
+        } else if (field.binaryValue() != null) {
+          doc.add(new Field(field.name(), field.binaryValue(), ft));
+        } else if (field.stringValue() != null) {
+          doc.add(new Field(field.name(), field.stringValue(), ft));
+        } else if (field.numericValue() != null) {
+          doc.add(new Field(field.name(), field.numericValue().toString(), ft));
+        }
+      }
     }
+    return doc;
   }
 
 }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/112078ea/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java
index 2984bb5..0b6f077 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java
@@ -17,27 +17,28 @@
 package org.apache.lucene.classification.utils;
 
 
-import org.apache.lucene.analysis.Analyzer;
+import java.io.IOException;
+import java.util.Random;
+
 import org.apache.lucene.analysis.MockAnalyzer;
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
 import org.apache.lucene.document.FieldType;
+import org.apache.lucene.document.SortedDocValuesField;
 import org.apache.lucene.document.TextField;
-import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.store.BaseDirectoryWrapper;
 import org.apache.lucene.store.Directory;
-import org.apache.lucene.util.TestUtil;
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.TestUtil;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
-import java.io.IOException;
-import java.util.Random;
-
 /**
  * Testcase for {@link org.apache.lucene.classification.utils.DatasetSplitter}
  */
@@ -47,9 +48,9 @@ public class DataSplitterTest extends LuceneTestCase {
   private RandomIndexWriter indexWriter;
   private Directory dir;
 
-  private String textFieldName = "text";
-  private String classFieldName = "class";
-  private String idFieldName = "id";
+  private static final String textFieldName = "text";
+  private static final String classFieldName = "class";
+  private static final String idFieldName = "id";
 
   @Override
   @Before
@@ -65,11 +66,11 @@ public class DataSplitterTest extends LuceneTestCase {
 
     Document doc;
     Random rnd = random();
-    for (int i = 0; i < 100; i++) {
+    for (int i = 0; i < 1000; i++) {
       doc = new Document();
-      doc.add(new Field(idFieldName, Integer.toString(i), ft));
+      doc.add(new Field(idFieldName, "id" + Integer.toString(i), ft));
       doc.add(new Field(textFieldName, TestUtil.randomUnicodeString(rnd, 1024), ft));
-      doc.add(new Field(classFieldName, TestUtil.randomUnicodeString(rnd, 10), ft));
+      doc.add(new Field(classFieldName, Integer.toString(rnd.nextInt(10)), ft));
       indexWriter.addDocument(doc);
     }
 
@@ -108,18 +109,18 @@ public class DataSplitterTest extends LuceneTestCase {
 
     try {
       DatasetSplitter datasetSplitter = new DatasetSplitter(testRatio, crossValidationRatio);
-      datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), fieldNames);
+      datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), true, classFieldName, fieldNames);
 
       assertNotNull(trainingIndex);
       assertNotNull(testIndex);
       assertNotNull(crossValidationIndex);
 
       DirectoryReader trainingReader = DirectoryReader.open(trainingIndex);
-      assertTrue((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)) == trainingReader.maxDoc());
+      assertEquals((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)), trainingReader.maxDoc(), 20);
       DirectoryReader testReader = DirectoryReader.open(testIndex);
-      assertTrue((int) (originalIndex.maxDoc() * testRatio) == testReader.maxDoc());
+      assertEquals((int) (originalIndex.maxDoc() * testRatio), testReader.maxDoc(), 20);
       DirectoryReader cvReader = DirectoryReader.open(crossValidationIndex);
-      assertTrue((int) (originalIndex.maxDoc() * crossValidationRatio) == cvReader.maxDoc());
+      assertEquals((int) (originalIndex.maxDoc() * crossValidationRatio), cvReader.maxDoc(), 20);
 
       trainingReader.close();
       testReader.close();