You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by mi...@apache.org on 2014/03/23 12:48:56 UTC

svn commit: r1580491 - in /lucene/dev/branches/branch_4x: ./ lucene/ lucene/core/ lucene/core/src/java/org/apache/lucene/search/ lucene/core/src/test/org/apache/lucene/search/ lucene/expressions/ lucene/expressions/src/java/org/apache/lucene/expression...

Author: mikemccand
Date: Sun Mar 23 11:48:56 2014
New Revision: 1580491

URL: http://svn.apache.org/r1580491
Log:
LUCENE-5545: add SortRescorer and Expression.getRescorer

Added:
    lucene/dev/branches/branch_4x/lucene/core/src/java/org/apache/lucene/search/SortRescorer.java
      - copied unchanged from r1580490, lucene/dev/trunk/lucene/core/src/java/org/apache/lucene/search/SortRescorer.java
    lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestSortRescorer.java
      - copied unchanged from r1580490, lucene/dev/trunk/lucene/core/src/test/org/apache/lucene/search/TestSortRescorer.java
    lucene/dev/branches/branch_4x/lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionRescorer.java
      - copied unchanged from r1580490, lucene/dev/trunk/lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionRescorer.java
    lucene/dev/branches/branch_4x/lucene/expressions/src/test/org/apache/lucene/expressions/TestExpressionRescorer.java
      - copied unchanged from r1580490, lucene/dev/trunk/lucene/expressions/src/test/org/apache/lucene/expressions/TestExpressionRescorer.java
Modified:
    lucene/dev/branches/branch_4x/   (props changed)
    lucene/dev/branches/branch_4x/lucene/   (props changed)
    lucene/dev/branches/branch_4x/lucene/CHANGES.txt
    lucene/dev/branches/branch_4x/lucene/core/   (props changed)
    lucene/dev/branches/branch_4x/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java
    lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java
    lucene/dev/branches/branch_4x/lucene/expressions/   (props changed)
    lucene/dev/branches/branch_4x/lucene/expressions/src/java/org/apache/lucene/expressions/Expression.java

Modified: lucene/dev/branches/branch_4x/lucene/CHANGES.txt
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/CHANGES.txt?rev=1580491&r1=1580490&r2=1580491&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/CHANGES.txt (original)
+++ lucene/dev/branches/branch_4x/lucene/CHANGES.txt Sun Mar 23 11:48:56 2014
@@ -74,6 +74,10 @@ New Features
   first pass search using scores from a more costly second pass
   search. (Simon Willnauer, Robert Muir, Mike McCandless)
 
+* LUCENE-5545: Add SortRescorer and Expression.getRescorer, to
+  resort the hits from a first pass search using a Sort or an
+  Expression. (Simon Willnauer, Robert Muir, Mike McCandless)
+
 API Changes
 
 * LUCENE-5454: Add RandomAccessOrds, an optional extension of SortedSetDocValues

Modified: lucene/dev/branches/branch_4x/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java?rev=1580491&r1=1580490&r2=1580491&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java (original)
+++ lucene/dev/branches/branch_4x/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java Sun Mar 23 11:48:56 2014
@@ -20,13 +20,9 @@ package org.apache.lucene.search;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Map;
+import java.util.List;
 
 import org.apache.lucene.index.AtomicReaderContext;
-import org.apache.lucene.util.Bits;
-
-// TODO: we could also have an ExpressionRescorer
 
 /** A {@link Rescorer} that uses a provided Query to assign
  *  scores to the first-pass hits.
@@ -52,43 +48,65 @@ public abstract class QueryRescorer exte
   protected abstract float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore);
 
   @Override
-  public TopDocs rescore(IndexSearcher searcher, TopDocs topDocs, int topN) throws IOException {
-    int[] docIDs = new int[topDocs.scoreDocs.length];
-    for(int i=0;i<docIDs.length;i++) {
-      docIDs[i] = topDocs.scoreDocs[i].doc;
-    }
+  public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException {
+    ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();
+    Arrays.sort(hits,
+                new Comparator<ScoreDoc>() {
+                  @Override
+                  public int compare(ScoreDoc a, ScoreDoc b) {
+                    return a.doc - b.doc;
+                  }
+                });
 
-    TopDocs topDocs2 = searcher.search(query, new OnlyDocIDsFilter(docIDs), topDocs.scoreDocs.length);
+    List<AtomicReaderContext> leaves = searcher.getIndexReader().leaves();
 
-    // TODO: we could save small young GC cost here if we
-    // cloned the incoming ScoreDoc[], sorted that by doc,
-    // passed that to OnlyDocIDsFilter, sorted 2nd pass
-    // TopDocs by doc, did a merge sort to combine the
-    // scores, and finally re-sorted by the combined score,
-    // but that is sizable added code complexity for minor
-    // GC savings:
-    Map<Integer,Float> newScores = new HashMap<Integer,Float>();
-    for(ScoreDoc sd : topDocs2.scoreDocs) {
-      newScores.put(sd.doc, sd.score);
-    }
+    Weight weight = searcher.createNormalizedWeight(query);
+
+    // Now merge sort docIDs from hits, with reader's leaves:
+    int hitUpto = 0;
+    int readerUpto = -1;
+    int endDoc = 0;
+    int docBase = 0;
+    Scorer scorer = null;
+
+    while (hitUpto < hits.length) {
+      ScoreDoc hit = hits[hitUpto];
+      int docID = hit.doc;
+      AtomicReaderContext readerContext = null;
+      while (docID >= endDoc) {
+        readerUpto++;
+        readerContext = leaves.get(readerUpto);
+        endDoc = readerContext.docBase + readerContext.reader().maxDoc();
+      }
+
+      if (readerContext != null) {
+        // We advanced to another segment:
+        docBase = readerContext.docBase;
+        scorer = weight.scorer(readerContext, null);
+      }
+
+      int targetDoc = docID - docBase;
+      int actualDoc = scorer.docID();
+      if (actualDoc < targetDoc) {
+        actualDoc = scorer.advance(targetDoc);
+      }
 
-    ScoreDoc[] newHits = new ScoreDoc[topDocs.scoreDocs.length];
-    for(int i=0;i<topDocs.scoreDocs.length;i++) {
-      ScoreDoc sd = topDocs.scoreDocs[i];
-      Float newScore = newScores.get(sd.doc);
-      float combinedScore;
-      if (newScore == null) {
-        combinedScore = combine(sd.score, false, 0.0f);
+      if (actualDoc == targetDoc) {
+        // Query did match this doc:
+        hit.score = combine(hit.score, true, scorer.score());
       } else {
-        combinedScore = combine(sd.score, true, newScore.floatValue());
+        // Query did not match this doc:
+        assert actualDoc > targetDoc;
+        hit.score = combine(hit.score, false, 0.0f);
       }
-      newHits[i] = new ScoreDoc(sd.doc, combinedScore);
+
+      hitUpto++;
     }
 
     // TODO: we should do a partial sort (of only topN)
     // instead, but typically the number of hits is
     // smallish:
-    Arrays.sort(newHits,
+    Arrays.sort(hits,
                 new Comparator<ScoreDoc>() {
                   @Override
                   public int compare(ScoreDoc a, ScoreDoc b) {
@@ -105,13 +123,13 @@ public abstract class QueryRescorer exte
                   }
                 });
 
-    if (topN < newHits.length) {
+    if (topN < hits.length) {
       ScoreDoc[] subset = new ScoreDoc[topN];
-      System.arraycopy(newHits, 0, subset, 0, topN);
-      newHits = subset;
+      System.arraycopy(hits, 0, subset, 0, topN);
+      hits = subset;
     }
 
-    return new TopDocs(topDocs.totalHits, newHits, newHits[0].score);
+    return new TopDocs(firstPassTopDocs.totalHits, hits, hits[0].score);
   }
 
   @Override
@@ -159,80 +177,4 @@ public abstract class QueryRescorer exte
       }
     }.rescore(searcher, topDocs, topN);
   }
-
-  /** Filter accepting only the specified docIDs */
-  private static class OnlyDocIDsFilter extends Filter {
-
-    private final int[] docIDs;
-
-    /** Sole constructor. */
-    public OnlyDocIDsFilter(int[] docIDs) {
-      this.docIDs = docIDs;
-      Arrays.sort(docIDs);
-    }
-
-    @Override
-    public DocIdSet getDocIdSet(final AtomicReaderContext context, final Bits acceptDocs) throws IOException {
-      int loc = Arrays.binarySearch(docIDs, context.docBase);
-      if (loc < 0) {
-        loc = -loc-1;
-      }
-
-      final int startLoc = loc;
-      final int endDoc = context.docBase + context.reader().maxDoc();
-
-      return new DocIdSet() {
-
-        int pos = startLoc;
-
-        @Override
-        public DocIdSetIterator iterator() throws IOException {
-          return new DocIdSetIterator() {
-
-            int docID;
-
-            @Override
-            public int docID() {
-              return docID;
-            }
-
-            @Override
-            public int nextDoc() {
-              if (pos == docIDs.length) {
-                return NO_MORE_DOCS;
-              }
-              int docID = docIDs[pos];
-              if (docID >= endDoc) {
-                return NO_MORE_DOCS;
-              }
-              pos++;
-              assert acceptDocs == null || acceptDocs.get(docID-context.docBase);
-              return docID-context.docBase;
-            }
-
-            @Override
-            public long cost() {
-              // NOTE: not quite right, since this is cost
-              // across all segments, and we are supposed to
-              // return cost for just this segment:
-              return docIDs.length;
-            }
-
-            @Override
-            public int advance(int target) {
-              // TODO: this is a full binary search; we
-              // could optimize (a bit) by setting lower
-              // bound to current pos instead:
-              int loc = Arrays.binarySearch(docIDs, target + context.docBase);
-              if (loc < 0) {
-                loc = -loc-1;
-              }
-              pos = loc;
-              return nextDoc();
-            }
-          };
-        }
-      };
-    }
-  }
 }

Modified: lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java?rev=1580491&r1=1580490&r2=1580491&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java (original)
+++ lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java Sun Mar 23 11:48:56 2014
@@ -17,8 +17,15 @@ package org.apache.lucene.search;
  * limitations under the License.
  */
 
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Set;
+
 import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field;
+import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.index.AtomicReaderContext;
 import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.Term;
@@ -28,7 +35,9 @@ import org.apache.lucene.search.spans.Sp
 import org.apache.lucene.search.spans.SpanQuery;
 import org.apache.lucene.search.spans.SpanTermQuery;
 import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.TestUtil;
 
 public class TestQueryRescorer extends LuceneTestCase {
 
@@ -62,6 +71,7 @@ public class TestQueryRescorer extends L
     bq.add(new TermQuery(new Term("field", "wizard")), Occur.SHOULD);
     bq.add(new TermQuery(new Term("field", "oz")), Occur.SHOULD);
     IndexSearcher searcher = getSearcher(r);
+    searcher.setSimilarity(new DefaultSimilarity());
 
     TopDocs hits = searcher.search(bq, 10);
     assertEquals(2, hits.totalHits);
@@ -283,4 +293,206 @@ public class TestQueryRescorer extends L
     r.close();
     dir.close();
   }
+
+  public void testRandom() throws Exception {
+    Directory dir = newDirectory();
+    int numDocs = atLeast(1000);
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+
+    final int[] idToNum = new int[numDocs];
+    int maxValue = TestUtil.nextInt(random(), 10, 1000000);
+    for(int i=0;i<numDocs;i++) {
+      Document doc = new Document();
+      doc.add(newStringField("id", ""+i, Field.Store.YES));
+      int numTokens = TestUtil.nextInt(random(), 1, 10);
+      StringBuilder b = new StringBuilder();
+      for(int j=0;j<numTokens;j++) {
+        b.append("a ");
+      }
+      doc.add(newTextField("field", b.toString(), Field.Store.NO));
+      idToNum[i] = random().nextInt(maxValue);
+      doc.add(new NumericDocValuesField("num", idToNum[i]));
+      w.addDocument(doc);
+    }
+    final IndexReader r = w.getReader();
+    w.close();
+
+    IndexSearcher s = newSearcher(r);
+    int numHits = TestUtil.nextInt(random(), 1, numDocs);
+    boolean reverse = random().nextBoolean();
+
+    //System.out.println("numHits=" + numHits + " reverse=" + reverse);
+    TopDocs hits = s.search(new TermQuery(new Term("field", "a")), numHits);
+
+    TopDocs hits2 = new QueryRescorer(new FixedScoreQuery(idToNum, reverse)) {
+        @Override
+        protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
+          return secondPassScore;
+        }
+      }.rescore(s, hits, numHits);
+
+    Integer[] expected = new Integer[numHits];
+    for(int i=0;i<numHits;i++) {
+      expected[i] = hits.scoreDocs[i].doc;
+    }
+
+    final int reverseInt = reverse ? -1 : 1;
+
+    Arrays.sort(expected,
+                new Comparator<Integer>() {
+                  @Override
+                  public int compare(Integer a, Integer b) {
+                    try {
+                      int av = idToNum[Integer.parseInt(r.document(a).get("id"))];
+                      int bv = idToNum[Integer.parseInt(r.document(b).get("id"))];
+                      if (av < bv) {
+                        return -reverseInt;
+                      } else if (bv < av) {
+                        return reverseInt;
+                      } else {
+                        // Tie break by docID, ascending
+                        return a - b;
+                      }
+                    } catch (IOException ioe) {
+                      throw new RuntimeException(ioe);
+                    }
+                  }
+                });
+
+    boolean fail = false;
+    for(int i=0;i<numHits;i++) {
+      //System.out.println("expected=" + expected[i] + " vs " + hits2.scoreDocs[i].doc + " v=" + idToNum[Integer.parseInt(r.document(expected[i]).get("id"))]);
+      if (expected[i].intValue() != hits2.scoreDocs[i].doc) {
+        //System.out.println("  diff!");
+        fail = true;
+      }
+    }
+    assertFalse(fail);
+
+    r.close();
+    dir.close();
+  }
+
+  /** Just assigns score == idToNum[doc("id")] for each doc. */
+  private static class FixedScoreQuery extends Query {
+    private final int[] idToNum;
+    private final boolean reverse;
+
+    public FixedScoreQuery(int[] idToNum, boolean reverse) {
+      this.idToNum = idToNum;
+      this.reverse = reverse;
+    }
+
+    @Override
+    public Weight createWeight(IndexSearcher searcher) throws IOException {
+
+      return new Weight() {
+
+        @Override
+        public Query getQuery() {
+          return FixedScoreQuery.this;
+        }
+
+        @Override
+        public float getValueForNormalization() {
+          return 1.0f;
+        }
+
+        @Override
+        public void normalize(float queryNorm, float topLevelBoost) {
+        }
+
+        @Override
+        public Scorer scorer(final AtomicReaderContext context, Bits acceptDocs) throws IOException {
+
+          return new Scorer(null) {
+            int docID = -1;
+
+            @Override
+            public int docID() {
+              return docID;
+            }
+
+            @Override
+            public int freq() {
+              return 1;
+            }
+
+            @Override
+            public long cost() {
+              return 1;
+            }
+
+            @Override
+            public int nextDoc() {
+              docID++;
+              if (docID >= context.reader().maxDoc()) {
+                return NO_MORE_DOCS;
+              }
+              return docID;
+            }
+
+            @Override
+            public int advance(int target) {
+              docID = target;
+              return docID;
+            }
+
+            @Override
+            public float score() throws IOException {
+              int num = idToNum[Integer.parseInt(context.reader().document(docID).get("id"))];
+              if (reverse) {
+                //System.out.println("score doc=" + docID + " num=" + num);
+                return num;
+              } else {
+                //System.out.println("score doc=" + docID + " num=" + -num);
+                return -num;
+              }
+            }
+          };
+        }
+
+        @Override
+        public Explanation explain(AtomicReaderContext context, int doc) throws IOException {
+          return null;
+        }
+      };
+    }
+
+    @Override
+    public void extractTerms(Set<Term> terms) {
+    }
+
+    @Override
+    public String toString(String field) {
+      return "FixedScoreQuery " + idToNum.length + " ids; reverse=" + reverse;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if ((o instanceof FixedScoreQuery) == false) {
+        return false;
+      }
+      FixedScoreQuery other = (FixedScoreQuery) o;
+      return Float.floatToIntBits(getBoost()) == Float.floatToIntBits(other.getBoost()) &&
+        reverse == other.reverse &&
+        Arrays.equals(idToNum, other.idToNum);
+    }
+
+    @Override
+    public Query clone() {
+      return new FixedScoreQuery(idToNum, reverse);
+    }
+
+    @Override
+    public int hashCode() {
+      int PRIME = 31;
+      int hash = super.hashCode();
+      if (reverse) {
+        hash = PRIME * hash + 3623;
+      }
+      hash = PRIME * hash + Arrays.hashCode(idToNum);
+      return hash;
+    }
+  }
 }

Modified: lucene/dev/branches/branch_4x/lucene/expressions/src/java/org/apache/lucene/expressions/Expression.java
URL: http://svn.apache.org/viewvc/lucene/dev/branches/branch_4x/lucene/expressions/src/java/org/apache/lucene/expressions/Expression.java?rev=1580491&r1=1580490&r2=1580491&view=diff
==============================================================================
--- lucene/dev/branches/branch_4x/lucene/expressions/src/java/org/apache/lucene/expressions/Expression.java (original)
+++ lucene/dev/branches/branch_4x/lucene/expressions/src/java/org/apache/lucene/expressions/Expression.java Sun Mar 23 11:48:56 2014
@@ -19,6 +19,7 @@ package org.apache.lucene.expressions;
 import org.apache.lucene.expressions.js.JavascriptCompiler; // javadocs
 import org.apache.lucene.queries.function.FunctionValues;
 import org.apache.lucene.queries.function.ValueSource;
+import org.apache.lucene.search.Rescorer;
 import org.apache.lucene.search.SortField;
 
 /**
@@ -83,4 +84,10 @@ public abstract class Expression {
   public SortField getSortField(Bindings bindings, boolean reverse) {
     return getValueSource(bindings).getSortField(reverse);
   }
+
+  /** Get a {@link Rescorer}, to rescore first-pass hits
+   *  using this expression. */
+  public Rescorer getRescorer(Bindings bindings) {
+    return new ExpressionRescorer(this, bindings);
+  }
 }