You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by iv...@apache.org on 2021/11/19 07:29:04 UTC

[lucene] branch main updated: LUCENE-9820: Separate logic for reading the BKD index from logic to intersecting it (#7)

This is an automated email from the ASF dual-hosted git repository.

ivera pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new ad911df  LUCENE-9820: Separate logic for reading the BKD index from logic to intersecting it (#7)
ad911df is described below

commit ad911df2605591a63ca58d237182f84aa8384325
Author: Ignacio Vera <iv...@apache.org>
AuthorDate: Fri Nov 19 08:28:01 2021 +0100

    LUCENE-9820: Separate logic for reading the BKD index from logic to intersecting it (#7)
    
    Extract BKD tree interface and move intersecting logic to the PointValues abstract class.
---
 lucene/CHANGES.txt                                 |    3 +
 .../lucene60/Lucene60PointsReader.java             |    6 +-
 .../lucene86/Lucene86PointsReader.java             |    6 +-
 .../lucene60/Lucene60PointsWriter.java             |   22 +-
 .../lucene60/TestLucene60PointsFormat.java         |   42 +-
 .../lucene86/Lucene86PointsWriter.java             |   21 +-
 .../lucene86/TestLucene86PointsFormat.java         |   42 +-
 .../codecs/simpletext/SimpleTextBKDReader.java     |  560 +++++----
 .../codecs/simpletext/SimpleTextBKDWriter.java     |   22 +-
 .../codecs/simpletext/SimpleTextPointsWriter.java  |    4 +-
 ...tablePointValues.java => MutablePointTree.java} |   46 +-
 .../org/apache/lucene/codecs/PointsWriter.java     |  142 ++-
 .../codecs/lucene90/Lucene90PointsReader.java      |    4 +-
 .../codecs/lucene90/Lucene90PointsWriter.java      |   21 +-
 .../lucene/index/ExitableDirectoryReader.java      |  116 +-
 .../java/org/apache/lucene/index/PointValues.java  |  132 +-
 .../org/apache/lucene/index/PointValuesWriter.java |  158 +--
 .../apache/lucene/index/SortingCodecReader.java    |  115 +-
 .../java/org/apache/lucene/util/bkd/BKDReader.java | 1287 +++++++++-----------
 .../java/org/apache/lucene/util/bkd/BKDWriter.java |  224 ++--
 ...Utils.java => MutablePointTreeReaderUtils.java} |   15 +-
 .../codecs/lucene90/TestLucene90PointsFormat.java  |   42 +-
 .../apache/lucene/util/TestDocIdSetBuilder.java    |    7 +-
 .../apache/lucene/util/bkd/Test2BBKDPoints.java    |    5 +-
 .../test/org/apache/lucene/util/bkd/TestBKD.java   |  625 ++++------
 ...s.java => TestMutablePointTreeReaderUtils.java} |   69 +-
 .../apache/lucene/index/memory/MemoryIndex.java    |   63 +-
 .../document/FloatPointNearestNeighbor.java        |  103 +-
 .../search/LatLonPointPrototypeQueries.java        |   10 +-
 .../lucene/sandbox/search/NearestNeighbor.java     |  108 +-
 .../lucene/codecs/cranky/CrankyPointsFormat.java   |   73 +-
 .../apache/lucene/index/AssertingLeafReader.java   |  141 ++-
 .../lucene/index/BasePointsFormatTestCase.java     |   91 +-
 .../java/org/apache/lucene/index/RandomCodec.java  |    4 +-
 34 files changed, 2296 insertions(+), 2033 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index e6af773..341d80b5 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -50,6 +50,9 @@ Improvements
 
 * LUCENE-10238: Upgrade icu4j dependency to 70.1. (Dawid Weiss)
 
+* LUCENE-9820: Extract BKD tree interface and move intersecting logic to the 
+  PointValues abstract class. (Ignacio Vera, Adrien Grand)
+
 Optimizations
 ---------------------
 
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsReader.java
index c314a5d..9a7f6bf 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsReader.java
@@ -35,7 +35,7 @@ import org.apache.lucene.util.bkd.BKDReader;
 public class Lucene60PointsReader extends PointsReader {
   final IndexInput dataIn;
   final SegmentReadState readState;
-  final Map<Integer, BKDReader> readers = new HashMap<>();
+  final Map<Integer, PointValues> readers = new HashMap<>();
 
   /** Sole constructor */
   public Lucene60PointsReader(SegmentReadState readState) throws IOException {
@@ -102,7 +102,7 @@ public class Lucene60PointsReader extends PointsReader {
         int fieldNumber = ent.getKey();
         long fp = ent.getValue();
         dataIn.seek(fp);
-        BKDReader reader = new BKDReader(dataIn, dataIn, dataIn);
+        PointValues reader = new BKDReader(dataIn, dataIn, dataIn);
         readers.put(fieldNumber, reader);
       }
 
@@ -115,7 +115,7 @@ public class Lucene60PointsReader extends PointsReader {
   }
 
   /**
-   * Returns the underlying {@link BKDReader}.
+   * Returns the underlying {@link PointValues}.
    *
    * @lucene.internal
    */
diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsReader.java
index e49562a..efe6f5d 100644
--- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsReader.java
+++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsReader.java
@@ -36,7 +36,7 @@ import org.apache.lucene.util.bkd.BKDReader;
 public class Lucene86PointsReader extends PointsReader {
   final IndexInput indexIn, dataIn;
   final SegmentReadState readState;
-  final Map<Integer, BKDReader> readers = new HashMap<>();
+  final Map<Integer, PointValues> readers = new HashMap<>();
 
   /** Sole constructor */
   public Lucene86PointsReader(SegmentReadState readState) throws IOException {
@@ -101,7 +101,7 @@ public class Lucene86PointsReader extends PointsReader {
             } else if (fieldNumber < 0) {
               throw new CorruptIndexException("Illegal field number: " + fieldNumber, metaIn);
             }
-            BKDReader reader = new BKDReader(metaIn, indexIn, dataIn);
+            PointValues reader = new BKDReader(metaIn, indexIn, dataIn);
             readers.put(fieldNumber, reader);
           }
           indexLength = metaIn.readLong();
@@ -125,7 +125,7 @@ public class Lucene86PointsReader extends PointsReader {
   }
 
   /**
-   * Returns the underlying {@link BKDReader}.
+   * Returns the underlying {@link PointValues}.
    *
    * @lucene.internal
    */
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsWriter.java
index c0276f9..b6e022e 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/Lucene60PointsWriter.java
@@ -23,7 +23,7 @@ import java.util.List;
 import java.util.Map;
 import org.apache.lucene.backward_codecs.store.EndiannessReverserUtil;
 import org.apache.lucene.codecs.CodecUtil;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.codecs.PointsReader;
 import org.apache.lucene.codecs.PointsWriter;
 import org.apache.lucene.index.FieldInfo;
@@ -37,7 +37,6 @@ import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.store.IndexOutput;
 import org.apache.lucene.util.IOUtils;
 import org.apache.lucene.util.bkd.BKDConfig;
-import org.apache.lucene.util.bkd.BKDReader;
 import org.apache.lucene.util.bkd.BKDWriter;
 
 /** Writes dimensional values */
@@ -99,7 +98,7 @@ public class Lucene60PointsWriter extends PointsWriter {
   @Override
   public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
-    PointValues values = reader.getValues(fieldInfo.name);
+    PointValues.PointTree values = reader.getValues(fieldInfo.name).getPointTree();
 
     BKDConfig config =
         new BKDConfig(
@@ -117,10 +116,9 @@ public class Lucene60PointsWriter extends PointsWriter {
             maxMBSortInHeap,
             values.size())) {
 
-      if (values instanceof MutablePointValues) {
+      if (values instanceof MutablePointTree) {
         Runnable finalizer =
-            writer.writeField(
-                dataOut, dataOut, dataOut, fieldInfo.name, (MutablePointValues) values);
+            writer.writeField(dataOut, dataOut, dataOut, fieldInfo.name, (MutablePointTree) values);
         if (finalizer != null) {
           indexFPs.put(fieldInfo.name, dataOut.getFilePointer());
           finalizer.run();
@@ -128,7 +126,7 @@ public class Lucene60PointsWriter extends PointsWriter {
         return;
       }
 
-      values.intersect(
+      values.visitDocValues(
           new IntersectVisitor() {
             @Override
             public void visit(int docID) {
@@ -214,7 +212,7 @@ public class Lucene60PointsWriter extends PointsWriter {
                   config,
                   maxMBSortInHeap,
                   totMaxSize)) {
-            List<BKDReader> bkdReaders = new ArrayList<>();
+            List<PointValues> pointValues = new ArrayList<>();
             List<MergeState.DocMap> docMaps = new ArrayList<>();
             for (int i = 0; i < mergeState.pointsReaders.length; i++) {
               PointsReader reader = mergeState.pointsReaders[i];
@@ -233,16 +231,16 @@ public class Lucene60PointsWriter extends PointsWriter {
                 FieldInfos readerFieldInfos = mergeState.fieldInfos[i];
                 FieldInfo readerFieldInfo = readerFieldInfos.fieldInfo(fieldInfo.name);
                 if (readerFieldInfo != null && readerFieldInfo.getPointDimensionCount() > 0) {
-                  BKDReader bkdReader = reader60.readers.get(readerFieldInfo.number);
-                  if (bkdReader != null) {
-                    bkdReaders.add(bkdReader);
+                  PointValues aPointValues = reader60.readers.get(readerFieldInfo.number);
+                  if (aPointValues != null) {
+                    pointValues.add(aPointValues);
                     docMaps.add(mergeState.docMaps[i]);
                   }
                 }
               }
             }
 
-            Runnable finalizer = writer.merge(dataOut, dataOut, dataOut, docMaps, bkdReaders);
+            Runnable finalizer = writer.merge(dataOut, dataOut, dataOut, docMaps, pointValues);
             if (finalizer != null) {
               indexFPs.put(fieldInfo.name, dataOut.getFilePointer());
               finalizer.run();
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/TestLucene60PointsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/TestLucene60PointsFormat.java
index 0d1d5fa..d3d2826 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/TestLucene60PointsFormat.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene60/TestLucene60PointsFormat.java
@@ -66,9 +66,11 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     final int numDocs =
         TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves
     final boolean multiValues = random().nextBoolean();
+    int totalValues = 0;
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
+        totalValues++;
         doc.add(new BinaryPoint("f", uniquePointValue));
       } else {
         final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
@@ -77,6 +79,7 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
             random().nextBytes(pointValue);
           } while (Arrays.equals(pointValue, uniquePointValue));
           doc.add(new BinaryPoint("f", pointValue));
+          totalValues++;
         }
       }
       w.addDocument(doc);
@@ -87,9 +90,6 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     final LeafReader lr = getOnlyLeafReader(r);
     PointValues points = lr.getPointValues("f");
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
-
     IntersectVisitor allPointsVisitor =
         new IntersectVisitor() {
           @Override
@@ -104,7 +104,7 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
           }
         };
 
-    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
     assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
 
     IntersectVisitor noPointsVisitor =
@@ -146,11 +146,16 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
     // If only one point matches, then the point count is (maxPointsInLeafNode + 1) / 2
     // in general, or maybe 2x that if the point is a split value
     final long pointCount = points.estimatePointCount(onePointMatchVisitor);
+    final long lastNodePointCount = totalValues % maxPointsInLeafNode;
     assertTrue(
         "" + pointCount,
-        pointCount == (maxPointsInLeafNode + 1) / 2
-            || // common case
-            pointCount == 2 * ((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount
+                == ((maxPointsInLeafNode + 1) / 2)
+                    + ((lastNodePointCount + 1)
+                        / 2)); // if the point is a split value and one leaf is not fully populated
 
     final long docCount = points.estimateDocCount(onePointMatchVisitor);
 
@@ -187,10 +192,12 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
             ? atLeast(10000)
             : atLeast(1000); // in nightly, make sure we have several leaves
     final boolean multiValues = random().nextBoolean();
+    int totalValues = 0;
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
         doc.add(new BinaryPoint("f", uniquePointValue));
+        totalValues++;
       } else {
         final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
         for (int j = 0; j < numValues; j++) {
@@ -200,6 +207,7 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
           } while (Arrays.equals(pointValue[0], uniquePointValue[0])
               || Arrays.equals(pointValue[1], uniquePointValue[1]));
           doc.add(new BinaryPoint("f", pointValue));
+          totalValues++;
         }
       }
       w.addDocument(doc);
@@ -224,10 +232,7 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
           }
         };
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
-
-    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
     assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
 
     IntersectVisitor noPointsVisitor =
@@ -273,11 +278,16 @@ public class TestLucene60PointsFormat extends BasePointsFormatTestCase {
         };
 
     final long pointCount = points.estimatePointCount(onePointMatchVisitor);
-    // The number of matches needs to be multiple of count per leaf
-    final long countPerLeaf = (maxPointsInLeafNode + 1) / 2;
-    assertTrue("" + pointCount, pointCount % countPerLeaf == 0);
-    // in extreme cases, a point can be be shared by 4 leaves
-    assertTrue("" + pointCount, pointCount / countPerLeaf <= 4 && pointCount / countPerLeaf >= 1);
+    final long lastNodePointCount = totalValues % maxPointsInLeafNode;
+    assertTrue(
+        "" + pointCount,
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount == ((maxPointsInLeafNode + 1) / 2) + ((lastNodePointCount + 1) / 2)
+            // in extreme cases, a point can be shared by 4 leaves
+            || pointCount == 4 * ((maxPointsInLeafNode + 1) / 2)
+            || pointCount == 3 * ((maxPointsInLeafNode + 1) / 2) + ((lastNodePointCount + 1) / 2));
 
     final long docCount = points.estimateDocCount(onePointMatchVisitor);
     if (multiValues) {
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsWriter.java
index a342f77..677f64b 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsWriter.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/Lucene86PointsWriter.java
@@ -21,7 +21,7 @@ import java.util.ArrayList;
 import java.util.List;
 import org.apache.lucene.backward_codecs.store.EndiannessReverserUtil;
 import org.apache.lucene.codecs.CodecUtil;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.codecs.PointsReader;
 import org.apache.lucene.codecs.PointsWriter;
 import org.apache.lucene.index.FieldInfo;
@@ -35,7 +35,6 @@ import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.store.IndexOutput;
 import org.apache.lucene.util.IOUtils;
 import org.apache.lucene.util.bkd.BKDConfig;
-import org.apache.lucene.util.bkd.BKDReader;
 import org.apache.lucene.util.bkd.BKDWriter;
 
 /** Writes dimensional values */
@@ -125,7 +124,7 @@ public class Lucene86PointsWriter extends PointsWriter {
   @Override
   public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
-    PointValues values = reader.getValues(fieldInfo.name);
+    PointValues.PointTree values = reader.getValues(fieldInfo.name).getPointTree();
 
     BKDConfig config =
         new BKDConfig(
@@ -143,10 +142,10 @@ public class Lucene86PointsWriter extends PointsWriter {
             maxMBSortInHeap,
             values.size())) {
 
-      if (values instanceof MutablePointValues) {
+      if (values instanceof MutablePointTree) {
         Runnable finalizer =
             writer.writeField(
-                metaOut, indexOut, dataOut, fieldInfo.name, (MutablePointValues) values);
+                metaOut, indexOut, dataOut, fieldInfo.name, (MutablePointTree) values);
         if (finalizer != null) {
           metaOut.writeInt(fieldInfo.number);
           finalizer.run();
@@ -154,7 +153,7 @@ public class Lucene86PointsWriter extends PointsWriter {
         return;
       }
 
-      values.intersect(
+      values.visitDocValues(
           new IntersectVisitor() {
             @Override
             public void visit(int docID) {
@@ -240,7 +239,7 @@ public class Lucene86PointsWriter extends PointsWriter {
                   config,
                   maxMBSortInHeap,
                   totMaxSize)) {
-            List<BKDReader> bkdReaders = new ArrayList<>();
+            List<PointValues> pointValues = new ArrayList<>();
             List<MergeState.DocMap> docMaps = new ArrayList<>();
             for (int i = 0; i < mergeState.pointsReaders.length; i++) {
               PointsReader reader = mergeState.pointsReaders[i];
@@ -259,16 +258,16 @@ public class Lucene86PointsWriter extends PointsWriter {
                 FieldInfos readerFieldInfos = mergeState.fieldInfos[i];
                 FieldInfo readerFieldInfo = readerFieldInfos.fieldInfo(fieldInfo.name);
                 if (readerFieldInfo != null && readerFieldInfo.getPointDimensionCount() > 0) {
-                  BKDReader bkdReader = reader60.readers.get(readerFieldInfo.number);
-                  if (bkdReader != null) {
-                    bkdReaders.add(bkdReader);
+                  PointValues aPointValues = reader60.readers.get(readerFieldInfo.number);
+                  if (aPointValues != null) {
+                    pointValues.add(aPointValues);
                     docMaps.add(mergeState.docMaps[i]);
                   }
                 }
               }
             }
 
-            Runnable finalizer = writer.merge(metaOut, indexOut, dataOut, docMaps, bkdReaders);
+            Runnable finalizer = writer.merge(metaOut, indexOut, dataOut, docMaps, pointValues);
             if (finalizer != null) {
               metaOut.writeInt(fieldInfo.number);
               finalizer.run();
diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/TestLucene86PointsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/TestLucene86PointsFormat.java
index 54c5235..20a0de3 100644
--- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/TestLucene86PointsFormat.java
+++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene86/TestLucene86PointsFormat.java
@@ -114,9 +114,11 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
     final int numDocs =
         TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves
     final boolean multiValues = random().nextBoolean();
+    int totalValues = 0;
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
+        totalValues++;
         doc.add(new BinaryPoint("f", uniquePointValue));
       } else {
         final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
@@ -125,6 +127,7 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
             random().nextBytes(pointValue);
           } while (Arrays.equals(pointValue, uniquePointValue));
           doc.add(new BinaryPoint("f", pointValue));
+          totalValues++;
         }
       }
       w.addDocument(doc);
@@ -135,9 +138,6 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
     final LeafReader lr = getOnlyLeafReader(r);
     PointValues points = lr.getPointValues("f");
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
-
     IntersectVisitor allPointsVisitor =
         new IntersectVisitor() {
           @Override
@@ -152,7 +152,7 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
           }
         };
 
-    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
     assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
 
     IntersectVisitor noPointsVisitor =
@@ -194,11 +194,16 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
     // If only one point matches, then the point count is (maxPointsInLeafNode + 1) / 2
     // in general, or maybe 2x that if the point is a split value
     final long pointCount = points.estimatePointCount(onePointMatchVisitor);
+    final long lastNodePointCount = totalValues % maxPointsInLeafNode;
     assertTrue(
         "" + pointCount,
-        pointCount == (maxPointsInLeafNode + 1) / 2
-            || // common case
-            pointCount == 2 * ((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount
+                == ((maxPointsInLeafNode + 1) / 2)
+                    + ((lastNodePointCount + 1)
+                        / 2)); // if the point is a split value and one leaf is not fully populated
 
     final long docCount = points.estimateDocCount(onePointMatchVisitor);
 
@@ -235,10 +240,12 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
             ? atLeast(10000)
             : atLeast(1000); // in nightly, make sure we have several leaves
     final boolean multiValues = random().nextBoolean();
+    int totalValues = 0;
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
         doc.add(new BinaryPoint("f", uniquePointValue));
+        totalValues++;
       } else {
         final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
         for (int j = 0; j < numValues; j++) {
@@ -248,6 +255,7 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
           } while (Arrays.equals(pointValue[0], uniquePointValue[0])
               || Arrays.equals(pointValue[1], uniquePointValue[1]));
           doc.add(new BinaryPoint("f", pointValue));
+          totalValues++;
         }
       }
       w.addDocument(doc);
@@ -272,10 +280,7 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
           }
         };
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
-
-    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
     assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
 
     IntersectVisitor noPointsVisitor =
@@ -321,11 +326,16 @@ public class TestLucene86PointsFormat extends BasePointsFormatTestCase {
         };
 
     final long pointCount = points.estimatePointCount(onePointMatchVisitor);
-    // The number of matches needs to be multiple of count per leaf
-    final long countPerLeaf = (maxPointsInLeafNode + 1) / 2;
-    assertTrue("" + pointCount, pointCount % countPerLeaf == 0);
-    // in extreme cases, a point can be be shared by 4 leaves
-    assertTrue("" + pointCount, pointCount / countPerLeaf <= 4 && pointCount / countPerLeaf >= 1);
+    final long lastNodePointCount = totalValues % maxPointsInLeafNode;
+    assertTrue(
+        "" + pointCount,
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount == ((maxPointsInLeafNode + 1) / 2) + ((lastNodePointCount + 1) / 2)
+            // in extreme cases, a point can be shared by 4 leaves
+            || pointCount == 4 * ((maxPointsInLeafNode + 1) / 2)
+            || pointCount == 3 * ((maxPointsInLeafNode + 1) / 2) + ((lastNodePointCount + 1) / 2));
 
     final long docCount = points.estimateDocCount(onePointMatchVisitor);
     if (multiValues) {
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java
index 15427cd..c21a010 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDReader.java
@@ -22,34 +22,30 @@ import static org.apache.lucene.codecs.simpletext.SimpleTextPointsWriter.BLOCK_V
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
 import org.apache.lucene.index.PointValues;
 import org.apache.lucene.store.IndexInput;
-import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRefBuilder;
-import org.apache.lucene.util.RamUsageEstimator;
+import org.apache.lucene.util.MathUtil;
 import org.apache.lucene.util.StringHelper;
+import org.apache.lucene.util.bkd.BKDConfig;
 import org.apache.lucene.util.bkd.BKDReader;
 
 /** Forked from {@link BKDReader} and simplified/specialized for SimpleText's usage */
-final class SimpleTextBKDReader extends PointValues implements Accountable {
+final class SimpleTextBKDReader extends PointValues {
   // Packed array of byte[] holding all split values in the full binary tree:
   private final byte[] splitPackedValues;
   final long[] leafBlockFPs;
   private final int leafNodeOffset;
-  final int numDims;
-  final int numIndexDims;
-  final int bytesPerDim;
+  final BKDConfig config;
   final int bytesPerIndexEntry;
   final IndexInput in;
-  final int maxPointsInLeafNode;
   final byte[] minPackedValue;
   final byte[] maxPackedValue;
   final long pointCount;
   final int docCount;
   final int version;
-  protected final int packedBytesLength;
-  protected final int packedIndexBytesLength;
 
   public SimpleTextBKDReader(
       IndexInput in,
@@ -62,17 +58,11 @@ final class SimpleTextBKDReader extends PointValues implements Accountable {
       byte[] minPackedValue,
       byte[] maxPackedValue,
       long pointCount,
-      int docCount)
-      throws IOException {
+      int docCount) {
     this.in = in;
-    this.numDims = numDims;
-    this.numIndexDims = numIndexDims;
-    this.maxPointsInLeafNode = maxPointsInLeafNode;
-    this.bytesPerDim = bytesPerDim;
+    this.config = new BKDConfig(numDims, numIndexDims, bytesPerDim, maxPointsInLeafNode);
     // no version check here because callers of this API (SimpleText) have no back compat:
     bytesPerIndexEntry = numIndexDims == 1 ? bytesPerDim : bytesPerDim + 1;
-    packedBytesLength = numDims * bytesPerDim;
-    packedIndexBytesLength = numIndexDims * bytesPerDim;
     this.leafNodeOffset = leafBlockFPs.length;
     this.leafBlockFPs = leafBlockFPs;
     this.splitPackedValues = splitPackedValues;
@@ -81,265 +71,335 @@ final class SimpleTextBKDReader extends PointValues implements Accountable {
     this.pointCount = pointCount;
     this.docCount = docCount;
     this.version = SimpleTextBKDWriter.VERSION_CURRENT;
-    assert minPackedValue.length == packedIndexBytesLength;
-    assert maxPackedValue.length == packedIndexBytesLength;
-  }
-
-  /** Used to track all state for a single call to {@link #intersect}. */
-  public static final class IntersectState {
-    final IndexInput in;
-    final int[] scratchDocIDs;
-    final byte[] scratchPackedValue;
-    final int[] commonPrefixLengths;
-
-    final IntersectVisitor visitor;
-
-    public IntersectState(
-        IndexInput in,
-        int numDims,
-        int packedBytesLength,
-        int maxPointsInLeafNode,
-        IntersectVisitor visitor) {
-      this.in = in;
-      this.visitor = visitor;
-      this.commonPrefixLengths = new int[numDims];
-      this.scratchDocIDs = new int[maxPointsInLeafNode];
-      this.scratchPackedValue = new byte[packedBytesLength];
-    }
+    assert minPackedValue.length == config.packedIndexBytesLength;
+    assert maxPackedValue.length == config.packedIndexBytesLength;
   }
 
   @Override
-  public void intersect(IntersectVisitor visitor) throws IOException {
-    intersect(getIntersectState(visitor), 1, minPackedValue, maxPackedValue);
+  public PointTree getPointTree() {
+    return new SimpleTextPointTree(1, 1, minPackedValue, maxPackedValue);
   }
 
-  /** Fast path: this is called when the query box fully encompasses all cells under this node. */
-  private void addAll(IntersectState state, int nodeID) throws IOException {
-    // System.out.println("R: addAll nodeID=" + nodeID);
-
-    if (nodeID >= leafNodeOffset) {
-      // System.out.println("ADDALL");
-      visitDocIDs(state.in, leafBlockFPs[nodeID - leafNodeOffset], state.visitor);
-      // TODO: we can assert that the first value here in fact matches what the index claimed?
-    } else {
-      addAll(state, 2 * nodeID);
-      addAll(state, 2 * nodeID + 1);
-    }
-  }
+  private class SimpleTextPointTree implements PointTree {
 
-  /** Create a new {@link IntersectState} */
-  public IntersectState getIntersectState(IntersectVisitor visitor) {
-    return new IntersectState(in.clone(), numDims, packedBytesLength, maxPointsInLeafNode, visitor);
-  }
+    final int[] scratchDocIDs;
+    final byte[] scratchPackedValue;
+    int nodeID;
+    int level;
+    final int rootNode;
+    final int lastLeafNodeCount;
+    // holds the min / max value of the current node.
+    private final byte[] minPackedValue, maxPackedValue;
+    // holds the previous value of the split dimension
+    private final byte[][] splitDimValueStack;
+    // holds the splitDim for each level:
+    private final int[] splitDims;
+
+    private SimpleTextPointTree(
+        int nodeID, int level, byte[] minPackedValue, byte[] maxPackedValue) {
+      this.scratchDocIDs = new int[config.maxPointsInLeafNode];
+      this.scratchPackedValue = new byte[config.packedBytesLength];
+      this.nodeID = nodeID;
+      this.rootNode = nodeID;
+      this.level = level;
+      this.maxPackedValue = maxPackedValue.clone();
+      this.minPackedValue = minPackedValue.clone();
+      int treeDepth = getTreeDepth(leafNodeOffset);
+      splitDimValueStack = new byte[treeDepth + 1][];
+      splitDims = new int[treeDepth + 1];
+      int lastLeafNodeCount = Math.toIntExact(pointCount % config.maxPointsInLeafNode);
+      this.lastLeafNodeCount =
+          lastLeafNodeCount == 0 ? config.maxPointsInLeafNode : lastLeafNodeCount;
+    }
 
-  /** Visits all docIDs and packed values in a single leaf block */
-  public void visitLeafBlockValues(int nodeID, IntersectState state) throws IOException {
-    int leafID = nodeID - leafNodeOffset;
-
-    // Leaf node; scan and filter all points in this block:
-    int count = readDocIDs(state.in, leafBlockFPs[leafID], state.scratchDocIDs);
-
-    // Again, this time reading values and checking with the visitor
-    visitDocValues(
-        state.commonPrefixLengths,
-        state.scratchPackedValue,
-        state.in,
-        state.scratchDocIDs,
-        count,
-        state.visitor);
-  }
+    private int getTreeDepth(int numLeaves) {
+      // First +1 because all the non-leave nodes makes another power
+      // of 2; e.g. to have a fully balanced tree with 4 leaves you
+      // need a depth=3 tree:
 
-  void visitDocIDs(IndexInput in, long blockFP, IntersectVisitor visitor) throws IOException {
-    BytesRefBuilder scratch = new BytesRefBuilder();
-    in.seek(blockFP);
-    readLine(in, scratch);
-    int count = parseInt(scratch, BLOCK_COUNT);
-    visitor.grow(count);
-    for (int i = 0; i < count; i++) {
-      readLine(in, scratch);
-      visitor.visit(parseInt(scratch, BLOCK_DOC_ID));
+      // Second +1 because MathUtil.log computes floor of the logarithm; e.g.
+      // with 5 leaves you need a depth=4 tree:
+      return MathUtil.log(numLeaves, 2) + 2;
     }
-  }
 
-  int readDocIDs(IndexInput in, long blockFP, int[] docIDs) throws IOException {
-    BytesRefBuilder scratch = new BytesRefBuilder();
-    in.seek(blockFP);
-    readLine(in, scratch);
-    int count = parseInt(scratch, BLOCK_COUNT);
-    for (int i = 0; i < count; i++) {
-      readLine(in, scratch);
-      docIDs[i] = parseInt(scratch, BLOCK_DOC_ID);
+    @Override
+    public PointTree clone() {
+      SimpleTextPointTree index =
+          new SimpleTextPointTree(nodeID, level, minPackedValue, maxPackedValue);
+      if (isLeafNode() == false) {
+        // copy node data
+        index.splitDims[level] = splitDims[level];
+        index.splitDimValueStack[level] = splitDimValueStack[level];
+      }
+      return index;
     }
-    return count;
-  }
 
-  void visitDocValues(
-      int[] commonPrefixLengths,
-      byte[] scratchPackedValue,
-      IndexInput in,
-      int[] docIDs,
-      int count,
-      IntersectVisitor visitor)
-      throws IOException {
-    visitor.grow(count);
-    // NOTE: we don't do prefix coding, so we ignore commonPrefixLengths
-    assert scratchPackedValue.length == packedBytesLength;
-    BytesRefBuilder scratch = new BytesRefBuilder();
-    for (int i = 0; i < count; i++) {
-      readLine(in, scratch);
-      assert startsWith(scratch, BLOCK_VALUE);
-      BytesRef br = SimpleTextUtil.fromBytesRefString(stripPrefix(scratch, BLOCK_VALUE));
-      assert br.length == packedBytesLength;
-      System.arraycopy(br.bytes, br.offset, scratchPackedValue, 0, packedBytesLength);
-      visitor.visit(docIDs[i], scratchPackedValue);
+    @Override
+    public boolean moveToChild() {
+      if (isLeafNode()) {
+        return false;
+      }
+      pushLeft();
+      return true;
     }
-  }
 
-  private void intersect(
-      IntersectState state, int nodeID, byte[] cellMinPacked, byte[] cellMaxPacked)
-      throws IOException {
-
-    /*
-    System.out.println("\nR: intersect nodeID=" + nodeID);
-    for(int dim=0;dim<numDims;dim++) {
-      System.out.println("  dim=" + dim + "\n    cellMin=" + new BytesRef(cellMinPacked, dim*bytesPerDim, bytesPerDim) + "\n    cellMax=" + new BytesRef(cellMaxPacked, dim*bytesPerDim, bytesPerDim));
-    }
-    */
-
-    Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
-
-    if (r == Relation.CELL_OUTSIDE_QUERY) {
-      // This cell is fully outside of the query shape: stop recursing
-      return;
-    } else if (r == Relation.CELL_INSIDE_QUERY) {
-      // This cell is fully inside of the query shape: recursively add all points in this cell
-      // without filtering
-      addAll(state, nodeID);
-      return;
-    } else {
-      // The cell crosses the shape boundary, or the cell fully contains the query, so we fall
-      // through and do full filtering
+    private void pushLeft() {
+      int address = nodeID * bytesPerIndexEntry;
+      // final int splitDimPos;
+      if (config.numIndexDims == 1) {
+        splitDims[level] = 0;
+      } else {
+        splitDims[level] = (splitPackedValues[address++] & 0xff);
+      }
+      final int splitDimPos = splitDims[level] * config.bytesPerDim;
+      if (splitDimValueStack[level] == null) {
+        splitDimValueStack[level] = new byte[config.bytesPerDim];
+      }
+      // save the dimension we are going to change
+      System.arraycopy(
+          maxPackedValue, splitDimPos, splitDimValueStack[level], 0, config.bytesPerDim);
+      assert Arrays.compareUnsigned(
+                  maxPackedValue,
+                  splitDimPos,
+                  splitDimPos + config.bytesPerDim,
+                  splitPackedValues,
+                  address,
+                  address + config.bytesPerDim)
+              >= 0
+          : "config.bytesPerDim="
+              + config.bytesPerDim
+              + " splitDim="
+              + splitDims[level]
+              + " config.numIndexDims="
+              + config.numIndexDims
+              + " config.numDims="
+              + config.numDims;
+      nodeID *= 2;
+      level++;
+      // add the split dim value:
+      System.arraycopy(splitPackedValues, address, maxPackedValue, splitDimPos, config.bytesPerDim);
     }
 
-    if (nodeID >= leafNodeOffset) {
-      // TODO: we can assert that the first value here in fact matches what the index claimed?
-
-      int leafID = nodeID - leafNodeOffset;
-
-      // In the unbalanced case it's possible the left most node only has one child:
-      if (leafID < leafBlockFPs.length) {
-        // Leaf node; scan and filter all points in this block:
-        int count = readDocIDs(state.in, leafBlockFPs[leafID], state.scratchDocIDs);
-
-        // Again, this time reading values and checking with the visitor
-        visitDocValues(
-            state.commonPrefixLengths,
-            state.scratchPackedValue,
-            state.in,
-            state.scratchDocIDs,
-            count,
-            state.visitor);
+    @Override
+    public boolean moveToSibling() {
+      if (nodeID != rootNode && (nodeID & 1) == 0) {
+        pop(true);
+        pushRight();
+        return true;
       }
+      return false;
+    }
 
-    } else {
-
-      // Non-leaf node: recurse on the split left and right nodes
-
+    private void pushRight() {
       int address = nodeID * bytesPerIndexEntry;
-      int splitDim;
-      if (numIndexDims == 1) {
-        splitDim = 0;
+      if (config.numIndexDims == 1) {
+        splitDims[level] = 0;
       } else {
-        splitDim = splitPackedValues[address++] & 0xff;
+        splitDims[level] = (splitPackedValues[address++] & 0xff);
       }
+      final int splitDimPos = splitDims[level] * config.bytesPerDim;
+      // we should have already visit the left node
+      assert splitDimValueStack[level] != null;
+      // save the dimension we are going to change
+      System.arraycopy(
+          minPackedValue, splitDimPos, splitDimValueStack[level], 0, config.bytesPerDim);
+      assert Arrays.compareUnsigned(
+                  minPackedValue,
+                  splitDimPos,
+                  splitDimPos + config.bytesPerDim,
+                  splitPackedValues,
+                  address,
+                  address + config.bytesPerDim)
+              <= 0
+          : "config.bytesPerDim="
+              + config.bytesPerDim
+              + " splitDim="
+              + splitDims[level]
+              + " config.numIndexDims="
+              + config.numIndexDims
+              + " config.numDims="
+              + config.numDims;
+      nodeID = 2 * nodeID + 1;
+      level++;
+      // add the split dim value:
+      System.arraycopy(splitPackedValues, address, minPackedValue, splitDimPos, config.bytesPerDim);
+    }
 
-      assert splitDim < numIndexDims;
-
-      // TODO: can we alloc & reuse this up front?
+    @Override
+    public boolean moveToParent() {
+      if (nodeID == rootNode) {
+        return false;
+      }
+      pop((nodeID & 1) == 0);
+      return true;
+    }
 
-      byte[] splitPackedValue = new byte[packedIndexBytesLength];
+    private void pop(boolean isLeft) {
+      nodeID /= 2;
+      level--;
+      // restore the split dimension
+      if (isLeft) {
+        System.arraycopy(
+            splitDimValueStack[level],
+            0,
+            maxPackedValue,
+            splitDims[level] * config.bytesPerDim,
+            config.bytesPerDim);
+      } else {
 
-      // Recurse on left sub-tree:
-      System.arraycopy(cellMaxPacked, 0, splitPackedValue, 0, packedIndexBytesLength);
-      System.arraycopy(
-          splitPackedValues, address, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
-      intersect(state, 2 * nodeID, cellMinPacked, splitPackedValue);
+        System.arraycopy(
+            splitDimValueStack[level],
+            0,
+            minPackedValue,
+            splitDims[level] * config.bytesPerDim,
+            config.bytesPerDim);
+      }
+    }
 
-      // Recurse on right sub-tree:
-      System.arraycopy(cellMinPacked, 0, splitPackedValue, 0, packedIndexBytesLength);
-      System.arraycopy(
-          splitPackedValues, address, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
-      intersect(state, 2 * nodeID + 1, splitPackedValue, cellMaxPacked);
+    @Override
+    public byte[] getMinPackedValue() {
+      return minPackedValue.clone();
     }
-  }
 
-  @Override
-  public long estimatePointCount(IntersectVisitor visitor) {
-    return estimatePointCount(getIntersectState(visitor), 1, minPackedValue, maxPackedValue);
-  }
+    @Override
+    public byte[] getMaxPackedValue() {
+      return maxPackedValue.clone();
+    }
 
-  private long estimatePointCount(
-      IntersectState state, int nodeID, byte[] cellMinPacked, byte[] cellMaxPacked) {
-    Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
+    @Override
+    public long size() {
+      int leftMostLeafNode = nodeID;
+      while (leftMostLeafNode < leafNodeOffset) {
+        leftMostLeafNode = leftMostLeafNode * 2;
+      }
+      int rightMostLeafNode = nodeID;
+      while (rightMostLeafNode < leafNodeOffset) {
+        rightMostLeafNode = rightMostLeafNode * 2 + 1;
+      }
+      final int numLeaves;
+      if (rightMostLeafNode >= leftMostLeafNode) {
+        // both are on the same level
+        numLeaves = rightMostLeafNode - leftMostLeafNode + 1;
+      } else {
+        // left is one level deeper than right
+        numLeaves = rightMostLeafNode - leftMostLeafNode + 1 + leafNodeOffset;
+      }
+      assert numLeaves == getNumLeavesSlow(nodeID) : numLeaves + " " + getNumLeavesSlow(nodeID);
+      return rightMostLeafNode == (1 << getTreeDepth(leafNodeOffset) - 1) - 1
+          ? (long) (numLeaves - 1) * config.maxPointsInLeafNode + lastLeafNodeCount
+          : (long) numLeaves * config.maxPointsInLeafNode;
+    }
 
-    if (r == Relation.CELL_OUTSIDE_QUERY) {
-      // This cell is fully outside of the query shape: stop recursing
-      return 0L;
-    } else if (nodeID >= leafNodeOffset) {
-      // Assume all points match and there are no dups
-      return maxPointsInLeafNode;
-    } else {
+    private int getNumLeavesSlow(int node) {
+      if (node >= 2 * leafNodeOffset) {
+        return 0;
+      } else if (node >= leafNodeOffset) {
+        return 1;
+      } else {
+        final int leftCount = getNumLeavesSlow(node * 2);
+        final int rightCount = getNumLeavesSlow(node * 2 + 1);
+        return leftCount + rightCount;
+      }
+    }
 
-      // Non-leaf node: recurse on the split left and right nodes
+    @Override
+    public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException {
+      addAll(visitor, false);
+    }
 
-      int address = nodeID * bytesPerIndexEntry;
-      int splitDim;
-      if (numIndexDims == 1) {
-        splitDim = 0;
+    public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
+      if (grown == false) {
+        final long size = size();
+        if (size <= Integer.MAX_VALUE) {
+          visitor.grow((int) size);
+          grown = true;
+        }
+      }
+      if (isLeafNode()) {
+        // Leaf node
+        BytesRefBuilder scratch = new BytesRefBuilder();
+        in.seek(leafBlockFPs[nodeID - leafNodeOffset]);
+        readLine(in, scratch);
+        int count = parseInt(scratch, BLOCK_COUNT);
+        for (int i = 0; i < count; i++) {
+          readLine(in, scratch);
+          visitor.visit(parseInt(scratch, BLOCK_DOC_ID));
+        }
       } else {
-        splitDim = splitPackedValues[address++] & 0xff;
+        pushLeft();
+        addAll(visitor, grown);
+        pop(true);
+        pushRight();
+        addAll(visitor, grown);
+        pop(false);
       }
+    }
 
-      assert splitDim < numIndexDims;
+    @Override
+    public void visitDocValues(PointValues.IntersectVisitor visitor) throws IOException {
+      if (isLeafNode()) {
+        // Leaf node
+        int leafID = nodeID - leafNodeOffset;
 
-      // TODO: can we alloc & reuse this up front?
+        // Leaf node; scan and filter all points in this block:
+        int count = readDocIDs(in, leafBlockFPs[leafID], scratchDocIDs);
+
+        // Again, this time reading values and checking with the visitor
+        visitor.grow(count);
+        // NOTE: we don't do prefix coding, so we ignore commonPrefixLengths
+        assert scratchPackedValue.length == config.packedBytesLength;
+        BytesRefBuilder scratch = new BytesRefBuilder();
+        for (int i = 0; i < count; i++) {
+          readLine(in, scratch);
+          assert startsWith(scratch, BLOCK_VALUE);
+          BytesRef br = SimpleTextUtil.fromBytesRefString(stripPrefix(scratch, BLOCK_VALUE));
+          assert br.length == config.packedBytesLength;
+          System.arraycopy(br.bytes, br.offset, scratchPackedValue, 0, config.packedBytesLength);
+          visitor.visit(scratchDocIDs[i], scratchPackedValue);
+        }
+      } else {
+        pushLeft();
+        visitDocValues(visitor);
+        pop(true);
+        pushRight();
+        visitDocValues(visitor);
+        pop(false);
+      }
+    }
 
-      byte[] splitPackedValue = new byte[packedIndexBytesLength];
+    int readDocIDs(IndexInput in, long blockFP, int[] docIDs) throws IOException {
+      BytesRefBuilder scratch = new BytesRefBuilder();
+      in.seek(blockFP);
+      readLine(in, scratch);
+      int count = parseInt(scratch, BLOCK_COUNT);
+      for (int i = 0; i < count; i++) {
+        readLine(in, scratch);
+        docIDs[i] = parseInt(scratch, BLOCK_DOC_ID);
+      }
+      return count;
+    }
 
-      // Recurse on left sub-tree:
-      System.arraycopy(cellMaxPacked, 0, splitPackedValue, 0, packedIndexBytesLength);
-      System.arraycopy(
-          splitPackedValues, address, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
-      final long leftCost = estimatePointCount(state, 2 * nodeID, cellMinPacked, splitPackedValue);
+    public boolean isLeafNode() {
+      return nodeID >= leafNodeOffset;
+    }
 
-      // Recurse on right sub-tree:
-      System.arraycopy(cellMinPacked, 0, splitPackedValue, 0, packedIndexBytesLength);
-      System.arraycopy(
-          splitPackedValues, address, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
-      final long rightCost =
-          estimatePointCount(state, 2 * nodeID + 1, splitPackedValue, cellMaxPacked);
-      return leftCost + rightCost;
+    private int parseInt(BytesRefBuilder scratch, BytesRef prefix) {
+      assert startsWith(scratch, prefix);
+      return Integer.parseInt(stripPrefix(scratch, prefix));
     }
-  }
 
-  /** Copies the split value for this node into the provided byte array */
-  public void copySplitValue(int nodeID, byte[] splitPackedValue) {
-    int address = nodeID * bytesPerIndexEntry;
-    int splitDim;
-    if (numIndexDims == 1) {
-      splitDim = 0;
-    } else {
-      splitDim = splitPackedValues[address++] & 0xff;
+    private String stripPrefix(BytesRefBuilder scratch, BytesRef prefix) {
+      return new String(
+          scratch.bytes(), prefix.length, scratch.length() - prefix.length, StandardCharsets.UTF_8);
     }
 
-    assert splitDim < numIndexDims;
-    System.arraycopy(
-        splitPackedValues, address, splitPackedValue, splitDim * bytesPerDim, bytesPerDim);
-  }
+    private boolean startsWith(BytesRefBuilder scratch, BytesRef prefix) {
+      return StringHelper.startsWith(scratch.get(), prefix);
+    }
 
-  @Override
-  public long ramBytesUsed() {
-    return RamUsageEstimator.sizeOf(splitPackedValues) + RamUsageEstimator.sizeOf(leafBlockFPs);
+    private void readLine(IndexInput in, BytesRefBuilder scratch) throws IOException {
+      SimpleTextUtil.readLine(in, scratch);
+    }
   }
 
   @Override
@@ -353,18 +413,18 @@ final class SimpleTextBKDReader extends PointValues implements Accountable {
   }
 
   @Override
-  public int getNumDimensions() {
-    return numDims;
+  public int getNumDimensions() throws IOException {
+    return config.numDims;
   }
 
   @Override
-  public int getNumIndexDimensions() {
-    return numIndexDims;
+  public int getNumIndexDimensions() throws IOException {
+    return config.numIndexDims;
   }
 
   @Override
-  public int getBytesPerDimension() {
-    return bytesPerDim;
+  public int getBytesPerDimension() throws IOException {
+    return config.bytesPerDim;
   }
 
   @Override
@@ -376,26 +436,4 @@ final class SimpleTextBKDReader extends PointValues implements Accountable {
   public int getDocCount() {
     return docCount;
   }
-
-  public boolean isLeafNode(int nodeID) {
-    return nodeID >= leafNodeOffset;
-  }
-
-  private int parseInt(BytesRefBuilder scratch, BytesRef prefix) {
-    assert startsWith(scratch, prefix);
-    return Integer.parseInt(stripPrefix(scratch, prefix));
-  }
-
-  private String stripPrefix(BytesRefBuilder scratch, BytesRef prefix) {
-    return new String(
-        scratch.bytes(), prefix.length, scratch.length() - prefix.length, StandardCharsets.UTF_8);
-  }
-
-  private boolean startsWith(BytesRefBuilder scratch, BytesRef prefix) {
-    return StringHelper.startsWith(scratch.get(), prefix);
-  }
-
-  private void readLine(IndexInput in, BytesRefBuilder scratch) throws IOException {
-    SimpleTextUtil.readLine(in, scratch);
-  }
 }
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
index d46bad2..33fae93 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
@@ -40,7 +40,7 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.function.IntFunction;
 import org.apache.lucene.codecs.CodecUtil;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.index.PointValues.IntersectVisitor;
 import org.apache.lucene.index.PointValues.Relation;
 import org.apache.lucene.store.ChecksumIndexInput;
@@ -58,7 +58,7 @@ import org.apache.lucene.util.bkd.BKDConfig;
 import org.apache.lucene.util.bkd.BKDRadixSelector;
 import org.apache.lucene.util.bkd.BKDWriter;
 import org.apache.lucene.util.bkd.HeapPointWriter;
-import org.apache.lucene.util.bkd.MutablePointsReaderUtils;
+import org.apache.lucene.util.bkd.MutablePointTreeReaderUtils;
 import org.apache.lucene.util.bkd.OfflinePointWriter;
 import org.apache.lucene.util.bkd.PointReader;
 import org.apache.lucene.util.bkd.PointValue;
@@ -248,12 +248,12 @@ final class SimpleTextBKDWriter implements Closeable {
   }
 
   /**
-   * Write a field from a {@link MutablePointValues}. This way of writing points is faster than
+   * Write a field from a {@link MutablePointTree}. This way of writing points is faster than
    * regular writes with {@link BKDWriter#add} since there is opportunity for reordering points
    * before writing them to disk. This method does not use transient disk in order to reorder
    * points.
    */
-  public long writeField(IndexOutput out, String fieldName, MutablePointValues reader)
+  public long writeField(IndexOutput out, String fieldName, MutablePointTree reader)
       throws IOException {
     if (config.numIndexDims == 1) {
       return writeField1Dim(out, fieldName, reader);
@@ -264,7 +264,7 @@ final class SimpleTextBKDWriter implements Closeable {
 
   /* In the 2+D case, we recursively pick the split dimension, compute the
    * median value and partition other values around it. */
-  private long writeFieldNDims(IndexOutput out, String fieldName, MutablePointValues values)
+  private long writeFieldNDims(IndexOutput out, String fieldName, MutablePointTree values)
       throws IOException {
     if (pointCount != 0) {
       throw new IllegalStateException("cannot mix add and writeField");
@@ -355,13 +355,13 @@ final class SimpleTextBKDWriter implements Closeable {
 
   /* In the 1D case, we can simply sort points in ascending order and use the
    * same writing logic as we use at merge time. */
-  private long writeField1Dim(IndexOutput out, String fieldName, MutablePointValues reader)
+  private long writeField1Dim(IndexOutput out, String fieldName, MutablePointTree reader)
       throws IOException {
-    MutablePointsReaderUtils.sort(config, maxDoc, reader, 0, Math.toIntExact(reader.size()));
+    MutablePointTreeReaderUtils.sort(config, maxDoc, reader, 0, Math.toIntExact(reader.size()));
 
     final OneDimensionBKDWriter oneDimWriter = new OneDimensionBKDWriter(out);
 
-    reader.intersect(
+    reader.visitDocValues(
         new IntersectVisitor() {
 
           @Override
@@ -919,7 +919,7 @@ final class SimpleTextBKDWriter implements Closeable {
   private void build(
       int nodeID,
       int leafNodeOffset,
-      MutablePointValues reader,
+      MutablePointTree reader,
       int from,
       int to,
       IndexOutput out,
@@ -980,7 +980,7 @@ final class SimpleTextBKDWriter implements Closeable {
       }
 
       // sort by sortedDim
-      MutablePointsReaderUtils.sortByDim(
+      MutablePointTreeReaderUtils.sortByDim(
           config,
           sortedDim,
           commonPrefixLengths,
@@ -1033,7 +1033,7 @@ final class SimpleTextBKDWriter implements Closeable {
           break;
         }
       }
-      MutablePointsReaderUtils.partition(
+      MutablePointTreeReaderUtils.partition(
           config,
           maxDoc,
           splitDim,
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java
index 2061ec6..6825239 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java
@@ -73,7 +73,7 @@ class SimpleTextPointsWriter extends PointsWriter {
   @Override
   public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
-    PointValues values = reader.getValues(fieldInfo.name);
+    PointValues.PointTree values = reader.getValues(fieldInfo.name).getPointTree();
 
     BKDConfig config =
         new BKDConfig(
@@ -92,7 +92,7 @@ class SimpleTextPointsWriter extends PointsWriter {
             SimpleTextBKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP,
             values.size())) {
 
-      values.intersect(
+      values.visitDocValues(
           new IntersectVisitor() {
             @Override
             public void visit(int docID) {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/MutablePointValues.java b/lucene/core/src/java/org/apache/lucene/codecs/MutablePointTree.java
similarity index 62%
rename from lucene/core/src/java/org/apache/lucene/codecs/MutablePointValues.java
rename to lucene/core/src/java/org/apache/lucene/codecs/MutablePointTree.java
index a0b1ccc..dcb5694 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/MutablePointValues.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/MutablePointTree.java
@@ -16,19 +16,20 @@
  */
 package org.apache.lucene.codecs;
 
-import org.apache.lucene.index.PointValues;
+import org.apache.lucene.index.PointValues.IntersectVisitor;
+import org.apache.lucene.index.PointValues.PointTree;
 import org.apache.lucene.util.BytesRef;
 
 /**
- * {@link PointValues} whose order of points can be changed. This class is useful for codecs to
- * optimize flush.
+ * One leaf {@link PointTree} whose order of points can be changed. This class is useful for codecs
+ * to optimize flush.
  *
  * @lucene.internal
  */
-public abstract class MutablePointValues extends PointValues {
+public abstract class MutablePointTree implements PointTree {
 
   /** Sole constructor. */
-  protected MutablePointValues() {}
+  protected MutablePointTree() {}
 
   /** Set {@code packedValue} with a reference to the packed bytes of the i-th value. */
   public abstract void getValue(int i, BytesRef packedValue);
@@ -47,4 +48,39 @@ public abstract class MutablePointValues extends PointValues {
 
   /** Restore values between i-th and j-th(excluding) in temporary storage into original storage. */
   public abstract void restore(int i, int j);
+
+  @Override
+  public final PointTree clone() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public final boolean moveToChild() {
+    return false;
+  }
+
+  @Override
+  public final boolean moveToSibling() {
+    return false;
+  }
+
+  @Override
+  public final boolean moveToParent() {
+    return false;
+  }
+
+  @Override
+  public byte[] getMinPackedValue() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public byte[] getMaxPackedValue() {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void visitDocIDs(IntersectVisitor visitor) {
+    throw new UnsupportedOperationException();
+  }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/PointsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/PointsWriter.java
index ea20789..9d756e3 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/PointsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/PointsWriter.java
@@ -41,7 +41,6 @@ public abstract class PointsWriter implements Closeable {
    */
   protected void mergeOneField(MergeState mergeState, FieldInfo fieldInfo) throws IOException {
     long maxPointCount = 0;
-    int docCount = 0;
     for (int i = 0; i < mergeState.pointsReaders.length; i++) {
       PointsReader pointsReader = mergeState.pointsReaders[i];
       if (pointsReader != null) {
@@ -50,13 +49,11 @@ public abstract class PointsWriter implements Closeable {
           PointValues values = pointsReader.getValues(fieldInfo.name);
           if (values != null) {
             maxPointCount += values.size();
-            docCount += values.getDocCount();
           }
         }
       }
     }
     final long finalMaxPointCount = maxPointCount;
-    final int finalDocCount = docCount;
     writeField(
         fieldInfo,
         new PointsReader() {
@@ -73,59 +70,104 @@ public abstract class PointsWriter implements Closeable {
             return new PointValues() {
 
               @Override
-              public void intersect(IntersectVisitor mergedVisitor) throws IOException {
-                for (int i = 0; i < mergeState.pointsReaders.length; i++) {
-                  PointsReader pointsReader = mergeState.pointsReaders[i];
-                  if (pointsReader == null) {
-                    // This segment has no points
-                    continue;
+              public PointTree getPointTree() {
+                return new PointTree() {
+
+                  @Override
+                  public PointTree clone() {
+                    throw new UnsupportedOperationException();
                   }
-                  FieldInfo readerFieldInfo = mergeState.fieldInfos[i].fieldInfo(fieldName);
-                  if (readerFieldInfo == null) {
-                    // This segment never saw this field
-                    continue;
+
+                  @Override
+                  public boolean moveToChild() {
+                    return false;
                   }
 
-                  if (readerFieldInfo.getPointDimensionCount() == 0) {
-                    // This segment saw this field, but the field did not index points in it:
-                    continue;
+                  @Override
+                  public boolean moveToSibling() {
+                    return false;
                   }
 
-                  PointValues values = pointsReader.getValues(fieldName);
-                  if (values == null) {
-                    continue;
+                  @Override
+                  public boolean moveToParent() {
+                    return false;
                   }
-                  MergeState.DocMap docMap = mergeState.docMaps[i];
-                  values.intersect(
-                      new IntersectVisitor() {
-                        @Override
-                        public void visit(int docID) {
-                          // Should never be called because our compare method never returns
-                          // Relation.CELL_INSIDE_QUERY
-                          throw new IllegalStateException();
-                        }
-
-                        @Override
-                        public void visit(int docID, byte[] packedValue) throws IOException {
-                          int newDocID = docMap.get(docID);
-                          if (newDocID != -1) {
-                            // Not deleted:
-                            mergedVisitor.visit(newDocID, packedValue);
-                          }
-                        }
-
-                        @Override
-                        public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-                          // Forces this segment's PointsReader to always visit all docs + values:
-                          return Relation.CELL_CROSSES_QUERY;
-                        }
-                      });
-                }
-              }
 
-              @Override
-              public long estimatePointCount(IntersectVisitor visitor) {
-                throw new UnsupportedOperationException();
+                  @Override
+                  public byte[] getMinPackedValue() {
+                    throw new UnsupportedOperationException();
+                  }
+
+                  @Override
+                  public byte[] getMaxPackedValue() {
+                    throw new UnsupportedOperationException();
+                  }
+
+                  @Override
+                  public long size() {
+                    return finalMaxPointCount;
+                  }
+
+                  @Override
+                  public void visitDocIDs(IntersectVisitor visitor) {
+                    throw new UnsupportedOperationException();
+                  }
+
+                  @Override
+                  public void visitDocValues(IntersectVisitor mergedVisitor) throws IOException {
+                    for (int i = 0; i < mergeState.pointsReaders.length; i++) {
+                      PointsReader pointsReader = mergeState.pointsReaders[i];
+                      if (pointsReader == null) {
+                        // This segment has no points
+                        continue;
+                      }
+                      FieldInfo readerFieldInfo = mergeState.fieldInfos[i].fieldInfo(fieldName);
+                      if (readerFieldInfo == null) {
+                        // This segment never saw this field
+                        continue;
+                      }
+
+                      if (readerFieldInfo.getPointDimensionCount() == 0) {
+                        // This segment saw this field, but the field did not index points in it:
+                        continue;
+                      }
+
+                      PointValues values = pointsReader.getValues(fieldName);
+                      if (values == null) {
+                        continue;
+                      }
+                      MergeState.DocMap docMap = mergeState.docMaps[i];
+                      values
+                          .getPointTree()
+                          .visitDocValues(
+                              new IntersectVisitor() {
+                                @Override
+                                public void visit(int docID) {
+                                  // Should never be called during #visitDocValues()
+                                  throw new IllegalStateException();
+                                }
+
+                                @Override
+                                public void visit(int docID, byte[] packedValue)
+                                    throws IOException {
+                                  int newDocID = docMap.get(docID);
+                                  if (newDocID != -1) {
+                                    // Not deleted:
+                                    mergedVisitor.visit(newDocID, packedValue);
+                                  }
+                                }
+
+                                @Override
+                                public Relation compare(
+                                    byte[] minPackedValue, byte[] maxPackedValue) {
+                                  // Forces this segment's PointsReader to always visit all docs +
+                                  // values:
+                                  return Relation.CELL_CROSSES_QUERY;
+                                }
+                              });
+                    }
+                  }
+                };
               }
 
               @Override
@@ -160,7 +202,7 @@ public abstract class PointsWriter implements Closeable {
 
               @Override
               public int getDocCount() {
-                return finalDocCount;
+                throw new UnsupportedOperationException();
               }
             };
           }
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsReader.java
index 0e6b5d6..dc3d8c9 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsReader.java
@@ -35,7 +35,7 @@ import org.apache.lucene.util.bkd.BKDReader;
 public class Lucene90PointsReader extends PointsReader {
   final IndexInput indexIn, dataIn;
   final SegmentReadState readState;
-  final Map<Integer, BKDReader> readers = new HashMap<>();
+  final Map<Integer, PointValues> readers = new HashMap<>();
 
   /** Sole constructor */
   public Lucene90PointsReader(SegmentReadState readState) throws IOException {
@@ -97,7 +97,7 @@ public class Lucene90PointsReader extends PointsReader {
             } else if (fieldNumber < 0) {
               throw new CorruptIndexException("Illegal field number: " + fieldNumber, metaIn);
             }
-            BKDReader reader = new BKDReader(metaIn, indexIn, dataIn);
+            PointValues reader = new BKDReader(metaIn, indexIn, dataIn);
             readers.put(fieldNumber, reader);
           }
           indexLength = metaIn.readLong();
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsWriter.java
index 3dc7dfa..5ea4732 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90PointsWriter.java
@@ -20,7 +20,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import org.apache.lucene.codecs.CodecUtil;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.codecs.PointsReader;
 import org.apache.lucene.codecs.PointsWriter;
 import org.apache.lucene.index.FieldInfo;
@@ -34,7 +34,6 @@ import org.apache.lucene.index.SegmentWriteState;
 import org.apache.lucene.store.IndexOutput;
 import org.apache.lucene.util.IOUtils;
 import org.apache.lucene.util.bkd.BKDConfig;
-import org.apache.lucene.util.bkd.BKDReader;
 import org.apache.lucene.util.bkd.BKDWriter;
 
 /** Writes dimensional values */
@@ -119,7 +118,7 @@ public class Lucene90PointsWriter extends PointsWriter {
   @Override
   public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
-    PointValues values = reader.getValues(fieldInfo.name);
+    PointValues.PointTree values = reader.getValues(fieldInfo.name).getPointTree();
 
     BKDConfig config =
         new BKDConfig(
@@ -137,10 +136,10 @@ public class Lucene90PointsWriter extends PointsWriter {
             maxMBSortInHeap,
             values.size())) {
 
-      if (values instanceof MutablePointValues) {
+      if (values instanceof MutablePointTree) {
         Runnable finalizer =
             writer.writeField(
-                metaOut, indexOut, dataOut, fieldInfo.name, (MutablePointValues) values);
+                metaOut, indexOut, dataOut, fieldInfo.name, (MutablePointTree) values);
         if (finalizer != null) {
           metaOut.writeInt(fieldInfo.number);
           finalizer.run();
@@ -148,7 +147,7 @@ public class Lucene90PointsWriter extends PointsWriter {
         return;
       }
 
-      values.intersect(
+      values.visitDocValues(
           new IntersectVisitor() {
             @Override
             public void visit(int docID) {
@@ -234,7 +233,7 @@ public class Lucene90PointsWriter extends PointsWriter {
                   config,
                   maxMBSortInHeap,
                   totMaxSize)) {
-            List<BKDReader> bkdReaders = new ArrayList<>();
+            List<PointValues> pointValues = new ArrayList<>();
             List<MergeState.DocMap> docMaps = new ArrayList<>();
             for (int i = 0; i < mergeState.pointsReaders.length; i++) {
               PointsReader reader = mergeState.pointsReaders[i];
@@ -253,16 +252,16 @@ public class Lucene90PointsWriter extends PointsWriter {
                 FieldInfos readerFieldInfos = mergeState.fieldInfos[i];
                 FieldInfo readerFieldInfo = readerFieldInfos.fieldInfo(fieldInfo.name);
                 if (readerFieldInfo != null && readerFieldInfo.getPointDimensionCount() > 0) {
-                  BKDReader bkdReader = reader90.readers.get(readerFieldInfo.number);
-                  if (bkdReader != null) {
-                    bkdReaders.add(bkdReader);
+                  PointValues aPointValues = reader90.readers.get(readerFieldInfo.number);
+                  if (aPointValues != null) {
+                    pointValues.add(aPointValues);
                     docMaps.add(mergeState.docMaps[i]);
                   }
                 }
               }
             }
 
-            Runnable finalizer = writer.merge(metaOut, indexOut, dataOut, docMaps, bkdReaders);
+            Runnable finalizer = writer.merge(metaOut, indexOut, dataOut, docMaps, pointValues);
             if (finalizer != null) {
               metaOut.writeInt(fieldInfo.number);
               finalizer.run();
diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
index 0313e7a..dfdd4e6 100644
--- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java
@@ -373,15 +373,9 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
     }
 
     @Override
-    public void intersect(IntersectVisitor visitor) throws IOException {
+    public PointTree getPointTree() throws IOException {
       checkAndThrow();
-      in.intersect(new ExitableIntersectVisitor(visitor, queryTimeout));
-    }
-
-    @Override
-    public long estimatePointCount(IntersectVisitor visitor) {
-      checkAndThrow();
-      return in.estimatePointCount(visitor);
+      return new ExitablePointTree(in, in.getPointTree(), queryTimeout);
     }
 
     @Override
@@ -427,17 +421,117 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
     }
   }
 
-  private static class ExitableIntersectVisitor implements PointValues.IntersectVisitor {
+  private static class ExitablePointTree implements PointValues.PointTree {
 
     private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
 
-    private final PointValues.IntersectVisitor in;
+    private final PointValues pointValues;
+    private final PointValues.PointTree in;
+    private final ExitableIntersectVisitor exitableIntersectVisitor;
     private final QueryTimeout queryTimeout;
     private int calls;
 
-    private ExitableIntersectVisitor(PointValues.IntersectVisitor in, QueryTimeout queryTimeout) {
+    private ExitablePointTree(
+        PointValues pointValues, PointValues.PointTree in, QueryTimeout queryTimeout) {
+      this.pointValues = pointValues;
       this.in = in;
       this.queryTimeout = queryTimeout;
+      this.exitableIntersectVisitor = new ExitableIntersectVisitor(queryTimeout);
+    }
+
+    /**
+     * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
+     * if {@link Thread#interrupted()} returns true.
+     */
+    private void checkAndThrowWithSampling() {
+      if (calls++ % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
+        checkAndThrow();
+      }
+    }
+
+    private void checkAndThrow() {
+      if (queryTimeout.shouldExit()) {
+        throw new ExitingReaderException(
+            "The request took too long to intersect point values. Timeout: "
+                + queryTimeout.toString()
+                + ", PointValues="
+                + pointValues);
+      } else if (Thread.interrupted()) {
+        throw new ExitingReaderException(
+            "Interrupted while intersecting point values. PointValues=" + in);
+      }
+    }
+
+    @Override
+    public PointValues.PointTree clone() {
+      checkAndThrow();
+      return new ExitablePointTree(pointValues, in.clone(), queryTimeout);
+    }
+
+    @Override
+    public boolean moveToChild() throws IOException {
+      checkAndThrowWithSampling();
+      return in.moveToChild();
+    }
+
+    @Override
+    public boolean moveToSibling() throws IOException {
+      checkAndThrowWithSampling();
+      return in.moveToSibling();
+    }
+
+    @Override
+    public boolean moveToParent() throws IOException {
+      checkAndThrowWithSampling();
+      return in.moveToParent();
+    }
+
+    @Override
+    public byte[] getMinPackedValue() {
+      checkAndThrowWithSampling();
+      return in.getMinPackedValue();
+    }
+
+    @Override
+    public byte[] getMaxPackedValue() {
+      checkAndThrowWithSampling();
+      return in.getMaxPackedValue();
+    }
+
+    @Override
+    public long size() {
+      checkAndThrow();
+      return in.size();
+    }
+
+    @Override
+    public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException {
+      checkAndThrow();
+      in.visitDocIDs(visitor);
+    }
+
+    @Override
+    public void visitDocValues(PointValues.IntersectVisitor visitor) throws IOException {
+      checkAndThrow();
+      exitableIntersectVisitor.setIntersectVisitor(visitor);
+      in.visitDocValues(exitableIntersectVisitor);
+    }
+  }
+
+  private static class ExitableIntersectVisitor implements PointValues.IntersectVisitor {
+
+    private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
+
+    private PointValues.IntersectVisitor in;
+    private final QueryTimeout queryTimeout;
+    private int calls;
+
+    private ExitableIntersectVisitor(QueryTimeout queryTimeout) {
+      this.queryTimeout = queryTimeout;
+    }
+
+    private void setIntersectVisitor(PointValues.IntersectVisitor in) {
+      this.in = in;
     }
 
     /**
diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java
index 3a7fbfb..51fcf1f 100644
--- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java
@@ -17,6 +17,7 @@
 package org.apache.lucene.index;
 
 import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.math.BigInteger;
 import java.net.InetAddress;
 import org.apache.lucene.document.BinaryPoint;
@@ -227,8 +228,59 @@ public abstract class PointValues {
     CELL_CROSSES_QUERY
   };
 
+  /** Create a new {@link PointTree} to navigate the index */
+  public abstract PointTree getPointTree() throws IOException;
+
   /**
-   * We recurse the BKD tree, using a provided instance of this to guide the recursion.
+   * Basic operations to read the KD-tree.
+   *
+   * @lucene.experimental
+   */
+  public interface PointTree extends Cloneable {
+
+    /**
+     * Clone, the current node becomes the root of the new tree. The method should not be called
+     * after a successful call to {@link #moveToParent()}
+     */
+    PointTree clone();
+
+    /**
+     * Move to the first child node and return {@code true} upon success. Returns {@code false} for
+     * leaf nodes and {@code true} otherwise. The method should not be called after a successful
+     * call to {@link #moveToParent()}
+     */
+    boolean moveToChild() throws IOException;
+
+    /**
+     * Move to the next sibling node and return {@code true} upon success. Returns {@code false} if
+     * the current node has no more siblings.
+     */
+    boolean moveToSibling() throws IOException;
+
+    /**
+     * Move to the parent node and return {@code true} upon success. Returns {@code false} for the
+     * root node and {@code true} otherwise.
+     */
+    boolean moveToParent() throws IOException;
+
+    /** Return the minimum packed value of the current node. */
+    byte[] getMinPackedValue();
+
+    /** Return the maximum packed value of the current node. */
+    byte[] getMaxPackedValue();
+
+    /** Return the number of points below the current node. */
+    long size();
+
+    /** Visit all the docs below the current node. */
+    void visitDocIDs(IntersectVisitor visitor) throws IOException;
+
+    /** Visit all the docs and values below the current node. */
+    void visitDocValues(IntersectVisitor visitor) throws IOException;
+  }
+
+  /**
+   * We recurse the {@link PointTree}, using a provided instance of this to guide the recursion.
    *
    * @lucene.experimental
    */
@@ -273,13 +325,85 @@ public abstract class PointValues {
    * Finds all documents and points matching the provided visitor. This method does not enforce live
    * documents, so it's up to the caller to test whether each document is deleted, if necessary.
    */
-  public abstract void intersect(IntersectVisitor visitor) throws IOException;
+  public final void intersect(IntersectVisitor visitor) throws IOException {
+    final PointTree pointTree = getPointTree();
+    intersect(visitor, pointTree);
+    assert pointTree.moveToParent() == false;
+  }
+
+  private void intersect(IntersectVisitor visitor, PointTree pointTree) throws IOException {
+    Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
+    switch (r) {
+      case CELL_OUTSIDE_QUERY:
+        // This cell is fully outside the query shape: stop recursing
+        break;
+      case CELL_INSIDE_QUERY:
+        // This cell is fully inside the query shape: recursively add all points in this cell
+        // without filtering
+        pointTree.visitDocIDs(visitor);
+        break;
+      case CELL_CROSSES_QUERY:
+        // The cell crosses the shape boundary, or the cell fully contains the query, so we fall
+        // through and do full filtering:
+        if (pointTree.moveToChild()) {
+          do {
+            intersect(visitor, pointTree);
+          } while (pointTree.moveToSibling());
+          pointTree.moveToParent();
+        } else {
+          // TODO: we can assert that the first value here in fact matches what the pointTree
+          // claimed?
+          // Leaf node; scan and filter all points in this block:
+          pointTree.visitDocValues(visitor);
+        }
+        break;
+      default:
+        throw new IllegalArgumentException("Unreachable code");
+    }
+  }
 
   /**
    * Estimate the number of points that would be visited by {@link #intersect} with the given {@link
    * IntersectVisitor}. This should run many times faster than {@link #intersect(IntersectVisitor)}.
    */
-  public abstract long estimatePointCount(IntersectVisitor visitor);
+  public final long estimatePointCount(IntersectVisitor visitor) {
+    try {
+      final PointTree pointTree = getPointTree();
+      final long count = estimatePointCount(visitor, pointTree);
+      assert pointTree.moveToParent() == false;
+      return count;
+    } catch (IOException ioe) {
+      throw new UncheckedIOException(ioe);
+    }
+  }
+
+  private long estimatePointCount(IntersectVisitor visitor, PointTree pointTree)
+      throws IOException {
+    Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
+    switch (r) {
+      case CELL_OUTSIDE_QUERY:
+        // This cell is fully outside the query shape: no points added
+        return 0L;
+      case CELL_INSIDE_QUERY:
+        // This cell is fully inside the query shape: add all points
+        return pointTree.size();
+      case CELL_CROSSES_QUERY:
+        // The cell crosses the shape boundary: keep recursing
+        if (pointTree.moveToChild()) {
+          long cost = 0;
+          do {
+            cost += estimatePointCount(visitor, pointTree);
+          } while (pointTree.moveToSibling());
+          pointTree.moveToParent();
+          return cost;
+        } else {
+          // Assume half the points matched
+          return (pointTree.size() + 1) / 2;
+        }
+      default:
+        throw new IllegalArgumentException("Unreachable code");
+    }
+  }
 
   /**
    * Estimate the number of documents that would be matched by {@link #intersect} with the given
@@ -288,7 +412,7 @@ public abstract class PointValues {
    *
    * @see DocIdSetIterator#cost
    */
-  public long estimateDocCount(IntersectVisitor visitor) {
+  public final long estimateDocCount(IntersectVisitor visitor) {
     long estimatedPointCount = estimatePointCount(visitor);
     int docCount = getDocCount();
     double size = size();
diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/PointValuesWriter.java
index 599ca77..0dae19a 100644
--- a/lucene/core/src/java/org/apache/lucene/index/PointValuesWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/index/PointValuesWriter.java
@@ -17,7 +17,7 @@
 package org.apache.lucene.index;
 
 import java.io.IOException;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.codecs.PointsReader;
 import org.apache.lucene.codecs.PointsWriter;
 import org.apache.lucene.store.DataOutput;
@@ -92,8 +92,8 @@ class PointValuesWriter {
   public void flush(SegmentWriteState state, Sorter.DocMap sortMap, PointsWriter writer)
       throws IOException {
     final PagedBytes.Reader bytesReader = bytes.freeze(false);
-    PointValues points =
-        new MutablePointValues() {
+    MutablePointTree points =
+        new MutablePointTree() {
           final int[] ords = new int[numPoints];
           int[] temp;
 
@@ -104,7 +104,12 @@ class PointValuesWriter {
           }
 
           @Override
-          public void intersect(IntersectVisitor visitor) throws IOException {
+          public long size() {
+            return numPoints;
+          }
+
+          @Override
+          public void visitDocValues(PointValues.IntersectVisitor visitor) throws IOException {
             final BytesRef scratch = new BytesRef();
             final byte[] packedValue = new byte[packedBytesLength];
             for (int i = 0; i < numPoints; i++) {
@@ -116,46 +121,6 @@ class PointValuesWriter {
           }
 
           @Override
-          public long estimatePointCount(IntersectVisitor visitor) {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public byte[] getMinPackedValue() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public byte[] getMaxPackedValue() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public int getNumDimensions() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public int getNumIndexDimensions() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public int getBytesPerDimension() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public long size() {
-            return numPoints;
-          }
-
-          @Override
-          public int getDocCount() {
-            return numDocs;
-          }
-
-          @Override
           public void swap(int i, int j) {
             int tmp = ords[i];
             ords[i] = ords[j];
@@ -195,11 +160,11 @@ class PointValuesWriter {
           }
         };
 
-    final PointValues values;
+    final PointValues.PointTree values;
     if (sortMap == null) {
       values = points;
     } else {
-      values = new MutableSortingPointValues((MutablePointValues) points, sortMap);
+      values = new MutableSortingPointValues(points, sortMap);
     }
     PointsReader reader =
         new PointsReader() {
@@ -208,7 +173,47 @@ class PointValuesWriter {
             if (fieldName.equals(fieldInfo.name) == false) {
               throw new IllegalArgumentException("fieldName must be the same");
             }
-            return values;
+            return new PointValues() {
+              @Override
+              public PointTree getPointTree() throws IOException {
+                return values;
+              }
+
+              @Override
+              public byte[] getMinPackedValue() throws IOException {
+                throw new UnsupportedOperationException();
+              }
+
+              @Override
+              public byte[] getMaxPackedValue() throws IOException {
+                throw new UnsupportedOperationException();
+              }
+
+              @Override
+              public int getNumDimensions() throws IOException {
+                throw new UnsupportedOperationException();
+              }
+
+              @Override
+              public int getNumIndexDimensions() throws IOException {
+                throw new UnsupportedOperationException();
+              }
+
+              @Override
+              public int getBytesPerDimension() throws IOException {
+                throw new UnsupportedOperationException();
+              }
+
+              @Override
+              public long size() {
+                throw new UnsupportedOperationException();
+              }
+
+              @Override
+              public int getDocCount() {
+                throw new UnsupportedOperationException();
+              }
+            };
           }
 
           @Override
@@ -222,20 +227,25 @@ class PointValuesWriter {
     writer.writeField(fieldInfo, reader);
   }
 
-  static final class MutableSortingPointValues extends MutablePointValues {
+  static final class MutableSortingPointValues extends MutablePointTree {
 
-    private final MutablePointValues in;
+    private final MutablePointTree in;
     private final Sorter.DocMap docMap;
 
-    public MutableSortingPointValues(final MutablePointValues in, Sorter.DocMap docMap) {
+    public MutableSortingPointValues(final MutablePointTree in, Sorter.DocMap docMap) {
       this.in = in;
       this.docMap = docMap;
     }
 
     @Override
-    public void intersect(IntersectVisitor visitor) throws IOException {
-      in.intersect(
-          new IntersectVisitor() {
+    public long size() {
+      return in.size();
+    }
+
+    @Override
+    public void visitDocValues(PointValues.IntersectVisitor visitor) throws IOException {
+      in.visitDocValues(
+          new PointValues.IntersectVisitor() {
             @Override
             public void visit(int docID) throws IOException {
               visitor.visit(docMap.oldToNew(docID));
@@ -247,53 +257,13 @@ class PointValuesWriter {
             }
 
             @Override
-            public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+            public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
               return visitor.compare(minPackedValue, maxPackedValue);
             }
           });
     }
 
     @Override
-    public long estimatePointCount(IntersectVisitor visitor) {
-      return in.estimatePointCount(visitor);
-    }
-
-    @Override
-    public byte[] getMinPackedValue() throws IOException {
-      return in.getMinPackedValue();
-    }
-
-    @Override
-    public byte[] getMaxPackedValue() throws IOException {
-      return in.getMaxPackedValue();
-    }
-
-    @Override
-    public int getNumDimensions() throws IOException {
-      return in.getNumDimensions();
-    }
-
-    @Override
-    public int getNumIndexDimensions() throws IOException {
-      return in.getNumIndexDimensions();
-    }
-
-    @Override
-    public int getBytesPerDimension() throws IOException {
-      return in.getBytesPerDimension();
-    }
-
-    @Override
-    public long size() {
-      return in.size();
-    }
-
-    @Override
-    public int getDocCount() {
-      return in.getDocCount();
-    }
-
-    @Override
     public void getValue(int i, BytesRef packedValue) {
       in.getValue(i, packedValue);
     }
diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
index f808c90..32fbb8d 100644
--- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
+++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java
@@ -82,29 +82,8 @@ public final class SortingCodecReader extends FilterCodecReader {
     }
 
     @Override
-    public void intersect(IntersectVisitor visitor) throws IOException {
-      in.intersect(
-          new IntersectVisitor() {
-            @Override
-            public void visit(int docID) throws IOException {
-              visitor.visit(docMap.oldToNew(docID));
-            }
-
-            @Override
-            public void visit(int docID, byte[] packedValue) throws IOException {
-              visitor.visit(docMap.oldToNew(docID), packedValue);
-            }
-
-            @Override
-            public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-              return visitor.compare(minPackedValue, maxPackedValue);
-            }
-          });
-    }
-
-    @Override
-    public long estimatePointCount(IntersectVisitor visitor) {
-      return in.estimatePointCount(visitor);
+    public PointTree getPointTree() throws IOException {
+      return new SortingPointTree(in.getPointTree(), docMap);
     }
 
     @Override
@@ -143,6 +122,96 @@ public final class SortingCodecReader extends FilterCodecReader {
     }
   }
 
+  private static class SortingPointTree implements PointValues.PointTree {
+
+    private final PointValues.PointTree indexTree;
+    private final Sorter.DocMap docMap;
+    private final SortingIntersectVisitor sortingIntersectVisitor;
+
+    SortingPointTree(PointValues.PointTree indexTree, Sorter.DocMap docMap) {
+      this.indexTree = indexTree;
+      this.docMap = docMap;
+      this.sortingIntersectVisitor = new SortingIntersectVisitor(docMap);
+    }
+
+    @Override
+    public PointValues.PointTree clone() {
+      return new SortingPointTree(indexTree.clone(), docMap);
+    }
+
+    @Override
+    public boolean moveToChild() throws IOException {
+      return indexTree.moveToChild();
+    }
+
+    @Override
+    public boolean moveToSibling() throws IOException {
+      return indexTree.moveToSibling();
+    }
+
+    @Override
+    public boolean moveToParent() throws IOException {
+      return indexTree.moveToParent();
+    }
+
+    @Override
+    public byte[] getMinPackedValue() {
+      return indexTree.getMinPackedValue();
+    }
+
+    @Override
+    public byte[] getMaxPackedValue() {
+      return indexTree.getMaxPackedValue();
+    }
+
+    @Override
+    public long size() {
+      return indexTree.size();
+    }
+
+    @Override
+    public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException {
+      sortingIntersectVisitor.setIntersectVisitor(visitor);
+      indexTree.visitDocIDs(sortingIntersectVisitor);
+    }
+
+    @Override
+    public void visitDocValues(PointValues.IntersectVisitor visitor) throws IOException {
+      sortingIntersectVisitor.setIntersectVisitor(visitor);
+      indexTree.visitDocValues(sortingIntersectVisitor);
+    }
+  }
+
+  private static class SortingIntersectVisitor implements PointValues.IntersectVisitor {
+
+    private final Sorter.DocMap docMap;
+
+    private PointValues.IntersectVisitor visitor;
+
+    SortingIntersectVisitor(Sorter.DocMap docMap) {
+      this.docMap = docMap;
+    }
+
+    private void setIntersectVisitor(PointValues.IntersectVisitor visitor) {
+      this.visitor = visitor;
+    }
+
+    @Override
+    public void visit(int docID) throws IOException {
+      visitor.visit(docMap.oldToNew(docID));
+    }
+
+    @Override
+    public void visit(int docID, byte[] packedValue) throws IOException {
+      visitor.visit(docMap.oldToNew(docID), packedValue);
+    }
+
+    @Override
+    public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+      return visitor.compare(minPackedValue, maxPackedValue);
+    }
+  }
+
   /**
    * Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
    * . If the reader is already sorted, this method might return the reader as-is.
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java
index 6a9683d..5205367 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDReader.java
@@ -17,26 +17,22 @@
 package org.apache.lucene.util.bkd;
 
 import java.io.IOException;
-import java.io.UncheckedIOException;
-import java.util.Arrays;
 import org.apache.lucene.codecs.CodecUtil;
 import org.apache.lucene.index.CorruptIndexException;
 import org.apache.lucene.index.PointValues;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.MathUtil;
 
 /**
- * Handles intersection of an multi-dimensional shape in byte[] space with a block KD-tree
- * previously written with {@link BKDWriter}.
+ * Handles reading a block KD-tree in byte[] space previously written with {@link BKDWriter}.
  *
  * @lucene.experimental
  */
-public final class BKDReader extends PointValues {
+public class BKDReader extends PointValues {
 
-  // Packed array of byte[] holding all split values in the full binary tree:
-  final int leafNodeOffset;
   final BKDConfig config;
   final int numLeaves;
   final IndexInput in;
@@ -71,22 +67,17 @@ public final class BKDReader extends PointValues {
     // Read index:
     numLeaves = metaIn.readVInt();
     assert numLeaves > 0;
-    leafNodeOffset = numLeaves;
 
     minPackedValue = new byte[config.packedIndexBytesLength];
     maxPackedValue = new byte[config.packedIndexBytesLength];
 
     metaIn.readBytes(minPackedValue, 0, config.packedIndexBytesLength);
     metaIn.readBytes(maxPackedValue, 0, config.packedIndexBytesLength);
-
+    final ArrayUtil.ByteArrayComparator comparator =
+        ArrayUtil.getUnsignedComparator(config.bytesPerDim);
     for (int dim = 0; dim < config.numIndexDims; dim++) {
-      if (Arrays.compareUnsigned(
-              minPackedValue,
-              dim * config.bytesPerDim,
-              dim * config.bytesPerDim + config.bytesPerDim,
-              maxPackedValue,
-              dim * config.bytesPerDim,
-              dim * config.bytesPerDim + config.bytesPerDim)
+      if (comparator.compare(
+              minPackedValue, dim * config.bytesPerDim, maxPackedValue, dim * config.bytesPerDim)
           > 0) {
         throw new CorruptIndexException(
             "minPackedValue "
@@ -116,152 +107,319 @@ public final class BKDReader extends PointValues {
     this.in = dataIn;
   }
 
-  long getMinLeafBlockFP() {
-    return minLeafBlockFP;
-  }
-
-  /**
-   * Used to walk the off-heap index. The format takes advantage of the limited access pattern to
-   * the BKD tree at search time, i.e. starting at the root node and recursing downwards one child
-   * at a time.
-   *
-   * @lucene.internal
-   */
-  public class IndexTree implements Cloneable {
+  @Override
+  public PointTree getPointTree() throws IOException {
+    return new BKDPointTree(
+        packedIndex.clone(),
+        this.in.clone(),
+        config,
+        numLeaves,
+        version,
+        pointCount,
+        minPackedValue,
+        maxPackedValue);
+  }
+
+  private static class BKDPointTree implements PointTree {
     private int nodeID;
+    // during clone, the node root can be different to 1
+    private final int nodeRoot;
     // level is 1-based so that we can do level-1 w/o checking each time:
     private int level;
-    private int splitDim;
-    private final byte[][] splitPackedValueStack;
     // used to read the packed tree off-heap
-    private final IndexInput in;
+    private final IndexInput innerNodes;
+    // used to read the packed leaves off-heap
+    private final IndexInput leafNodes;
     // holds the minimum (left most) leaf block file pointer for each level we've recursed to:
     private final long[] leafBlockFPStack;
     // holds the address, in the off-heap index, of the right-node of each level:
     private final int[] rightNodePositions;
-    // holds the splitDim for each level:
-    private final int[] splitDims;
+    // holds the splitDim position for each level:
+    private final int[] splitDimsPos;
     // true if the per-dim delta we read for the node at this level is a negative offset vs. the
     // last split on this dim; this is a packed
     // 2D array, i.e. to access array[level][dim] you read from negativeDeltas[level*numDims+dim].
     // this will be true if the last time we
     // split on this dimension, we next pushed to the left sub-tree:
     private final boolean[] negativeDeltas;
-    // holds the packed per-level split values; the intersect method uses this to save the cell
-    // min/max as it recurses:
+    // holds the packed per-level split values
     private final byte[][] splitValuesStack;
-    // scratch value to return from getPackedValue:
-    private final BytesRef scratch;
-
-    IndexTree() {
-      this(packedIndex.clone(), 1, 1);
+    // holds the min / max value of the current node.
+    private final byte[] minPackedValue, maxPackedValue;
+    // holds the previous value of the split dimension
+    private final byte[][] splitDimValueStack;
+    // tree parameters
+    private final BKDConfig config;
+    // number of leaves
+    private final int leafNodeOffset;
+    // version of the index
+    private final int version;
+    // last node might not be fully populated
+    private final int lastLeafNodePointCount;
+    // right most leaf node ID
+    private final int rightMostLeafNode;
+    // helper objects for reading doc values
+    private final byte[] scratchDataPackedValue,
+        scratchMinIndexPackedValue,
+        scratchMaxIndexPackedValue;
+    private final int[] commonPrefixLengths;
+    private final BKDReaderDocIDSetIterator scratchIterator;
+
+    private BKDPointTree(
+        IndexInput innerNodes,
+        IndexInput leafNodes,
+        BKDConfig config,
+        int numLeaves,
+        int version,
+        long pointCount,
+        byte[] minPackedValue,
+        byte[] maxPackedValue)
+        throws IOException {
+      this(
+          innerNodes,
+          leafNodes,
+          config,
+          numLeaves,
+          version,
+          Math.toIntExact(pointCount % config.maxPointsInLeafNode),
+          1,
+          1,
+          minPackedValue,
+          maxPackedValue,
+          new BKDReaderDocIDSetIterator(config.maxPointsInLeafNode),
+          new byte[config.packedBytesLength],
+          new byte[config.packedIndexBytesLength],
+          new byte[config.packedIndexBytesLength],
+          new int[config.numDims]);
       // read root node
       readNodeData(false);
     }
 
-    private IndexTree(IndexInput in, int nodeID, int level) {
-      int treeDepth = getTreeDepth();
-      splitPackedValueStack = new byte[treeDepth + 1][];
+    private BKDPointTree(
+        IndexInput innerNodes,
+        IndexInput leafNodes,
+        BKDConfig config,
+        int numLeaves,
+        int version,
+        int lastLeafNodePointCount,
+        int nodeID,
+        int level,
+        byte[] minPackedValue,
+        byte[] maxPackedValue,
+        BKDReaderDocIDSetIterator scratchIterator,
+        byte[] scratchDataPackedValue,
+        byte[] scratchMinIndexPackedValue,
+        byte[] scratchMaxIndexPackedValue,
+        int[] commonPrefixLengths) {
+      this.config = config;
+      this.version = version;
       this.nodeID = nodeID;
+      this.nodeRoot = nodeID;
       this.level = level;
-      splitPackedValueStack[level] = new byte[config.packedIndexBytesLength];
-      leafBlockFPStack = new long[treeDepth + 1];
-      rightNodePositions = new int[treeDepth + 1];
-      splitValuesStack = new byte[treeDepth + 1][];
-      splitDims = new int[treeDepth + 1];
-      negativeDeltas = new boolean[config.numIndexDims * (treeDepth + 1)];
-      this.in = in;
+      leafNodeOffset = numLeaves;
+      this.innerNodes = innerNodes;
+      this.leafNodes = leafNodes;
+      this.minPackedValue = minPackedValue.clone();
+      this.maxPackedValue = maxPackedValue.clone();
+      // stack arrays that keep information at different levels
+      int treeDepth = getTreeDepth(numLeaves);
+      splitDimValueStack = new byte[treeDepth][];
+      splitValuesStack = new byte[treeDepth][];
       splitValuesStack[0] = new byte[config.packedIndexBytesLength];
-      scratch = new BytesRef();
-      scratch.length = config.bytesPerDim;
+      leafBlockFPStack = new long[treeDepth + 1];
+      rightNodePositions = new int[treeDepth];
+      splitDimsPos = new int[treeDepth];
+      negativeDeltas = new boolean[config.numIndexDims * treeDepth];
+      // information about the unbalance of the tree so we can report the exact size below a node
+      rightMostLeafNode = (1 << treeDepth - 1) - 1;
+      this.lastLeafNodePointCount =
+          lastLeafNodePointCount == 0 ? config.maxPointsInLeafNode : lastLeafNodePointCount;
+      // scratch objects, reused between clones so NN search are not creating those objects
+      // in every clone.
+      this.scratchIterator = scratchIterator;
+      this.commonPrefixLengths = commonPrefixLengths;
+      this.scratchDataPackedValue = scratchDataPackedValue;
+      this.scratchMinIndexPackedValue = scratchMinIndexPackedValue;
+      this.scratchMaxIndexPackedValue = scratchMaxIndexPackedValue;
+    }
+
+    @Override
+    public PointTree clone() {
+      BKDPointTree index =
+          new BKDPointTree(
+              innerNodes.clone(),
+              leafNodes.clone(),
+              config,
+              leafNodeOffset,
+              version,
+              lastLeafNodePointCount,
+              nodeID,
+              level,
+              minPackedValue,
+              maxPackedValue,
+              scratchIterator,
+              scratchDataPackedValue,
+              scratchMinIndexPackedValue,
+              scratchMaxIndexPackedValue,
+              commonPrefixLengths);
+      index.leafBlockFPStack[index.level] = leafBlockFPStack[level];
+      if (isLeafNode() == false) {
+        // copy node data
+        index.rightNodePositions[index.level] = rightNodePositions[level];
+        index.splitValuesStack[index.level] = splitValuesStack[level].clone();
+        System.arraycopy(
+            negativeDeltas,
+            level * config.numIndexDims,
+            index.negativeDeltas,
+            level * config.numIndexDims,
+            config.numIndexDims);
+        index.splitDimsPos[level] = splitDimsPos[level];
+      }
+      return index;
+    }
+
+    @Override
+    public byte[] getMinPackedValue() {
+      return minPackedValue;
+    }
+
+    @Override
+    public byte[] getMaxPackedValue() {
+      return maxPackedValue;
+    }
+
+    @Override
+    public boolean moveToChild() throws IOException {
+      if (isLeafNode()) {
+        return false;
+      }
+      pushBoundsLeft();
+      pushLeft();
+      return true;
     }
 
-    public void pushLeft() {
+    private void pushBoundsLeft() {
+      final int splitDimPos = splitDimsPos[level];
+      if (splitDimValueStack[level] == null) {
+        splitDimValueStack[level] = new byte[config.bytesPerDim];
+      }
+      // save the dimension we are going to change
+      System.arraycopy(
+          maxPackedValue, splitDimPos, splitDimValueStack[level], 0, config.bytesPerDim);
+      assert ArrayUtil.getUnsignedComparator(config.bytesPerDim)
+                  .compare(maxPackedValue, splitDimPos, splitValuesStack[level], splitDimPos)
+              >= 0
+          : "config.bytesPerDim="
+              + config.bytesPerDim
+              + " splitDimPos="
+              + splitDimsPos[level]
+              + " config.numIndexDims="
+              + config.numIndexDims
+              + " config.numDims="
+              + config.numDims;
+      // add the split dim value:
+      System.arraycopy(
+          splitValuesStack[level], splitDimPos, maxPackedValue, splitDimPos, config.bytesPerDim);
+    }
+
+    private void pushLeft() throws IOException {
       nodeID *= 2;
       level++;
       readNodeData(true);
     }
 
-    /** Clone, but you are not allowed to pop up past the point where the clone happened. */
-    @Override
-    public IndexTree clone() {
-      IndexTree index = new IndexTree(in.clone(), nodeID, level);
-      // copy node data
-      index.splitDim = splitDim;
-      index.leafBlockFPStack[level] = leafBlockFPStack[level];
-      index.rightNodePositions[level] = rightNodePositions[level];
-      index.splitValuesStack[index.level] = splitValuesStack[index.level].clone();
+    private void pushBoundsRight() {
+      final int splitDimPos = splitDimsPos[level];
+      // we should have already visited the left node
+      assert splitDimValueStack[level] != null;
+      // save the dimension we are going to change
       System.arraycopy(
-          negativeDeltas,
-          level * config.numIndexDims,
-          index.negativeDeltas,
-          level * config.numIndexDims,
-          config.numIndexDims);
-      index.splitDims[level] = splitDims[level];
-      return index;
+          minPackedValue, splitDimPos, splitDimValueStack[level], 0, config.bytesPerDim);
+      assert ArrayUtil.getUnsignedComparator(config.bytesPerDim)
+                  .compare(minPackedValue, splitDimPos, splitValuesStack[level], splitDimPos)
+              <= 0
+          : "config.bytesPerDim="
+              + config.bytesPerDim
+              + " splitDimPos="
+              + splitDimsPos[level]
+              + " config.numIndexDims="
+              + config.numIndexDims
+              + " config.numDims="
+              + config.numDims;
+      // add the split dim value:
+      System.arraycopy(
+          splitValuesStack[level], splitDimPos, minPackedValue, splitDimPos, config.bytesPerDim);
     }
 
-    public void pushRight() {
+    private void pushRight() throws IOException {
       final int nodePosition = rightNodePositions[level];
-      assert nodePosition >= in.getFilePointer()
-          : "nodePosition = " + nodePosition + " < currentPosition=" + in.getFilePointer();
-      nodeID = nodeID * 2 + 1;
+      assert nodePosition >= innerNodes.getFilePointer()
+          : "nodePosition = " + nodePosition + " < currentPosition=" + innerNodes.getFilePointer();
+      innerNodes.seek(nodePosition);
+      nodeID = 2 * nodeID + 1;
       level++;
-      try {
-        in.seek(nodePosition);
-      } catch (IOException e) {
-        throw new UncheckedIOException(e);
-      }
       readNodeData(false);
     }
 
-    public void pop() {
+    @Override
+    public boolean moveToSibling() throws IOException {
+      if (isLeftNode() == false || isRootNode()) {
+        return false;
+      }
+      pop();
+      popBounds(maxPackedValue);
+      pushBoundsRight();
+      pushRight();
+      assert nodeExists();
+      return true;
+    }
+
+    private void pop() {
       nodeID /= 2;
       level--;
-      splitDim = splitDims[level];
-      // System.out.println("  pop nodeID=" + nodeID);
     }
 
-    public boolean isLeafNode() {
-      return nodeID >= leafNodeOffset;
+    private void popBounds(byte[] packedValue) {
+      // restore the split dimension
+      System.arraycopy(
+          splitDimValueStack[level], 0, packedValue, splitDimsPos[level], config.bytesPerDim);
     }
 
-    public boolean nodeExists() {
-      return nodeID - leafNodeOffset < leafNodeOffset;
+    @Override
+    public boolean moveToParent() {
+      if (isRootNode()) {
+        return false;
+      }
+      final byte[] packedValue = isLeftNode() ? maxPackedValue : minPackedValue;
+      pop();
+      popBounds(packedValue);
+      return true;
     }
 
-    public int getNodeID() {
-      return nodeID;
+    private boolean isRootNode() {
+      return nodeID == nodeRoot;
     }
 
-    public byte[] getSplitPackedValue() {
-      assert isLeafNode() == false;
-      assert splitPackedValueStack[level] != null : "level=" + level;
-      return splitPackedValueStack[level];
+    private boolean isLeftNode() {
+      return (nodeID & 1) == 0;
     }
 
-    /** Only valid after pushLeft or pushRight, not pop! */
-    public int getSplitDim() {
-      assert isLeafNode() == false;
-      return splitDim;
+    private boolean isLeafNode() {
+      return nodeID >= leafNodeOffset;
     }
 
-    /** Only valid after pushLeft or pushRight, not pop! */
-    public BytesRef getSplitDimValue() {
-      assert isLeafNode() == false;
-      scratch.bytes = splitValuesStack[level];
-      scratch.offset = splitDim * config.bytesPerDim;
-      return scratch;
+    private boolean nodeExists() {
+      return nodeID - leafNodeOffset < leafNodeOffset;
     }
 
     /** Only valid after pushLeft or pushRight, not pop! */
-    public long getLeafBlockFP() {
+    private long getLeafBlockFP() {
       assert isLeafNode() : "nodeID=" + nodeID + " is not a leaf";
       return leafBlockFPStack[level];
     }
 
-    /** Return the number of leaves below the current node. */
-    public int getNumLeaves() {
+    @Override
+    public long size() {
       int leftMostLeafNode = nodeID;
       while (leftMostLeafNode < leafNodeOffset) {
         leftMostLeafNode = leftMostLeafNode * 2;
@@ -279,7 +437,91 @@ public final class BKDReader extends PointValues {
         numLeaves = rightMostLeafNode - leftMostLeafNode + 1 + leafNodeOffset;
       }
       assert numLeaves == getNumLeavesSlow(nodeID) : numLeaves + " " + getNumLeavesSlow(nodeID);
-      return numLeaves;
+      return rightMostLeafNode == this.rightMostLeafNode
+          ? (long) (numLeaves - 1) * config.maxPointsInLeafNode + lastLeafNodePointCount
+          : (long) numLeaves * config.maxPointsInLeafNode;
+    }
+
+    @Override
+    public void visitDocIDs(PointValues.IntersectVisitor visitor) throws IOException {
+      addAll(visitor, false);
+    }
+
+    public void addAll(PointValues.IntersectVisitor visitor, boolean grown) throws IOException {
+      if (grown == false) {
+        final long size = size();
+        if (size <= Integer.MAX_VALUE) {
+          visitor.grow((int) size);
+          grown = true;
+        }
+      }
+      if (isLeafNode()) {
+        // Leaf node
+        leafNodes.seek(getLeafBlockFP());
+        // How many points are stored in this leaf cell:
+        int count = leafNodes.readVInt();
+        // No need to call grow(), it has been called up-front
+        DocIdsWriter.readInts(leafNodes, count, visitor);
+      } else {
+        pushLeft();
+        addAll(visitor, grown);
+        pop();
+        pushRight();
+        addAll(visitor, grown);
+        pop();
+      }
+    }
+
+    @Override
+    public void visitDocValues(PointValues.IntersectVisitor visitor) throws IOException {
+      if (isLeafNode()) {
+        // Leaf node
+        visitDocValues(visitor, getLeafBlockFP());
+      } else {
+        pushLeft();
+        visitDocValues(visitor);
+        pop();
+        pushRight();
+        visitDocValues(visitor);
+        pop();
+      }
+    }
+
+    private void visitDocValues(PointValues.IntersectVisitor visitor, long fp) throws IOException {
+      // Leaf node; scan and filter all points in this block:
+      int count = readDocIDs(leafNodes, fp, scratchIterator);
+      if (version >= BKDWriter.VERSION_LOW_CARDINALITY_LEAVES) {
+        visitDocValuesWithCardinality(
+            commonPrefixLengths,
+            scratchDataPackedValue,
+            scratchMinIndexPackedValue,
+            scratchMaxIndexPackedValue,
+            leafNodes,
+            scratchIterator,
+            count,
+            visitor);
+      } else {
+        visitDocValuesNoCardinality(
+            commonPrefixLengths,
+            scratchDataPackedValue,
+            scratchMinIndexPackedValue,
+            scratchMaxIndexPackedValue,
+            leafNodes,
+            scratchIterator,
+            count,
+            visitor);
+      }
+    }
+
+    private int readDocIDs(IndexInput in, long blockFP, BKDReaderDocIDSetIterator iterator)
+        throws IOException {
+      in.seek(blockFP);
+      // How many points are stored in this leaf cell:
+      int count = in.readVInt();
+
+      DocIdsWriter.readInts(in, count, iterator.docIDs);
+
+      return count;
     }
 
     // for assertions
@@ -295,309 +537,89 @@ public final class BKDReader extends PointValues {
       }
     }
 
-    private void readNodeData(boolean isLeft) {
-      if (splitPackedValueStack[level] == null) {
-        splitPackedValueStack[level] = new byte[config.packedIndexBytesLength];
-      }
-      System.arraycopy(
-          negativeDeltas,
-          (level - 1) * config.numIndexDims,
-          negativeDeltas,
-          level * config.numIndexDims,
-          config.numIndexDims);
-      assert splitDim != -1;
-      negativeDeltas[level * config.numIndexDims + splitDim] = isLeft;
-
-      try {
-        leafBlockFPStack[level] = leafBlockFPStack[level - 1];
-
+    private void readNodeData(boolean isLeft) throws IOException {
+      leafBlockFPStack[level] = leafBlockFPStack[level - 1];
+      if (isLeft == false) {
         // read leaf block FP delta
-        if (isLeft == false) {
-          leafBlockFPStack[level] += in.readVLong();
-        }
+        leafBlockFPStack[level] += innerNodes.readVLong();
+      }
 
-        if (isLeafNode()) {
-          splitDim = -1;
+      if (isLeafNode() == false) {
+        System.arraycopy(
+            negativeDeltas,
+            (level - 1) * config.numIndexDims,
+            negativeDeltas,
+            level * config.numIndexDims,
+            config.numIndexDims);
+        negativeDeltas[
+                level * config.numIndexDims + (splitDimsPos[level - 1] / config.bytesPerDim)] =
+            isLeft;
+
+        if (splitValuesStack[level] == null) {
+          splitValuesStack[level] = splitValuesStack[level - 1].clone();
         } else {
-
-          // read split dim, prefix, firstDiffByteDelta encoded as int:
-          int code = in.readVInt();
-          splitDim = code % config.numIndexDims;
-          splitDims[level] = splitDim;
-          code /= config.numIndexDims;
-          int prefix = code % (1 + config.bytesPerDim);
-          int suffix = config.bytesPerDim - prefix;
-
-          if (splitValuesStack[level] == null) {
-            splitValuesStack[level] = new byte[config.packedIndexBytesLength];
-          }
           System.arraycopy(
               splitValuesStack[level - 1],
               0,
               splitValuesStack[level],
               0,
               config.packedIndexBytesLength);
-          if (suffix > 0) {
-            int firstDiffByteDelta = code / (1 + config.bytesPerDim);
-            if (negativeDeltas[level * config.numIndexDims + splitDim]) {
-              firstDiffByteDelta = -firstDiffByteDelta;
-            }
-            int oldByte = splitValuesStack[level][splitDim * config.bytesPerDim + prefix] & 0xFF;
-            splitValuesStack[level][splitDim * config.bytesPerDim + prefix] =
-                (byte) (oldByte + firstDiffByteDelta);
-            in.readBytes(
-                splitValuesStack[level], splitDim * config.bytesPerDim + prefix + 1, suffix - 1);
-          } else {
-            // our split value is == last split value in this dim, which can happen when there are
-            // many duplicate values
-          }
+        }
 
-          int leftNumBytes;
-          if (nodeID * 2 < leafNodeOffset) {
-            leftNumBytes = in.readVInt();
-          } else {
-            leftNumBytes = 0;
+        // read split dim, prefix, firstDiffByteDelta encoded as int:
+        int code = innerNodes.readVInt();
+        final int splitDim = code % config.numIndexDims;
+        splitDimsPos[level] = splitDim * config.bytesPerDim;
+        code /= config.numIndexDims;
+        final int prefix = code % (1 + config.bytesPerDim);
+        final int suffix = config.bytesPerDim - prefix;
+
+        if (suffix > 0) {
+          int firstDiffByteDelta = code / (1 + config.bytesPerDim);
+          if (negativeDeltas[level * config.numIndexDims + splitDim]) {
+            firstDiffByteDelta = -firstDiffByteDelta;
           }
-          rightNodePositions[level] = Math.toIntExact(in.getFilePointer()) + leftNumBytes;
+          final int startPos = splitDimsPos[level] + prefix;
+          final int oldByte = splitValuesStack[level][startPos] & 0xFF;
+          splitValuesStack[level][startPos] = (byte) (oldByte + firstDiffByteDelta);
+          innerNodes.readBytes(splitValuesStack[level], startPos + 1, suffix - 1);
+        } else {
+          // our split value is == last split value in this dim, which can happen when there are
+          // many duplicate values
         }
-      } catch (IOException e) {
-        throw new UncheckedIOException(e);
-      }
-    }
-  }
-
-  private int getTreeDepth() {
-    // First +1 because all the non-leave nodes makes another power
-    // of 2; e.g. to have a fully balanced tree with 4 leaves you
-    // need a depth=3 tree:
-
-    // Second +1 because MathUtil.log computes floor of the logarithm; e.g.
-    // with 5 leaves you need a depth=4 tree:
-    return MathUtil.log(numLeaves, 2) + 2;
-  }
-
-  /** Used to track all state for a single call to {@link #intersect}. */
-  public static final class IntersectState {
-    final IndexInput in;
-    final BKDReaderDocIDSetIterator scratchIterator;
-    final byte[] scratchDataPackedValue, scratchMinIndexPackedValue, scratchMaxIndexPackedValue;
-    final int[] commonPrefixLengths;
-
-    final IntersectVisitor visitor;
-    public final IndexTree index;
-
-    public IntersectState(
-        IndexInput in, BKDConfig config, IntersectVisitor visitor, IndexTree indexVisitor) {
-      this.in = in;
-      this.visitor = visitor;
-      this.commonPrefixLengths = new int[config.numDims];
-      this.scratchIterator = new BKDReaderDocIDSetIterator(config.maxPointsInLeafNode);
-      this.scratchDataPackedValue = new byte[config.packedBytesLength];
-      this.scratchMinIndexPackedValue = new byte[config.packedIndexBytesLength];
-      this.scratchMaxIndexPackedValue = new byte[config.packedIndexBytesLength];
-      this.index = indexVisitor;
-    }
-  }
-
-  @Override
-  public void intersect(IntersectVisitor visitor) throws IOException {
-    intersect(getIntersectState(visitor), minPackedValue, maxPackedValue);
-  }
-
-  @Override
-  public long estimatePointCount(IntersectVisitor visitor) {
-    return estimatePointCount(getIntersectState(visitor), minPackedValue, maxPackedValue);
-  }
 
-  /** Fast path: this is called when the query box fully encompasses all cells under this node. */
-  private void addAll(IntersectState state, boolean grown) throws IOException {
-    // System.out.println("R: addAll nodeID=" + nodeID);
-
-    if (grown == false) {
-      final long maxPointCount = (long) config.maxPointsInLeafNode * state.index.getNumLeaves();
-      if (maxPointCount
-          <= Integer.MAX_VALUE) { // could be >MAX_VALUE if there are more than 2B points in total
-        state.visitor.grow((int) maxPointCount);
-        grown = true;
-      }
-    }
-
-    if (state.index.isLeafNode()) {
-      assert grown;
-      // System.out.println("ADDALL");
-      if (state.index.nodeExists()) {
-        visitDocIDs(state.in, state.index.getLeafBlockFP(), state.visitor);
-      }
-      // TODO: we can assert that the first value here in fact matches what the index claimed?
-    } else {
-      state.index.pushLeft();
-      addAll(state, grown);
-      state.index.pop();
-
-      state.index.pushRight();
-      addAll(state, grown);
-      state.index.pop();
-    }
-  }
-
-  /** Create a new {@link IntersectState} */
-  public IntersectState getIntersectState(IntersectVisitor visitor) {
-    IndexTree index = new IndexTree();
-    return new IntersectState(in.clone(), config, visitor, index);
-  }
-
-  /** Visits all docIDs and packed values in a single leaf block */
-  public void visitLeafBlockValues(IndexTree index, IntersectState state) throws IOException {
-
-    // Leaf node; scan and filter all points in this block:
-    int count = readDocIDs(state.in, index.getLeafBlockFP(), state.scratchIterator);
-
-    // Again, this time reading values and checking with the visitor
-    visitDocValues(
-        state.commonPrefixLengths,
-        state.scratchDataPackedValue,
-        state.scratchMinIndexPackedValue,
-        state.scratchMaxIndexPackedValue,
-        state.in,
-        state.scratchIterator,
-        count,
-        state.visitor);
-  }
-
-  private void visitDocIDs(IndexInput in, long blockFP, IntersectVisitor visitor)
-      throws IOException {
-    // Leaf node
-    in.seek(blockFP);
-
-    // How many points are stored in this leaf cell:
-    int count = in.readVInt();
-    // No need to call grow(), it has been called up-front
-
-    DocIdsWriter.readInts(in, count, visitor);
-  }
-
-  int readDocIDs(IndexInput in, long blockFP, BKDReaderDocIDSetIterator iterator)
-      throws IOException {
-    in.seek(blockFP);
-
-    // How many points are stored in this leaf cell:
-    int count = in.readVInt();
-
-    DocIdsWriter.readInts(in, count, iterator.docIDs);
-
-    return count;
-  }
-
-  void visitDocValues(
-      int[] commonPrefixLengths,
-      byte[] scratchDataPackedValue,
-      byte[] scratchMinIndexPackedValue,
-      byte[] scratchMaxIndexPackedValue,
-      IndexInput in,
-      BKDReaderDocIDSetIterator scratchIterator,
-      int count,
-      IntersectVisitor visitor)
-      throws IOException {
-    if (version >= BKDWriter.VERSION_LOW_CARDINALITY_LEAVES) {
-      visitDocValuesWithCardinality(
-          commonPrefixLengths,
-          scratchDataPackedValue,
-          scratchMinIndexPackedValue,
-          scratchMaxIndexPackedValue,
-          in,
-          scratchIterator,
-          count,
-          visitor);
-    } else {
-      visitDocValuesNoCardinality(
-          commonPrefixLengths,
-          scratchDataPackedValue,
-          scratchMinIndexPackedValue,
-          scratchMaxIndexPackedValue,
-          in,
-          scratchIterator,
-          count,
-          visitor);
-    }
-  }
-
-  void visitDocValuesNoCardinality(
-      int[] commonPrefixLengths,
-      byte[] scratchDataPackedValue,
-      byte[] scratchMinIndexPackedValue,
-      byte[] scratchMaxIndexPackedValue,
-      IndexInput in,
-      BKDReaderDocIDSetIterator scratchIterator,
-      int count,
-      IntersectVisitor visitor)
-      throws IOException {
-    readCommonPrefixes(commonPrefixLengths, scratchDataPackedValue, in);
-
-    if (config.numIndexDims != 1 && version >= BKDWriter.VERSION_LEAF_STORES_BOUNDS) {
-      byte[] minPackedValue = scratchMinIndexPackedValue;
-      System.arraycopy(scratchDataPackedValue, 0, minPackedValue, 0, config.packedIndexBytesLength);
-      byte[] maxPackedValue = scratchMaxIndexPackedValue;
-      // Copy common prefixes before reading adjusted box
-      System.arraycopy(minPackedValue, 0, maxPackedValue, 0, config.packedIndexBytesLength);
-      readMinMax(commonPrefixLengths, minPackedValue, maxPackedValue, in);
-
-      // The index gives us range of values for each dimension, but the actual range of values
-      // might be much more narrow than what the index told us, so we double check the relation
-      // here, which is cheap yet might help figure out that the block either entirely matches
-      // or does not match at all. This is especially more likely in the case that there are
-      // multiple dimensions that have correlation, ie. splitting on one dimension also
-      // significantly changes the range of values in another dimension.
-      Relation r = visitor.compare(minPackedValue, maxPackedValue);
-      if (r == Relation.CELL_OUTSIDE_QUERY) {
-        return;
-      }
-      visitor.grow(count);
-
-      if (r == Relation.CELL_INSIDE_QUERY) {
-        for (int i = 0; i < count; ++i) {
-          visitor.visit(scratchIterator.docIDs[i]);
+        final int leftNumBytes;
+        if (nodeID * 2 < leafNodeOffset) {
+          leftNumBytes = innerNodes.readVInt();
+        } else {
+          leftNumBytes = 0;
         }
-        return;
+        rightNodePositions[level] = Math.toIntExact(innerNodes.getFilePointer()) + leftNumBytes;
       }
-    } else {
-      visitor.grow(count);
     }
 
-    int compressedDim = readCompressedDim(in);
-
-    if (compressedDim == -1) {
-      visitUniqueRawDocValues(scratchDataPackedValue, scratchIterator, count, visitor);
-    } else {
-      visitCompressedDocValues(
-          commonPrefixLengths,
-          scratchDataPackedValue,
-          in,
-          scratchIterator,
-          count,
-          visitor,
-          compressedDim);
-    }
-  }
-
-  void visitDocValuesWithCardinality(
-      int[] commonPrefixLengths,
-      byte[] scratchDataPackedValue,
-      byte[] scratchMinIndexPackedValue,
-      byte[] scratchMaxIndexPackedValue,
-      IndexInput in,
-      BKDReaderDocIDSetIterator scratchIterator,
-      int count,
-      IntersectVisitor visitor)
-      throws IOException {
-
-    readCommonPrefixes(commonPrefixLengths, scratchDataPackedValue, in);
-    int compressedDim = readCompressedDim(in);
-    if (compressedDim == -1) {
-      // all values are the same
-      visitor.grow(count);
-      visitUniqueRawDocValues(scratchDataPackedValue, scratchIterator, count, visitor);
-    } else {
-      if (config.numIndexDims != 1) {
+    private int getTreeDepth(int numLeaves) {
+      // First +1 because all the non-leave nodes makes another power
+      // of 2; e.g. to have a fully balanced tree with 4 leaves you
+      // need a depth=3 tree:
+
+      // Second +1 because MathUtil.log computes floor of the logarithm; e.g.
+      // with 5 leaves you need a depth=4 tree:
+      return MathUtil.log(numLeaves, 2) + 2;
+    }
+
+    private void visitDocValuesNoCardinality(
+        int[] commonPrefixLengths,
+        byte[] scratchDataPackedValue,
+        byte[] scratchMinIndexPackedValue,
+        byte[] scratchMaxIndexPackedValue,
+        IndexInput in,
+        BKDReaderDocIDSetIterator scratchIterator,
+        int count,
+        PointValues.IntersectVisitor visitor)
+        throws IOException {
+      readCommonPrefixes(commonPrefixLengths, scratchDataPackedValue, in);
+      if (config.numIndexDims != 1 && version >= BKDWriter.VERSION_LEAF_STORES_BOUNDS) {
         byte[] minPackedValue = scratchMinIndexPackedValue;
         System.arraycopy(
             scratchDataPackedValue, 0, minPackedValue, 0, config.packedIndexBytesLength);
@@ -612,13 +634,12 @@ public final class BKDReader extends PointValues {
         // or does not match at all. This is especially more likely in the case that there are
         // multiple dimensions that have correlation, ie. splitting on one dimension also
         // significantly changes the range of values in another dimension.
-        Relation r = visitor.compare(minPackedValue, maxPackedValue);
-        if (r == Relation.CELL_OUTSIDE_QUERY) {
+        PointValues.Relation r = visitor.compare(minPackedValue, maxPackedValue);
+        if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) {
           return;
         }
         visitor.grow(count);
-
-        if (r == Relation.CELL_INSIDE_QUERY) {
+        if (r == PointValues.Relation.CELL_INSIDE_QUERY) {
           for (int i = 0; i < count; ++i) {
             visitor.visit(scratchIterator.docIDs[i]);
           }
@@ -627,12 +648,12 @@ public final class BKDReader extends PointValues {
       } else {
         visitor.grow(count);
       }
-      if (compressedDim == -2) {
-        // low cardinality values
-        visitSparseRawDocValues(
-            commonPrefixLengths, scratchDataPackedValue, in, scratchIterator, count, visitor);
+
+      int compressedDim = readCompressedDim(in);
+
+      if (compressedDim == -1) {
+        visitUniqueRawDocValues(scratchDataPackedValue, scratchIterator, count, visitor);
       } else {
-        // high cardinality
         visitCompressedDocValues(
             commonPrefixLengths,
             scratchDataPackedValue,
@@ -643,335 +664,183 @@ public final class BKDReader extends PointValues {
             compressedDim);
       }
     }
-  }
 
-  private void readMinMax(
-      int[] commonPrefixLengths, byte[] minPackedValue, byte[] maxPackedValue, IndexInput in)
-      throws IOException {
-    for (int dim = 0; dim < config.numIndexDims; dim++) {
-      int prefix = commonPrefixLengths[dim];
-      in.readBytes(minPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
-      in.readBytes(maxPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
+    private void visitDocValuesWithCardinality(
+        int[] commonPrefixLengths,
+        byte[] scratchDataPackedValue,
+        byte[] scratchMinIndexPackedValue,
+        byte[] scratchMaxIndexPackedValue,
+        IndexInput in,
+        BKDReaderDocIDSetIterator scratchIterator,
+        int count,
+        PointValues.IntersectVisitor visitor)
+        throws IOException {
+
+      readCommonPrefixes(commonPrefixLengths, scratchDataPackedValue, in);
+      int compressedDim = readCompressedDim(in);
+      if (compressedDim == -1) {
+        // all values are the same
+        visitor.grow(count);
+        visitUniqueRawDocValues(scratchDataPackedValue, scratchIterator, count, visitor);
+      } else {
+        if (config.numIndexDims != 1) {
+          byte[] minPackedValue = scratchMinIndexPackedValue;
+          System.arraycopy(
+              scratchDataPackedValue, 0, minPackedValue, 0, config.packedIndexBytesLength);
+          byte[] maxPackedValue = scratchMaxIndexPackedValue;
+          // Copy common prefixes before reading adjusted box
+          System.arraycopy(minPackedValue, 0, maxPackedValue, 0, config.packedIndexBytesLength);
+          readMinMax(commonPrefixLengths, minPackedValue, maxPackedValue, in);
+
+          // The index gives us range of values for each dimension, but the actual range of values
+          // might be much more narrow than what the index told us, so we double check the relation
+          // here, which is cheap yet might help figure out that the block either entirely matches
+          // or does not match at all. This is especially more likely in the case that there are
+          // multiple dimensions that have correlation, ie. splitting on one dimension also
+          // significantly changes the range of values in another dimension.
+          PointValues.Relation r = visitor.compare(minPackedValue, maxPackedValue);
+          if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) {
+            return;
+          }
+          visitor.grow(count);
+
+          if (r == PointValues.Relation.CELL_INSIDE_QUERY) {
+            for (int i = 0; i < count; ++i) {
+              visitor.visit(scratchIterator.docIDs[i]);
+            }
+            return;
+          }
+        } else {
+          visitor.grow(count);
+        }
+
+        if (compressedDim == -2) {
+          // low cardinality values
+          visitSparseRawDocValues(
+              commonPrefixLengths, scratchDataPackedValue, in, scratchIterator, count, visitor);
+        } else {
+          // high cardinality
+          visitCompressedDocValues(
+              commonPrefixLengths,
+              scratchDataPackedValue,
+              in,
+              scratchIterator,
+              count,
+              visitor,
+              compressedDim);
+        }
+      }
     }
-  }
 
-  // read cardinality and point
-  private void visitSparseRawDocValues(
-      int[] commonPrefixLengths,
-      byte[] scratchPackedValue,
-      IndexInput in,
-      BKDReaderDocIDSetIterator scratchIterator,
-      int count,
-      IntersectVisitor visitor)
-      throws IOException {
-    int i;
-    for (i = 0; i < count; ) {
-      int length = in.readVInt();
-      for (int dim = 0; dim < config.numDims; dim++) {
+    private void readMinMax(
+        int[] commonPrefixLengths, byte[] minPackedValue, byte[] maxPackedValue, IndexInput in)
+        throws IOException {
+      for (int dim = 0; dim < config.numIndexDims; dim++) {
         int prefix = commonPrefixLengths[dim];
         in.readBytes(
-            scratchPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
+            minPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
+        in.readBytes(
+            maxPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
       }
-      scratchIterator.reset(i, length);
-      visitor.visit(scratchIterator, scratchPackedValue);
-      i += length;
-    }
-    if (i != count) {
-      throw new CorruptIndexException(
-          "Sub blocks do not add up to the expected count: " + count + " != " + i, in);
     }
-  }
-
-  // point is under commonPrefix
-  private void visitUniqueRawDocValues(
-      byte[] scratchPackedValue,
-      BKDReaderDocIDSetIterator scratchIterator,
-      int count,
-      IntersectVisitor visitor)
-      throws IOException {
-    scratchIterator.reset(0, count);
-    visitor.visit(scratchIterator, scratchPackedValue);
-  }
 
-  private void visitCompressedDocValues(
-      int[] commonPrefixLengths,
-      byte[] scratchPackedValue,
-      IndexInput in,
-      BKDReaderDocIDSetIterator scratchIterator,
-      int count,
-      IntersectVisitor visitor,
-      int compressedDim)
-      throws IOException {
-    // the byte at `compressedByteOffset` is compressed using run-length compression,
-    // other suffix bytes are stored verbatim
-    final int compressedByteOffset =
-        compressedDim * config.bytesPerDim + commonPrefixLengths[compressedDim];
-    commonPrefixLengths[compressedDim]++;
-    int i;
-    for (i = 0; i < count; ) {
-      scratchPackedValue[compressedByteOffset] = in.readByte();
-      final int runLen = Byte.toUnsignedInt(in.readByte());
-      for (int j = 0; j < runLen; ++j) {
+    // read cardinality and point
+    private void visitSparseRawDocValues(
+        int[] commonPrefixLengths,
+        byte[] scratchPackedValue,
+        IndexInput in,
+        BKDReaderDocIDSetIterator scratchIterator,
+        int count,
+        PointValues.IntersectVisitor visitor)
+        throws IOException {
+      int i;
+      for (i = 0; i < count; ) {
+        int length = in.readVInt();
         for (int dim = 0; dim < config.numDims; dim++) {
           int prefix = commonPrefixLengths[dim];
           in.readBytes(
               scratchPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
         }
-        visitor.visit(scratchIterator.docIDs[i + j], scratchPackedValue);
+        scratchIterator.reset(i, length);
+        visitor.visit(scratchIterator, scratchPackedValue);
+        i += length;
+      }
+      if (i != count) {
+        throw new CorruptIndexException(
+            "Sub blocks do not add up to the expected count: " + count + " != " + i, in);
       }
-      i += runLen;
-    }
-    if (i != count) {
-      throw new CorruptIndexException(
-          "Sub blocks do not add up to the expected count: " + count + " != " + i, in);
     }
-  }
 
-  private int readCompressedDim(IndexInput in) throws IOException {
-    int compressedDim = in.readByte();
-    if (compressedDim < -2
-        || compressedDim >= config.numDims
-        || (version < BKDWriter.VERSION_LOW_CARDINALITY_LEAVES && compressedDim == -2)) {
-      throw new CorruptIndexException("Got compressedDim=" + compressedDim, in);
+    // point is under commonPrefix
+    private void visitUniqueRawDocValues(
+        byte[] scratchPackedValue,
+        BKDReaderDocIDSetIterator scratchIterator,
+        int count,
+        PointValues.IntersectVisitor visitor)
+        throws IOException {
+      scratchIterator.reset(0, count);
+      visitor.visit(scratchIterator, scratchPackedValue);
     }
-    return compressedDim;
-  }
 
-  private void readCommonPrefixes(
-      int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in) throws IOException {
-    for (int dim = 0; dim < config.numDims; dim++) {
-      int prefix = in.readVInt();
-      commonPrefixLengths[dim] = prefix;
-      if (prefix > 0) {
-        in.readBytes(scratchPackedValue, dim * config.bytesPerDim, prefix);
+    private void visitCompressedDocValues(
+        int[] commonPrefixLengths,
+        byte[] scratchPackedValue,
+        IndexInput in,
+        BKDReaderDocIDSetIterator scratchIterator,
+        int count,
+        PointValues.IntersectVisitor visitor,
+        int compressedDim)
+        throws IOException {
+      // the byte at `compressedByteOffset` is compressed using run-length compression,
+      // other suffix bytes are stored verbatim
+      final int compressedByteOffset =
+          compressedDim * config.bytesPerDim + commonPrefixLengths[compressedDim];
+      commonPrefixLengths[compressedDim]++;
+      int i;
+      for (i = 0; i < count; ) {
+        scratchPackedValue[compressedByteOffset] = in.readByte();
+        final int runLen = Byte.toUnsignedInt(in.readByte());
+        for (int j = 0; j < runLen; ++j) {
+          for (int dim = 0; dim < config.numDims; dim++) {
+            int prefix = commonPrefixLengths[dim];
+            in.readBytes(
+                scratchPackedValue, dim * config.bytesPerDim + prefix, config.bytesPerDim - prefix);
+          }
+          visitor.visit(scratchIterator.docIDs[i + j], scratchPackedValue);
+        }
+        i += runLen;
+      }
+      if (i != count) {
+        throw new CorruptIndexException(
+            "Sub blocks do not add up to the expected count: " + count + " != " + i, in);
       }
-      // System.out.println("R: " + dim + " of " + numDims + " prefix=" + prefix);
     }
-  }
 
-  private void intersect(IntersectState state, byte[] cellMinPacked, byte[] cellMaxPacked)
-      throws IOException {
-
-    /*
-    System.out.println("\nR: intersect nodeID=" + state.index.getNodeID());
-    for(int dim=0;dim<numDims;dim++) {
-      System.out.println("  dim=" + dim + "\n    cellMin=" + new BytesRef(cellMinPacked, dim*config.bytesPerDim, config.bytesPerDim) + "\n    cellMax=" + new BytesRef(cellMaxPacked, dim*config.bytesPerDim, config.bytesPerDim));
-    }
-    */
-
-    Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
-
-    if (r == Relation.CELL_OUTSIDE_QUERY) {
-      // This cell is fully outside of the query shape: stop recursing
-    } else if (r == Relation.CELL_INSIDE_QUERY) {
-      // This cell is fully inside of the query shape: recursively add all points in this cell
-      // without filtering
-      addAll(state, false);
-      // The cell crosses the shape boundary, or the cell fully contains the query, so we fall
-      // through and do full filtering:
-    } else if (state.index.isLeafNode()) {
-
-      // TODO: we can assert that the first value here in fact matches what the index claimed?
-
-      // In the unbalanced case it's possible the left most node only has one child:
-      if (state.index.nodeExists()) {
-        // Leaf node; scan and filter all points in this block:
-        int count = readDocIDs(state.in, state.index.getLeafBlockFP(), state.scratchIterator);
-
-        // Again, this time reading values and checking with the visitor
-        visitDocValues(
-            state.commonPrefixLengths,
-            state.scratchDataPackedValue,
-            state.scratchMinIndexPackedValue,
-            state.scratchMaxIndexPackedValue,
-            state.in,
-            state.scratchIterator,
-            count,
-            state.visitor);
+    private int readCompressedDim(IndexInput in) throws IOException {
+      int compressedDim = in.readByte();
+      if (compressedDim < -2
+          || compressedDim >= config.numDims
+          || (version < BKDWriter.VERSION_LOW_CARDINALITY_LEAVES && compressedDim == -2)) {
+        throw new CorruptIndexException("Got compressedDim=" + compressedDim, in);
       }
-
-    } else {
-
-      // Non-leaf node: recurse on the split left and right nodes
-      int splitDim = state.index.getSplitDim();
-      assert splitDim >= 0
-          : "splitDim=" + splitDim + ", config.numIndexDims=" + config.numIndexDims;
-      assert splitDim < config.numIndexDims
-          : "splitDim=" + splitDim + ", config.numIndexDims=" + config.numIndexDims;
-
-      byte[] splitPackedValue = state.index.getSplitPackedValue();
-      BytesRef splitDimValue = state.index.getSplitDimValue();
-      assert splitDimValue.length == config.bytesPerDim;
-      // System.out.println("  splitDimValue=" + splitDimValue + " splitDim=" + splitDim);
-
-      // make sure cellMin <= splitValue <= cellMax:
-      assert Arrays.compareUnsigned(
-                  cellMinPacked,
-                  splitDim * config.bytesPerDim,
-                  splitDim * config.bytesPerDim + config.bytesPerDim,
-                  splitDimValue.bytes,
-                  splitDimValue.offset,
-                  splitDimValue.offset + config.bytesPerDim)
-              <= 0
-          : "config.bytesPerDim="
-              + config.bytesPerDim
-              + " splitDim="
-              + splitDim
-              + " config.numIndexDims="
-              + config.numIndexDims
-              + " config.numDims="
-              + config.numDims;
-      assert Arrays.compareUnsigned(
-                  cellMaxPacked,
-                  splitDim * config.bytesPerDim,
-                  splitDim * config.bytesPerDim + config.bytesPerDim,
-                  splitDimValue.bytes,
-                  splitDimValue.offset,
-                  splitDimValue.offset + config.bytesPerDim)
-              >= 0
-          : "config.bytesPerDim="
-              + config.bytesPerDim
-              + " splitDim="
-              + splitDim
-              + " config.numIndexDims="
-              + config.numIndexDims
-              + " config.numDims="
-              + config.numDims;
-
-      // Recurse on left sub-tree:
-      System.arraycopy(cellMaxPacked, 0, splitPackedValue, 0, config.packedIndexBytesLength);
-      System.arraycopy(
-          splitDimValue.bytes,
-          splitDimValue.offset,
-          splitPackedValue,
-          splitDim * config.bytesPerDim,
-          config.bytesPerDim);
-      state.index.pushLeft();
-      intersect(state, cellMinPacked, splitPackedValue);
-      state.index.pop();
-
-      // Restore the split dim value since it may have been overwritten while recursing:
-      System.arraycopy(
-          splitPackedValue,
-          splitDim * config.bytesPerDim,
-          splitDimValue.bytes,
-          splitDimValue.offset,
-          config.bytesPerDim);
-
-      // Recurse on right sub-tree:
-      System.arraycopy(cellMinPacked, 0, splitPackedValue, 0, config.packedIndexBytesLength);
-      System.arraycopy(
-          splitDimValue.bytes,
-          splitDimValue.offset,
-          splitPackedValue,
-          splitDim * config.bytesPerDim,
-          config.bytesPerDim);
-      state.index.pushRight();
-      intersect(state, splitPackedValue, cellMaxPacked);
-      state.index.pop();
+      return compressedDim;
     }
-  }
-
-  private long estimatePointCount(
-      IntersectState state, byte[] cellMinPacked, byte[] cellMaxPacked) {
 
-    /*
-    System.out.println("\nR: intersect nodeID=" + state.index.getNodeID());
-    for(int dim=0;dim<numDims;dim++) {
-      System.out.println("  dim=" + dim + "\n    cellMin=" + new BytesRef(cellMinPacked, dim*config.bytesPerDim, config.bytesPerDim) + "\n    cellMax=" + new BytesRef(cellMaxPacked, dim*config.bytesPerDim, config.bytesPerDim));
+    private void readCommonPrefixes(
+        int[] commonPrefixLengths, byte[] scratchPackedValue, IndexInput in) throws IOException {
+      for (int dim = 0; dim < config.numDims; dim++) {
+        int prefix = in.readVInt();
+        commonPrefixLengths[dim] = prefix;
+        if (prefix > 0) {
+          in.readBytes(scratchPackedValue, dim * config.bytesPerDim, prefix);
+        }
+        // System.out.println("R: " + dim + " of " + numDims + " prefix=" + prefix);
+      }
     }
-    */
-
-    Relation r = state.visitor.compare(cellMinPacked, cellMaxPacked);
-
-    if (r == Relation.CELL_OUTSIDE_QUERY) {
-      // This cell is fully outside of the query shape: stop recursing
-      return 0L;
-    } else if (r == Relation.CELL_INSIDE_QUERY) {
-      return (long) config.maxPointsInLeafNode * state.index.getNumLeaves();
-    } else if (state.index.isLeafNode()) {
-      // Assume half the points matched
-      return (config.maxPointsInLeafNode + 1) / 2;
-    } else {
 
-      // Non-leaf node: recurse on the split left and right nodes
-      int splitDim = state.index.getSplitDim();
-      assert splitDim >= 0
-          : "splitDim=" + splitDim + ", config.numIndexDims=" + config.numIndexDims;
-      assert splitDim < config.numIndexDims
-          : "splitDim=" + splitDim + ", config.numIndexDims=" + config.numIndexDims;
-
-      byte[] splitPackedValue = state.index.getSplitPackedValue();
-      BytesRef splitDimValue = state.index.getSplitDimValue();
-      assert splitDimValue.length == config.bytesPerDim;
-      // System.out.println("  splitDimValue=" + splitDimValue + " splitDim=" + splitDim);
-
-      // make sure cellMin <= splitValue <= cellMax:
-      assert Arrays.compareUnsigned(
-                  cellMinPacked,
-                  splitDim * config.bytesPerDim,
-                  splitDim * config.bytesPerDim + config.bytesPerDim,
-                  splitDimValue.bytes,
-                  splitDimValue.offset,
-                  splitDimValue.offset + config.bytesPerDim)
-              <= 0
-          : "config.bytesPerDim="
-              + config.bytesPerDim
-              + " splitDim="
-              + splitDim
-              + " config.numIndexDims="
-              + config.numIndexDims
-              + " config.numDims="
-              + config.numDims;
-      assert Arrays.compareUnsigned(
-                  cellMaxPacked,
-                  splitDim * config.bytesPerDim,
-                  splitDim * config.bytesPerDim + config.bytesPerDim,
-                  splitDimValue.bytes,
-                  splitDimValue.offset,
-                  splitDimValue.offset + config.bytesPerDim)
-              >= 0
-          : "config.bytesPerDim="
-              + config.bytesPerDim
-              + " splitDim="
-              + splitDim
-              + " config.numIndexDims="
-              + config.numIndexDims
-              + " config.numDims="
-              + config.numDims;
-
-      // Recurse on left sub-tree:
-      System.arraycopy(cellMaxPacked, 0, splitPackedValue, 0, config.packedIndexBytesLength);
-      System.arraycopy(
-          splitDimValue.bytes,
-          splitDimValue.offset,
-          splitPackedValue,
-          splitDim * config.bytesPerDim,
-          config.bytesPerDim);
-      state.index.pushLeft();
-      final long leftCost = estimatePointCount(state, cellMinPacked, splitPackedValue);
-      state.index.pop();
-
-      // Restore the split dim value since it may have been overwritten while recursing:
-      System.arraycopy(
-          splitPackedValue,
-          splitDim * config.bytesPerDim,
-          splitDimValue.bytes,
-          splitDimValue.offset,
-          config.bytesPerDim);
-
-      // Recurse on right sub-tree:
-      System.arraycopy(cellMinPacked, 0, splitPackedValue, 0, config.packedIndexBytesLength);
-      System.arraycopy(
-          splitDimValue.bytes,
-          splitDimValue.offset,
-          splitPackedValue,
-          splitDim * config.bytesPerDim,
-          config.bytesPerDim);
-      state.index.pushRight();
-      final long rightCost = estimatePointCount(state, splitPackedValue, cellMaxPacked);
-      state.index.pop();
-      return leftCost + rightCost;
+    @Override
+    public String toString() {
+      return "nodeID=" + nodeID;
     }
   }
 
@@ -986,17 +855,17 @@ public final class BKDReader extends PointValues {
   }
 
   @Override
-  public int getNumDimensions() {
+  public int getNumDimensions() throws IOException {
     return config.numDims;
   }
 
   @Override
-  public int getNumIndexDimensions() {
+  public int getNumIndexDimensions() throws IOException {
     return config.numIndexDims;
   }
 
   @Override
-  public int getBytesPerDimension() {
+  public int getBytesPerDimension() throws IOException {
     return config.bytesPerDim;
   }
 
@@ -1010,12 +879,8 @@ public final class BKDReader extends PointValues {
     return docCount;
   }
 
-  public boolean isLeafNode(int nodeID) {
-    return nodeID >= leafNodeOffset;
-  }
-
   /** Reusable {@link DocIdSetIterator} to handle low cardinality leaves. */
-  protected static class BKDReaderDocIDSetIterator extends DocIdSetIterator {
+  private static class BKDReaderDocIDSetIterator extends DocIdSetIterator {
 
     private int idx;
     private int length;
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
index ab801d9..655cf8c 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
@@ -24,8 +24,9 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.function.IntFunction;
 import org.apache.lucene.codecs.CodecUtil;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.index.MergeState;
+import org.apache.lucene.index.PointValues;
 import org.apache.lucene.index.PointValues.IntersectVisitor;
 import org.apache.lucene.index.PointValues.Relation;
 import org.apache.lucene.store.ByteBuffersDataOutput;
@@ -239,83 +240,42 @@ public class BKDWriter implements Closeable {
   }
 
   private static class MergeReader {
-    final BKDReader bkd;
-    final BKDReader.IntersectState state;
-    final MergeState.DocMap docMap;
-
-    /** Current doc ID */
-    public int docID;
-
+    private final PointValues.PointTree pointTree;
+    private final int packedBytesLength;
+    private final MergeState.DocMap docMap;
+    private final MergeIntersectsVisitor mergeIntersectsVisitor;
     /** Which doc in this block we are up to */
     private int docBlockUpto;
-
-    /** How many docs in the current block */
-    private int docsInBlock;
-
-    /** Which leaf block we are up to */
-    private int blockID;
-
-    private final byte[] packedValues;
-
-    public MergeReader(BKDReader bkd, MergeState.DocMap docMap) throws IOException {
-      this.bkd = bkd;
-      state = new BKDReader.IntersectState(bkd.in.clone(), bkd.config, null, null);
+    /** Current doc ID */
+    public int docID;
+    /** Current packed value */
+    public final byte[] packedValue;
+
+    public MergeReader(PointValues pointValues, MergeState.DocMap docMap) throws IOException {
+      this.packedBytesLength = pointValues.getBytesPerDimension() * pointValues.getNumDimensions();
+      this.pointTree = pointValues.getPointTree();
+      this.mergeIntersectsVisitor = new MergeIntersectsVisitor(packedBytesLength);
+      // move to first child of the tree and collect docs
+      while (pointTree.moveToChild()) {}
+      pointTree.visitDocValues(mergeIntersectsVisitor);
       this.docMap = docMap;
-      state.in.seek(bkd.getMinLeafBlockFP());
-      this.packedValues = new byte[bkd.config.maxPointsInLeafNode * bkd.config.packedBytesLength];
+      this.packedValue = new byte[packedBytesLength];
     }
 
     public boolean next() throws IOException {
       // System.out.println("MR.next this=" + this);
       while (true) {
-        if (docBlockUpto == docsInBlock) {
-          if (blockID == bkd.leafNodeOffset) {
-            // System.out.println("  done!");
+        if (docBlockUpto == mergeIntersectsVisitor.docsInBlock) {
+          if (collectNextLeaf() == false) {
+            assert mergeIntersectsVisitor.docsInBlock == 0;
             return false;
           }
-          // System.out.println("  new block @ fp=" + state.in.getFilePointer());
-          docsInBlock = bkd.readDocIDs(state.in, state.in.getFilePointer(), state.scratchIterator);
-          assert docsInBlock > 0;
+          assert mergeIntersectsVisitor.docsInBlock > 0;
           docBlockUpto = 0;
-          bkd.visitDocValues(
-              state.commonPrefixLengths,
-              state.scratchDataPackedValue,
-              state.scratchMinIndexPackedValue,
-              state.scratchMaxIndexPackedValue,
-              state.in,
-              state.scratchIterator,
-              docsInBlock,
-              new IntersectVisitor() {
-                int i = 0;
-
-                @Override
-                public void visit(int docID) {
-                  throw new UnsupportedOperationException();
-                }
-
-                @Override
-                public void visit(int docID, byte[] packedValue) {
-                  assert docID == state.scratchIterator.docIDs[i];
-                  System.arraycopy(
-                      packedValue,
-                      0,
-                      packedValues,
-                      i * bkd.config.packedBytesLength,
-                      bkd.config.packedBytesLength);
-                  i++;
-                }
-
-                @Override
-                public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
-                  return Relation.CELL_CROSSES_QUERY;
-                }
-              });
-
-          blockID++;
         }
 
         final int index = docBlockUpto++;
-        int oldDocID = state.scratchIterator.docIDs[index];
+        int oldDocID = mergeIntersectsVisitor.docIDs[index];
 
         int mappedDocID;
         if (docMap == null) {
@@ -328,23 +288,89 @@ public class BKDWriter implements Closeable {
           // Not deleted!
           docID = mappedDocID;
           System.arraycopy(
-              packedValues,
-              index * bkd.config.packedBytesLength,
-              state.scratchDataPackedValue,
+              mergeIntersectsVisitor.packedValues,
+              index * packedBytesLength,
+              packedValue,
               0,
-              bkd.config.packedBytesLength);
+              packedBytesLength);
           return true;
         }
       }
     }
+
+    private boolean collectNextLeaf() throws IOException {
+      assert pointTree.moveToChild() == false;
+      mergeIntersectsVisitor.reset();
+      do {
+        if (pointTree.moveToSibling()) {
+          // move to first child of this node and collect docs
+          while (pointTree.moveToChild()) {}
+          pointTree.visitDocValues(mergeIntersectsVisitor);
+          return true;
+        }
+      } while (pointTree.moveToParent());
+      return false;
+    }
+  }
+
+  private static class MergeIntersectsVisitor implements IntersectVisitor {
+
+    int docsInBlock = 0;
+    byte[] packedValues;
+    int[] docIDs;
+    private final int packedBytesLength;
+
+    MergeIntersectsVisitor(int packedBytesLength) {
+      this.docIDs = new int[0];
+      this.packedValues = new byte[0];
+      this.packedBytesLength = packedBytesLength;
+    }
+
+    void reset() {
+      docsInBlock = 0;
+    }
+
+    @Override
+    public void grow(int count) {
+      assert docsInBlock == 0;
+      if (docIDs.length < count) {
+        docIDs = ArrayUtil.grow(docIDs, count);
+        int packedValuesSize = Math.toIntExact(docIDs.length * (long) packedBytesLength);
+        if (packedValuesSize > ArrayUtil.MAX_ARRAY_LENGTH) {
+          throw new IllegalStateException(
+              "array length must be <= to "
+                  + ArrayUtil.MAX_ARRAY_LENGTH
+                  + " but was: "
+                  + packedValuesSize);
+        }
+        packedValues = ArrayUtil.growExact(packedValues, packedValuesSize);
+      }
+    }
+
+    @Override
+    public void visit(int docID) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public void visit(int docID, byte[] packedValue) {
+      System.arraycopy(
+          packedValue, 0, packedValues, docsInBlock * packedBytesLength, packedBytesLength);
+      docIDs[docsInBlock++] = docID;
+    }
+
+    @Override
+    public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+      return Relation.CELL_CROSSES_QUERY;
+    }
   }
 
   private static class BKDMergeQueue extends PriorityQueue<MergeReader> {
-    private final ByteArrayComparator comparator;
+    private final int bytesPerDim;
 
     public BKDMergeQueue(int bytesPerDim, int maxSize) {
       super(maxSize);
-      this.comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
+      this.bytesPerDim = bytesPerDim;
     }
 
     @Override
@@ -352,7 +378,7 @@ public class BKDWriter implements Closeable {
       assert a != b;
 
       int cmp =
-          comparator.compare(a.state.scratchDataPackedValue, 0, b.state.scratchDataPackedValue, 0);
+          Arrays.compareUnsigned(a.packedValue, 0, bytesPerDim, b.packedValue, 0, bytesPerDim);
       if (cmp < 0) {
         return true;
       } else if (cmp > 0) {
@@ -387,7 +413,7 @@ public class BKDWriter implements Closeable {
   }
 
   /**
-   * Write a field from a {@link MutablePointValues}. This way of writing points is faster than
+   * Write a field from a {@link MutablePointTree}. This way of writing points is faster than
    * regular writes with {@link BKDWriter#add} since there is opportunity for reordering points
    * before writing them to disk. This method does not use transient disk in order to reorder
    * points.
@@ -397,7 +423,7 @@ public class BKDWriter implements Closeable {
       IndexOutput indexOut,
       IndexOutput dataOut,
       String fieldName,
-      MutablePointValues reader)
+      MutablePointTree reader)
       throws IOException {
     if (config.numDims == 1) {
       return writeField1Dim(metaOut, indexOut, dataOut, fieldName, reader);
@@ -407,7 +433,7 @@ public class BKDWriter implements Closeable {
   }
 
   private void computePackedValueBounds(
-      MutablePointValues values,
+      MutablePointTree values,
       int from,
       int to,
       byte[] minPackedValue,
@@ -425,8 +451,14 @@ public class BKDWriter implements Closeable {
       values.getValue(i, scratch);
       for (int dim = 0; dim < config.numIndexDims; dim++) {
         final int startOffset = dim * config.bytesPerDim;
-        if (comparator.compare(
-                scratch.bytes, scratch.offset + startOffset, minPackedValue, startOffset)
+        final int endOffset = startOffset + config.bytesPerDim;
+        if (Arrays.compareUnsigned(
+                scratch.bytes,
+                scratch.offset + startOffset,
+                scratch.offset + endOffset,
+                minPackedValue,
+                startOffset,
+                endOffset)
             < 0) {
           System.arraycopy(
               scratch.bytes,
@@ -434,8 +466,13 @@ public class BKDWriter implements Closeable {
               minPackedValue,
               startOffset,
               config.bytesPerDim);
-        } else if (comparator.compare(
-                scratch.bytes, scratch.offset + startOffset, maxPackedValue, startOffset)
+        } else if (Arrays.compareUnsigned(
+                scratch.bytes,
+                scratch.offset + startOffset,
+                scratch.offset + endOffset,
+                maxPackedValue,
+                startOffset,
+                endOffset)
             > 0) {
           System.arraycopy(
               scratch.bytes,
@@ -455,7 +492,7 @@ public class BKDWriter implements Closeable {
       IndexOutput indexOut,
       IndexOutput dataOut,
       String fieldName,
-      MutablePointValues values)
+      MutablePointTree values)
       throws IOException {
     if (pointCount != 0) {
       throw new IllegalStateException("cannot mix add and writeField");
@@ -549,14 +586,14 @@ public class BKDWriter implements Closeable {
       IndexOutput indexOut,
       IndexOutput dataOut,
       String fieldName,
-      MutablePointValues reader)
+      MutablePointTree reader)
       throws IOException {
-    MutablePointsReaderUtils.sort(config, maxDoc, reader, 0, Math.toIntExact(reader.size()));
+    MutablePointTreeReaderUtils.sort(config, maxDoc, reader, 0, Math.toIntExact(reader.size()));
 
     final OneDimensionBKDWriter oneDimWriter =
         new OneDimensionBKDWriter(metaOut, indexOut, dataOut);
 
-    reader.intersect(
+    reader.visitDocValues(
         new IntersectVisitor() {
 
           @Override
@@ -579,30 +616,33 @@ public class BKDWriter implements Closeable {
   }
 
   /**
-   * More efficient bulk-add for incoming {@link BKDReader}s. This does a merge sort of the already
-   * sorted values and currently only works when numDims==1. This returns -1 if all documents
-   * containing dimensional values were deleted.
+   * More efficient bulk-add for incoming {@link PointValues}s. This does a merge sort of the
+   * already sorted values and currently only works when numDims==1. This returns -1 if all
+   * documents containing dimensional values were deleted.
    */
   public Runnable merge(
       IndexOutput metaOut,
       IndexOutput indexOut,
       IndexOutput dataOut,
       List<MergeState.DocMap> docMaps,
-      List<BKDReader> readers)
+      List<PointValues> readers)
       throws IOException {
     assert docMaps == null || readers.size() == docMaps.size();
 
     BKDMergeQueue queue = new BKDMergeQueue(config.bytesPerDim, readers.size());
 
     for (int i = 0; i < readers.size(); i++) {
-      BKDReader bkd = readers.get(i);
+      PointValues pointValues = readers.get(i);
+      assert pointValues.getNumDimensions() == config.numDims
+          && pointValues.getBytesPerDimension() == config.bytesPerDim
+          && pointValues.getNumIndexDimensions() == config.numIndexDims;
       MergeState.DocMap docMap;
       if (docMaps == null) {
         docMap = null;
       } else {
         docMap = docMaps.get(i);
       }
-      MergeReader reader = new MergeReader(bkd, docMap);
+      MergeReader reader = new MergeReader(pointValues, docMap);
       if (reader.next()) {
         queue.add(reader);
       }
@@ -614,7 +654,7 @@ public class BKDWriter implements Closeable {
       MergeReader reader = queue.top();
       // System.out.println("iter reader=" + reader);
 
-      oneDimWriter.add(reader.state.scratchDataPackedValue, reader.docID);
+      oneDimWriter.add(reader.packedValue, reader.docID);
 
       if (reader.next()) {
         queue.updateTop();
@@ -1561,7 +1601,7 @@ public class BKDWriter implements Closeable {
   private void build(
       int leavesOffset,
       int numLeaves,
-      MutablePointValues reader,
+      MutablePointTree reader,
       int from,
       int to,
       IndexOutput out,
@@ -1626,7 +1666,7 @@ public class BKDWriter implements Closeable {
       }
 
       // sort by sortedDim
-      MutablePointsReaderUtils.sortByDim(
+      MutablePointTreeReaderUtils.sortByDim(
           config,
           sortedDim,
           commonPrefixLengths,
@@ -1722,7 +1762,7 @@ public class BKDWriter implements Closeable {
               maxPackedValue,
               splitDim * config.bytesPerDim);
 
-      MutablePointsReaderUtils.partition(
+      MutablePointTreeReaderUtils.partition(
           config,
           maxDoc,
           splitDim,
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/MutablePointsReaderUtils.java b/lucene/core/src/java/org/apache/lucene/util/bkd/MutablePointTreeReaderUtils.java
similarity index 95%
rename from lucene/core/src/java/org/apache/lucene/util/bkd/MutablePointsReaderUtils.java
rename to lucene/core/src/java/org/apache/lucene/util/bkd/MutablePointTreeReaderUtils.java
index 24c8403..4a9a290 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/MutablePointsReaderUtils.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/MutablePointTreeReaderUtils.java
@@ -17,7 +17,7 @@
 package org.apache.lucene.util.bkd;
 
 import java.util.Arrays;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.ArrayUtil.ByteArrayComparator;
 import org.apache.lucene.util.BytesRef;
@@ -33,13 +33,12 @@ import org.apache.lucene.util.packed.PackedInts;
  *
  * @lucene.internal
  */
-public final class MutablePointsReaderUtils {
+public final class MutablePointTreeReaderUtils {
 
-  MutablePointsReaderUtils() {}
+  MutablePointTreeReaderUtils() {}
 
-  /** Sort the given {@link MutablePointValues} based on its packed value then doc ID. */
-  public static void sort(
-      BKDConfig config, int maxDoc, MutablePointValues reader, int from, int to) {
+  /** Sort the given {@link MutablePointTree} based on its packed value then doc ID. */
+  public static void sort(BKDConfig config, int maxDoc, MutablePointTree reader, int from, int to) {
 
     boolean sortedByDocID = true;
     int prevDoc = 0;
@@ -90,7 +89,7 @@ public final class MutablePointsReaderUtils {
       BKDConfig config,
       int sortedDim,
       int[] commonPrefixLengths,
-      MutablePointValues reader,
+      MutablePointTree reader,
       int from,
       int to,
       BytesRef scratch1,
@@ -149,7 +148,7 @@ public final class MutablePointsReaderUtils {
       int maxDoc,
       int splitDim,
       int commonPrefixLen,
-      MutablePointValues reader,
+      MutablePointTree reader,
       int from,
       int to,
       int mid,
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java
index fca1b83..f7284d0 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene90/TestLucene90PointsFormat.java
@@ -113,9 +113,11 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
     final int numDocs =
         TEST_NIGHTLY ? atLeast(10000) : atLeast(500); // at night, make sure we have several leaves
     final boolean multiValues = random().nextBoolean();
+    int totalValues = 0;
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
+        totalValues++;
         doc.add(new BinaryPoint("f", uniquePointValue));
       } else {
         final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
@@ -124,6 +126,7 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
             random().nextBytes(pointValue);
           } while (Arrays.equals(pointValue, uniquePointValue));
           doc.add(new BinaryPoint("f", pointValue));
+          totalValues++;
         }
       }
       w.addDocument(doc);
@@ -134,9 +137,6 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
     final LeafReader lr = getOnlyLeafReader(r);
     PointValues points = lr.getPointValues("f");
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
-
     IntersectVisitor allPointsVisitor =
         new IntersectVisitor() {
           @Override
@@ -151,7 +151,7 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
           }
         };
 
-    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
     assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
 
     IntersectVisitor noPointsVisitor =
@@ -193,11 +193,16 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
     // If only one point matches, then the point count is (maxPointsInLeafNode + 1) / 2
     // in general, or maybe 2x that if the point is a split value
     final long pointCount = points.estimatePointCount(onePointMatchVisitor);
+    final long lastNodePointCount = totalValues % maxPointsInLeafNode;
     assertTrue(
         "" + pointCount,
-        pointCount == (maxPointsInLeafNode + 1) / 2
-            || // common case
-            pointCount == 2 * ((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount
+                == ((maxPointsInLeafNode + 1) / 2)
+                    + ((lastNodePointCount + 1)
+                        / 2)); // if the point is a split value and one leaf is not fully populated
 
     final long docCount = points.estimateDocCount(onePointMatchVisitor);
 
@@ -234,10 +239,12 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
             ? atLeast(10000)
             : atLeast(1000); // in nightly, make sure we have several leaves
     final boolean multiValues = random().nextBoolean();
+    int totalValues = 0;
     for (int i = 0; i < numDocs; ++i) {
       Document doc = new Document();
       if (i == numDocs / 2) {
         doc.add(new BinaryPoint("f", uniquePointValue));
+        totalValues++;
       } else {
         final int numValues = (multiValues) ? TestUtil.nextInt(random(), 2, 100) : 1;
         for (int j = 0; j < numValues; j++) {
@@ -247,6 +254,7 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
           } while (Arrays.equals(pointValue[0], uniquePointValue[0])
               || Arrays.equals(pointValue[1], uniquePointValue[1]));
           doc.add(new BinaryPoint("f", pointValue));
+          totalValues++;
         }
       }
       w.addDocument(doc);
@@ -271,10 +279,7 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
           }
         };
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    final int numLeaves = (int) Math.ceil((double) points.size() / maxPointsInLeafNode);
-
-    assertEquals(numLeaves * maxPointsInLeafNode, points.estimatePointCount(allPointsVisitor));
+    assertEquals(totalValues, points.estimatePointCount(allPointsVisitor));
     assertEquals(numDocs, points.estimateDocCount(allPointsVisitor));
 
     IntersectVisitor noPointsVisitor =
@@ -320,11 +325,16 @@ public class TestLucene90PointsFormat extends BasePointsFormatTestCase {
         };
 
     final long pointCount = points.estimatePointCount(onePointMatchVisitor);
-    // The number of matches needs to be multiple of count per leaf
-    final long countPerLeaf = (maxPointsInLeafNode + 1) / 2;
-    assertTrue("" + pointCount, pointCount % countPerLeaf == 0);
-    // in extreme cases, a point can be be shared by 4 leaves
-    assertTrue("" + pointCount, pointCount / countPerLeaf <= 4 && pointCount / countPerLeaf >= 1);
+    final long lastNodePointCount = totalValues % maxPointsInLeafNode;
+    assertTrue(
+        "" + pointCount,
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount == ((maxPointsInLeafNode + 1) / 2) + ((lastNodePointCount + 1) / 2)
+            // in extreme cases, a point can be shared by 4 leaves
+            || pointCount == 4 * ((maxPointsInLeafNode + 1) / 2)
+            || pointCount == 3 * ((maxPointsInLeafNode + 1) / 2) + ((lastNodePointCount + 1) / 2));
 
     final long docCount = points.estimateDocCount(onePointMatchVisitor);
     if (multiValues) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java b/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java
index 4c1d7fa..1c7171a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestDocIdSetBuilder.java
@@ -308,12 +308,7 @@ public class TestDocIdSetBuilder extends LuceneTestCase {
     }
 
     @Override
-    public void intersect(IntersectVisitor visitor) throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public long estimatePointCount(IntersectVisitor visitor) {
+    public PointTree getPointTree() {
       throw new UnsupportedOperationException();
     }
 
diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java b/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java
index 3667afb..09707fe 100644
--- a/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java
+++ b/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java
@@ -18,6 +18,7 @@ package org.apache.lucene.util.bkd;
 
 import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite;
 import org.apache.lucene.index.CheckIndex;
+import org.apache.lucene.index.PointValues;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.FSDirectory;
 import org.apache.lucene.store.IOContext;
@@ -72,7 +73,7 @@ public class Test2BBKDPoints extends LuceneTestCase {
 
     IndexInput in = dir.openInput("1d.bkd", IOContext.DEFAULT);
     in.seek(indexFP);
-    BKDReader r = new BKDReader(in, in, in);
+    PointValues r = new BKDReader(in, in, in);
     CheckIndex.VerifyPointsVisitor visitor = new CheckIndex.VerifyPointsVisitor("1d", numDocs, r);
     r.intersect(visitor);
     assertEquals(r.size(), visitor.getPointCountSeen());
@@ -121,7 +122,7 @@ public class Test2BBKDPoints extends LuceneTestCase {
 
     IndexInput in = dir.openInput("2d.bkd", IOContext.DEFAULT);
     in.seek(indexFP);
-    BKDReader r = new BKDReader(in, in, in);
+    PointValues r = new BKDReader(in, in, in);
     CheckIndex.VerifyPointsVisitor visitor = new CheckIndex.VerifyPointsVisitor("2d", numDocs, r);
     r.intersect(visitor);
     assertEquals(r.size(), visitor.getPointCountSeen());
diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
index e1ff58e..3f4e6a0 100644
--- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
+++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
@@ -23,7 +23,7 @@ import java.util.Arrays;
 import java.util.BitSet;
 import java.util.List;
 import java.util.Random;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
 import org.apache.lucene.index.CorruptIndexException;
 import org.apache.lucene.index.MergeState;
 import org.apache.lucene.index.PointValues;
@@ -46,9 +46,14 @@ import org.apache.lucene.util.TestUtil;
 
 public class TestBKD extends LuceneTestCase {
 
+  protected PointValues getPointValues(IndexInput in) throws IOException {
+    return new BKDReader(in, in, in);
+  }
+
   public void testBasicInts1D() throws Exception {
+    final BKDConfig config = new BKDConfig(1, 1, 4, 2);
     try (Directory dir = getDirectory(100)) {
-      BKDWriter w = new BKDWriter(100, dir, "tmp", new BKDConfig(1, 1, 4, 2), 1.0f, 100);
+      BKDWriter w = new BKDWriter(100, dir, "tmp", config, 1.0f, 100);
       byte[] scratch = new byte[4];
       for (int docID = 0; docID < 100; docID++) {
         NumericUtils.intToSortableBytes(docID, scratch, 0);
@@ -64,63 +69,19 @@ public class TestBKD extends LuceneTestCase {
 
       try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
         in.seek(indexFP);
-        BKDReader r = new BKDReader(in, in, in);
+        PointValues r = getPointValues(in);
 
         // Simple 1D range query:
-        final int queryMin = 42;
-        final int queryMax = 87;
+        final byte[][] queryMin = new byte[1][4];
+        NumericUtils.intToSortableBytes(42, queryMin[0], 0);
+        final byte[][] queryMax = new byte[1][4];
+        NumericUtils.intToSortableBytes(87, queryMax[0], 0);
 
         final BitSet hits = new BitSet();
-        r.intersect(
-            new IntersectVisitor() {
-              @Override
-              public void visit(int docID) {
-                hits.set(docID);
-                if (VERBOSE) {
-                  System.out.println("visit docID=" + docID);
-                }
-              }
-
-              @Override
-              public void visit(int docID, byte[] packedValue) {
-                int x = NumericUtils.sortableBytesToInt(packedValue, 0);
-                if (VERBOSE) {
-                  System.out.println("visit docID=" + docID + " x=" + x);
-                }
-                if (x >= queryMin && x <= queryMax) {
-                  hits.set(docID);
-                }
-              }
-
-              @Override
-              public Relation compare(byte[] minPacked, byte[] maxPacked) {
-                int min = NumericUtils.sortableBytesToInt(minPacked, 0);
-                int max = NumericUtils.sortableBytesToInt(maxPacked, 0);
-                assert max >= min;
-                if (VERBOSE) {
-                  System.out.println(
-                      "compare: min="
-                          + min
-                          + " max="
-                          + max
-                          + " vs queryMin="
-                          + queryMin
-                          + " queryMax="
-                          + queryMax);
-                }
-
-                if (max < queryMin || min > queryMax) {
-                  return Relation.CELL_OUTSIDE_QUERY;
-                } else if (min >= queryMin && max <= queryMax) {
-                  return Relation.CELL_INSIDE_QUERY;
-                } else {
-                  return Relation.CELL_CROSSES_QUERY;
-                }
-              }
-            });
+        r.intersect(getIntersectVisitor(hits, queryMin, queryMax, config));
 
         for (int docID = 0; docID < 100; docID++) {
-          boolean expected = docID >= queryMin && docID <= queryMax;
+          boolean expected = docID >= 42 && docID <= 87;
           boolean actual = hits.get(docID);
           assertEquals("docID=" + docID, expected, actual);
         }
@@ -135,14 +96,8 @@ public class TestBKD extends LuceneTestCase {
       int numIndexDims = TestUtil.nextInt(random(), 1, numDims);
       int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
       float maxMB = (float) 3.0 + (3 * random().nextFloat());
-      BKDWriter w =
-          new BKDWriter(
-              numDocs,
-              dir,
-              "tmp",
-              new BKDConfig(numDims, numIndexDims, 4, maxPointsInLeafNode),
-              maxMB,
-              numDocs);
+      BKDConfig config = new BKDConfig(numDims, numIndexDims, 4, maxPointsInLeafNode);
+      BKDWriter w = new BKDWriter(numDocs, dir, "tmp", config, maxMB, numDocs);
 
       if (VERBOSE) {
         System.out.println(
@@ -185,7 +140,7 @@ public class TestBKD extends LuceneTestCase {
 
       try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
         in.seek(indexFP);
-        BKDReader r = new BKDReader(in, in, in);
+        PointValues r = getPointValues(in);
 
         byte[] minPackedValue = r.getMinPackedValue();
         byte[] maxPackedValue = r.getMaxPackedValue();
@@ -204,7 +159,9 @@ public class TestBKD extends LuceneTestCase {
 
           // Random N dims rect query:
           int[] queryMin = new int[numDims];
+          byte[][] queryMinBytes = new byte[numDims][4];
           int[] queryMax = new int[numDims];
+          byte[][] queryMaxBytes = new byte[numDims][4];
           for (int dim = 0; dim < numIndexDims; dim++) {
             queryMin[dim] = random().nextInt();
             queryMax[dim] = random().nextInt();
@@ -213,54 +170,12 @@ public class TestBKD extends LuceneTestCase {
               queryMin[dim] = queryMax[dim];
               queryMax[dim] = x;
             }
+            NumericUtils.intToSortableBytes(queryMin[dim], queryMinBytes[dim], 0);
+            NumericUtils.intToSortableBytes(queryMax[dim], queryMaxBytes[dim], 0);
           }
 
           final BitSet hits = new BitSet();
-          r.intersect(
-              new IntersectVisitor() {
-                @Override
-                public void visit(int docID) {
-                  hits.set(docID);
-                  // System.out.println("visit docID=" + docID);
-                }
-
-                @Override
-                public void visit(int docID, byte[] packedValue) {
-                  // System.out.println("visit check docID=" + docID);
-                  for (int dim = 0; dim < numIndexDims; dim++) {
-                    int x = NumericUtils.sortableBytesToInt(packedValue, dim * Integer.BYTES);
-                    if (x < queryMin[dim] || x > queryMax[dim]) {
-                      // System.out.println("  no");
-                      return;
-                    }
-                  }
-
-                  // System.out.println("  yes");
-                  hits.set(docID);
-                }
-
-                @Override
-                public Relation compare(byte[] minPacked, byte[] maxPacked) {
-                  boolean crosses = false;
-                  for (int dim = 0; dim < numIndexDims; dim++) {
-                    int min = NumericUtils.sortableBytesToInt(minPacked, dim * Integer.BYTES);
-                    int max = NumericUtils.sortableBytesToInt(maxPacked, dim * Integer.BYTES);
-                    assert max >= min;
-
-                    if (max < queryMin[dim] || min > queryMax[dim]) {
-                      return Relation.CELL_OUTSIDE_QUERY;
-                    } else if (min < queryMin[dim] || max > queryMax[dim]) {
-                      crosses = true;
-                    }
-                  }
-
-                  if (crosses) {
-                    return Relation.CELL_CROSSES_QUERY;
-                  } else {
-                    return Relation.CELL_INSIDE_QUERY;
-                  }
-                }
-              });
+          r.intersect(getIntersectVisitor(hits, queryMinBytes, queryMaxBytes, config));
 
           for (int docID = 0; docID < numDocs; docID++) {
             int[] docValues = docs[docID];
@@ -289,14 +204,8 @@ public class TestBKD extends LuceneTestCase {
       int numDims = TestUtil.nextInt(random(), 1, 5);
       int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
       float maxMB = (float) 3.0 + (3 * random().nextFloat());
-      BKDWriter w =
-          new BKDWriter(
-              numDocs,
-              dir,
-              "tmp",
-              new BKDConfig(numDims, numDims, numBytesPerDim, maxPointsInLeafNode),
-              maxMB,
-              numDocs);
+      BKDConfig config = new BKDConfig(numDims, numDims, numBytesPerDim, maxPointsInLeafNode);
+      BKDWriter w = new BKDWriter(numDocs, dir, "tmp", config, maxMB, numDocs);
       BigInteger[][] docs = new BigInteger[numDocs][];
 
       byte[] scratch = new byte[numBytesPerDim * numDims];
@@ -326,7 +235,7 @@ public class TestBKD extends LuceneTestCase {
 
       try (IndexInput in = dir.openInput("bkd", IOContext.DEFAULT)) {
         in.seek(indexFP);
-        BKDReader r = new BKDReader(in, in, in);
+        PointValues pointValues = getPointValues(in);
 
         int iters = atLeast(100);
         for (int iter = 0; iter < iters; iter++) {
@@ -336,7 +245,9 @@ public class TestBKD extends LuceneTestCase {
 
           // Random N dims rect query:
           BigInteger[] queryMin = new BigInteger[numDims];
+          byte[][] queryMinBytes = new byte[numDims][numBytesPerDim];
           BigInteger[] queryMax = new BigInteger[numDims];
+          byte[][] queryMaxBytes = new byte[numDims][numBytesPerDim];
           for (int dim = 0; dim < numDims; dim++) {
             queryMin[dim] = randomBigInt(numBytesPerDim);
             queryMax[dim] = randomBigInt(numBytesPerDim);
@@ -345,61 +256,14 @@ public class TestBKD extends LuceneTestCase {
               queryMin[dim] = queryMax[dim];
               queryMax[dim] = x;
             }
+            NumericUtils.bigIntToSortableBytes(
+                queryMin[dim], numBytesPerDim, queryMinBytes[dim], 0);
+            NumericUtils.bigIntToSortableBytes(
+                queryMax[dim], numBytesPerDim, queryMaxBytes[dim], 0);
           }
 
           final BitSet hits = new BitSet();
-          r.intersect(
-              new IntersectVisitor() {
-                @Override
-                public void visit(int docID) {
-                  hits.set(docID);
-                  // System.out.println("visit docID=" + docID);
-                }
-
-                @Override
-                public void visit(int docID, byte[] packedValue) {
-                  // System.out.println("visit check docID=" + docID);
-                  for (int dim = 0; dim < numDims; dim++) {
-                    BigInteger x =
-                        NumericUtils.sortableBytesToBigInt(
-                            packedValue, dim * numBytesPerDim, numBytesPerDim);
-                    if (x.compareTo(queryMin[dim]) < 0 || x.compareTo(queryMax[dim]) > 0) {
-                      // System.out.println("  no");
-                      return;
-                    }
-                  }
-
-                  // System.out.println("  yes");
-                  hits.set(docID);
-                }
-
-                @Override
-                public Relation compare(byte[] minPacked, byte[] maxPacked) {
-                  boolean crosses = false;
-                  for (int dim = 0; dim < numDims; dim++) {
-                    BigInteger min =
-                        NumericUtils.sortableBytesToBigInt(
-                            minPacked, dim * numBytesPerDim, numBytesPerDim);
-                    BigInteger max =
-                        NumericUtils.sortableBytesToBigInt(
-                            maxPacked, dim * numBytesPerDim, numBytesPerDim);
-                    assert max.compareTo(min) >= 0;
-
-                    if (max.compareTo(queryMin[dim]) < 0 || min.compareTo(queryMax[dim]) > 0) {
-                      return Relation.CELL_OUTSIDE_QUERY;
-                    } else if (min.compareTo(queryMin[dim]) < 0
-                        || max.compareTo(queryMax[dim]) > 0) {
-                      crosses = true;
-                    }
-                  }
-
-                  if (crosses) {
-                    return Relation.CELL_CROSSES_QUERY;
-                  } else {
-                    return Relation.CELL_INSIDE_QUERY;
-                  }
-                }
-              });
+          pointValues.intersect(getIntersectVisitor(hits, queryMinBytes, queryMaxBytes, config));
 
           for (int docID = 0; docID < numDocs; docID++) {
             BigInteger[] docValues = docs[docID];
@@ -925,10 +789,10 @@ public class TestBKD extends LuceneTestCase {
                 new BKDConfig(numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode),
                 maxMB,
                 docValues.length);
-        List<BKDReader> readers = new ArrayList<>();
+        List<PointValues> readers = new ArrayList<>();
         for (long fp : toMerge) {
           in.seek(fp);
-          readers.add(new BKDReader(in, in, in));
+          readers.add(getPointValues(in));
         }
         out = dir.createOutput("bkd2", IOContext.DEFAULT);
         Runnable finalizer = w.merge(out, out, out, docMaps, readers);
@@ -946,7 +810,9 @@ public class TestBKD extends LuceneTestCase {
       }
 
       in.seek(indexFP);
-      BKDReader r = new BKDReader(in, in, in);
+      PointValues pointValues = getPointValues(in);
+
+      assertSize(pointValues.getPointTree());
 
       int iters = atLeast(100);
       for (int iter = 0; iter < iters; iter++) {
@@ -971,115 +837,6 @@ public class TestBKD extends LuceneTestCase {
           }
         }
 
-        final BitSet hits = new BitSet();
-        r.intersect(
-            new IntersectVisitor() {
-              @Override
-              public void visit(int docID) {
-                hits.set(docID);
-                // System.out.println("visit docID=" + docID);
-              }
-
-              @Override
-              public void visit(int docID, byte[] packedValue) {
-                // System.out.println("visit check docID=" + docID);
-                for (int dim = 0; dim < numIndexDims; dim++) {
-                  if (Arrays.compareUnsigned(
-                              packedValue,
-                              dim * numBytesPerDim,
-                              dim * numBytesPerDim + numBytesPerDim,
-                              queryMin[dim],
-                              0,
-                              numBytesPerDim)
-                          < 0
-                      || Arrays.compareUnsigned(
-                              packedValue,
-                              dim * numBytesPerDim,
-                              dim * numBytesPerDim + numBytesPerDim,
-                              queryMax[dim],
-                              0,
-                              numBytesPerDim)
-                          > 0) {
-                    // System.out.println("  no");
-                    return;
-                  }
-                }
-
-                // System.out.println("  yes");
-                hits.set(docID);
-              }
-
-              @Override
-              public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
-                if (random().nextBoolean()) {
-                  // check the default method is correct
-                  IntersectVisitor.super.visit(iterator, packedValue);
-                } else {
-                  assertEquals(iterator.docID(), -1);
-                  int cost = Math.toIntExact(iterator.cost());
-                  int numberOfPoints = 0;
-                  int docID;
-                  while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
-                    assertEquals(iterator.docID(), docID);
-                    visit(docID, packedValue);
-                    numberOfPoints++;
-                  }
-                  assertEquals(cost, numberOfPoints);
-                  assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS);
-                  assertEquals(iterator.nextDoc(), DocIdSetIterator.NO_MORE_DOCS);
-                  assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS);
-                }
-              }
-
-              @Override
-              public Relation compare(byte[] minPacked, byte[] maxPacked) {
-                boolean crosses = false;
-                for (int dim = 0; dim < numIndexDims; dim++) {
-                  if (Arrays.compareUnsigned(
-                              maxPacked,
-                              dim * numBytesPerDim,
-                              dim * numBytesPerDim + numBytesPerDim,
-                              queryMin[dim],
-                              0,
-                              numBytesPerDim)
-                          < 0
-                      || Arrays.compareUnsigned(
-                              minPacked,
-                              dim * numBytesPerDim,
-                              dim * numBytesPerDim + numBytesPerDim,
-                              queryMax[dim],
-                              0,
-                              numBytesPerDim)
-                          > 0) {
-                    return Relation.CELL_OUTSIDE_QUERY;
-                  } else if (Arrays.compareUnsigned(
-                              minPacked,
-                              dim * numBytesPerDim,
-                              dim * numBytesPerDim + numBytesPerDim,
-                              queryMin[dim],
-                              0,
-                              numBytesPerDim)
-                          < 0
-                      || Arrays.compareUnsigned(
-                              maxPacked,
-                              dim * numBytesPerDim,
-                              dim * numBytesPerDim + numBytesPerDim,
-                              queryMax[dim],
-                              0,
-                              numBytesPerDim)
-                          > 0) {
-                    crosses = true;
-                  }
-                }
-
-                if (crosses) {
-                  return Relation.CELL_CROSSES_QUERY;
-                } else {
-                  return Relation.CELL_INSIDE_QUERY;
-                }
-              }
-            });
-
         BitSet expected = new BitSet();
         for (int ord = 0; ord < numValues; ord++) {
           boolean matches = true;
@@ -1104,10 +861,17 @@ public class TestBKD extends LuceneTestCase {
           }
         }
 
-        int limit = Math.max(expected.length(), hits.length());
-        for (int docID = 0; docID < limit; docID++) {
-          assertEquals("docID=" + docID, expected.get(docID), hits.get(docID));
-        }
+        BKDConfig config =
+            new BKDConfig(numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode);
+        final BitSet hits = new BitSet();
+        pointValues.intersect(getIntersectVisitor(hits, queryMin, queryMax, config));
+        assertHits(hits, expected);
+
+        hits.clear();
+        pointValues
+            .getPointTree()
+            .visitDocValues(getIntersectVisitor(hits, queryMin, queryMax, config));
+        assertHits(hits, expected);
       }
       in.close();
       dir.deleteFile("bkd");
@@ -1123,6 +887,152 @@ public class TestBKD extends LuceneTestCase {
     }
   }
 
+  private void assertSize(PointValues.PointTree tree) throws IOException {
+    final PointValues.PointTree clone = tree.clone();
+    assertEquals(clone.size(), tree.size());
+    final long[] size = new long[] {0};
+    clone.visitDocIDs(
+        new IntersectVisitor() {
+          @Override
+          public void visit(int docID) {
+            size[0]++;
+          }
+
+          @Override
+          public void visit(int docID, byte[] packedValue) {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+            throw new UnsupportedOperationException();
+          }
+        });
+    assertEquals(size[0], tree.size());
+    if (tree.moveToChild()) {
+      do {
+        assertSize(tree);
+      } while (tree.moveToSibling());
+    }
+  }
+
+  private void assertHits(BitSet hits, BitSet expected) {
+    int limit = Math.max(expected.length(), hits.length());
+    for (int docID = 0; docID < limit; docID++) {
+      assertEquals("docID=" + docID, expected.get(docID), hits.get(docID));
+    }
+  }
+
+  private IntersectVisitor getIntersectVisitor(
+      BitSet hits, byte[][] queryMin, byte[][] queryMax, BKDConfig config) {
+    return new IntersectVisitor() {
+      @Override
+      public void visit(int docID) {
+        hits.set(docID);
+        // System.out.println("visit docID=" + docID);
+      }
+
+      @Override
+      public void visit(int docID, byte[] packedValue) {
+        // System.out.println("visit check docID=" + docID);
+        for (int dim = 0; dim < config.numIndexDims; dim++) {
+          if (Arrays.compareUnsigned(
+                      packedValue,
+                      dim * config.bytesPerDim,
+                      dim * config.bytesPerDim + config.bytesPerDim,
+                      queryMin[dim],
+                      0,
+                      config.bytesPerDim)
+                  < 0
+              || Arrays.compareUnsigned(
+                      packedValue,
+                      dim * config.bytesPerDim,
+                      dim * config.bytesPerDim + config.bytesPerDim,
+                      queryMax[dim],
+                      0,
+                      config.bytesPerDim)
+                  > 0) {
+            // System.out.println("  no");
+            return;
+          }
+        }
+
+        // System.out.println("  yes");
+        hits.set(docID);
+      }
+
+      @Override
+      public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
+        if (random().nextBoolean()) {
+          // check the default method is correct
+          IntersectVisitor.super.visit(iterator, packedValue);
+        } else {
+          assertEquals(iterator.docID(), -1);
+          int cost = Math.toIntExact(iterator.cost());
+          int numberOfPoints = 0;
+          int docID;
+          while ((docID = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
+            assertEquals(iterator.docID(), docID);
+            visit(docID, packedValue);
+            numberOfPoints++;
+          }
+          assertEquals(cost, numberOfPoints);
+          assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS);
+          assertEquals(iterator.nextDoc(), DocIdSetIterator.NO_MORE_DOCS);
+          assertEquals(iterator.docID(), DocIdSetIterator.NO_MORE_DOCS);
+        }
+      }
+
+      @Override
+      public Relation compare(byte[] minPacked, byte[] maxPacked) {
+        boolean crosses = false;
+        for (int dim = 0; dim < config.numIndexDims; dim++) {
+          if (Arrays.compareUnsigned(
+                      maxPacked,
+                      dim * config.bytesPerDim,
+                      dim * config.bytesPerDim + config.bytesPerDim,
+                      queryMin[dim],
+                      0,
+                      config.bytesPerDim)
+                  < 0
+              || Arrays.compareUnsigned(
+                      minPacked,
+                      dim * config.bytesPerDim,
+                      dim * config.bytesPerDim + config.bytesPerDim,
+                      queryMax[dim],
+                      0,
+                      config.bytesPerDim)
+                  > 0) {
+            return Relation.CELL_OUTSIDE_QUERY;
+          } else if (Arrays.compareUnsigned(
+                      minPacked,
+                      dim * config.bytesPerDim,
+                      dim * config.bytesPerDim + config.bytesPerDim,
+                      queryMin[dim],
+                      0,
+                      config.bytesPerDim)
+                  < 0
+              || Arrays.compareUnsigned(
+                      maxPacked,
+                      dim * config.bytesPerDim,
+                      dim * config.bytesPerDim + config.bytesPerDim,
+                      queryMax[dim],
+                      0,
+                      config.bytesPerDim)
+                  > 0) {
+            crosses = true;
+          }
+        }
+
+        if (crosses) {
+          return Relation.CELL_CROSSES_QUERY;
+        } else {
+          return Relation.CELL_INSIDE_QUERY;
+        }
+      }
+    };
+  }
+
   private BigInteger randomBigInt(int numBytes) {
     BigInteger x = new BigInteger(numBytes * 8 - 1, random());
     if (random().nextBoolean()) {
@@ -1283,7 +1193,7 @@ public class TestBKD extends LuceneTestCase {
 
       IndexInput in = dir.openInput("bkd", IOContext.DEFAULT);
       in.seek(fp);
-      BKDReader r = new BKDReader(in, in, in);
+      PointValues r = getPointValues(in);
       r.intersect(
           new IntersectVisitor() {
             int lastDocID = -1;
@@ -1350,7 +1260,7 @@ public class TestBKD extends LuceneTestCase {
 
     IndexInput pointsIn = dir.openInput("bkd", IOContext.DEFAULT);
     pointsIn.seek(indexFP);
-    BKDReader points = new BKDReader(pointsIn, pointsIn, pointsIn);
+    PointValues points = getPointValues(pointsIn);
 
     points.intersect(
         new IntersectVisitor() {
@@ -1411,7 +1321,7 @@ public class TestBKD extends LuceneTestCase {
 
       IndexInput in = dir.openInput("bkd", IOContext.DEFAULT);
       in.seek(fp);
-      BKDReader r = new BKDReader(in, in, in);
+      PointValues r = getPointValues(in);
       int[] count = new int[1];
       r.intersect(
           new IntersectVisitor() {
@@ -1478,7 +1388,7 @@ public class TestBKD extends LuceneTestCase {
 
     IndexInput in = dir.openInput("bkd", IOContext.DEFAULT);
     in.seek(fp);
-    BKDReader r = new BKDReader(in, in, in);
+    PointValues r = getPointValues(in);
     int[] count = new int[1];
     r.intersect(
         new IntersectVisitor() {
@@ -1547,15 +1457,11 @@ public class TestBKD extends LuceneTestCase {
 
     IndexInput pointsIn = dir.openInput("bkd", IOContext.DEFAULT);
     pointsIn.seek(indexFP);
-    BKDReader points = new BKDReader(pointsIn, pointsIn, pointsIn);
+    PointValues points = getPointValues(pointsIn);
 
-    // If all points match, then the point count is numLeaves * maxPointsInLeafNode
-    int numLeaves = numValues / maxPointsInLeafNode;
-    if (numValues % maxPointsInLeafNode != 0) {
-      numLeaves++;
-    }
+    // If all points match, then the point count is numValues
     assertEquals(
-        numLeaves * maxPointsInLeafNode,
+        numValues,
         points.estimatePointCount(
             new IntersectVisitor() {
               @Override
@@ -1611,11 +1517,16 @@ public class TestBKD extends LuceneTestCase {
                 return Relation.CELL_CROSSES_QUERY;
               }
             });
+    long lastNodePointCount = numValues % maxPointsInLeafNode;
     assertTrue(
         "" + pointCount,
-        pointCount == (maxPointsInLeafNode + 1) / 2
-            || // common case
-            pointCount == 2 * ((maxPointsInLeafNode + 1) / 2)); // if the point is a split value
+        pointCount == (maxPointsInLeafNode + 1) / 2 // common case
+            || pointCount == (lastNodePointCount + 1) / 2 // not fully populated leaf
+            || pointCount == 2 * ((maxPointsInLeafNode + 1) / 2) // if the point is a split value
+            || pointCount
+                == ((maxPointsInLeafNode + 1) / 2)
+                    + ((lastNodePointCount + 1)
+                        / 2)); // if the point is a split value and one leaf is not fully populated
 
     pointsIn.close();
     dir.close();
@@ -1629,55 +1540,8 @@ public class TestBKD extends LuceneTestCase {
     final byte[] pointValue = new byte[numBytesPerDim];
     random().nextBytes(pointValue);
 
-    MutablePointValues reader =
-        new MutablePointValues() {
-
-          @Override
-          public void intersect(IntersectVisitor visitor) throws IOException {
-            for (int i = 0; i < numPointsAdded; i++) {
-              visitor.visit(0, pointValue);
-            }
-          }
-
-          @Override
-          public long estimatePointCount(IntersectVisitor visitor) {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public byte[] getMinPackedValue() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public byte[] getMaxPackedValue() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public int getNumDimensions() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public int getNumIndexDimensions() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public int getBytesPerDimension() {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public long size() {
-            return numPointsAdded;
-          }
-
-          @Override
-          public int getDocCount() {
-            return numPointsAdded;
-          }
+    MutablePointTree reader =
+        new MutablePointTree() {
 
           @Override
           public void swap(int i, int j) {
@@ -1710,6 +1574,18 @@ public class TestBKD extends LuceneTestCase {
           public void restore(int i, int j) {
             throw new UnsupportedOperationException();
           }
+
+          @Override
+          public long size() {
+            return numPointsAdded;
+          }
+
+          @Override
+          public void visitDocValues(IntersectVisitor visitor) throws IOException {
+            for (int i = 0; i < numPointsAdded; i++) {
+              visitor.visit(0, pointValue);
+            }
+          }
         };
 
     BKDWriter w =
@@ -1779,8 +1655,8 @@ public class TestBKD extends LuceneTestCase {
     for (int i = 0; i < numValues + 1; i++) {
       random().nextBytes(pointValue[i]);
     }
-    MutablePointValues val =
-        new MutablePointValues() {
+    MutablePointTree val =
+        new MutablePointTree() {
           @Override
           public void getValue(int i, BytesRef packedValue) {
             packedValue.bytes = pointValue[i];
@@ -1806,40 +1682,13 @@ public class TestBKD extends LuceneTestCase {
           }
 
           @Override
-          public void intersect(IntersectVisitor visitor) throws IOException {
-            for (int i = 0; i < size(); i++) {
-              visitor.visit(i, pointValue[i]);
-            }
-          }
-
-          @Override
-          public long estimatePointCount(IntersectVisitor visitor) {
-            return 11;
-          }
-
-          @Override
-          public byte[] getMinPackedValue() {
-            return new byte[numBytesPerDim];
-          }
-
-          @Override
-          public byte[] getMaxPackedValue() {
-            return new byte[numBytesPerDim];
-          }
-
-          @Override
-          public int getNumDimensions() {
-            return 1;
-          }
-
-          @Override
-          public int getNumIndexDimensions() {
-            return 1;
+          public void save(int i, int j) {
+            throw new UnsupportedOperationException();
           }
 
           @Override
-          public int getBytesPerDimension() {
-            return numBytesPerDim;
+          public void restore(int i, int j) {
+            throw new UnsupportedOperationException();
           }
 
           @Override
@@ -1848,18 +1697,10 @@ public class TestBKD extends LuceneTestCase {
           }
 
           @Override
-          public int getDocCount() {
-            return 11;
-          }
-
-          @Override
-          public void save(int i, int j) {
-            throw new UnsupportedOperationException();
-          }
-
-          @Override
-          public void restore(int i, int j) {
-            throw new UnsupportedOperationException();
+          public void visitDocValues(IntersectVisitor visitor) throws IOException {
+            for (int i = 0; i < size(); i++) {
+              visitor.visit(i, pointValue[i]);
+            }
           }
         };
     try (IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT)) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestMutablePointsReaderUtils.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestMutablePointTreeReaderUtils.java
similarity index 90%
rename from lucene/core/src/test/org/apache/lucene/util/bkd/TestMutablePointsReaderUtils.java
rename to lucene/core/src/test/org/apache/lucene/util/bkd/TestMutablePointTreeReaderUtils.java
index 4419156..e25824a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestMutablePointsReaderUtils.java
+++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestMutablePointTreeReaderUtils.java
@@ -16,16 +16,16 @@
  */
 package org.apache.lucene.util.bkd;
 
-import java.io.IOException;
 import java.util.Arrays;
 import java.util.Comparator;
-import org.apache.lucene.codecs.MutablePointValues;
+import org.apache.lucene.codecs.MutablePointTree;
+import org.apache.lucene.index.PointValues;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
 import org.apache.lucene.util.TestUtil;
 
-public class TestMutablePointsReaderUtils extends LuceneTestCase {
+public class TestMutablePointTreeReaderUtils extends LuceneTestCase {
 
   public void testSort() {
     for (int iter = 0; iter < 10; ++iter) {
@@ -45,7 +45,7 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
     BKDConfig config = new BKDConfig(1, 1, bytesPerDim, BKDConfig.DEFAULT_MAX_POINTS_IN_LEAF_NODE);
     Point[] points = createRandomPoints(config, maxDoc, new int[1], isDocIdIncremental);
     DummyPointsReader reader = new DummyPointsReader(points);
-    MutablePointsReaderUtils.sort(config, maxDoc, reader, 0, points.length);
+    MutablePointTreeReaderUtils.sort(config, maxDoc, reader, 0, points.length);
     Arrays.sort(
         points,
         new Comparator<Point>() {
@@ -91,7 +91,7 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
     Point[] points = createRandomPoints(config, maxDoc, commonPrefixLengths, false);
     DummyPointsReader reader = new DummyPointsReader(points);
     final int sortedDim = random().nextInt(config.numIndexDims);
-    MutablePointsReaderUtils.sortByDim(
+    MutablePointTreeReaderUtils.sortByDim(
         config,
         sortedDim,
         commonPrefixLengths,
@@ -145,7 +145,7 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
     final int splitDim = random().nextInt(config.numIndexDims);
     DummyPointsReader reader = new DummyPointsReader(points);
     final int pivot = TestUtil.nextInt(random(), 0, points.length - 1);
-    MutablePointsReaderUtils.partition(
+    MutablePointTreeReaderUtils.partition(
         config,
         maxDoc,
         splitDim,
@@ -303,7 +303,7 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
     }
   }
 
-  private static class DummyPointsReader extends MutablePointValues {
+  private static class DummyPointsReader extends MutablePointTree {
 
     private final Point[] points;
 
@@ -337,51 +337,6 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
     }
 
     @Override
-    public void intersect(IntersectVisitor visitor) throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public long estimatePointCount(IntersectVisitor visitor) {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public byte[] getMinPackedValue() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public byte[] getMaxPackedValue() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public int getNumDimensions() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public int getNumIndexDimensions() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public int getBytesPerDimension() throws IOException {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public long size() {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
-    public int getDocCount() {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
     public void save(int i, int j) {
       if (temp == null) {
         temp = new Point[points.length];
@@ -395,5 +350,15 @@ public class TestMutablePointsReaderUtils extends LuceneTestCase {
         System.arraycopy(temp, i, points, i, j - i);
       }
     }
+
+    @Override
+    public long size() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public void visitDocValues(PointValues.IntersectVisitor visitor) {
+      throw new UnsupportedOperationException();
+    }
   }
 }
diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
index 5d67edb..cbc2114 100644
--- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
+++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java
@@ -1672,18 +1672,61 @@ public class MemoryIndex {
       }
 
       @Override
-      public void intersect(IntersectVisitor visitor) throws IOException {
-        BytesRef[] values = info.pointValues;
+      public PointTree getPointTree() {
+        return new PointTree() {
+          @Override
+          public PointTree clone() {
+            return this;
+          }
 
-        visitor.grow(info.pointValuesCount);
-        for (int i = 0; i < info.pointValuesCount; i++) {
-          visitor.visit(0, values[i].bytes);
-        }
-      }
+          @Override
+          public boolean moveToChild() {
+            return false;
+          }
 
-      @Override
-      public long estimatePointCount(IntersectVisitor visitor) {
-        return 1L;
+          @Override
+          public boolean moveToSibling() {
+            return false;
+          }
+
+          @Override
+          public boolean moveToParent() {
+            return false;
+          }
+
+          @Override
+          public byte[] getMinPackedValue() {
+            return info.minPackedValue;
+          }
+
+          @Override
+          public byte[] getMaxPackedValue() {
+            return info.maxPackedValue;
+          }
+
+          @Override
+          public long size() {
+            return info.pointValuesCount;
+          }
+
+          @Override
+          public void visitDocIDs(IntersectVisitor visitor) throws IOException {
+            visitor.grow(info.pointValuesCount);
+            for (int i = 0; i < info.pointValuesCount; i++) {
+              visitor.visit(0);
+            }
+          }
+
+          @Override
+          public void visitDocValues(IntersectVisitor visitor) throws IOException {
+            BytesRef[] values = info.pointValues;
+
+            visitor.grow(info.pointValuesCount);
+            for (int i = 0; i < info.pointValuesCount; i++) {
+              visitor.visit(0, values[i].bytes);
+            }
+          }
+        };
       }
 
       @Override
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/document/FloatPointNearestNeighbor.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/document/FloatPointNearestNeighbor.java
index 675e410..0e0922f 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/document/FloatPointNearestNeighbor.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/document/FloatPointNearestNeighbor.java
@@ -24,14 +24,13 @@ import java.util.PriorityQueue;
 import org.apache.lucene.document.FloatPoint;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.index.PointValues;
+import org.apache.lucene.index.PointValues.PointTree;
 import org.apache.lucene.search.FieldDoc;
 import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.TopFieldDocs;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.util.Bits;
-import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.bkd.BKDReader;
 
 /**
  * KNN search on top of N dimensional indexed float points.
@@ -44,12 +43,12 @@ public class FloatPointNearestNeighbor {
     final int readerIndex;
     final byte[] minPacked;
     final byte[] maxPacked;
-    final BKDReader.IndexTree index;
+    final PointTree index;
     /** The closest possible distance^2 of all points in this cell */
     final double distanceSquared;
 
     Cell(
-        BKDReader.IndexTree index,
+        PointTree index,
         int readerIndex,
         byte[] minPacked,
         byte[] maxPacked,
@@ -70,10 +69,8 @@ public class FloatPointNearestNeighbor {
     public String toString() {
       return "Cell(readerIndex="
           + readerIndex
-          + " nodeID="
-          + index.getNodeID()
-          + " isLeaf="
-          + index.isLeafNode()
+          + " "
+          + index.toString()
           + " distanceSquared="
           + distanceSquared
           + ")";
@@ -176,7 +173,7 @@ public class FloatPointNearestNeighbor {
   }
 
   private static NearestHit[] nearest(
-      List<BKDReader> readers,
+      List<PointValues> readers,
       List<Bits> liveDocs,
       List<Integer> docBases,
       final int topN,
@@ -201,31 +198,17 @@ public class FloatPointNearestNeighbor {
     PriorityQueue<Cell> cellQueue = new PriorityQueue<>();
 
     NearestVisitor visitor = new NearestVisitor(hitQueue, topN, origin);
-    List<BKDReader.IntersectState> states = new ArrayList<>();
 
     // Add root cell for each reader into the queue:
-    int bytesPerDim = -1;
-
     for (int i = 0; i < readers.size(); ++i) {
-      BKDReader reader = readers.get(i);
-      if (bytesPerDim == -1) {
-        bytesPerDim = reader.getBytesPerDimension();
-      } else if (bytesPerDim != reader.getBytesPerDimension()) {
-        throw new IllegalStateException(
-            "bytesPerDim changed from "
-                + bytesPerDim
-                + " to "
-                + reader.getBytesPerDimension()
-                + " across readers");
-      }
+      PointValues reader = readers.get(i);
       byte[] minPackedValue = reader.getMinPackedValue();
       byte[] maxPackedValue = reader.getMaxPackedValue();
-      BKDReader.IntersectState state = reader.getIntersectState(visitor);
-      states.add(state);
+      PointTree indexTree = reader.getPointTree();
 
       cellQueue.offer(
           new Cell(
-              state.index,
+              indexTree,
               i,
               reader.getMinPackedValue(),
               reader.getMaxPackedValue(),
@@ -240,57 +223,47 @@ public class FloatPointNearestNeighbor {
         break;
       }
 
-      BKDReader reader = readers.get(cell.readerIndex);
-      if (cell.index.isLeafNode()) {
+      if (cell.index.moveToChild() == false) {
         // System.out.println("    leaf");
         // Leaf block: visit all points and possibly collect them:
         visitor.curDocBase = docBases.get(cell.readerIndex);
         visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
-        reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));
+        cell.index.visitDocValues(visitor);
+        // reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));
 
         // assert hitQueue.peek().distanceSquared >= cell.distanceSquared;
         // System.out.println("    now " + hitQueue.size() + " hits");
       } else {
-        // System.out.println("    non-leaf");
-        // Non-leaf block: split into two cells and put them back into the queue:
-
-        BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue());
-        int splitDim = cell.index.getSplitDim();
 
         // we must clone the index so that we we can recurse left and right "concurrently":
-        BKDReader.IndexTree newIndex = cell.index.clone();
-        byte[] splitPackedValue = cell.maxPacked.clone();
-        System.arraycopy(
-            splitValue.bytes,
-            splitValue.offset,
-            splitPackedValue,
-            splitDim * bytesPerDim,
-            bytesPerDim);
-
-        cell.index.pushLeft();
+        PointTree newIndex = cell.index.clone();
+
         double distanceLeft =
-            pointToRectangleDistanceSquared(cell.minPacked, splitPackedValue, origin);
+            pointToRectangleDistanceSquared(
+                newIndex.getMinPackedValue(), newIndex.getMaxPackedValue(), origin);
         if (distanceLeft <= visitor.bottomNearestDistanceSquared) {
           cellQueue.offer(
               new Cell(
-                  cell.index, cell.readerIndex, cell.minPacked, splitPackedValue, distanceLeft));
+                  newIndex,
+                  cell.readerIndex,
+                  newIndex.getMinPackedValue(),
+                  newIndex.getMaxPackedValue(),
+                  distanceLeft));
         }
 
-        splitPackedValue = cell.minPacked.clone();
-        System.arraycopy(
-            splitValue.bytes,
-            splitValue.offset,
-            splitPackedValue,
-            splitDim * bytesPerDim,
-            bytesPerDim);
-
-        newIndex.pushRight();
-        double distanceRight =
-            pointToRectangleDistanceSquared(splitPackedValue, cell.maxPacked, origin);
-        if (distanceRight <= visitor.bottomNearestDistanceSquared) {
-          cellQueue.offer(
-              new Cell(
-                  newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked, distanceRight));
+        if (cell.index.moveToSibling()) {
+          double distanceRight =
+              pointToRectangleDistanceSquared(
+                  cell.index.getMinPackedValue(), cell.index.getMaxPackedValue(), origin);
+          if (distanceRight <= visitor.bottomNearestDistanceSquared) {
+            cellQueue.offer(
+                new Cell(
+                    cell.index,
+                    cell.readerIndex,
+                    cell.index.getMinPackedValue(),
+                    cell.index.getMaxPackedValue(),
+                    distanceRight));
+          }
         }
       }
     }
@@ -335,19 +308,15 @@ public class FloatPointNearestNeighbor {
     if (searcher == null) {
       throw new IllegalArgumentException("searcher must not be null");
     }
-    List<BKDReader> readers = new ArrayList<>();
+    List<PointValues> readers = new ArrayList<>();
     List<Integer> docBases = new ArrayList<>();
     List<Bits> liveDocs = new ArrayList<>();
     int totalHits = 0;
     for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
       PointValues points = leaf.reader().getPointValues(field);
       if (points != null) {
-        if (points instanceof BKDReader == false) {
-          throw new IllegalArgumentException(
-              "can only run on Lucene60PointsReader points implementation, but got " + points);
-        }
         totalHits += points.getDocCount();
-        readers.add((BKDReader) points);
+        readers.add(points);
         docBases.add(leaf.docBase);
         liveDocs.add(leaf.reader().getLiveDocs());
       }
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/LatLonPointPrototypeQueries.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/LatLonPointPrototypeQueries.java
index e7f1eb6..38cf543 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/LatLonPointPrototypeQueries.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/LatLonPointPrototypeQueries.java
@@ -33,7 +33,6 @@ import org.apache.lucene.search.TopFieldDocs;
 import org.apache.lucene.search.TotalHits;
 import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.SloppyMath;
-import org.apache.lucene.util.bkd.BKDReader;
 
 /**
  * Holder class for prototype sandboxed queries
@@ -89,20 +88,15 @@ public class LatLonPointPrototypeQueries {
     if (searcher == null) {
       throw new IllegalArgumentException("searcher must not be null");
     }
-    List<BKDReader> readers = new ArrayList<>();
+    List<PointValues> readers = new ArrayList<>();
     List<Integer> docBases = new ArrayList<>();
     List<Bits> liveDocs = new ArrayList<>();
     int totalHits = 0;
     for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
       PointValues points = leaf.reader().getPointValues(field);
       if (points != null) {
-        if (points instanceof BKDReader == false) {
-          throw new IllegalArgumentException(
-              "can only run on Lucene60PointsReader points implementation, but got " + points);
-        }
         totalHits += points.getDocCount();
-        BKDReader reader = (BKDReader) points;
-        readers.add(reader);
+        readers.add(points);
         docBases.add(leaf.docBase);
         liveDocs.add(leaf.reader().getLiveDocs());
       }
diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/NearestNeighbor.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/NearestNeighbor.java
index 1059b44..a3fe984 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/NearestNeighbor.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/NearestNeighbor.java
@@ -20,19 +20,16 @@ import static org.apache.lucene.geo.GeoEncodingUtils.decodeLatitude;
 import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
 import java.util.PriorityQueue;
 import org.apache.lucene.geo.Rectangle;
+import org.apache.lucene.index.PointValues;
 import org.apache.lucene.index.PointValues.IntersectVisitor;
+import org.apache.lucene.index.PointValues.PointTree;
 import org.apache.lucene.index.PointValues.Relation;
 import org.apache.lucene.util.Bits;
-import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.SloppyMath;
-import org.apache.lucene.util.bkd.BKDReader;
-import org.apache.lucene.util.bkd.BKDReader.IndexTree;
-import org.apache.lucene.util.bkd.BKDReader.IntersectState;
 
 /**
  * KNN search on top of 2D lat/lon indexed points.
@@ -45,7 +42,7 @@ class NearestNeighbor {
     final int readerIndex;
     final byte[] minPacked;
     final byte[] maxPacked;
-    final IndexTree index;
+    final PointTree index;
 
     /**
      * The closest distance from a point in this cell to the query point, computed as a sort key
@@ -55,7 +52,7 @@ class NearestNeighbor {
     final double distanceSortKey;
 
     public Cell(
-        IndexTree index,
+        PointTree index,
         int readerIndex,
         byte[] minPacked,
         byte[] maxPacked,
@@ -80,10 +77,8 @@ class NearestNeighbor {
       double maxLon = decodeLongitude(maxPacked, Integer.BYTES);
       return "Cell(readerIndex="
           + readerIndex
-          + " nodeID="
-          + index.getNodeID()
-          + " isLeaf="
-          + index.isLeafNode()
+          + " "
+          + index.toString()
           + " lat="
           + minLat
           + " TO "
@@ -247,7 +242,7 @@ class NearestNeighbor {
   public static NearestHit[] nearest(
       double pointLat,
       double pointLon,
-      List<BKDReader> readers,
+      List<PointValues> readers,
       List<Bits> liveDocs,
       List<Integer> docBases,
       final int n)
@@ -278,31 +273,17 @@ class NearestNeighbor {
     PriorityQueue<Cell> cellQueue = new PriorityQueue<>();
 
     NearestVisitor visitor = new NearestVisitor(hitQueue, n, pointLat, pointLon);
-    List<BKDReader.IntersectState> states = new ArrayList<>();
 
     // Add root cell for each reader into the queue:
-    int bytesPerDim = -1;
-
     for (int i = 0; i < readers.size(); i++) {
-      BKDReader reader = readers.get(i);
-      if (bytesPerDim == -1) {
-        bytesPerDim = reader.getBytesPerDimension();
-      } else if (bytesPerDim != reader.getBytesPerDimension()) {
-        throw new IllegalStateException(
-            "bytesPerDim changed from "
-                + bytesPerDim
-                + " to "
-                + reader.getBytesPerDimension()
-                + " across readers");
-      }
+      PointValues reader = readers.get(i);
       byte[] minPackedValue = reader.getMinPackedValue();
       byte[] maxPackedValue = reader.getMaxPackedValue();
-      IntersectState state = reader.getIntersectState(visitor);
-      states.add(state);
+      PointTree indexTree = reader.getPointTree();
 
       cellQueue.offer(
           new Cell(
-              state.index,
+              indexTree,
               i,
               reader.getMinPackedValue(),
               reader.getMaxPackedValue(),
@@ -312,64 +293,53 @@ class NearestNeighbor {
     while (cellQueue.size() > 0) {
       Cell cell = cellQueue.poll();
       // System.out.println("  visit " + cell);
+      if (visitor.compare(cell.minPacked, cell.maxPacked) == Relation.CELL_OUTSIDE_QUERY) {
+        continue;
+      }
 
       // TODO: if we replace approxBestDistance with actualBestDistance, we can put an opto here to
       // break once this "best" cell is fully outside of the hitQueue bottom's radius:
-      BKDReader reader = readers.get(cell.readerIndex);
 
-      if (cell.index.isLeafNode()) {
+      if (cell.index.moveToChild() == false) {
         // System.out.println("    leaf");
         // Leaf block: visit all points and possibly collect them:
         visitor.curDocBase = docBases.get(cell.readerIndex);
         visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
-        reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));
+        cell.index.visitDocValues(visitor);
         // System.out.println("    now " + hitQueue.size() + " hits");
       } else {
         // System.out.println("    non-leaf");
         // Non-leaf block: split into two cells and put them back into the queue:
 
-        if (visitor.compare(cell.minPacked, cell.maxPacked) == Relation.CELL_OUTSIDE_QUERY) {
-          continue;
-        }
-
-        BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue());
-        int splitDim = cell.index.getSplitDim();
+        // we must clone the index so that we can recurse left and right "concurrently":
+        PointTree newIndex = cell.index.clone();
 
-        // we must clone the index so that we we can recurse left and right "concurrently":
-        IndexTree newIndex = cell.index.clone();
-        byte[] splitPackedValue = cell.maxPacked.clone();
-        System.arraycopy(
-            splitValue.bytes,
-            splitValue.offset,
-            splitPackedValue,
-            splitDim * bytesPerDim,
-            bytesPerDim);
-
-        cell.index.pushLeft();
-        cellQueue.offer(
-            new Cell(
-                cell.index,
-                cell.readerIndex,
-                cell.minPacked,
-                splitPackedValue,
-                approxBestDistance(cell.minPacked, splitPackedValue, pointLat, pointLon)));
-
-        splitPackedValue = cell.minPacked.clone();
-        System.arraycopy(
-            splitValue.bytes,
-            splitValue.offset,
-            splitPackedValue,
-            splitDim * bytesPerDim,
-            bytesPerDim);
-
-        newIndex.pushRight();
         cellQueue.offer(
             new Cell(
                 newIndex,
                 cell.readerIndex,
-                splitPackedValue,
-                cell.maxPacked,
-                approxBestDistance(splitPackedValue, cell.maxPacked, pointLat, pointLon)));
+                newIndex.getMinPackedValue(),
+                newIndex.getMaxPackedValue(),
+                approxBestDistance(
+                    newIndex.getMinPackedValue(),
+                    newIndex.getMaxPackedValue(),
+                    pointLat,
+                    pointLon)));
+
+        // TODO: we are assuming a binary tree
+        if (cell.index.moveToSibling()) {
+          cellQueue.offer(
+              new Cell(
+                  cell.index,
+                  cell.readerIndex,
+                  cell.index.getMinPackedValue(),
+                  cell.index.getMaxPackedValue(),
+                  approxBestDistance(
+                      cell.index.getMinPackedValue(),
+                      cell.index.getMaxPackedValue(),
+                      pointLat,
+                      pointLon)));
+        }
       }
     }
 
diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/cranky/CrankyPointsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/cranky/CrankyPointsFormat.java
index 56a556f..f368250 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/codecs/cranky/CrankyPointsFormat.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/cranky/CrankyPointsFormat.java
@@ -123,19 +123,66 @@ class CrankyPointsFormat extends PointsFormat {
       return new PointValues() {
 
         @Override
-        public void intersect(IntersectVisitor visitor) throws IOException {
-          if (random.nextInt(100) == 0) {
-            throw new IOException("Fake IOException");
-          }
-          delegate.intersect(visitor);
-          if (random.nextInt(100) == 0) {
-            throw new IOException("Fake IOException");
-          }
-        }
-
-        @Override
-        public long estimatePointCount(IntersectVisitor visitor) {
-          return delegate.estimatePointCount(visitor);
+        public PointTree getPointTree() throws IOException {
+          PointTree pointTree = delegate.getPointTree();
+          return new PointTree() {
+            @Override
+            public PointTree clone() {
+              return pointTree.clone();
+            }
+
+            @Override
+            public boolean moveToChild() throws IOException {
+              return pointTree.moveToChild();
+            }
+
+            @Override
+            public boolean moveToSibling() throws IOException {
+              return pointTree.moveToSibling();
+            }
+
+            @Override
+            public boolean moveToParent() throws IOException {
+              return pointTree.moveToParent();
+            }
+
+            @Override
+            public byte[] getMinPackedValue() {
+              return pointTree.getMinPackedValue();
+            }
+
+            @Override
+            public byte[] getMaxPackedValue() {
+              return pointTree.getMaxPackedValue();
+            }
+
+            @Override
+            public long size() {
+              return pointTree.size();
+            }
+
+            @Override
+            public void visitDocIDs(IntersectVisitor visitor) throws IOException {
+              if (random.nextInt(100) == 0) {
+                throw new IOException("Fake IOException");
+              }
+              pointTree.visitDocIDs(visitor);
+              if (random.nextInt(100) == 0) {
+                throw new IOException("Fake IOException");
+              }
+            }
+
+            @Override
+            public void visitDocValues(IntersectVisitor visitor) throws IOException {
+              if (random.nextInt(100) == 0) {
+                throw new IOException("Fake IOException");
+              }
+              pointTree.visitDocValues(visitor);
+              if (random.nextInt(100) == 0) {
+                throw new IOException("Fake IOException");
+              }
+            }
+          };
         }
 
         @Override
diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/AssertingLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/index/AssertingLeafReader.java
index 636603f..168b241 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/index/AssertingLeafReader.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/index/AssertingLeafReader.java
@@ -1091,22 +1091,9 @@ public class AssertingLeafReader extends FilterLeafReader {
     }
 
     @Override
-    public void intersect(IntersectVisitor visitor) throws IOException {
+    public PointTree getPointTree() throws IOException {
       assertThread("Points", creationThread);
-      in.intersect(
-          new AssertingIntersectVisitor(
-              in.getNumDimensions(),
-              in.getNumIndexDimensions(),
-              in.getBytesPerDimension(),
-              visitor));
-    }
-
-    @Override
-    public long estimatePointCount(IntersectVisitor visitor) {
-      assertThread("Points", creationThread);
-      long cost = in.estimatePointCount(visitor);
-      assert cost >= 0;
-      return cost;
+      return new AssertingPointTree(in, in.getPointTree());
     }
 
     @Override
@@ -1152,6 +1139,80 @@ public class AssertingLeafReader extends FilterLeafReader {
     }
   }
 
+  /** Validates that we don't call moveToChild() or clone() after having called moveToParent() */
+  static class AssertingPointTree implements PointValues.PointTree {
+
+    final PointValues pointValues;
+    final PointValues.PointTree in;
+    private boolean moveToParent;
+
+    AssertingPointTree(PointValues pointValues, PointValues.PointTree in) {
+      this.pointValues = pointValues;
+      this.in = in;
+    }
+
+    @Override
+    public PointValues.PointTree clone() {
+      assert moveToParent == false : "calling clone() after calling moveToParent()";
+      return new AssertingPointTree(pointValues, in.clone());
+    }
+
+    @Override
+    public boolean moveToChild() throws IOException {
+      assert moveToParent == false : "calling moveToChild() after calling moveToParent()";
+      return in.moveToChild();
+    }
+
+    @Override
+    public boolean moveToSibling() throws IOException {
+      moveToParent = false;
+      return in.moveToSibling();
+    }
+
+    @Override
+    public boolean moveToParent() throws IOException {
+      moveToParent = true;
+      return in.moveToParent();
+    }
+
+    @Override
+    public byte[] getMinPackedValue() {
+      return in.getMinPackedValue();
+    }
+
+    @Override
+    public byte[] getMaxPackedValue() {
+      return in.getMaxPackedValue();
+    }
+
+    @Override
+    public long size() {
+      final long size = in.size();
+      assert size > 0;
+      return size;
+    }
+
+    @Override
+    public void visitDocIDs(IntersectVisitor visitor) throws IOException {
+      in.visitDocIDs(
+          new AssertingIntersectVisitor(
+              pointValues.getNumDimensions(),
+              pointValues.getNumIndexDimensions(),
+              pointValues.getBytesPerDimension(),
+              visitor));
+    }
+
+    @Override
+    public void visitDocValues(IntersectVisitor visitor) throws IOException {
+      in.visitDocValues(
+          new AssertingIntersectVisitor(
+              pointValues.getNumDimensions(),
+              pointValues.getNumIndexDimensions(),
+              pointValues.getBytesPerDimension(),
+              visitor));
+    }
+  }
+
   /**
    * Validates in the 1D case that all points are visited in order, and point values are in bounds
    * of the last cell checked
@@ -1189,7 +1250,7 @@ public class AssertingLeafReader extends FilterLeafReader {
 
       // This method, not filtering each hit, should only be invoked when the cell is inside the
       // query shape:
-      assert lastCompareResult == Relation.CELL_INSIDE_QUERY;
+      assert lastCompareResult == null || lastCompareResult == Relation.CELL_INSIDE_QUERY;
       in.visit(docID);
     }
 
@@ -1199,28 +1260,32 @@ public class AssertingLeafReader extends FilterLeafReader {
 
       // This method, to filter each doc's value, should only be invoked when the cell crosses the
       // query shape:
-      assert lastCompareResult == PointValues.Relation.CELL_CROSSES_QUERY;
-
-      // This doc's packed value should be contained in the last cell passed to compare:
-      for (int dim = 0; dim < numIndexDims; dim++) {
-        assert Arrays.compareUnsigned(
-                    lastMinPackedValue,
-                    dim * bytesPerDim,
-                    dim * bytesPerDim + bytesPerDim,
-                    packedValue,
-                    dim * bytesPerDim,
-                    dim * bytesPerDim + bytesPerDim)
-                <= 0
-            : "dim=" + dim + " of " + numDataDims + " value=" + new BytesRef(packedValue);
-        assert Arrays.compareUnsigned(
-                    lastMaxPackedValue,
-                    dim * bytesPerDim,
-                    dim * bytesPerDim + bytesPerDim,
-                    packedValue,
-                    dim * bytesPerDim,
-                    dim * bytesPerDim + bytesPerDim)
-                >= 0
-            : "dim=" + dim + " of " + numDataDims + " value=" + new BytesRef(packedValue);
+      assert lastCompareResult == null
+          || lastCompareResult == PointValues.Relation.CELL_CROSSES_QUERY;
+
+      if (lastCompareResult != null) {
+        // This doc's packed value should be contained in the last cell passed to compare:
+        for (int dim = 0; dim < numIndexDims; dim++) {
+          assert Arrays.compareUnsigned(
+                      lastMinPackedValue,
+                      dim * bytesPerDim,
+                      dim * bytesPerDim + bytesPerDim,
+                      packedValue,
+                      dim * bytesPerDim,
+                      dim * bytesPerDim + bytesPerDim)
+                  <= 0
+              : "dim=" + dim + " of " + numDataDims + " value=" + new BytesRef(packedValue);
+          assert Arrays.compareUnsigned(
+                      lastMaxPackedValue,
+                      dim * bytesPerDim,
+                      dim * bytesPerDim + bytesPerDim,
+                      packedValue,
+                      dim * bytesPerDim,
+                      dim * bytesPerDim + bytesPerDim)
+                  >= 0
+              : "dim=" + dim + " of " + numDataDims + " value=" + new BytesRef(packedValue);
+        }
+        lastCompareResult = null;
       }
 
       // TODO: we should assert that this "matches" whatever relation the last call to compare had
diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/BasePointsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/index/BasePointsFormatTestCase.java
index bd7a757..c6df558 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/index/BasePointsFormatTestCase.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/index/BasePointsFormatTestCase.java
@@ -1185,20 +1185,33 @@ public abstract class BasePointsFormatTestCase extends BaseIndexFileFormatTestCa
   }
 
   public void testDocCountEdgeCases() {
+    IntersectVisitor visitor =
+        new IntersectVisitor() {
+          @Override
+          public void visit(int docID) {}
+
+          @Override
+          public void visit(int docID, byte[] packedValue) {}
+
+          @Override
+          public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+            return Relation.CELL_INSIDE_QUERY;
+          }
+        };
     PointValues values = getPointValues(Long.MAX_VALUE, 1, Long.MAX_VALUE);
-    long docs = values.estimateDocCount(null);
+    long docs = values.estimateDocCount(visitor);
     assertEquals(1, docs);
     values = getPointValues(Long.MAX_VALUE, 1, 1);
-    docs = values.estimateDocCount(null);
+    docs = values.estimateDocCount(visitor);
     assertEquals(1, docs);
     values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE);
-    docs = values.estimateDocCount(null);
+    docs = values.estimateDocCount(visitor);
     assertEquals(Integer.MAX_VALUE, docs);
     values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, Long.MAX_VALUE / 2);
-    docs = values.estimateDocCount(null);
+    docs = values.estimateDocCount(visitor);
     assertEquals(Integer.MAX_VALUE, docs);
     values = getPointValues(Long.MAX_VALUE, Integer.MAX_VALUE, 1);
-    docs = values.estimateDocCount(null);
+    docs = values.estimateDocCount(visitor);
     assertEquals(1, docs);
   }
 
@@ -1209,7 +1222,20 @@ public abstract class BasePointsFormatTestCase extends BaseIndexFileFormatTestCa
       int docCount = TestUtil.nextInt(random(), 1, maxDoc);
       long estimatedPointCount = TestUtil.nextLong(random(), 0, size);
       PointValues values = getPointValues(size, docCount, estimatedPointCount);
-      long docs = values.estimateDocCount(null);
+      long docs =
+          values.estimateDocCount(
+              new IntersectVisitor() {
+                @Override
+                public void visit(int docID) {}
+
+                @Override
+                public void visit(int docID, byte[] packedValue) {}
+
+                @Override
+                public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+                  return Relation.CELL_INSIDE_QUERY;
+                }
+              });
       assertTrue(docs <= estimatedPointCount);
       assertTrue(docs <= maxDoc);
       assertTrue(docs >= estimatedPointCount / (size / docCount));
@@ -1219,13 +1245,54 @@ public abstract class BasePointsFormatTestCase extends BaseIndexFileFormatTestCa
   private PointValues getPointValues(long size, int docCount, long estimatedPointCount) {
     return new PointValues() {
       @Override
-      public void intersect(IntersectVisitor visitor) {
-        throw new UnsupportedOperationException();
-      }
+      public PointTree getPointTree() {
+        return new PointTree() {
 
-      @Override
-      public long estimatePointCount(IntersectVisitor visitor) {
-        return estimatedPointCount;
+          @Override
+          public PointTree clone() {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public boolean moveToChild() {
+            return false;
+          }
+
+          @Override
+          public boolean moveToSibling() {
+            return false;
+          }
+
+          @Override
+          public boolean moveToParent() {
+            return false;
+          }
+
+          @Override
+          public byte[] getMinPackedValue() {
+            return new byte[0];
+          }
+
+          @Override
+          public byte[] getMaxPackedValue() {
+            return new byte[0];
+          }
+
+          @Override
+          public long size() {
+            return estimatedPointCount;
+          }
+
+          @Override
+          public void visitDocIDs(IntersectVisitor visitor) {
+            throw new UnsupportedOperationException();
+          }
+
+          @Override
+          public void visitDocValues(IntersectVisitor visitor) {
+            throw new UnsupportedOperationException();
+          }
+        };
       }
 
       @Override
diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java b/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java
index 1bd2656..8f71767 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java
@@ -106,7 +106,7 @@ public class RandomCodec extends AssertingCodec {
               @Override
               public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
-                PointValues values = reader.getValues(fieldInfo.name);
+                PointValues.PointTree values = reader.getValues(fieldInfo.name).getPointTree();
 
                 BKDConfig config =
                     new BKDConfig(
@@ -124,7 +124,7 @@ public class RandomCodec extends AssertingCodec {
                         maxMBSortInHeap,
                         values.size(),
                         bkdSplitRandomSeed ^ fieldInfo.name.hashCode())) {
-                  values.intersect(
+                  values.visitDocValues(
                       new IntersectVisitor() {
                         @Override
                         public void visit(int docID) {