You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by iv...@apache.org on 2019/10/04 08:19:31 UTC

[lucene-solr] branch branch_8x updated: LUCENE-8990: Add estimateDocCount(visitor) method to PointValues (#905)

This is an automated email from the ASF dual-hosted git repository.

ivera pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/branch_8x by this push:
     new e4ceb97  LUCENE-8990: Add estimateDocCount(visitor) method to PointValues (#905)
e4ceb97 is described below

commit e4ceb9763f1ca64d0603747add87646eec78c368
Author: Ignacio Vera <iv...@apache.org>
AuthorDate: Fri Oct 4 10:13:55 2019 +0200

    LUCENE-8990: Add estimateDocCount(visitor) method to PointValues (#905)
---
 lucene/CHANGES.txt                                 |   4 +
 .../lucene/document/LatLonPointDistanceQuery.java  |   2 +-
 .../lucene/document/LatLonPointInPolygonQuery.java |   2 +-
 .../apache/lucene/document/RangeFieldQuery.java    |   2 +-
 .../java/org/apache/lucene/index/PointValues.java  |  28 +-
 .../org/apache/lucene/search/PointRangeQuery.java  |   2 +-
 .../codecs/lucene60/TestLucene60PointsFormat.java  | 316 ++++++++++++++-------
 .../lucene/search/TestIndexOrDocValuesQuery.java   |  67 +++++
 .../org/apache/lucene/document/ShapeQuery.java     |   2 +-
 .../org/apache/lucene/search/MultiRangeQuery.java  |   2 +-
 10 files changed, 320 insertions(+), 107 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 4a00908..57592fe 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -22,6 +22,10 @@ API Changes
   And don't call if docFreq <= 0.  The previous implementation survives as deprecated and final.  It's removed in 9.0.
   (Bruno Roustant, David Smiley, Alan Woodward)
 
+* LUCENE-8990: PointValues#estimateDocCount(visitor) estimates the number of documents that would be matched by
+  the given IntersectVisitor. THe method is used to compute the cost() of ScorerSuppliers instead of
+  PointValues#estimatePointCount(visitor). (Ignacio Vera, Adrien Grand)
+
 New Features
 
 * LUCENE-8936: Add SpanishMinimalStemFilter (vinod kumar via Tomoko Uchida)
diff --git a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java
index 62a70da..28ed1b0 100644
--- a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java
@@ -177,7 +177,7 @@ final class LatLonPointDistanceQuery extends Query {
           @Override
           public long cost() {
             if (cost == -1) {
-              cost = values.estimatePointCount(visitor);
+              cost = values.estimateDocCount(visitor);
             }
             assert cost >= 0;
             return cost;
diff --git a/lucene/core/src/java/org/apache/lucene/document/LatLonPointInPolygonQuery.java b/lucene/core/src/java/org/apache/lucene/document/LatLonPointInPolygonQuery.java
index 90e47a9..9006f16 100644
--- a/lucene/core/src/java/org/apache/lucene/document/LatLonPointInPolygonQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/LatLonPointInPolygonQuery.java
@@ -191,7 +191,7 @@ final class LatLonPointInPolygonQuery extends Query {
           public long cost() {
             if (cost == -1) {
                // Computing the cost may be expensive, so only do it if necessary
-              cost = values.estimatePointCount(visitor);
+              cost = values.estimateDocCount(visitor);
               assert cost >= 0;
             }
             return cost;
diff --git a/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java b/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java
index 4d254c5..57d4019 100644
--- a/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java
@@ -362,7 +362,7 @@ abstract class RangeFieldQuery extends Query {
             public long cost() {
               if (cost == -1) {
                 // Computing the cost may be expensive, so only do it if necessary
-                cost = values.estimatePointCount(visitor);
+                cost = values.estimateDocCount(visitor);
                 assert cost >= 0;
               }
               return cost;
diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java
index 0e3f27e..78c72ba 100644
--- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java
@@ -233,9 +233,35 @@ public abstract class PointValues {
 
   /** Estimate the number of points that would be visited by {@link #intersect}
    * with the given {@link IntersectVisitor}. This should run many times faster
+   * than {@link #intersect(IntersectVisitor)}. */
+  public abstract long estimatePointCount(IntersectVisitor visitor);
+
+  /** Estimate the number of documents that would be matched by {@link #intersect}
+   * with the given {@link IntersectVisitor}. This should run many times faster
    * than {@link #intersect(IntersectVisitor)}.
    * @see DocIdSetIterator#cost */
-  public abstract long estimatePointCount(IntersectVisitor visitor);
+  public long estimateDocCount(IntersectVisitor visitor) {
+    long estimatedPointCount = estimatePointCount(visitor);
+    int docCount = getDocCount();
+    double size = size();
+    if (estimatedPointCount >= size) {
+      // math all docs
+      return docCount;
+    } else if (size == docCount || estimatedPointCount == 0L ) {
+      // if the point count estimate is 0 or we have only single values
+      // return this estimate
+      return  estimatedPointCount;
+    } else {
+      // in case of multi values estimate the number of docs using the solution provided in
+      // https://math.stackexchange.com/questions/1175295/urn-problem-probability-of-drawing-balls-of-k-unique-colors
+      // then approximate the solution for points per doc << size() which results in the expression
+      // D * (1 - ((N - n) / N)^(N/D))
+      // where D is the total number of docs, N the total number of points and n the estimated point count
+      long docEstimate = (long) (docCount * (1d - Math.pow((size - estimatedPointCount) / size, size / docCount)));
+      return docEstimate == 0L ? 1L : docEstimate;
+    }
+  }
+
 
   /** Returns minimum value for each dimension, packed, or null if {@link #size} is <code>0</code> */
   public abstract byte[] getMinPackedValue() throws IOException;
diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java
index 8aa87a3..6660b3a 100644
--- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java
@@ -317,7 +317,7 @@ public abstract class PointRangeQuery extends Query {
             public long cost() {
               if (cost == -1) {
                 // Computing the cost may be expensive, so only do it if necessary
-                cost = values.estimatePointCount(visitor);
+                cost = values.estimateDocCount(visitor);
                 assert cost >= 0;
               }
               return cost;
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene60/TestLucene60PointsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene60/TestLucene60PointsFormat.java
index 87fc5e2..4f4de39 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene60/TestLucene60PointsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene60/TestLucene60PointsFormat.java
@@ -50,7 +50,7 @@ import org.apache.lucene.util.bkd.BKDWriter;
 public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
   private final Codec codec;
   private final int maxPointsInLeafNode;
-  
+
   public TestLucene60PointsFormat() {
     // standard issue
     Codec defaultCodec = TestUtil.getDefaultCodec();
@@ -110,15 +110,19 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     byte[] uniquePointValue = new byte[3];
     random().nextBytes(uniquePointValue);
     final int numDocs = atLeast(10000); // make sure we have several leaves
+    final boolean multiValues = random().nextBoolean();
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
         doc.add(new BinaryPoint("f", uniquePointValue));
       } else {
-        do {
-          random().nextBytes(pointValue);
-        } while (Arrays.equals(pointValue, uniquePointValue));
-        doc.add(new BinaryPoint("f", pointValue));
+        final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
+        for (int j = 0; j < numValues; j ++) {
+          do {
+            random().nextBytes(pointValue);
+          } while (Arrays.equals(pointValue, uniquePointValue));
+          doc.add(new BinaryPoint("f", pointValue));
+        }
       }
       w.addDocument(doc);
     }
@@ -129,58 +133,72 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     PointValues points = lr.getPointValues("f");
 
     // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) numDocs / maxPointsInLeafNode);
-    assertEquals(numLeaves * maxPointsInLeafNode,
-        points.estimatePointCount(new IntersectVisitor() {
-          @Override
-          public void visit(int docID, byte[] packedValue) throws IOException {}
-          
-          @Override
-          public void visit(int docID) throws IOException {}
-          
-          @Override
-          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-            return Relation.CELL_INSIDE_QUERY;
-          }
-        }));
+    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
+
+    IntersectVisitor allPointsVisitor = new IntersectVisitor() {
+      @Override
+      public void visit(int docID, byte[] packedValue) throws IOException {}
+
+      @Override
+      public void visit(int docID) throws IOException {}
+
+      @Override
+      public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+        return Relation.CELL_INSIDE_QUERY;
+      }
+    };
+
+    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
+
+    IntersectVisitor noPointsVisitor = new IntersectVisitor() {
+      @Override
+      public void visit(int docID, byte[] packedValue) throws IOException {}
+
+      @Override
+      public void visit(int docID) throws IOException {}
+
+      @Override
+      public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+        return Relation.CELL_OUTSIDE_QUERY;
+      }
+    };
 
     // Return 0 if no points match
-    assertEquals(0,
-        points.estimatePointCount(new IntersectVisitor() {
-          @Override
-          public void visit(int docID, byte[] packedValue) throws IOException {}
-          
-          @Override
-          public void visit(int docID) throws IOException {}
-          
-          @Override
-          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-            return Relation.CELL_OUTSIDE_QUERY;
-          }
-        }));
+    assertEquals(0, points.estimatePointCount(noPointsVisitor));
+    assertEquals(0, points.estimateDocCount(noPointsVisitor));
+
+    IntersectVisitor onePointMatchVisitor = new IntersectVisitor() {
+      @Override
+      public void visit(int docID, byte[] packedValue) throws IOException {}
+
+      @Override
+      public void visit(int docID) throws IOException {}
+
+      @Override
+      public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+        if (FutureArrays.compareUnsigned(uniquePointValue, 0, 3, maxPackedValue, 0, 3) > 0 ||
+            FutureArrays.compareUnsigned(uniquePointValue, 0, 3, minPackedValue, 0, 3) < 0) {
+          return Relation.CELL_OUTSIDE_QUERY;
+        }
+        return Relation.CELL_CROSSES_QUERY;
+      }
+    };
 
     // If only one point matches, then the point count is (maxPointsInLeafNode + 1) / 2
     // in general, or maybe 2x that if the point is a split value
-    final long pointCount = points.estimatePointCount(new IntersectVisitor() {
-          @Override
-          public void visit(int docID, byte[] packedValue) throws IOException {}
-          
-          @Override
-          public void visit(int docID) throws IOException {}
-          
-          @Override
-          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-            if (FutureArrays.compareUnsigned(uniquePointValue, 0, 3, maxPackedValue, 0, 3) > 0 ||
-                FutureArrays.compareUnsigned(uniquePointValue, 0, 3, minPackedValue, 0, 3) < 0) {
-              return Relation.CELL_OUTSIDE_QUERY;
-            }
-            return Relation.CELL_CROSSES_QUERY;
-          }
-        });
+    final long pointCount = points.estimatePointCount(onePointMatchVisitor);
     assertTrue(""+pointCount,
         pointCount == (maxPointsInLeafNode + 1) / 2 || // common case
-        pointCount == 2*((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
+            pointCount == 2*((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
+
+    final long docCount = points.estimateDocCount(onePointMatchVisitor);
 
+    if (multiValues) {
+      assertEquals(docCount, (long) (docCount * (1d - Math.pow( (numDocs -  pointCount) / points.size() , points.size() / docCount))));
+    } else {
+      assertEquals(pointCount, docCount);
+    }
     r.close();
     dir.close();
   }
@@ -199,16 +217,20 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     random().nextBytes(uniquePointValue[0]);
     random().nextBytes(uniquePointValue[1]);
     final int numDocs = atLeast(10000); // make sure we have several leaves
+    final boolean multiValues = random().nextBoolean();
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
         doc.add(new BinaryPoint("f", uniquePointValue));
       } else {
-        do {
-          random().nextBytes(pointValue[0]);
-          random().nextBytes(pointValue[1]);
-        } while (Arrays.equals(pointValue[0], uniquePointValue[0]) || Arrays.equals(pointValue[1], uniquePointValue[1]));
-        doc.add(new BinaryPoint("f", pointValue));
+        final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
+        for (int j = 0; j < numValues; j ++) {
+          do {
+            random().nextBytes(pointValue[0]);
+            random().nextBytes(pointValue[1]);
+          } while (Arrays.equals(pointValue[0], uniquePointValue[0]) || Arrays.equals(pointValue[1], uniquePointValue[1]));
+          doc.add(new BinaryPoint("f", pointValue));
+        }
       }
       w.addDocument(doc);
     }
@@ -219,67 +241,161 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     PointValues points = lr.getPointValues("f");
 
     // With >1 dims, the tree is balanced
-    int actualMaxPointsInLeafNode = numDocs;
+    long actualMaxPointsInLeafNode = points.size();
     while (actualMaxPointsInLeafNode > maxPointsInLeafNode) {
       actualMaxPointsInLeafNode = (actualMaxPointsInLeafNode + 1) / 2;
     }
 
+    IntersectVisitor allPointsVisitor = new IntersectVisitor() {
+      @Override
+      public void visit(int docID, byte[] packedValue) throws IOException {}
+
+      @Override
+      public void visit(int docID) throws IOException {}
+
+      @Override
+      public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+        return Relation.CELL_INSIDE_QUERY;
+      }
+    };
+
     // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = Integer.highestOneBit((numDocs - 1) / actualMaxPointsInLeafNode) << 1;
-    assertEquals(numLeaves * actualMaxPointsInLeafNode,
-        points.estimatePointCount(new IntersectVisitor() {
-          @Override
-          public void visit(int docID, byte[] packedValue) throws IOException {}
-          
-          @Override
-          public void visit(int docID) throws IOException {}
-          
-          @Override
-          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-            return Relation.CELL_INSIDE_QUERY;
-          }
-        }));
+    final int numLeaves = (int) Long.highestOneBit( ((points.size() - 1) / actualMaxPointsInLeafNode)) << 1;
+
+    assertEquals(numLeaves * actualMaxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
+
+    IntersectVisitor noPointsVisitor = new IntersectVisitor() {
+      @Override
+      public void visit(int docID, byte[] packedValue) throws IOException {}
+
+      @Override
+      public void visit(int docID) throws IOException {}
+
+      @Override
+      public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+        return Relation.CELL_OUTSIDE_QUERY;
+      }
+    };
 
     // Return 0 if no points match
-    assertEquals(0,
-        points.estimatePointCount(new IntersectVisitor() {
-          @Override
-          public void visit(int docID, byte[] packedValue) throws IOException {}
-          
-          @Override
-          public void visit(int docID) throws IOException {}
-          
-          @Override
-          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+    assertEquals(0, points.estimatePointCount(noPointsVisitor));
+    assertEquals(0, points.estimateDocCount(noPointsVisitor));
+
+    IntersectVisitor onePointMatchVisitor = new IntersectVisitor() {
+      @Override
+      public void visit(int docID, byte[] packedValue) throws IOException {}
+
+      @Override
+      public void visit(int docID) throws IOException {}
+
+      @Override
+      public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+        for (int dim = 0; dim < 2; ++dim) {
+          if (FutureArrays.compareUnsigned(uniquePointValue[dim], 0, 3, maxPackedValue, dim * 3, dim * 3 + 3) > 0 ||
+              FutureArrays.compareUnsigned(uniquePointValue[dim], 0, 3, minPackedValue, dim * 3, dim * 3 + 3) < 0) {
             return Relation.CELL_OUTSIDE_QUERY;
           }
-        }));
-
+        }
+        return Relation.CELL_CROSSES_QUERY;
+      }
+    };
     // If only one point matches, then the point count is (actualMaxPointsInLeafNode + 1) / 2
     // in general, or maybe 2x that if the point is a split value
-    final long pointCount = points.estimatePointCount(new IntersectVisitor() {
-        @Override
-        public void visit(int docID, byte[] packedValue) throws IOException {}
-        
-        @Override
-        public void visit(int docID) throws IOException {}
-        
-        @Override
-        public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-          for (int dim = 0; dim < 2; ++dim) {
-            if (FutureArrays.compareUnsigned(uniquePointValue[dim], 0, 3, maxPackedValue, dim * 3, dim * 3 + 3) > 0 ||
-                FutureArrays.compareUnsigned(uniquePointValue[dim], 0, 3, minPackedValue, dim * 3, dim * 3 + 3) < 0) {
-              return Relation.CELL_OUTSIDE_QUERY;
-            }
-          }
-          return Relation.CELL_CROSSES_QUERY;
-        }
-      });
+    final long pointCount = points.estimatePointCount(onePointMatchVisitor);
     assertTrue(""+pointCount,
         pointCount == (actualMaxPointsInLeafNode + 1) / 2 || // common case
-        pointCount == 2*((actualMaxPointsInLeafNode + 1) / 2)); // if the point is a split value
+            pointCount == 2*((actualMaxPointsInLeafNode + 1) / 2)); // if the point is a split value
 
+    final long docCount = points.estimateDocCount(onePointMatchVisitor);
+    if (multiValues) {
+      assertEquals(docCount, (long) (docCount * (1d - Math.pow( (numDocs -  pointCount) / points.size() , points.size() / docCount))));
+    } else {
+      assertEquals(pointCount, docCount);
+    }
     r.close();
     dir.close();
   }
+
+  public void testDocCountEdgeCases() {
+    PointValues values = getPointValues(Long.MAX_VALUE, 1, Long.MAX_VALUE);
+    long docs = values.estimateDocCount(null);
+    assertEquals(1, docs);
+    values = getPointValues(Long.MAX_VALUE, 1, 1);
+    docs = values.estimateDocCount(null);
+    assertEquals(1, docs);
+    values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);
+    docs = values.estimateDocCount(null);
+    assertEquals(Integer.MAX_VALUE, docs);
+    values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE / 2);
+    docs = values.estimateDocCount(null);
+    assertEquals(Integer.MAX_VALUE, docs);
+    values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, 1);
+    docs = values.estimateDocCount(null);
+    assertEquals(1, docs);
+  }
+
+  public void testRandomDocCount() {
+    for (int i = 0; i < 100; i++) {
+      long size = TestUtil.nextLong(random(), 1, Long.MAX_VALUE);
+      int maxDoc = (size > Integer.MAX_VALUE) ? Integer.MAX_VALUE : Math.toIntExact(size);
+      int docCount = TestUtil.nextInt(random(), 1, maxDoc);
+      long estimatedPointCount = TestUtil.nextLong(random(), 0, size);
+      PointValues values = getPointValues(size, docCount, estimatedPointCount);
+      long docs = values.estimateDocCount(null);
+      assertTrue(docs <= estimatedPointCount);
+      assertTrue(docs <= maxDoc);
+      assertTrue(docs >= estimatedPointCount / (size/docCount));
+    }
+  }
+
+
+  private PointValues getPointValues(long size, int docCount, long estimatedPointCount) {
+    return new PointValues() {
+      @Override
+      public void intersect(IntersectVisitor visitor) {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public long estimatePointCount(IntersectVisitor visitor) {
+        return estimatedPointCount;
+      }
+
+      @Override
+      public byte[] getMinPackedValue() throws IOException {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public byte[] getMaxPackedValue() throws IOException {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public int getNumDataDimensions() throws IOException {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public int getNumIndexDimensions() throws IOException {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public int getBytesPerDimension() throws IOException {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public long size() {
+        return size;
+      }
+
+      @Override
+      public int getDocCount() {
+        return docCount;
+      }
+    };
+  }
 }
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestIndexOrDocValuesQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestIndexOrDocValuesQuery.java
index d784b12..de5d947 100644
--- a/lucene/core/src/test/org/apache/lucene/search/TestIndexOrDocValuesQuery.java
+++ b/lucene/core/src/test/org/apache/lucene/search/TestIndexOrDocValuesQuery.java
@@ -22,6 +22,7 @@ import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field.Store;
 import org.apache.lucene.document.LongPoint;
 import org.apache.lucene.document.NumericDocValuesField;
+import org.apache.lucene.document.SortedNumericDocValuesField;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.index.DirectoryReader;
 import org.apache.lucene.index.IndexReader;
@@ -86,4 +87,70 @@ public class TestIndexOrDocValuesQuery extends LuceneTestCase {
     dir.close();
   }
 
+  public void testUseIndexForSelectiveMultiValueQueries() throws IOException {
+    Directory dir = newDirectory();
+    IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()
+        // relies on costs and PointValues.estimateCost so we need the default codec
+        .setCodec(TestUtil.getDefaultCodec()));
+    for (int i = 0; i < 2000; ++i) {
+      Document doc = new Document();
+      if (i < 1000) {
+        doc.add(new StringField("f1", "bar", Store.NO));
+        for (int j =0; j < 500; j++) {
+          doc.add(new LongPoint("f2", 42L));
+          doc.add(new SortedNumericDocValuesField("f2", 42L));
+        }
+      } else if (i == 1001) {
+        doc.add(new StringField("f1", "foo", Store.NO));
+        doc.add(new LongPoint("f2", 2L));
+        doc.add(new SortedNumericDocValuesField("f2", 42L));
+      } else {
+        doc.add(new StringField("f1", "bar", Store.NO));
+        for (int j =0; j < 100; j++) {
+          doc.add(new LongPoint("f2", 2L));
+          doc.add(new SortedNumericDocValuesField("f2", 2L));
+        }
+      }
+      w.addDocument(doc);
+    }
+    w.forceMerge(1);
+    IndexReader reader = DirectoryReader.open(w);
+    IndexSearcher searcher = newSearcher(reader);
+    searcher.setQueryCache(null);
+
+    // The term query is less selective, so the IndexOrDocValuesQuery should use points
+    final Query q1 = new BooleanQuery.Builder()
+        .add(new TermQuery(new Term("f1", "bar")), Occur.MUST)
+        .add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 2), SortedNumericDocValuesField.newSlowRangeQuery("f2", 2L, 2L)), Occur.MUST)
+        .build();
+
+    final Weight w1 = searcher.createWeight(searcher.rewrite(q1), ScoreMode.COMPLETE, 1);
+    final Scorer s1 = w1.scorer(searcher.getIndexReader().leaves().get(0));
+    assertNull(s1.twoPhaseIterator()); // means we use points
+
+    // The term query is less selective, so the IndexOrDocValuesQuery should use points
+    final Query q2 = new BooleanQuery.Builder()
+        .add(new TermQuery(new Term("f1", "bar")), Occur.MUST)
+        .add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 42), SortedNumericDocValuesField.newSlowRangeQuery("f2", 42, 42L)), Occur.MUST)
+        .build();
+
+    final Weight w2 = searcher.createWeight(searcher.rewrite(q2), ScoreMode.COMPLETE, 1);
+    final Scorer s2 = w2.scorer(searcher.getIndexReader().leaves().get(0));
+    assertNull(s2.twoPhaseIterator()); // means we use points
+
+    // The term query is more selective, so the IndexOrDocValuesQuery should use doc values
+    final Query q3 = new BooleanQuery.Builder()
+        .add(new TermQuery(new Term("f1", "foo")), Occur.MUST)
+        .add(new IndexOrDocValuesQuery(LongPoint.newExactQuery("f2", 42), SortedNumericDocValuesField.newSlowRangeQuery("f2", 42, 42L)), Occur.MUST)
+        .build();
+
+    final Weight w3 = searcher.createWeight(searcher.rewrite(q3), ScoreMode.COMPLETE, 1);
+    final Scorer s3 = w3.scorer(searcher.getIndexReader().leaves().get(0));
+    assertNotNull(s3.twoPhaseIterator()); // means we use doc values
+
+    reader.close();
+    w.close();
+    dir.close();
+  }
+
 }
diff --git a/lucene/sandbox/src/java/org/apache/lucene/document/ShapeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/document/ShapeQuery.java
index a2ba95e..d27a3ea 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/document/ShapeQuery.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/document/ShapeQuery.java
@@ -282,7 +282,7 @@ abstract class ShapeQuery extends Query {
     public long cost() {
       if (cost == -1) {
         // Computing the cost may be expensive, so only do it if necessary
-        cost = values.estimatePointCount(getEstimateVisitor(query));
+        cost = values.estimateDocCount(getEstimateVisitor(query));
         assert cost >= 0;
       }
       return cost;
diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/MultiRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/MultiRangeQuery.java
index 1bb4f4b..988c20b 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/search/MultiRangeQuery.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/search/MultiRangeQuery.java
@@ -279,7 +279,7 @@ public abstract class MultiRangeQuery extends Query {
             public long cost() {
               if (cost == -1) {
                 // Computing the cost may be expensive, so only do it if necessary
-                cost = values.estimatePointCount(visitor) * rangeClauses.size();
+                cost = values.estimateDocCount(visitor) * rangeClauses.size();
                 assert cost >= 0;
               }
               return cost;