You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by iv...@apache.org on 2019/02/07 07:19:27 UTC

[lucene-solr] branch branch_8x updated: LUCENE-8673: Use radix partitioning when merging dimensional points instead of sorting all dimensions before hand.

This is an automated email from the ASF dual-hosted git repository.

ivera pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/branch_8x by this push:
     new a843b3a  LUCENE-8673: Use radix partitioning when merging dimensional points instead of sorting all dimensions before hand.
a843b3a is described below

commit a843b3aa35fba2bab82caa2fbdcbcefa61fd40cd
Author: iverase <iv...@apache.org>
AuthorDate: Thu Feb 7 08:12:13 2019 +0100

    LUCENE-8673: Use radix partitioning when merging dimensional points instead of sorting all dimensions before hand.
---
 lucene/CHANGES.txt                                 |   8 +
 .../codecs/simpletext/SimpleTextBKDWriter.java     | 492 ++++-------------
 .../codecs/simpletext/SimpleTextPointsWriter.java  |   4 +-
 .../codecs/lucene60/Lucene60PointsWriter.java      |  10 +-
 .../apache/lucene/util/bkd/BKDRadixSelector.java   | 311 +++++++++++
 .../java/org/apache/lucene/util/bkd/BKDWriter.java | 610 +++++----------------
 .../apache/lucene/util/bkd/HeapPointReader.java    |  53 +-
 .../apache/lucene/util/bkd/HeapPointWriter.java    | 118 ++--
 .../apache/lucene/util/bkd/OfflinePointReader.java | 211 ++-----
 .../apache/lucene/util/bkd/OfflinePointWriter.java |  81 +--
 .../org/apache/lucene/util/bkd/PointReader.java    |  50 +-
 .../org/apache/lucene/util/bkd/PointWriter.java    |  18 +-
 .../apache/lucene/util/bkd/Test2BBKDPoints.java    |   4 +-
 .../test/org/apache/lucene/util/bkd/TestBKD.java   |  28 +-
 .../lucene/util/bkd/TestBKDRadixSelector.java      | 309 +++++++++++
 .../java/org/apache/lucene/index/RandomCodec.java  |  10 +-
 16 files changed, 1060 insertions(+), 1257 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 453779e..2504d79 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -3,6 +3,14 @@ Lucene Change Log
 For more information on past and future Lucene versions, please see:
 http://s.apache.org/luceneversions
 
+
+======================= Lucene 8.1.0 =======================
+
+Improvements
+
+* LUCENE-8673: Use radix partitioning when merging dimensional points instead
+  of sorting all dimensions before hand. (Ignacio Vera, Adrien Grand)
+
 ======================= Lucene 8.0.0 =======================
 
 API Changes
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
index f45209c..a6276ea 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextBKDWriter.java
@@ -20,7 +20,6 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Comparator;
 import java.util.List;
 import java.util.function.IntFunction;
 
@@ -36,18 +35,15 @@ import org.apache.lucene.store.TrackingDirectoryWrapper;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRefBuilder;
-import org.apache.lucene.util.BytesRefComparator;
 import org.apache.lucene.util.FixedBitSet;
 import org.apache.lucene.util.FutureArrays;
 import org.apache.lucene.util.IOUtils;
-import org.apache.lucene.util.LongBitSet;
 import org.apache.lucene.util.MSBRadixSorter;
 import org.apache.lucene.util.NumericUtils;
-import org.apache.lucene.util.OfflineSorter;
+import org.apache.lucene.util.bkd.BKDRadixSelector;
 import org.apache.lucene.util.bkd.BKDWriter;
 import org.apache.lucene.util.bkd.HeapPointWriter;
 import org.apache.lucene.util.bkd.MutablePointsReaderUtils;
-import org.apache.lucene.util.bkd.OfflinePointReader;
 import org.apache.lucene.util.bkd.OfflinePointWriter;
 import org.apache.lucene.util.bkd.PointReader;
 import org.apache.lucene.util.bkd.PointWriter;
@@ -148,32 +144,15 @@ final class SimpleTextBKDWriter implements Closeable {
 
   protected long pointCount;
 
-  /** true if we have so many values that we must write ords using long (8 bytes) instead of int (4 bytes) */
-  protected final boolean longOrds;
-
   /** An upper bound on how many points the caller will add (includes deletions) */
   private final long totalPointCount;
 
-  /** True if every document has at most one value.  We specialize this case by not bothering to store the ord since it's redundant with docID.  */
-  protected final boolean singleValuePerDoc;
-
-  /** How much heap OfflineSorter is allowed to use */
-  protected final OfflineSorter.BufferSize offlineSorterBufferMB;
-
-  /** How much heap OfflineSorter is allowed to use */
-  protected final int offlineSorterMaxTempFiles;
 
   private final int maxDoc;
 
-  public SimpleTextBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim,
-                             int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount, boolean singleValuePerDoc) throws IOException {
-    this(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount, singleValuePerDoc,
-         totalPointCount > Integer.MAX_VALUE, Math.max(1, (long) maxMBSortInHeap), OfflineSorter.MAX_TEMPFILES);
-  }
 
-  private SimpleTextBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim,
-                              int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount,
-                              boolean singleValuePerDoc, boolean longOrds, long offlineSorterBufferMB, int offlineSorterMaxTempFiles) throws IOException {
+  public SimpleTextBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim,
+                              int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount) throws IOException {
     verifyParams(numDataDims, numIndexDims, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount);
     // We use tracking dir to deal with removing files on exception, so each place that
     // creates temp files doesn't need crazy try/finally/sucess logic:
@@ -185,8 +164,6 @@ final class SimpleTextBKDWriter implements Closeable {
     this.bytesPerDim = bytesPerDim;
     this.totalPointCount = totalPointCount;
     this.maxDoc = maxDoc;
-    this.offlineSorterBufferMB = OfflineSorter.BufferSize.megabytes(offlineSorterBufferMB);
-    this.offlineSorterMaxTempFiles = offlineSorterMaxTempFiles;
     docsSeen = new FixedBitSet(maxDoc);
     packedBytesLength = numDataDims * bytesPerDim;
     packedIndexBytesLength = numIndexDims * bytesPerDim;
@@ -199,21 +176,8 @@ final class SimpleTextBKDWriter implements Closeable {
     minPackedValue = new byte[packedIndexBytesLength];
     maxPackedValue = new byte[packedIndexBytesLength];
 
-    // If we may have more than 1+Integer.MAX_VALUE values, then we must encode ords with long (8 bytes), else we can use int (4 bytes).
-    this.longOrds = longOrds;
-
-    this.singleValuePerDoc = singleValuePerDoc;
-
-    // dimensional values (numDims * bytesPerDim) + ord (int or long) + docID (int)
-    if (singleValuePerDoc) {
-      // Lucene only supports up to 2.1 docs, so we better not need longOrds in this case:
-      assert longOrds == false;
-      bytesPerDoc = packedBytesLength + Integer.BYTES;
-    } else if (longOrds) {
-      bytesPerDoc = packedBytesLength + Long.BYTES + Integer.BYTES;
-    } else {
-      bytesPerDoc = packedBytesLength + Integer.BYTES + Integer.BYTES;
-    }
+    // dimensional values (numDims * bytesPerDim) +  docID (int)
+    bytesPerDoc = packedBytesLength + Integer.BYTES;
 
     // As we recurse, we compute temporary partitions of the data, halving the
     // number of points at each recursion.  Once there are few enough points,
@@ -221,10 +185,10 @@ final class SimpleTextBKDWriter implements Closeable {
     // time in the recursion, we hold the number of points at that level, plus
     // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X
     // what that level would consume, so we multiply by 0.5 to convert from
-    // bytes to points here.  Each dimension has its own sorted partition, so
-    // we must divide by numDims as wel.
+    // bytes to points here.  In addition the radix partitioning may sort on memory
+    // double of this size so we multiply by another 0.5.
 
-    maxPointsSortInHeap = (int) (0.5 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims));
+    maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims));
 
     // Finally, we must be able to hold at least the leaf node in heap during build:
     if (maxPointsSortInHeap < maxPointsInLeafNode) {
@@ -232,7 +196,7 @@ final class SimpleTextBKDWriter implements Closeable {
     }
 
     // We write first maxPointsSortInHeap in heap, then cutover to offline for additional points:
-    heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength, longOrds, singleValuePerDoc);
+    heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength);
 
     this.maxMBSortInHeap = maxMBSortInHeap;
   }
@@ -264,13 +228,14 @@ final class SimpleTextBKDWriter implements Closeable {
   private void spillToOffline() throws IOException {
 
     // For each .add we just append to this input file, then in .finish we sort this input and resursively build the tree:
-    offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, longOrds, "spill", 0, singleValuePerDoc);
+    offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "spill", 0);
     tempInput = offlinePointWriter.out;
     PointReader reader = heapPointWriter.getReader(0, pointCount);
     for(int i=0;i<pointCount;i++) {
       boolean hasNext = reader.next();
       assert hasNext;
-      offlinePointWriter.append(reader.packedValue(), i, heapPointWriter.docIDs[i]);
+      reader.packedValue(scratchBytesRef1);
+      offlinePointWriter.append(scratchBytesRef1, heapPointWriter.docIDs[i]);
     }
 
     heapPointWriter = null;
@@ -285,10 +250,10 @@ final class SimpleTextBKDWriter implements Closeable {
       if (offlinePointWriter == null) {
         spillToOffline();
       }
-      offlinePointWriter.append(packedValue, pointCount, docID);
+      offlinePointWriter.append(packedValue, docID);
     } else {
       // Not too many points added yet, continue using heap:
-      heapPointWriter.append(packedValue, pointCount, docID);
+      heapPointWriter.append(packedValue, docID);
     }
 
     // TODO: we could specialize for the 1D case:
@@ -613,7 +578,7 @@ final class SimpleTextBKDWriter implements Closeable {
 
   /** Sort the heap writer by the specified dim */
   private void sortHeapPointWriter(final HeapPointWriter writer, int dim) {
-    final int pointCount = Math.toIntExact(this.pointCount);
+    final int pointCount = Math.toIntExact(writer.count());
     // Tie-break by docID:
 
     // No need to tie break on ord, for the case where the same doc has the same value in a given dimension indexed more than once: it
@@ -637,131 +602,12 @@ final class SimpleTextBKDWriter implements Closeable {
 
       @Override
       protected void swap(int i, int j) {
-        int docID = writer.docIDs[i];
-        writer.docIDs[i] = writer.docIDs[j];
-        writer.docIDs[j] = docID;
-
-        if (singleValuePerDoc == false) {
-          if (longOrds) {
-            long ord = writer.ordsLong[i];
-            writer.ordsLong[i] = writer.ordsLong[j];
-            writer.ordsLong[j] = ord;
-          } else {
-            int ord = writer.ords[i];
-            writer.ords[i] = writer.ords[j];
-            writer.ords[j] = ord;
-          }
-        }
-
-        byte[] blockI = writer.blocks.get(i / writer.valuesPerBlock);
-        int indexI = (i % writer.valuesPerBlock) * packedBytesLength;
-        byte[] blockJ = writer.blocks.get(j / writer.valuesPerBlock);
-        int indexJ = (j % writer.valuesPerBlock) * packedBytesLength;
-
-        // scratch1 = values[i]
-        System.arraycopy(blockI, indexI, scratch1, 0, packedBytesLength);
-        // values[i] = values[j]
-        System.arraycopy(blockJ, indexJ, blockI, indexI, packedBytesLength);
-        // values[j] = scratch1
-        System.arraycopy(scratch1, 0, blockJ, indexJ, packedBytesLength);
+        writer.swap(i, j);
       }
 
     }.sort(0, pointCount);
   }
 
-  private PointWriter sort(int dim) throws IOException {
-    assert dim >= 0 && dim < numDataDims;
-
-    if (heapPointWriter != null) {
-
-      assert tempInput == null;
-
-      // We never spilled the incoming points to disk, so now we sort in heap:
-      HeapPointWriter sorted;
-
-      if (dim == 0) {
-        // First dim can re-use the current heap writer
-        sorted = heapPointWriter;
-      } else {
-        // Subsequent dims need a private copy
-        sorted = new HeapPointWriter((int) pointCount, (int) pointCount, packedBytesLength, longOrds, singleValuePerDoc);
-        sorted.copyFrom(heapPointWriter);
-      }
-
-      //long t0 = System.nanoTime();
-      sortHeapPointWriter(sorted, dim);
-      //long t1 = System.nanoTime();
-      //System.out.println("BKD: sort took " + ((t1-t0)/1000000.0) + " msec");
-
-      sorted.close();
-      return sorted;
-    } else {
-
-      // Offline sort:
-      assert tempInput != null;
-
-      final int offset = bytesPerDim * dim;
-
-      Comparator<BytesRef> cmp;
-      if (dim == numDataDims - 1) {
-        // in that case the bytes for the dimension and for the doc id are contiguous,
-        // so we don't need a branch
-        cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) {
-          @Override
-          protected int byteAt(BytesRef ref, int i) {
-            return ref.bytes[ref.offset + offset + i] & 0xff;
-          }
-        };
-      } else {
-        cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) {
-          @Override
-          protected int byteAt(BytesRef ref, int i) {
-            if (i < bytesPerDim) {
-              return ref.bytes[ref.offset + offset + i] & 0xff;
-            } else {
-              return ref.bytes[ref.offset + packedBytesLength + i - bytesPerDim] & 0xff;
-            }
-          }
-        };
-      }
-
-      OfflineSorter sorter = new OfflineSorter(tempDir, tempFileNamePrefix + "_bkd" + dim, cmp, offlineSorterBufferMB, offlineSorterMaxTempFiles, bytesPerDoc, null, 0) {
-
-          /** We write/read fixed-byte-width file that {@link OfflinePointReader} can read. */
-          @Override
-          protected ByteSequencesWriter getWriter(IndexOutput out, long count) {
-            return new ByteSequencesWriter(out) {
-              @Override
-              public void write(byte[] bytes, int off, int len) throws IOException {
-                assert len == bytesPerDoc: "len=" + len + " bytesPerDoc=" + bytesPerDoc;
-                out.writeBytes(bytes, off, len);
-              }
-            };
-          }
-
-          /** We write/read fixed-byte-width file that {@link OfflinePointReader} can read. */
-          @Override
-          protected ByteSequencesReader getReader(ChecksumIndexInput in, String name) throws IOException {
-            return new ByteSequencesReader(in, name) {
-              final BytesRef scratch = new BytesRef(new byte[bytesPerDoc]);
-              @Override
-              public BytesRef next() throws IOException {
-                if (in.getFilePointer() >= end) {
-                  return null;
-                }
-                in.readBytes(scratch.bytes, 0, bytesPerDoc);
-                return scratch;
-              }
-            };
-          }
-        };
-
-      String name = sorter.sort(tempInput.getName());
-
-      return new OfflinePointWriter(tempDir, name, packedBytesLength, pointCount, longOrds, singleValuePerDoc);
-    }
-  }
-
   private void checkMaxLeafNodeCount(int numLeaves) {
     if ((1+bytesPerDim) * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) {
       throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex");
@@ -779,25 +625,21 @@ final class SimpleTextBKDWriter implements Closeable {
       throw new IllegalStateException("already finished");
     }
 
+    PointWriter data;
+
     if (offlinePointWriter != null) {
       offlinePointWriter.close();
+      data = offlinePointWriter;
+      tempInput = null;
+    } else {
+      data = heapPointWriter;
+      heapPointWriter = null;
     }
 
     if (pointCount == 0) {
       throw new IllegalStateException("must index at least one point");
     }
 
-    LongBitSet ordBitSet;
-    if (numDataDims > 1) {
-      if (singleValuePerDoc) {
-        ordBitSet = new LongBitSet(maxDoc);
-      } else {
-        ordBitSet = new LongBitSet(pointCount);
-      }
-    } else {
-      ordBitSet = null;
-    }
-
     long countPerLeaf = pointCount;
     long innerNodeCount = 1;
 
@@ -822,39 +664,17 @@ final class SimpleTextBKDWriter implements Closeable {
     // Make sure the math above "worked":
     assert pointCount / numLeaves <= maxPointsInLeafNode: "pointCount=" + pointCount + " numLeaves=" + numLeaves + " maxPointsInLeafNode=" + maxPointsInLeafNode;
 
-    // Sort all docs once by each dimension:
-    PathSlice[] sortedPointWriters = new PathSlice[numDataDims];
-
-    // This is only used on exception; on normal code paths we close all files we opened:
-    List<Closeable> toCloseHeroically = new ArrayList<>();
+    //We re-use the selector so we do not need to create an object every time.
+    BKDRadixSelector radixSelector = new BKDRadixSelector(numDataDims, bytesPerDim, maxPointsSortInHeap, tempDir, tempFileNamePrefix);
 
     boolean success = false;
     try {
-      //long t0 = System.nanoTime();
-      for(int dim=0;dim<numDataDims;dim++) {
-        sortedPointWriters[dim] = new PathSlice(sort(dim), 0, pointCount);
-      }
-      //long t1 = System.nanoTime();
-      //System.out.println("sort time: " + ((t1-t0)/1000000.0) + " msec");
 
-      if (tempInput != null) {
-        tempDir.deleteFile(tempInput.getName());
-        tempInput = null;
-      } else {
-        assert heapPointWriter != null;
-        heapPointWriter = null;
-      }
 
-      build(1, numLeaves, sortedPointWriters,
-            ordBitSet, out,
-            minPackedValue, maxPackedValue,
-            splitPackedValues,
-            leafBlockFPs,
-            toCloseHeroically);
+      build(1, numLeaves, data, out,
+          radixSelector, minPackedValue, maxPackedValue,
+            splitPackedValues, leafBlockFPs);
 
-      for(PathSlice slice : sortedPointWriters) {
-        slice.writer.destroy();
-      }
 
       // If no exception, we should have cleaned everything up:
       assert tempDir.getCreatedFiles().isEmpty();
@@ -865,7 +685,6 @@ final class SimpleTextBKDWriter implements Closeable {
     } finally {
       if (success == false) {
         IOUtils.deleteFilesIgnoringExceptions(tempDir, tempDir.getCreatedFiles());
-        IOUtils.closeWhileHandlingException(toCloseHeroically);
       }
     }
 
@@ -1002,24 +821,6 @@ final class SimpleTextBKDWriter implements Closeable {
     }
   }
 
-  /** Sliced reference to points in an OfflineSorter.ByteSequencesWriter file. */
-  private static final class PathSlice {
-    final PointWriter writer;
-    final long start;
-    final long count;
-
-    public PathSlice(PointWriter writer, long start, long count) {
-      this.writer = writer;
-      this.start = start;
-      this.count = count;
-    }
-
-    @Override
-    public String toString() {
-      return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
-    }
-  }
-
   /** Called on exception, to check whether the checksum is also corrupt in this source, and add that
    *  information (checksum matched or didn't) as a suppressed exception. */
   private Error verifyChecksum(Throwable priorException, PointWriter writer) throws IOException {
@@ -1040,31 +841,6 @@ final class SimpleTextBKDWriter implements Closeable {
     throw IOUtils.rethrowAlways(priorException);
   }
 
-  /** Marks bits for the ords (points) that belong in the right sub tree (those docs that have values >= the splitValue). */
-  private byte[] markRightTree(long rightCount, int splitDim, PathSlice source, LongBitSet ordBitSet) throws IOException {
-
-    // Now we mark ords that fall into the right half, so we can partition on all other dims that are not the split dim:
-
-    // Read the split value, then mark all ords in the right tree (larger than the split value):
-
-    // TODO: find a way to also checksum this reader?  If we changed to markLeftTree, and scanned the final chunk, it could work?
-    try (PointReader reader = source.writer.getReader(source.start + source.count - rightCount, rightCount)) {
-      boolean result = reader.next();
-      assert result;
-      System.arraycopy(reader.packedValue(), splitDim*bytesPerDim, scratch1, 0, bytesPerDim);
-      if (numDataDims > 1) {
-        assert ordBitSet.get(reader.ord()) == false;
-        ordBitSet.set(reader.ord());
-        // Subtract 1 from rightCount because we already did the first value above (so we could record the split value):
-        reader.markOrds(rightCount-1, ordBitSet);
-      }
-    } catch (Throwable t) {
-      throw verifyChecksum(t, source.writer);
-    }
-
-    return scratch1;
-  }
-
   /** Called only in assert */
   private boolean valueInBounds(BytesRef packedValue, byte[] minPackedValue, byte[] maxPackedValue) {
     for(int dim=0;dim<numIndexDims;dim++) {
@@ -1096,19 +872,21 @@ final class SimpleTextBKDWriter implements Closeable {
   }
 
   /** Pull a partition back into heap once the point count is low enough while recursing. */
-  private PathSlice switchToHeap(PathSlice source, List<Closeable> toCloseHeroically) throws IOException {
-    int count = Math.toIntExact(source.count);
+  private HeapPointWriter switchToHeap(PointWriter source) throws IOException {
+    int count = Math.toIntExact(source.count());
     // Not inside the try because we don't want to close it here:
-    PointReader reader = source.writer.getSharedReader(source.start, source.count, toCloseHeroically);
-    try (PointWriter writer = new HeapPointWriter(count, count, packedBytesLength, longOrds, singleValuePerDoc)) {
+
+    try (PointReader reader = source.getReader(0, count);
+        HeapPointWriter writer = new HeapPointWriter(count, count, packedBytesLength)) {
       for(int i=0;i<count;i++) {
         boolean hasNext = reader.next();
         assert hasNext;
-        writer.append(reader.packedValue(), reader.ord(), reader.docID());
+        reader.packedValue(scratchBytesRef1);
+        writer.append(scratchBytesRef1, reader.docID());
       }
-      return new PathSlice(writer, 0, count);
+      return writer;
     } catch (Throwable t) {
-      throw verifyChecksum(t, source.writer);
+      throw verifyChecksum(t, source);
     }
   }
 
@@ -1239,69 +1017,50 @@ final class SimpleTextBKDWriter implements Closeable {
 
   /** The array (sized numDims) of PathSlice describe the cell we have currently recursed to. */
   private void build(int nodeID, int leafNodeOffset,
-                     PathSlice[] slices,
-                     LongBitSet ordBitSet,
+                     PointWriter data,
                      IndexOutput out,
+                     BKDRadixSelector radixSelector,
                      byte[] minPackedValue, byte[] maxPackedValue,
                      byte[] splitPackedValues,
-                     long[] leafBlockFPs,
-                     List<Closeable> toCloseHeroically) throws IOException {
-
-    for(PathSlice slice : slices) {
-      assert slice.count == slices[0].count;
-    }
-
-    if (numDataDims == 1 && slices[0].writer instanceof OfflinePointWriter && slices[0].count <= maxPointsSortInHeap) {
-      // Special case for 1D, to cutover to heap once we recurse deeply enough:
-      slices[0] = switchToHeap(slices[0], toCloseHeroically);
-    }
+                     long[] leafBlockFPs) throws IOException {
 
     if (nodeID >= leafNodeOffset) {
 
       // Leaf node: write block
       // We can write the block in any order so by default we write it sorted by the dimension that has the
       // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient
-      int sortedDim = 0;
-      int sortedDimCardinality = Integer.MAX_VALUE;
-
-      for (int dim=0;dim<numDataDims;dim++) {
-        if (slices[dim].writer instanceof HeapPointWriter == false) {
-          // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started
-          // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer
-          slices[dim] = switchToHeap(slices[dim], toCloseHeroically);
-        }
 
-        PathSlice source = slices[dim];
+      if (data instanceof HeapPointWriter == false) {
+        // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started
+        // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer
+        data = switchToHeap(data);
+      }
 
-        HeapPointWriter heapSource = (HeapPointWriter) source.writer;
+      // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
+      HeapPointWriter heapSource = (HeapPointWriter) data;
 
-        // Find common prefix by comparing first and last values, already sorted in this dimension:
-        heapSource.readPackedValue(Math.toIntExact(source.start), scratch1);
-        heapSource.readPackedValue(Math.toIntExact(source.start + source.count - 1), scratch2);
+      //we store common prefix on scratch1
+      computeCommonPrefixLength(heapSource, scratch1);
 
-        int offset = dim * bytesPerDim;
-        commonPrefixLengths[dim] = bytesPerDim;
-        for(int j=0;j<bytesPerDim;j++) {
-          if (scratch1[offset+j] != scratch2[offset+j]) {
-            commonPrefixLengths[dim] = j;
-            break;
-          }
+      int sortedDim = 0;
+      int sortedDimCardinality = Integer.MAX_VALUE;
+      FixedBitSet[] usedBytes = new FixedBitSet[numDataDims];
+      for (int dim = 0; dim < numDataDims; ++dim) {
+        if (commonPrefixLengths[dim] < bytesPerDim) {
+          usedBytes[dim] = new FixedBitSet(256);
         }
-
+      }
+      //Find the dimension to compress
+      for (int dim = 0; dim < numDataDims; dim++) {
         int prefix = commonPrefixLengths[dim];
         if (prefix < bytesPerDim) {
-          int cardinality = 1;
-          byte previous = scratch1[offset + prefix];
-          for (long i = 1; i < source.count; ++i) {
-            heapSource.readPackedValue(Math.toIntExact(source.start + i), scratch2);
-            byte b = scratch2[offset + prefix];
-            assert Byte.toUnsignedInt(previous) <= Byte.toUnsignedInt(b);
-            if (b != previous) {
-              cardinality++;
-              previous = b;
-            }
+          int offset = dim * bytesPerDim;
+          for (int i = 0; i < heapSource.count(); ++i) {
+            heapSource.getPackedValueSlice(i, scratchBytesRef1);
+            int bucket = scratchBytesRef1.bytes[scratchBytesRef1.offset + offset + prefix] & 0xff;
+            usedBytes[dim].set(bucket);
           }
-          assert cardinality <= 256;
+          int cardinality = usedBytes[dim].cardinality();
           if (cardinality < sortedDimCardinality) {
             sortedDim = dim;
             sortedDimCardinality = cardinality;
@@ -1309,10 +1068,7 @@ final class SimpleTextBKDWriter implements Closeable {
         }
       }
 
-      PathSlice source = slices[sortedDim];
-
-      // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
-      HeapPointWriter heapSource = (HeapPointWriter) source.writer;
+      sortHeapPointWriter(heapSource, sortedDim);
 
       // Save the block file pointer:
       leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer();
@@ -1320,9 +1076,9 @@ final class SimpleTextBKDWriter implements Closeable {
 
       // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
       // loading the values:
-      int count = Math.toIntExact(source.count);
+      int count = Math.toIntExact(heapSource.count());
       assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset;
-      writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(source.start), count);
+      writeLeafBlockDocs(out, heapSource.docIDs, 0, count);
 
       // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us
       // from the index, much like how terms dict does so from the FST:
@@ -1337,12 +1093,12 @@ final class SimpleTextBKDWriter implements Closeable {
 
         @Override
         public BytesRef apply(int i) {
-          heapSource.getPackedValueSlice(Math.toIntExact(source.start + i), scratch);
+          heapSource.getPackedValueSlice(i, scratch);
           return scratch;
         }
       };
       assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues,
-          heapSource.docIDs, Math.toIntExact(source.start));
+          heapSource.docIDs, Math.toIntExact(0));
       writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues);
 
     } else {
@@ -1355,91 +1111,67 @@ final class SimpleTextBKDWriter implements Closeable {
         splitDim = 0;
       }
 
-      PathSlice source = slices[splitDim];
 
-      assert nodeID < splitPackedValues.length: "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
+      assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
 
       // How many points will be in the left tree:
-      long rightCount = source.count / 2;
-      long leftCount = source.count - rightCount;
+      long rightCount = data.count() / 2;
+      long leftCount = data.count() - rightCount;
+
+      PointWriter leftPointWriter;
+      PointWriter rightPointWriter;
+      byte[] splitValue;
+
+      try (PointWriter leftPointWriter2 = getPointWriter(leftCount, "left" + splitDim);
+           PointWriter rightPointWriter2 = getPointWriter(rightCount, "right" + splitDim)) {
+        splitValue = radixSelector.select(data, leftPointWriter2, rightPointWriter2, 0, data.count(),  leftCount, splitDim);
+        leftPointWriter = leftPointWriter2;
+        rightPointWriter = rightPointWriter2;
+      } catch (Throwable t) {
+        throw verifyChecksum(t, data);
+      }
 
-      byte[] splitValue = markRightTree(rightCount, splitDim, source, ordBitSet);
-      int address = nodeID * (1+bytesPerDim);
+      int address = nodeID * (1 + bytesPerDim);
       splitPackedValues[address] = (byte) splitDim;
       System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim);
 
-      // Partition all PathSlice that are not the split dim into sorted left and right sets, so we can recurse:
-
-      PathSlice[] leftSlices = new PathSlice[numDataDims];
-      PathSlice[] rightSlices = new PathSlice[numDataDims];
-
       byte[] minSplitPackedValue = new byte[packedIndexBytesLength];
       System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, packedIndexBytesLength);
 
       byte[] maxSplitPackedValue = new byte[packedIndexBytesLength];
       System.arraycopy(maxPackedValue, 0, maxSplitPackedValue, 0, packedIndexBytesLength);
 
-      // When we are on this dim, below, we clear the ordBitSet:
-      int dimToClear;
-      if (numDataDims - 1 == splitDim) {
-        dimToClear = numDataDims - 2;
-      } else {
-        dimToClear = numDataDims - 1;
-      }
+      System.arraycopy(splitValue, 0, minSplitPackedValue, splitDim * bytesPerDim, bytesPerDim);
+      System.arraycopy(splitValue, 0, maxSplitPackedValue, splitDim * bytesPerDim, bytesPerDim);
 
-      for(int dim=0;dim<numDataDims;dim++) {
 
-        if (dim == splitDim) {
-          // No need to partition on this dim since it's a simple slice of the incoming already sorted slice, and we
-          // will re-use its shared reader when visiting it as we recurse:
-          leftSlices[dim] = new PathSlice(source.writer, source.start, leftCount);
-          rightSlices[dim] = new PathSlice(source.writer, source.start + leftCount, rightCount);
-          System.arraycopy(splitValue, 0, minSplitPackedValue, dim*bytesPerDim, bytesPerDim);
-          System.arraycopy(splitValue, 0, maxSplitPackedValue, dim*bytesPerDim, bytesPerDim);
-          continue;
-        }
-
-        // Not inside the try because we don't want to close this one now, so that after recursion is done,
-        // we will have done a singel full sweep of the file:
-        PointReader reader = slices[dim].writer.getSharedReader(slices[dim].start, slices[dim].count, toCloseHeroically);
-
-        try (PointWriter leftPointWriter = getPointWriter(leftCount, "left" + dim);
-             PointWriter rightPointWriter = getPointWriter(source.count - leftCount, "right" + dim)) {
-
-          long nextRightCount = reader.split(source.count, ordBitSet, leftPointWriter, rightPointWriter, dim == dimToClear);
-          if (rightCount != nextRightCount) {
-            throw new IllegalStateException("wrong number of points in split: expected=" + rightCount + " but actual=" + nextRightCount + " in dim " + dim);
-          }
-
-          leftSlices[dim] = new PathSlice(leftPointWriter, 0, leftCount);
-          rightSlices[dim] = new PathSlice(rightPointWriter, 0, rightCount);
-        } catch (Throwable t) {
-          throw verifyChecksum(t, slices[dim].writer);
-        }
-      }
 
       // Recurse on left tree:
-      build(2*nodeID, leafNodeOffset, leftSlices,
-            ordBitSet, out,
-            minPackedValue, maxSplitPackedValue,
-            splitPackedValues, leafBlockFPs, toCloseHeroically);
-      for(int dim=0;dim<numDataDims;dim++) {
-        // Don't destroy the dim we split on because we just re-used what our caller above gave us for that dim:
-        if (dim != splitDim) {
-          leftSlices[dim].writer.destroy();
-        }
-      }
+      build(2*nodeID, leafNodeOffset, leftPointWriter, out, radixSelector,
+            minPackedValue, maxSplitPackedValue, splitPackedValues, leafBlockFPs);
 
       // TODO: we could "tail recurse" here?  have our parent discard its refs as we recurse right?
       // Recurse on right tree:
-      build(2*nodeID+1, leafNodeOffset, rightSlices,
-            ordBitSet, out,
-            minSplitPackedValue, maxPackedValue,
-            splitPackedValues, leafBlockFPs, toCloseHeroically);
-      for(int dim=0;dim<numDataDims;dim++) {
-        // Don't destroy the dim we split on because we just re-used what our caller above gave us for that dim:
-        if (dim != splitDim) {
-          rightSlices[dim].writer.destroy();
+      build(2*nodeID+1, leafNodeOffset, rightPointWriter, out, radixSelector,
+            minSplitPackedValue, maxPackedValue, splitPackedValues, leafBlockFPs);
+    }
+  }
+
+  private void computeCommonPrefixLength(HeapPointWriter heapPointWriter, byte[] commonPrefix) {
+    Arrays.fill(commonPrefixLengths, bytesPerDim);
+    scratchBytesRef1.length = packedBytesLength;
+    heapPointWriter.getPackedValueSlice(0, scratchBytesRef1);
+    for (int dim = 0; dim < numDataDims; dim++) {
+      System.arraycopy(scratchBytesRef1.bytes, scratchBytesRef1.offset + dim * bytesPerDim, commonPrefix, dim * bytesPerDim, bytesPerDim);
+    }
+    for (int i = 1; i < heapPointWriter.count(); i++) {
+      heapPointWriter.getPackedValueSlice(i, scratchBytesRef1);
+      for (int dim = 0; dim < numDataDims; dim++) {
+        if (commonPrefixLengths[dim] != 0) {
+          int j = FutureArrays.mismatch(commonPrefix, dim * bytesPerDim, dim * bytesPerDim + commonPrefixLengths[dim], scratchBytesRef1.bytes, scratchBytesRef1.offset + dim * bytesPerDim, scratchBytesRef1.offset + dim * bytesPerDim + commonPrefixLengths[dim]);
+          if (j != -1) {
+            commonPrefixLengths[dim] = j;
+          }
         }
       }
     }
@@ -1483,9 +1215,9 @@ final class SimpleTextBKDWriter implements Closeable {
   PointWriter getPointWriter(long count, String desc) throws IOException {
     if (count <= maxPointsSortInHeap) {
       int size = Math.toIntExact(count);
-      return new HeapPointWriter(size, size, packedBytesLength, longOrds, singleValuePerDoc);
+      return new HeapPointWriter(size, size, packedBytesLength);
     } else {
-      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, longOrds, desc, count, singleValuePerDoc);
+      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count);
     }
   }
 
diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java
index 2da74d6..b44da0c 100644
--- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java
+++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextPointsWriter.java
@@ -71,7 +71,6 @@ class SimpleTextPointsWriter extends PointsWriter {
   public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
     PointValues values = reader.getValues(fieldInfo.name);
-    boolean singleValuePerDoc = values.size() == values.getDocCount();
 
     // We use our own fork of the BKDWriter to customize how it writes the index and blocks to disk:
     try (SimpleTextBKDWriter writer = new SimpleTextBKDWriter(writeState.segmentInfo.maxDoc(),
@@ -82,8 +81,7 @@ class SimpleTextPointsWriter extends PointsWriter {
                                                               fieldInfo.getPointNumBytes(),
                                                               SimpleTextBKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE,
                                                               SimpleTextBKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP,
-                                                              values.size(),
-                                                              singleValuePerDoc)) {
+                                                              values.size())) {
 
       values.intersect(new IntersectVisitor() {
           @Override
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene60/Lucene60PointsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene60/Lucene60PointsWriter.java
index fddf08c..d9e6305 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene60/Lucene60PointsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene60/Lucene60PointsWriter.java
@@ -89,7 +89,6 @@ public class Lucene60PointsWriter extends PointsWriter implements Closeable {
   public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
     PointValues values = reader.getValues(fieldInfo.name);
-    boolean singleValuePerDoc = values.size() == values.getDocCount();
 
     try (BKDWriter writer = new BKDWriter(writeState.segmentInfo.maxDoc(),
                                           writeState.directory,
@@ -99,8 +98,7 @@ public class Lucene60PointsWriter extends PointsWriter implements Closeable {
                                           fieldInfo.getPointNumBytes(),
                                           maxPointsInLeafNode,
                                           maxMBSortInHeap,
-                                          values.size(),
-                                          singleValuePerDoc)) {
+                                          values.size())) {
 
       if (values instanceof MutablePointValues) {
         final long fp = writer.writeField(dataOut, fieldInfo.name, (MutablePointValues) values);
@@ -156,8 +154,6 @@ public class Lucene60PointsWriter extends PointsWriter implements Closeable {
       if (fieldInfo.getPointDataDimensionCount() != 0) {
         if (fieldInfo.getPointDataDimensionCount() == 1) {
 
-          boolean singleValuePerDoc = true;
-
           // Worst case total maximum size (if none of the points are deleted):
           long totMaxSize = 0;
           for(int i=0;i<mergeState.pointsReaders.length;i++) {
@@ -169,7 +165,6 @@ public class Lucene60PointsWriter extends PointsWriter implements Closeable {
                 PointValues values = reader.getValues(fieldInfo.name);
                 if (values != null) {
                   totMaxSize += values.size();
-                  singleValuePerDoc &= values.size() == values.getDocCount();
                 }
               }
             }
@@ -187,8 +182,7 @@ public class Lucene60PointsWriter extends PointsWriter implements Closeable {
                                                 fieldInfo.getPointNumBytes(),
                                                 maxPointsInLeafNode,
                                                 maxMBSortInHeap,
-                                                totMaxSize,
-                                                singleValuePerDoc)) {
+                                                totMaxSize)) {
             List<BKDReader> bkdReaders = new ArrayList<>();
             List<MergeState.DocMap> docMaps = new ArrayList<>();
             for(int i=0;i<mergeState.pointsReaders.length;i++) {
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java
new file mode 100644
index 0000000..8d6c852
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDRadixSelector.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util.bkd;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.FutureArrays;
+import org.apache.lucene.util.RadixSelector;
+
+
+/**
+ *
+ * Offline Radix selector for BKD tree.
+ *
+ *  @lucene.internal
+ * */
+public final class BKDRadixSelector {
+  //size of the histogram
+  private static final int HISTOGRAM_SIZE = 256;
+  //size of the online buffer: 8 KB
+  private static final int MAX_SIZE_OFFLINE_BUFFER = 1024 * 8;
+  // we store one histogram per recursion level
+  private final long[][] histogram;
+  //bytes per dimension
+  private final int bytesPerDim;
+  // number of bytes to be sorted: bytesPerDim + Integer.BYTES
+  private final int bytesSorted;
+  //data dimensions size
+  private final int packedBytesLength;
+  //flag to when we are moving to sort on heap
+  private final int maxPointsSortInHeap;
+  //reusable buffer
+  private final byte[] offlineBuffer;
+  //holder for partition points
+  private final int[] partitionBucket;
+  // scratch object to move bytes around
+  private final BytesRef bytesRef1 = new BytesRef();
+  // scratch object to move bytes around
+  private final BytesRef bytesRef2 = new BytesRef();
+  //Directory to create new Offline writer
+  private final Directory tempDir;
+  // prefix for temp files
+  private final String tempFileNamePrefix;
+
+  /**
+   * Sole constructor.
+   */
+  public BKDRadixSelector(int numDim, int bytesPerDim, int maxPointsSortInHeap, Directory tempDir, String tempFileNamePrefix) {
+    this.bytesPerDim = bytesPerDim;
+    this.packedBytesLength = numDim * bytesPerDim;
+    this.bytesSorted = bytesPerDim + Integer.BYTES;
+    this.maxPointsSortInHeap = 2 * maxPointsSortInHeap;
+    int numberOfPointsOffline  = MAX_SIZE_OFFLINE_BUFFER / (packedBytesLength + Integer.BYTES);
+    this.offlineBuffer = new byte[numberOfPointsOffline * (packedBytesLength + Integer.BYTES)];
+    this.partitionBucket = new int[bytesSorted];
+    this.histogram = new long[bytesSorted][HISTOGRAM_SIZE];
+    this.bytesRef1.length = numDim * bytesPerDim;
+    this.tempDir = tempDir;
+    this.tempFileNamePrefix = tempFileNamePrefix;
+  }
+
+  /**
+   *
+   * Method to partition the input data. It returns the value of the dimension where
+   * the split happens. The method destroys the original writer.
+   *
+   */
+  public byte[] select(PointWriter points, PointWriter left, PointWriter right, long from, long to, long partitionPoint, int dim) throws IOException {
+    checkArgs(from, to, partitionPoint);
+
+    //If we are on heap then we just select on heap
+    if (points instanceof HeapPointWriter) {
+      return heapSelect((HeapPointWriter) points, left, right, dim, Math.toIntExact(from), Math.toIntExact(to),  Math.toIntExact(partitionPoint), 0);
+    }
+
+    //reset histogram
+    for (int i = 0; i < bytesSorted; i++) {
+      Arrays.fill(histogram[i], 0);
+    }
+    OfflinePointWriter offlinePointWriter = (OfflinePointWriter) points;
+
+    //find common prefix, it does already set histogram values if needed
+    int commonPrefix = findCommonPrefix(offlinePointWriter, from, to, dim);
+
+    //if all equals we just partition the data
+    if (commonPrefix ==  bytesSorted) {
+      partition(offlinePointWriter, left,  right, null, from, to, dim, commonPrefix - 1, partitionPoint);
+      return partitionPointFromCommonPrefix();
+    }
+    //let's rock'n'roll
+    return buildHistogramAndPartition(offlinePointWriter, left, right, from, to, partitionPoint, 0, commonPrefix, dim);
+  }
+
+  void checkArgs(long from, long to, long partitionPoint) {
+    if (partitionPoint < from) {
+      throw new IllegalArgumentException("partitionPoint must be >= from");
+    }
+    if (partitionPoint >= to) {
+      throw new IllegalArgumentException("partitionPoint must be < to");
+    }
+  }
+
+  private int findCommonPrefix(OfflinePointWriter points, long from, long to, int dim) throws IOException{
+    //find common prefix
+    byte[] commonPrefix = new byte[bytesSorted];
+    int commonPrefixPosition = bytesSorted;
+    try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
+      reader.next();
+      reader.packedValueWithDocId(bytesRef1);
+      // copy dimension
+      System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, commonPrefix, 0, bytesPerDim);
+      // copy docID
+      System.arraycopy(bytesRef1.bytes, bytesRef1.offset + packedBytesLength, commonPrefix, bytesPerDim, Integer.BYTES);
+      for (long i = from + 1; i< to; i++) {
+        reader.next();
+        reader.packedValueWithDocId(bytesRef1);
+        int startIndex =  dim * bytesPerDim;
+        int endIndex  = (commonPrefixPosition > bytesPerDim) ? startIndex + bytesPerDim :  startIndex + commonPrefixPosition;
+        int j = FutureArrays.mismatch(commonPrefix, 0, endIndex - startIndex, bytesRef1.bytes, bytesRef1.offset + startIndex, bytesRef1.offset + endIndex);
+        if (j == 0) {
+          return 0;
+        } else if (j == -1) {
+          if (commonPrefixPosition > bytesPerDim) {
+            //tie-break on docID
+            int k = FutureArrays.mismatch(commonPrefix, bytesPerDim, commonPrefixPosition, bytesRef1.bytes, bytesRef1.offset + packedBytesLength, bytesRef1.offset + packedBytesLength + commonPrefixPosition - bytesPerDim );
+            if (k != -1) {
+              commonPrefixPosition = bytesPerDim + k;
+            }
+          }
+        } else {
+          commonPrefixPosition = j;
+        }
+      }
+    }
+
+    //build histogram up to the common prefix
+    for (int i = 0; i < commonPrefixPosition; i++) {
+      partitionBucket[i] = commonPrefix[i] & 0xff;
+      histogram[i][partitionBucket[i]] = to - from;
+    }
+    return commonPrefixPosition;
+  }
+
+  private byte[] buildHistogramAndPartition(OfflinePointWriter points, PointWriter left, PointWriter right,
+                                            long from, long to, long partitionPoint, int iteration,  int commonPrefix, int dim) throws IOException {
+
+    long leftCount = 0;
+    long rightCount = 0;
+    //build histogram at the commonPrefix byte
+    try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
+      while (reader.next()) {
+        reader.packedValueWithDocId(bytesRef1);
+        int bucket;
+        if (commonPrefix < bytesPerDim) {
+          bucket = bytesRef1.bytes[bytesRef1.offset + dim * bytesPerDim + commonPrefix] & 0xff;
+        } else {
+          bucket = bytesRef1.bytes[bytesRef1.offset + packedBytesLength + commonPrefix - bytesPerDim] & 0xff;
+        }
+        histogram[commonPrefix][bucket]++;
+      }
+    }
+    //Count left points and record the partition point
+    for(int i = 0; i < HISTOGRAM_SIZE; i++) {
+      long size = histogram[commonPrefix][i];
+      if (leftCount + size > partitionPoint - from) {
+        partitionBucket[commonPrefix] = i;
+        break;
+      }
+      leftCount += size;
+    }
+    //Count right points
+    for(int i = partitionBucket[commonPrefix] + 1; i < HISTOGRAM_SIZE; i++) {
+      rightCount += histogram[commonPrefix][i];
+    }
+
+    long delta = histogram[commonPrefix][partitionBucket[commonPrefix]];
+    assert leftCount + rightCount + delta == to - from;
+
+    //special case when be have lot of points that are equal
+    if (commonPrefix == bytesSorted - 1) {
+      long tieBreakCount =(partitionPoint - from - leftCount);
+      partition(points, left,  right, null, from, to, dim, commonPrefix, tieBreakCount);
+      return partitionPointFromCommonPrefix();
+    }
+
+    //create the delta points writer
+    PointWriter deltaPoints;
+    if (delta <= maxPointsSortInHeap) {
+      deltaPoints =  new HeapPointWriter(Math.toIntExact(delta), Math.toIntExact(delta), packedBytesLength);
+    } else {
+      deltaPoints = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "delta" + iteration, delta);
+    }
+    //divide the points. This actually destroys the current writer
+    partition(points, left, right, deltaPoints, from, to, dim, commonPrefix, 0);
+    //close delta point writer
+    deltaPoints.close();
+
+    long newPartitionPoint = partitionPoint - from - leftCount;
+
+    if (deltaPoints instanceof HeapPointWriter) {
+      return heapSelect((HeapPointWriter) deltaPoints, left, right, dim, 0, (int) deltaPoints.count(), Math.toIntExact(newPartitionPoint), ++commonPrefix);
+    } else {
+      return buildHistogramAndPartition((OfflinePointWriter) deltaPoints, left, right, 0, deltaPoints.count(), newPartitionPoint, ++iteration, ++commonPrefix, dim);
+    }
+  }
+
+  private void partition(OfflinePointWriter points, PointWriter left, PointWriter right, PointWriter deltaPoints,
+                           long from, long to, int dim, int bytePosition, long numDocsTiebreak) throws IOException {
+    assert bytePosition == bytesSorted -1 || deltaPoints != null;
+    long tiebreakCounter = 0;
+    try (OfflinePointReader reader = points.getReader(from, to - from, offlineBuffer)) {
+      while (reader.next()) {
+        reader.packedValueWithDocId(bytesRef1);
+        reader.packedValue(bytesRef2);
+        int docID = reader.docID();
+        int bucket;
+        if (bytePosition < bytesPerDim) {
+          bucket = bytesRef1.bytes[bytesRef1.offset + dim * bytesPerDim + bytePosition] & 0xff;
+        } else {
+          bucket = bytesRef1.bytes[bytesRef1.offset + packedBytesLength + bytePosition - bytesPerDim] & 0xff;
+        }
+        //int bucket = getBucket(bytesRef1, dim, thisCommonPrefix);
+        if (bucket < this.partitionBucket[bytePosition]) {
+          // to the left side
+          left.append(bytesRef2, docID);
+        } else if (bucket > this.partitionBucket[bytePosition]) {
+          // to the right side
+          right.append(bytesRef2, docID);
+        } else {
+          if (bytePosition == bytesSorted - 1) {
+            if (tiebreakCounter < numDocsTiebreak) {
+              left.append(bytesRef2, docID);
+              tiebreakCounter++;
+            } else {
+              right.append(bytesRef2, docID);
+            }
+          } else {
+            deltaPoints.append(bytesRef2, docID);
+          }
+        }
+      }
+    }
+    //Delete original file
+    points.destroy();
+  }
+
+  private byte[] partitionPointFromCommonPrefix() {
+    byte[] partition = new byte[bytesPerDim];
+    for (int i = 0; i < bytesPerDim; i++) {
+      partition[i] = (byte)partitionBucket[i];
+    }
+    return partition;
+  }
+
+  private byte[] heapSelect(HeapPointWriter points, PointWriter left, PointWriter right, int dim, int from, int to, int partitionPoint, int commonPrefix) throws IOException {
+    final int offset = dim * bytesPerDim + commonPrefix;
+    new RadixSelector(bytesSorted - commonPrefix) {
+
+      @Override
+      protected void swap(int i, int j) {
+        points.swap(i, j);
+      }
+
+      @Override
+      protected int byteAt(int i, int k) {
+        assert k >= 0;
+        if (k + commonPrefix < bytesPerDim) {
+          // dim bytes
+          int block = i / points.valuesPerBlock;
+          int index = i % points.valuesPerBlock;
+          return points.blocks.get(block)[index * packedBytesLength + offset + k] & 0xff;
+        } else {
+          // doc id
+          int s = 3 - (k + commonPrefix - bytesPerDim);
+          return (points.docIDs[i] >>> (s * 8)) & 0xff;
+        }
+      }
+    }.select(from, to, partitionPoint);
+
+    for (int i = from; i < to; i++) {
+      points.getPackedValueSlice(i, bytesRef1);
+      int docID = points.docIDs[i];
+      if (i < partitionPoint) {
+        left.append(bytesRef1, docID);
+      } else {
+        right.append(bytesRef1, docID);
+      }
+    }
+    byte[] partition = new byte[bytesPerDim];
+    points.getPackedValueSlice(partitionPoint, bytesRef1);
+    System.arraycopy(bytesRef1.bytes, bytesRef1.offset + dim * bytesPerDim, partition, 0, bytesPerDim);
+    return partition;
+  }
+}
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
index 676896f..8b66de4 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
@@ -20,7 +20,6 @@ import java.io.Closeable;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Comparator;
 import java.util.List;
 import java.util.function.IntFunction;
 
@@ -40,14 +39,11 @@ import org.apache.lucene.store.TrackingDirectoryWrapper;
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRefBuilder;
-import org.apache.lucene.util.BytesRefComparator;
 import org.apache.lucene.util.FixedBitSet;
 import org.apache.lucene.util.FutureArrays;
 import org.apache.lucene.util.IOUtils;
-import org.apache.lucene.util.LongBitSet;
 import org.apache.lucene.util.MSBRadixSorter;
 import org.apache.lucene.util.NumericUtils;
-import org.apache.lucene.util.OfflineSorter;
 import org.apache.lucene.util.PriorityQueue;
 
 // TODO
@@ -60,7 +56,8 @@ import org.apache.lucene.util.PriorityQueue;
 //     per leaf, and you can reduce that by putting more points per leaf
 //   - we could use threads while building; the higher nodes are very parallelizable
 
-/** Recursively builds a block KD-tree to assign all incoming points in N-dim space to smaller
+/**
+ *  Recursively builds a block KD-tree to assign all incoming points in N-dim space to smaller
  *  and smaller N-dim rectangles (cells) until the number of points in a given
  *  rectangle is &lt;= <code>maxPointsInLeafNode</code>.  The tree is
  *  fully balanced, which means the leaf nodes will have between 50% and 100% of
@@ -69,14 +66,13 @@ import org.apache.lucene.util.PriorityQueue;
  *
  *  <p>The number of dimensions can be 1 to 8, but every byte[] value is fixed length.
  *
- *  <p>
- *  See <a href="https://www.cs.duke.edu/~pankaj/publications/papers/bkd-sstd.pdf">this paper</a> for details.
- *
- *  <p>This consumes heap during writing: it allocates a <code>LongBitSet(numPoints)</code>,
- *  and then uses up to the specified {@code maxMBSortInHeap} heap space for writing.
+ *  <p>This consumes heap during writing: it allocates a <code>Long[numLeaves]</code>,
+ *  a <code>byte[numLeaves*(1+bytesPerDim)]</code> and then uses up to the specified
+ *  {@code maxMBSortInHeap} heap space for writing.
  *
  *  <p>
- *  <b>NOTE</b>: This can write at most Integer.MAX_VALUE * <code>maxPointsInLeafNode</code> total points.
+ *  <b>NOTE</b>: This can write at most Integer.MAX_VALUE * <code>maxPointsInLeafNode</code> / (1+bytesPerDim)
+ *  total points.
  *
  * @lucene.experimental */
 
@@ -144,32 +140,13 @@ public class BKDWriter implements Closeable {
 
   protected long pointCount;
 
-  /** true if we have so many values that we must write ords using long (8 bytes) instead of int (4 bytes) */
-  protected final boolean longOrds;
-
   /** An upper bound on how many points the caller will add (includes deletions) */
   private final long totalPointCount;
 
-  /** True if every document has at most one value.  We specialize this case by not bothering to store the ord since it's redundant with docID.  */
-  protected final boolean singleValuePerDoc;
-
-  /** How much heap OfflineSorter is allowed to use */
-  protected final OfflineSorter.BufferSize offlineSorterBufferMB;
-
-  /** How much heap OfflineSorter is allowed to use */
-  protected final int offlineSorterMaxTempFiles;
-
   private final int maxDoc;
 
   public BKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim,
-                   int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount, boolean singleValuePerDoc) throws IOException {
-    this(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount, singleValuePerDoc,
-         totalPointCount > Integer.MAX_VALUE, Math.max(1, (long) maxMBSortInHeap), OfflineSorter.MAX_TEMPFILES);
-  }
-
-  protected BKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims, int bytesPerDim,
-                      int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount,
-                      boolean singleValuePerDoc, boolean longOrds, long offlineSorterBufferMB, int offlineSorterMaxTempFiles) throws IOException {
+                      int maxPointsInLeafNode, double maxMBSortInHeap, long totalPointCount) throws IOException {
     verifyParams(numDataDims, numIndexDims, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount);
     // We use tracking dir to deal with removing files on exception, so each place that
     // creates temp files doesn't need crazy try/finally/sucess logic:
@@ -181,8 +158,6 @@ public class BKDWriter implements Closeable {
     this.bytesPerDim = bytesPerDim;
     this.totalPointCount = totalPointCount;
     this.maxDoc = maxDoc;
-    this.offlineSorterBufferMB = OfflineSorter.BufferSize.megabytes(offlineSorterBufferMB);
-    this.offlineSorterMaxTempFiles = offlineSorterMaxTempFiles;
     docsSeen = new FixedBitSet(maxDoc);
     packedBytesLength = numDataDims * bytesPerDim;
     packedIndexBytesLength = numIndexDims * bytesPerDim;
@@ -195,21 +170,9 @@ public class BKDWriter implements Closeable {
     minPackedValue = new byte[packedIndexBytesLength];
     maxPackedValue = new byte[packedIndexBytesLength];
 
-    // If we may have more than 1+Integer.MAX_VALUE values, then we must encode ords with long (8 bytes), else we can use int (4 bytes).
-    this.longOrds = longOrds;
-
-    this.singleValuePerDoc = singleValuePerDoc;
+    // dimensional values (numDims * bytesPerDim) + docID (int)
+    bytesPerDoc = packedBytesLength + Integer.BYTES;
 
-    // dimensional values (numDims * bytesPerDim) + ord (int or long) + docID (int)
-    if (singleValuePerDoc) {
-      // Lucene only supports up to 2.1 docs, so we better not need longOrds in this case:
-      assert longOrds == false;
-      bytesPerDoc = packedBytesLength + Integer.BYTES;
-    } else if (longOrds) {
-      bytesPerDoc = packedBytesLength + Long.BYTES + Integer.BYTES;
-    } else {
-      bytesPerDoc = packedBytesLength + Integer.BYTES + Integer.BYTES;
-    }
 
     // As we recurse, we compute temporary partitions of the data, halving the
     // number of points at each recursion.  Once there are few enough points,
@@ -217,10 +180,10 @@ public class BKDWriter implements Closeable {
     // time in the recursion, we hold the number of points at that level, plus
     // all recursive halves (i.e. 16 + 8 + 4 + 2) so the memory usage is 2X
     // what that level would consume, so we multiply by 0.5 to convert from
-    // bytes to points here.  Each dimension has its own sorted partition, so
-    // we must divide by numDims as wel.
+    // bytes to points here.  In addition the radix partitioning may sort on memory
+    // double of this size so we multiply by another 0.5.
 
-    maxPointsSortInHeap = (int) (0.5 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc * numDataDims));
+    maxPointsSortInHeap = (int) (0.25 * (maxMBSortInHeap * 1024 * 1024) / (bytesPerDoc));
 
     // Finally, we must be able to hold at least the leaf node in heap during build:
     if (maxPointsSortInHeap < maxPointsInLeafNode) {
@@ -228,7 +191,7 @@ public class BKDWriter implements Closeable {
     }
 
     // We write first maxPointsSortInHeap in heap, then cutover to offline for additional points:
-    heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength, longOrds, singleValuePerDoc);
+    heapPointWriter = new HeapPointWriter(16, maxPointsSortInHeap, packedBytesLength);
 
     this.maxMBSortInHeap = maxMBSortInHeap;
   }
@@ -260,15 +223,13 @@ public class BKDWriter implements Closeable {
   private void spillToOffline() throws IOException {
 
     // For each .add we just append to this input file, then in .finish we sort this input and resursively build the tree:
-    offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, longOrds, "spill", 0, singleValuePerDoc);
+    offlinePointWriter = new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, "spill", 0);
     tempInput = offlinePointWriter.out;
-    PointReader reader = heapPointWriter.getReader(0, pointCount);
+    scratchBytesRef1.length = packedBytesLength;
     for(int i=0;i<pointCount;i++) {
-      boolean hasNext = reader.next();
-      assert hasNext;
-      offlinePointWriter.append(reader.packedValue(), i, heapPointWriter.docIDs[i]);
+      heapPointWriter.getPackedValueSlice(i, scratchBytesRef1);
+      offlinePointWriter.append(scratchBytesRef1, heapPointWriter.docIDs[i]);
     }
-
     heapPointWriter = null;
   }
 
@@ -281,10 +242,10 @@ public class BKDWriter implements Closeable {
       if (offlinePointWriter == null) {
         spillToOffline();
       }
-      offlinePointWriter.append(packedValue, pointCount, docID);
+      offlinePointWriter.append(packedValue, docID);
     } else {
       // Not too many points added yet, continue using heap:
-      heapPointWriter.append(packedValue, pointCount, docID);
+      heapPointWriter.append(packedValue, docID);
     }
 
     // TODO: we could specialize for the 1D case:
@@ -762,57 +723,28 @@ public class BKDWriter implements Closeable {
   // encoding and not have our own ByteSequencesReader/Writer
 
   /** Sort the heap writer by the specified dim */
-  private void sortHeapPointWriter(final HeapPointWriter writer, int pointCount, int dim) {
+  private void sortHeapPointWriter(final HeapPointWriter writer, int pointCount, int dim, int commonPrefixLength) {
     // Tie-break by docID:
-
-    // No need to tie break on ord, for the case where the same doc has the same value in a given dimension indexed more than once: it
-    // can't matter at search time since we don't write ords into the index:
-    new MSBRadixSorter(bytesPerDim + Integer.BYTES) {
+    new MSBRadixSorter(bytesPerDim + Integer.BYTES - commonPrefixLength) {
 
       @Override
       protected int byteAt(int i, int k) {
         assert k >= 0;
-        if (k < bytesPerDim) {
+        if (k + commonPrefixLength < bytesPerDim) {
           // dim bytes
           int block = i / writer.valuesPerBlock;
           int index = i % writer.valuesPerBlock;
-          return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k] & 0xff;
+          return writer.blocks.get(block)[index * packedBytesLength + dim * bytesPerDim + k + commonPrefixLength] & 0xff;
         } else {
           // doc id
-          int s = 3 - (k - bytesPerDim);
+          int s = 3 - (k + commonPrefixLength - bytesPerDim);
           return (writer.docIDs[i] >>> (s * 8)) & 0xff;
         }
       }
 
       @Override
       protected void swap(int i, int j) {
-        int docID = writer.docIDs[i];
-        writer.docIDs[i] = writer.docIDs[j];
-        writer.docIDs[j] = docID;
-
-        if (singleValuePerDoc == false) {
-          if (longOrds) {
-            long ord = writer.ordsLong[i];
-            writer.ordsLong[i] = writer.ordsLong[j];
-            writer.ordsLong[j] = ord;
-          } else {
-            int ord = writer.ords[i];
-            writer.ords[i] = writer.ords[j];
-            writer.ords[j] = ord;
-          }
-        }
-
-        byte[] blockI = writer.blocks.get(i / writer.valuesPerBlock);
-        int indexI = (i % writer.valuesPerBlock) * packedBytesLength;
-        byte[] blockJ = writer.blocks.get(j / writer.valuesPerBlock);
-        int indexJ = (j % writer.valuesPerBlock) * packedBytesLength;
-
-        // scratch1 = values[i]
-        System.arraycopy(blockI, indexI, scratch1, 0, packedBytesLength);
-        // values[i] = values[j]
-        System.arraycopy(blockJ, indexJ, blockI, indexI, packedBytesLength);
-        // values[j] = scratch1
-        System.arraycopy(scratch1, 0, blockJ, indexJ, packedBytesLength);
+        writer.swap(i, j);
       }
 
     }.sort(0, pointCount);
@@ -836,134 +768,6 @@ public class BKDWriter implements Closeable {
   }
   */
 
-  //return a new point writer sort by the provided dimension from input data
-  private PointWriter sort(int dim) throws IOException {
-    assert dim >= 0 && dim < numDataDims;
-
-    if (heapPointWriter != null) {
-      assert tempInput == null;
-      // We never spilled the incoming points to disk, so now we sort in heap:
-      HeapPointWriter sorted = heapPointWriter;
-      //long t0 = System.nanoTime();
-      sortHeapPointWriter(sorted, Math.toIntExact(this.pointCount), dim);
-      //long t1 = System.nanoTime();
-      //System.out.println("BKD: sort took " + ((t1-t0)/1000000.0) + " msec");
-      sorted.close();
-      heapPointWriter = null;
-      return sorted;
-    } else {
-      // Offline sort:
-      assert tempInput != null;
-      OfflinePointWriter sorted = sortOffLine(dim, tempInput.getName(), 0, pointCount);
-      tempDir.deleteFile(tempInput.getName());
-      tempInput = null;
-      return sorted;
-    }
-  }
-
-  //return a new point writer sort by the provided dimension from start to start + pointCount
-  private PointWriter sort(int dim, PointWriter writer, final long start, final long pointCount) throws IOException {
-    assert dim >= 0 && dim < numDataDims;
-
-    if (writer instanceof HeapPointWriter) {
-      HeapPointWriter heapPointWriter = createHeapPointWriterCopy((HeapPointWriter) writer, start, pointCount);
-      sortHeapPointWriter(heapPointWriter, Math.toIntExact(pointCount), dim);
-      return heapPointWriter;
-    } else {
-      OfflinePointWriter offlinePointWriter = (OfflinePointWriter) writer;
-      return sortOffLine(dim, offlinePointWriter.name, start, pointCount);
-    }
-  }
-
-  // sort a given file on a given dimension for start to start + point count
-  private OfflinePointWriter sortOffLine(int dim, String inputName, final long start, final long pointCount) throws IOException {
-
-    final int offset = bytesPerDim * dim;
-
-    Comparator<BytesRef> cmp;
-    if (dim == numDataDims - 1) {
-      // in that case the bytes for the dimension and for the doc id are contiguous,
-      // so we don't need a branch
-      cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) {
-        @Override
-        protected int byteAt(BytesRef ref, int i) {
-          return ref.bytes[ref.offset + offset + i] & 0xff;
-        }
-      };
-    } else {
-      cmp = new BytesRefComparator(bytesPerDim + Integer.BYTES) {
-        @Override
-        protected int byteAt(BytesRef ref, int i) {
-          if (i < bytesPerDim) {
-            return ref.bytes[ref.offset + offset + i] & 0xff;
-          } else {
-            return ref.bytes[ref.offset + packedBytesLength + i - bytesPerDim] & 0xff;
-          }
-        }
-      };
-    }
-
-    OfflineSorter sorter = new OfflineSorter(tempDir, tempFileNamePrefix + "_bkd" + dim, cmp, offlineSorterBufferMB, offlineSorterMaxTempFiles, bytesPerDoc, null, 0) {
-      /**
-       * We write/read fixed-byte-width file that {@link OfflinePointReader} can read.
-       */
-      @Override
-      protected ByteSequencesWriter getWriter(IndexOutput out, long count) {
-        return new ByteSequencesWriter(out) {
-          @Override
-          public void write(byte[] bytes, int off, int len) throws IOException {
-            assert len == bytesPerDoc : "len=" + len + " bytesPerDoc=" + bytesPerDoc;
-            out.writeBytes(bytes, off, len);
-          }
-        };
-      }
-
-      /**
-       * We write/read fixed-byte-width file that {@link OfflinePointReader} can read.
-       */
-      @Override
-      protected ByteSequencesReader getReader(ChecksumIndexInput in, String name) throws IOException {
-        //This allows to read only a subset of the original file
-        long startPointer = (name.equals(inputName)) ? bytesPerDoc * start : in.getFilePointer();
-        long endPointer = (name.equals(inputName)) ? startPointer + bytesPerDoc * pointCount : Long.MAX_VALUE;
-        in.seek(startPointer);
-        return new ByteSequencesReader(in, name) {
-          final BytesRef scratch = new BytesRef(new byte[bytesPerDoc]);
-
-          @Override
-          public BytesRef next() throws IOException {
-            if (in.getFilePointer() >= end) {
-              return null;
-            } else if (in.getFilePointer() >= endPointer) {
-              in.seek(end);
-              return null;
-            }
-            in.readBytes(scratch.bytes, 0, bytesPerDoc);
-            return scratch;
-          }
-        };
-      }
-    };
-
-    String name = sorter.sort(inputName);
-    return new OfflinePointWriter(tempDir, name, packedBytesLength, pointCount, longOrds, singleValuePerDoc);
-  }
-
-  private HeapPointWriter createHeapPointWriterCopy(HeapPointWriter writer, long start, long count) throws IOException {
-    //TODO: Can we do this faster?
-    int size = Math.toIntExact(count);
-    try (HeapPointWriter pointWriter = new HeapPointWriter(size, size, packedBytesLength, longOrds, singleValuePerDoc);
-         PointReader reader = writer.getReader(start, count)) {
-     for (long i =0; i < count; i++) {
-       reader.next();
-       pointWriter.append(reader.packedValue(), reader.ord(), reader.docID());
-     }
-     return pointWriter;
-    } catch (Throwable t) {
-      throw verifyChecksum(t, writer);
-    }
-  }
-
   private void checkMaxLeafNodeCount(int numLeaves) {
     if ((1+bytesPerDim) * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) {
       throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex");
@@ -981,25 +785,20 @@ public class BKDWriter implements Closeable {
       throw new IllegalStateException("already finished");
     }
 
+    PointWriter writer;
     if (offlinePointWriter != null) {
       offlinePointWriter.close();
+      writer = offlinePointWriter;
+      tempInput = null;
+    } else {
+      writer = heapPointWriter;
+      heapPointWriter = null;
     }
 
     if (pointCount == 0) {
       throw new IllegalStateException("must index at least one point");
     }
 
-    LongBitSet ordBitSet;
-    if (numIndexDims > 1) {
-      if (singleValuePerDoc) {
-        ordBitSet = new LongBitSet(maxDoc);
-      } else {
-        ordBitSet = new LongBitSet(pointCount);
-      }
-    } else {
-      ordBitSet = null;
-    }
-
     long countPerLeaf = pointCount;
     long innerNodeCount = 1;
 
@@ -1024,23 +823,19 @@ public class BKDWriter implements Closeable {
     // Make sure the math above "worked":
     assert pointCount / numLeaves <= maxPointsInLeafNode: "pointCount=" + pointCount + " numLeaves=" + numLeaves + " maxPointsInLeafNode=" + maxPointsInLeafNode;
 
-    // Slices are created as they are needed
-    PathSlice[] sortedPointWriters = new PathSlice[numIndexDims];
-
-    // This is only used on exception; on normal code paths we close all files we opened:
-    List<Closeable> toCloseHeroically = new ArrayList<>();
+    //We re-use the selector so we do not need to create an object every time.
+    BKDRadixSelector radixSelector = new BKDRadixSelector(numDataDims, bytesPerDim, maxPointsSortInHeap, tempDir, tempFileNamePrefix);
 
     boolean success = false;
     try {
 
       final int[] parentSplits = new int[numIndexDims];
-      build(1, numLeaves, sortedPointWriters,
-            ordBitSet, out,
+      build(1, numLeaves, writer,
+             out, radixSelector,
             minPackedValue, maxPackedValue,
             parentSplits,
             splitPackedValues,
-            leafBlockFPs,
-            toCloseHeroically);
+            leafBlockFPs);
       assert Arrays.equals(parentSplits, new int[numIndexDims]);
 
       // If no exception, we should have cleaned everything up:
@@ -1052,7 +847,6 @@ public class BKDWriter implements Closeable {
     } finally {
       if (success == false) {
         IOUtils.deleteFilesIgnoringExceptions(tempDir, tempDir.getCreatedFiles());
-        IOUtils.closeWhileHandlingException(toCloseHeroically);
       }
     }
 
@@ -1247,7 +1041,6 @@ public class BKDWriter implements Closeable {
   }
 
   private long getLeftMostLeafBlockFP(long[] leafBlockFPs, int nodeID) {
-    int nodeIDIn = nodeID;
     // TODO: can we do this cheaper, e.g. a closed form solution instead of while loop?  Or
     // change the recursion while packing the index to return this left-most leaf block FP
     // from each recursion instead?
@@ -1404,24 +1197,6 @@ public class BKDWriter implements Closeable {
     }
   }
 
-  /** Sliced reference to points in an OfflineSorter.ByteSequencesWriter file. */
-  private static final class PathSlice {
-    final PointWriter writer;
-    final long start;
-    final long count;
-
-    public PathSlice(PointWriter writer, long start, long count) {
-      this.writer = writer;
-      this.start = start;
-      this.count = count;
-    }
-
-    @Override
-    public String toString() {
-      return "PathSlice(start=" + start + " count=" + count + " writer=" + writer + ")";
-    }
-  }
-
   /** Called on exception, to check whether the checksum is also corrupt in this source, and add that
    *  information (checksum matched or didn't) as a suppressed exception. */
   private Error verifyChecksum(Throwable priorException, PointWriter writer) throws IOException {
@@ -1434,8 +1209,10 @@ public class BKDWriter implements Closeable {
     if (writer instanceof OfflinePointWriter) {
       // We are reading from a temp file; go verify the checksum:
       String tempFileName = ((OfflinePointWriter) writer).name;
-      try (ChecksumIndexInput in = tempDir.openChecksumInput(tempFileName, IOContext.READONCE)) {
-        CodecUtil.checkFooter(in, priorException);
+      if (tempDir.getCreatedFiles().contains(tempFileName)) {
+        try (ChecksumIndexInput in = tempDir.openChecksumInput(tempFileName, IOContext.READONCE)) {
+          CodecUtil.checkFooter(in, priorException);
+        }
       }
     }
     
@@ -1443,31 +1220,6 @@ public class BKDWriter implements Closeable {
     throw IOUtils.rethrowAlways(priorException);
   }
 
-  /** Marks bits for the ords (points) that belong in the right sub tree (those docs that have values >= the splitValue). */
-  private byte[] markRightTree(long rightCount, int splitDim, PathSlice source, LongBitSet ordBitSet) throws IOException {
-
-    // Now we mark ords that fall into the right half, so we can partition on all other dims that are not the split dim:
-
-    // Read the split value, then mark all ords in the right tree (larger than the split value):
-
-    // TODO: find a way to also checksum this reader?  If we changed to markLeftTree, and scanned the final chunk, it could work?
-    try (PointReader reader = source.writer.getReader(source.start + source.count - rightCount, rightCount)) {
-      boolean result = reader.next();
-      assert result: "rightCount=" + rightCount + " source.count=" + source.count + " source.writer=" + source.writer;
-      System.arraycopy(reader.packedValue(), splitDim*bytesPerDim, scratch1, 0, bytesPerDim);
-      if (numIndexDims > 1 && ordBitSet != null) {
-        assert ordBitSet.get(reader.ord()) == false;
-        ordBitSet.set(reader.ord());
-        // Subtract 1 from rightCount because we already did the first value above (so we could record the split value):
-        reader.markOrds(rightCount-1, ordBitSet);
-      }
-    } catch (Throwable t) {
-      throw verifyChecksum(t, source.writer);
-    }
-
-    return scratch1;
-  }
-
   /** Called only in assert */
   private boolean valueInBounds(BytesRef packedValue, byte[] minPackedValue, byte[] maxPackedValue) {
     for(int dim=0;dim<numIndexDims;dim++) {
@@ -1521,19 +1273,22 @@ public class BKDWriter implements Closeable {
   }
 
   /** Pull a partition back into heap once the point count is low enough while recursing. */
-  private PathSlice switchToHeap(PathSlice source, List<Closeable> toCloseHeroically) throws IOException {
-    int count = Math.toIntExact(source.count);
+  private HeapPointWriter switchToHeap(PointWriter source) throws IOException {
+    int count = Math.toIntExact(source.count());
     // Not inside the try because we don't want to close it here:
-    PointReader reader = source.writer.getSharedReader(source.start, source.count, toCloseHeroically);
-    try (PointWriter writer = new HeapPointWriter(count, count, packedBytesLength, longOrds, singleValuePerDoc)) {
+
+    try (PointReader reader = source.getReader(0, source.count());
+        HeapPointWriter writer = new HeapPointWriter(count, count, packedBytesLength)) {
       for(int i=0;i<count;i++) {
         boolean hasNext = reader.next();
         assert hasNext;
-        writer.append(reader.packedValue(), reader.ord(), reader.docID());
+        reader.packedValue(scratchBytesRef1);
+        writer.append(scratchBytesRef1, reader.docID());
       }
-      return new PathSlice(writer, 0, count);
+      source.destroy();
+      return writer;
     } catch (Throwable t) {
-      throw verifyChecksum(t, source.writer);
+      throw verifyChecksum(t, source);
     }
   }
 
@@ -1677,71 +1432,54 @@ public class BKDWriter implements Closeable {
     }
   }
 
-  /** The array (sized numDims) of PathSlice describe the cell we have currently recursed to.
+  /** The point writer contains the data that is going to be splitted using radix selection.
   /*  This method is used when we are merging previously written segments, in the numDims > 1 case. */
   private void build(int nodeID, int leafNodeOffset,
-                     PathSlice[] slices,
-                     LongBitSet ordBitSet,
+                     PointWriter points,
                      IndexOutput out,
+                     BKDRadixSelector radixSelector,
                      byte[] minPackedValue, byte[] maxPackedValue,
                      int[] parentSplits,
                      byte[] splitPackedValues,
-                     long[] leafBlockFPs,
-                     List<Closeable> toCloseHeroically) throws IOException {
+                     long[] leafBlockFPs) throws IOException {
 
     if (nodeID >= leafNodeOffset) {
 
       // Leaf node: write block
       // We can write the block in any order so by default we write it sorted by the dimension that has the
       // least number of unique bytes at commonPrefixLengths[dim], which makes compression more efficient
-      int sortedDim = 0;
-      int sortedDimCardinality = Integer.MAX_VALUE;
 
-      for (int dim=0;dim<numIndexDims;dim++) {
-        //create a slice if it does not exist
-        boolean created = false;
-        if (slices[dim] == null) {
-          createPathSlice(slices, dim);
-          created = true;
-        }
-        if (slices[dim].writer instanceof HeapPointWriter == false) {
-          // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started
-          // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer
-          PathSlice slice = slices[dim];
-          slices[dim] = switchToHeap(slices[dim], toCloseHeroically);
-          if (created) {
-            slice.writer.destroy();
-          }
-        }
-
-        PathSlice source = slices[dim];
+      if (points instanceof HeapPointWriter == false) {
+        // Adversarial cases can cause this, e.g. very lopsided data, all equal points, such that we started
+        // offline, but then kept splitting only in one dimension, and so never had to rewrite into heap writer
+        points = switchToHeap(points);
+      }
 
-        HeapPointWriter heapSource = (HeapPointWriter) source.writer;
+      // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
+      HeapPointWriter heapSource = (HeapPointWriter) points;
 
-        // Find common prefix by comparing first and last values, already sorted in this dimension:
-        heapSource.readPackedValue(Math.toIntExact(source.start), scratch1);
-        heapSource.readPackedValue(Math.toIntExact(source.start + source.count - 1), scratch2);
+      //we store common prefix on scratch1
+      computeCommonPrefixLength(heapSource, scratch1);
 
-        int offset = dim * bytesPerDim;
-        commonPrefixLengths[dim] = FutureArrays.mismatch(scratch1, offset, offset + bytesPerDim, scratch2, offset, offset + bytesPerDim);
-        if (commonPrefixLengths[dim] == -1) {
-          commonPrefixLengths[dim] = bytesPerDim;
+      int sortedDim = 0;
+      int sortedDimCardinality = Integer.MAX_VALUE;
+      FixedBitSet[] usedBytes = new FixedBitSet[numDataDims];
+      for (int dim = 0; dim < numDataDims; ++dim) {
+        if (commonPrefixLengths[dim] < bytesPerDim) {
+          usedBytes[dim] = new FixedBitSet(256);
         }
-
+      }
+      //Find the dimension to compress
+      for (int dim = 0; dim < numDataDims; dim++) {
         int prefix = commonPrefixLengths[dim];
         if (prefix < bytesPerDim) {
-          int cardinality = 1;
-          byte previous = scratch1[offset + prefix];
-          for (long i = 1; i < source.count; ++i) {
-            heapSource.readPackedValue(Math.toIntExact(source.start + i), scratch2);
-            byte b = scratch2[offset + prefix];
-            assert Byte.toUnsignedInt(previous) <= Byte.toUnsignedInt(b);
-            if (b != previous) {
-              cardinality++;
-              previous = b;
-            }
+          int offset = dim * bytesPerDim;
+          for (int i = 0; i < heapSource.count(); ++i) {
+            heapSource.getPackedValueSlice(i, scratchBytesRef1);
+            int bucket = scratchBytesRef1.bytes[scratchBytesRef1.offset + offset + prefix] & 0xff;
+            usedBytes[dim].set(bucket);
           }
-          assert cardinality <= 256;
+          int cardinality =usedBytes[dim].cardinality();
           if (cardinality < sortedDimCardinality) {
             sortedDim = dim;
             sortedDimCardinality = cardinality;
@@ -1749,44 +1487,8 @@ public class BKDWriter implements Closeable {
         }
       }
 
-      PathSlice dataDimPathSlice = null;
-
-      if (numDataDims != numIndexDims) {
-        HeapPointWriter heapSource = (HeapPointWriter) slices[0].writer;
-        int from = (int) slices[0].start;
-        int to = from + (int) slices[0].count;
-        Arrays.fill(commonPrefixLengths, numIndexDims, numDataDims, bytesPerDim);
-        heapSource.readPackedValue(from, scratch1);
-        for (int i = from + 1; i < to; ++i) {
-          heapSource.readPackedValue(i, scratch2);
-          for (int dim = numIndexDims; dim < numDataDims; dim++) {
-            final int offset = dim * bytesPerDim;
-            int dimensionPrefixLength = commonPrefixLengths[dim];
-            commonPrefixLengths[dim] = FutureArrays.mismatch(scratch1, offset, offset + dimensionPrefixLength,
-                scratch2, offset, offset + dimensionPrefixLength);
-            if (commonPrefixLengths[dim] == -1) {
-                commonPrefixLengths[dim] = dimensionPrefixLength;
-            }
-          }
-        }
-        //handle case when all index dimensions contain the same value but not the data dimensions
-        if (commonPrefixLengths[sortedDim] == bytesPerDim) {
-          for (int dim = numIndexDims; dim < numDataDims; ++dim) {
-            if (commonPrefixLengths[dim] != bytesPerDim) {
-              sortedDim = dim;
-              //create a new slice in memory
-              dataDimPathSlice = switchToHeap(slices[0], toCloseHeroically);
-              sortHeapPointWriter((HeapPointWriter) dataDimPathSlice.writer, (int) dataDimPathSlice.count, sortedDim);
-              break;
-            }
-          }
-        }
-      }
-
-      PathSlice source = (dataDimPathSlice != null) ? dataDimPathSlice : slices[sortedDim];
-
-      // We ensured that maxPointsSortInHeap was >= maxPointsInLeafNode, so we better be in heap at this point:
-      HeapPointWriter heapSource = (HeapPointWriter) source.writer;
+      // sort the chosen dimension
+      sortHeapPointWriter(heapSource, Math.toIntExact(heapSource.count()), sortedDim, commonPrefixLengths[sortedDim]);
 
       // Save the block file pointer:
       leafBlockFPs[nodeID - leafNodeOffset] = out.getFilePointer();
@@ -1794,9 +1496,9 @@ public class BKDWriter implements Closeable {
 
       // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
       // loading the values:
-      int count = Math.toIntExact(source.count);
+      int count = Math.toIntExact(heapSource.count());
       assert count > 0: "nodeID=" + nodeID + " leafNodeOffset=" + leafNodeOffset;
-      writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(source.start), count);
+      writeLeafBlockDocs(out, heapSource.docIDs, Math.toIntExact(0), count);
 
       // TODO: minor opto: we don't really have to write the actual common prefixes, because BKDReader on recursing can regenerate it for us
       // from the index, much like how terms dict does so from the FST:
@@ -1814,12 +1516,12 @@ public class BKDWriter implements Closeable {
 
         @Override
         public BytesRef apply(int i) {
-          heapSource.getPackedValueSlice(Math.toIntExact(source.start + i), scratch);
+          heapSource.getPackedValueSlice(Math.toIntExact(i), scratch);
           return scratch;
         }
       };
       assert valuesInOrderAndBounds(count, sortedDim, minPackedValue, maxPackedValue, packedValues,
-          heapSource.docIDs, Math.toIntExact(source.start));
+          heapSource.docIDs, Math.toIntExact(0));
       writeLeafBlockPackedValues(out, commonPrefixLengths, count, sortedDim, packedValues);
 
     } else {
@@ -1832,125 +1534,71 @@ public class BKDWriter implements Closeable {
         splitDim = 0;
       }
 
-      //We delete the created path slices at the same level
-      boolean deleteSplitDim = false;
-      if (slices[splitDim] == null) {
-        createPathSlice(slices, splitDim);
-        deleteSplitDim = true;
-      }
-      PathSlice source = slices[splitDim];
 
-      assert nodeID < splitPackedValues.length: "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
+      assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
 
       // How many points will be in the left tree:
-      long rightCount = source.count / 2;
-      long leftCount = source.count - rightCount;
-
-      // When we are on this dim, below, we clear the ordBitSet:
-      int dimToClear = numIndexDims - 1;
-      while (dimToClear >= 0) {
-        if (slices[dimToClear] != null && splitDim != dimToClear) {
-          break;
-        }
-        dimToClear--;
-      }
-
-      byte[] splitValue = (dimToClear == -1) ? markRightTree(rightCount, splitDim, source, null) : markRightTree(rightCount, splitDim, source, ordBitSet);
-      int address = nodeID * (1+bytesPerDim);
+      long rightCount = points.count() / 2;
+      long leftCount = points.count() - rightCount;
+
+      PointWriter leftPointWriter;
+      PointWriter rightPointWriter;
+      byte[] splitValue;
+      try (PointWriter tempLeftPointWriter = getPointWriter(leftCount, "left" + splitDim);
+           PointWriter tempRightPointWriter = getPointWriter(rightCount, "right" + splitDim)) {
+        splitValue = radixSelector.select(points, tempLeftPointWriter, tempRightPointWriter, 0, points.count(),  leftCount, splitDim);
+        leftPointWriter = tempLeftPointWriter;
+        rightPointWriter = tempRightPointWriter;
+      } catch (Throwable t) {
+        throw verifyChecksum(t, points);
+      }
+
+      int address = nodeID * (1 + bytesPerDim);
       splitPackedValues[address] = (byte) splitDim;
       System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim);
 
-      // Partition all PathSlice that are not the split dim into sorted left and right sets, so we can recurse:
-
-      PathSlice[] leftSlices = new PathSlice[numIndexDims];
-      PathSlice[] rightSlices = new PathSlice[numIndexDims];
-
       byte[] minSplitPackedValue = new byte[packedIndexBytesLength];
       System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, packedIndexBytesLength);
 
       byte[] maxSplitPackedValue = new byte[packedIndexBytesLength];
       System.arraycopy(maxPackedValue, 0, maxSplitPackedValue, 0, packedIndexBytesLength);
 
-
-      for(int dim=0;dim<numIndexDims;dim++) {
-        if (slices[dim] == null) {
-          continue;
-        }
-        if (dim == splitDim) {
-          // No need to partition on this dim since it's a simple slice of the incoming already sorted slice, and we
-          // will re-use its shared reader when visiting it as we recurse:
-          leftSlices[dim] = new PathSlice(source.writer, source.start, leftCount);
-          rightSlices[dim] = new PathSlice(source.writer, source.start + leftCount, rightCount);
-          System.arraycopy(splitValue, 0, minSplitPackedValue, dim*bytesPerDim, bytesPerDim);
-          System.arraycopy(splitValue, 0, maxSplitPackedValue, dim*bytesPerDim, bytesPerDim);
-          continue;
-        }
-
-        // Not inside the try because we don't want to close this one now, so that after recursion is done,
-        // we will have done a singel full sweep of the file:
-        PointReader reader = slices[dim].writer.getSharedReader(slices[dim].start, slices[dim].count, toCloseHeroically);
-
-        try (PointWriter leftPointWriter = getPointWriter(leftCount, "left" + dim);
-             PointWriter rightPointWriter = getPointWriter(source.count - leftCount, "right" + dim)) {
-
-          long nextRightCount = reader.split(source.count, ordBitSet, leftPointWriter, rightPointWriter, dim == dimToClear);
-          if (rightCount != nextRightCount) {
-            throw new IllegalStateException("wrong number of points in split: expected=" + rightCount + " but actual=" + nextRightCount);
-          }
-
-          leftSlices[dim] = new PathSlice(leftPointWriter, 0, leftCount);
-          rightSlices[dim] = new PathSlice(rightPointWriter, 0, rightCount);
-        } catch (Throwable t) {
-          throw verifyChecksum(t, slices[dim].writer);
-        }
-      }
+      System.arraycopy(splitValue, 0, minSplitPackedValue, splitDim * bytesPerDim, bytesPerDim);
+      System.arraycopy(splitValue, 0, maxSplitPackedValue, splitDim * bytesPerDim, bytesPerDim);
 
       parentSplits[splitDim]++;
       // Recurse on left tree:
-      build(2*nodeID, leafNodeOffset, leftSlices,
-            ordBitSet, out,
-            minPackedValue, maxSplitPackedValue, parentSplits,
-            splitPackedValues, leafBlockFPs, toCloseHeroically);
-      for(int dim=0;dim<numIndexDims;dim++) {
-        // Don't destroy the dim we split on because we just re-used what our caller above gave us for that dim:
-        if (dim != splitDim && slices[dim] != null) {
-          leftSlices[dim].writer.destroy();
-        }
-      }
+      build(2 * nodeID, leafNodeOffset, leftPointWriter,
+          out, radixSelector, minPackedValue, maxSplitPackedValue,
+          parentSplits, splitPackedValues, leafBlockFPs);
 
-      // TODO: we could "tail recurse" here?  have our parent discard its refs as we recurse right?
       // Recurse on right tree:
-      build(2*nodeID+1, leafNodeOffset, rightSlices,
-            ordBitSet, out,
-            minSplitPackedValue, maxPackedValue, parentSplits,
-            splitPackedValues, leafBlockFPs, toCloseHeroically);
-      for(int dim=0;dim<numIndexDims;dim++) {
-        // Don't destroy the dim we split on because we just re-used what our caller above gave us for that dim:
-        if (dim != splitDim && slices[dim] != null) {
-          rightSlices[dim].writer.destroy();
-        }
-      }
+      build(2 * nodeID + 1, leafNodeOffset, rightPointWriter,
+          out, radixSelector, minSplitPackedValue, maxPackedValue
+          , parentSplits, splitPackedValues, leafBlockFPs);
+
       parentSplits[splitDim]--;
-      if (deleteSplitDim) {
-        slices[splitDim].writer.destroy();
-      }
     }
   }
 
-  private void createPathSlice(PathSlice[] slices, int dim) throws IOException{
-    assert slices[dim] == null;
-    PathSlice current = null;
-    for(PathSlice slice : slices) {
-      if (slice != null) {
-        current = slice;
-        break;
+  private void computeCommonPrefixLength(HeapPointWriter heapPointWriter, byte[] commonPrefix) {
+    Arrays.fill(commonPrefixLengths, bytesPerDim);
+    scratchBytesRef1.length = packedBytesLength;
+    heapPointWriter.getPackedValueSlice(0, scratchBytesRef1);
+    for (int dim = 0; dim < numDataDims; dim++) {
+      System.arraycopy(scratchBytesRef1.bytes, scratchBytesRef1.offset + dim * bytesPerDim, commonPrefix, dim * bytesPerDim, bytesPerDim);
+    }
+    for (int i = 1; i < heapPointWriter.count(); i++) {
+      heapPointWriter.getPackedValueSlice(i, scratchBytesRef1);
+      for (int dim = 0; dim < numDataDims; dim++) {
+        if (commonPrefixLengths[dim] != 0) {
+          int j = FutureArrays.mismatch(commonPrefix, dim * bytesPerDim, dim * bytesPerDim + commonPrefixLengths[dim], scratchBytesRef1.bytes, scratchBytesRef1.offset + dim * bytesPerDim, scratchBytesRef1.offset + dim * bytesPerDim + commonPrefixLengths[dim]);
+          if (j != -1) {
+            commonPrefixLengths[dim] = j;
+          }
+        }
       }
     }
-    if (current == null) {
-      slices[dim] =  new PathSlice(sort(dim), 0, pointCount);
-    } else {
-      slices[dim] =  new PathSlice(sort(dim, current.writer, current.start, current.count), 0, current.count);
-    }
   }
 
   // only called from assert
@@ -1991,9 +1639,9 @@ public class BKDWriter implements Closeable {
   PointWriter getPointWriter(long count, String desc) throws IOException {
     if (count <= maxPointsSortInHeap) {
       int size = Math.toIntExact(count);
-      return new HeapPointWriter(size, size, packedBytesLength, longOrds, singleValuePerDoc);
+      return new HeapPointWriter(size, size, packedBytesLength);
     } else {
-      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, longOrds, desc, count, singleValuePerDoc);
+      return new OfflinePointWriter(tempDir, tempFileNamePrefix, packedBytesLength, desc, count);
     }
   }
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointReader.java
index 99182cb..9e3b263 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointReader.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointReader.java
@@ -18,47 +18,28 @@ package org.apache.lucene.util.bkd;
 
 import java.util.List;
 
-/** Utility class to read buffered points from in-heap arrays.
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * Utility class to read buffered points from in-heap arrays.
  *
- * @lucene.internal */
+ * @lucene.internal
+ * */
 public final class HeapPointReader extends PointReader {
   private int curRead;
   final List<byte[]> blocks;
   final int valuesPerBlock;
   final int packedBytesLength;
-  final long[] ordsLong;
-  final int[] ords;
   final int[] docIDs;
   final int end;
-  final byte[] scratch;
-  final boolean singleValuePerDoc;
 
-  public HeapPointReader(List<byte[]> blocks, int valuesPerBlock, int packedBytesLength, int[] ords, long[] ordsLong, int[] docIDs, int start, int end, boolean singleValuePerDoc) {
+  public HeapPointReader(List<byte[]> blocks, int valuesPerBlock, int packedBytesLength, int[] docIDs, int start, int end) {
     this.blocks = blocks;
     this.valuesPerBlock = valuesPerBlock;
-    this.singleValuePerDoc = singleValuePerDoc;
-    this.ords = ords;
-    this.ordsLong = ordsLong;
     this.docIDs = docIDs;
     curRead = start-1;
     this.end = end;
     this.packedBytesLength = packedBytesLength;
-    scratch = new byte[packedBytesLength];
-  }
-
-  void writePackedValue(int index, byte[] bytes) {
-    int block = index / valuesPerBlock;
-    int blockIndex = index % valuesPerBlock;
-    while (blocks.size() <= block) {
-      blocks.add(new byte[valuesPerBlock*packedBytesLength]);
-    }
-    System.arraycopy(bytes, 0, blocks.get(blockIndex), blockIndex * packedBytesLength, packedBytesLength);
-  }
-
-  void readPackedValue(int index, byte[] bytes) {
-    int block = index / valuesPerBlock;
-    int blockIndex = index % valuesPerBlock;
-    System.arraycopy(blocks.get(block), blockIndex * packedBytesLength, bytes, 0, packedBytesLength);
   }
 
   @Override
@@ -68,9 +49,12 @@ public final class HeapPointReader extends PointReader {
   }
 
   @Override
-  public byte[] packedValue() {
-    readPackedValue(curRead, scratch);
-    return scratch;
+  public void packedValue(BytesRef bytesRef) {
+    int block = curRead / valuesPerBlock;
+    int blockIndex = curRead % valuesPerBlock;
+    bytesRef.bytes = blocks.get(block);
+    bytesRef.offset = blockIndex * packedBytesLength;
+    bytesRef.length = packedBytesLength;
   }
 
   @Override
@@ -79,17 +63,6 @@ public final class HeapPointReader extends PointReader {
   }
 
   @Override
-  public long ord() {
-    if (singleValuePerDoc) {
-      return docIDs[curRead];
-    } else if (ordsLong != null) {
-      return ordsLong[curRead];
-    } else {
-      return ords[curRead];
-    }
-  }
-
-  @Override
   public void close() {
   }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java
index eb1d48b..0e4ad78 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/HeapPointWriter.java
@@ -16,46 +16,36 @@
  */
 package org.apache.lucene.util.bkd;
 
-import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.lucene.util.ArrayUtil;
 import org.apache.lucene.util.BytesRef;
 
-/** Utility class to write new points into in-heap arrays.
+/**
+ * Utility class to write new points into in-heap arrays.
  *
- *  @lucene.internal */
+ *  @lucene.internal
+ *  */
 public final class HeapPointWriter implements PointWriter {
   public int[] docIDs;
-  public long[] ordsLong;
-  public int[] ords;
   private int nextWrite;
   private boolean closed;
   final int maxSize;
   public final int valuesPerBlock;
   final int packedBytesLength;
-  final boolean singleValuePerDoc;
   // NOTE: can't use ByteBlockPool because we need random-write access when sorting in heap
   public final List<byte[]> blocks = new ArrayList<>();
+  private byte[] scratch;
 
-  public HeapPointWriter(int initSize, int maxSize, int packedBytesLength, boolean longOrds, boolean singleValuePerDoc) {
+
+  public HeapPointWriter(int initSize, int maxSize, int packedBytesLength) {
     docIDs = new int[initSize];
     this.maxSize = maxSize;
     this.packedBytesLength = packedBytesLength;
-    this.singleValuePerDoc = singleValuePerDoc;
-    if (singleValuePerDoc) {
-      this.ordsLong = null;
-      this.ords = null;
-    } else {
-      if (longOrds) {
-        this.ordsLong = new long[initSize];
-      } else {
-        this.ords = new int[initSize];
-      }
-    }
     // 4K per page, unless each value is > 4K:
     valuesPerBlock = Math.max(1, 4096/packedBytesLength);
+    scratch = new byte[packedBytesLength];
   }
 
   public void copyFrom(HeapPointWriter other) {
@@ -63,36 +53,19 @@ public final class HeapPointWriter implements PointWriter {
       throw new IllegalStateException("docIDs.length=" + docIDs.length + " other.nextWrite=" + other.nextWrite);
     }
     System.arraycopy(other.docIDs, 0, docIDs, 0, other.nextWrite);
-    if (singleValuePerDoc == false) {
-      if (other.ords != null) {
-        assert this.ords != null;
-        System.arraycopy(other.ords, 0, ords, 0, other.nextWrite);
-      } else {
-        assert this.ordsLong != null;
-        System.arraycopy(other.ordsLong, 0, ordsLong, 0, other.nextWrite);
-      }
-    }
-
     for(byte[] block : other.blocks) {
       blocks.add(block.clone());
     }
     nextWrite = other.nextWrite;
   }
 
-  public void readPackedValue(int index, byte[] bytes) {
-    assert bytes.length == packedBytesLength;
-    int block = index / valuesPerBlock;
-    int blockIndex = index % valuesPerBlock;
-    System.arraycopy(blocks.get(block), blockIndex * packedBytesLength, bytes, 0, packedBytesLength);
-  }
-
   /** Returns a reference, in <code>result</code>, to the byte[] slice holding this value */
   public void getPackedValueSlice(int index, BytesRef result) {
     int block = index / valuesPerBlock;
     int blockIndex = index % valuesPerBlock;
     result.bytes = blocks.get(block);
     result.offset = blockIndex * packedBytesLength;
-    assert result.length == packedBytesLength;
+    result.length = packedBytesLength;
   }
 
   void writePackedValue(int index, byte[] bytes) {
@@ -108,45 +81,76 @@ public final class HeapPointWriter implements PointWriter {
     System.arraycopy(bytes, 0, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength);
   }
 
+  void writePackedValue(int index, BytesRef bytes) {
+    assert bytes.length == packedBytesLength;
+    int block = index / valuesPerBlock;
+    int blockIndex = index % valuesPerBlock;
+    //System.out.println("writePackedValue: index=" + index + " bytes.length=" + bytes.length + " block=" + block + " blockIndex=" + blockIndex + " valuesPerBlock=" + valuesPerBlock);
+    while (blocks.size() <= block) {
+      // If this is the last block, only allocate as large as necessary for maxSize:
+      int valuesInBlock = Math.min(valuesPerBlock, maxSize - (blocks.size() * valuesPerBlock));
+      blocks.add(new byte[valuesInBlock*packedBytesLength]);
+    }
+    System.arraycopy(bytes.bytes, bytes.offset, blocks.get(block), blockIndex * packedBytesLength, packedBytesLength);
+  }
+
   @Override
-  public void append(byte[] packedValue, long ord, int docID) {
+  public void append(byte[] packedValue, int docID) {
     assert closed == false;
     assert packedValue.length == packedBytesLength;
     if (docIDs.length == nextWrite) {
       int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, Integer.BYTES));
       assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite;
       docIDs = ArrayUtil.growExact(docIDs, nextSize);
-      if (singleValuePerDoc == false) {
-        if (ordsLong != null) {
-          ordsLong = ArrayUtil.growExact(ordsLong, nextSize);
-        } else {
-          ords = ArrayUtil.growExact(ords, nextSize);
-        }
-      }
     }
     writePackedValue(nextWrite, packedValue);
-    if (singleValuePerDoc == false) {
-      if (ordsLong != null) {
-        ordsLong[nextWrite] = ord;
-      } else {
-        assert ord <= Integer.MAX_VALUE;
-        ords[nextWrite] = (int) ord;
-      }
+    docIDs[nextWrite] = docID;
+    nextWrite++;
+  }
+
+  @Override
+  public void append(BytesRef packedValue, int docID) {
+    assert closed == false;
+    assert packedValue.length == packedBytesLength;
+    if (docIDs.length == nextWrite) {
+      int nextSize = Math.min(maxSize, ArrayUtil.oversize(nextWrite+1, Integer.BYTES));
+      assert nextSize > nextWrite: "nextSize=" + nextSize + " vs nextWrite=" + nextWrite;
+      docIDs = ArrayUtil.growExact(docIDs, nextSize);
     }
+    writePackedValue(nextWrite, packedValue);
     docIDs[nextWrite] = docID;
     nextWrite++;
   }
 
+  public void swap(int i, int j) {
+    int docID = docIDs[i];
+    docIDs[i] = docIDs[j];
+    docIDs[j] = docID;
+
+
+    byte[] blockI = blocks.get(i / valuesPerBlock);
+    int indexI = (i % valuesPerBlock) * packedBytesLength;
+    byte[] blockJ = blocks.get(j / valuesPerBlock);
+    int indexJ = (j % valuesPerBlock) * packedBytesLength;
+
+    // scratch1 = values[i]
+    System.arraycopy(blockI, indexI, scratch, 0, packedBytesLength);
+    // values[i] = values[j]
+    System.arraycopy(blockJ, indexJ, blockI, indexI, packedBytesLength);
+    // values[j] = scratch1
+    System.arraycopy(scratch, 0, blockJ, indexJ, packedBytesLength);
+  }
+
   @Override
-  public PointReader getReader(long start, long length) {
-    assert start + length <= docIDs.length: "start=" + start + " length=" + length + " docIDs.length=" + docIDs.length;
-    assert start + length <= nextWrite: "start=" + start + " length=" + length + " nextWrite=" + nextWrite;
-    return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, ords, ordsLong, docIDs, (int) start, Math.toIntExact(start+length), singleValuePerDoc);
+  public long count() {
+    return nextWrite;
   }
 
   @Override
-  public PointReader getSharedReader(long start, long length, List<Closeable> toCloseHeroically) {
-    return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, ords, ordsLong, docIDs, (int) start, nextWrite, singleValuePerDoc);
+  public PointReader getReader(long start, long length) {
+    assert start + length <= docIDs.length: "start=" + start + " length=" + length + " docIDs.length=" + docIDs.length;
+    assert start + length <= nextWrite: "start=" + start + " length=" + length + " nextWrite=" + nextWrite;
+    return new HeapPointReader(blocks, valuesPerBlock, packedBytesLength, docIDs, (int) start, Math.toIntExact(start+length));
   }
 
   @Override
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java
index 2861d59..86afc79 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointReader.java
@@ -24,44 +24,42 @@ import org.apache.lucene.store.ChecksumIndexInput;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexInput;
-import org.apache.lucene.store.IndexOutput;
-import org.apache.lucene.util.LongBitSet;
+import org.apache.lucene.util.BytesRef;
 
-/** Reads points from disk in a fixed-with format, previously written with {@link OfflinePointWriter}.
+/**
+ * Reads points from disk in a fixed-with format, previously written with {@link OfflinePointWriter}.
  *
- * @lucene.internal */
+ * @lucene.internal
+ * */
 public final class OfflinePointReader extends PointReader {
+
   long countLeft;
   final IndexInput in;
-  private final byte[] packedValue;
-  final boolean singleValuePerDoc;
+  byte[] onHeapBuffer;
+  private int offset;
   final int bytesPerDoc;
-  private long ord;
-  private int docID;
-  // true if ords are written as long (8 bytes), else 4 bytes
-  private boolean longOrds;
   private boolean checked;
-
+  private final int packedValueLength;
+  private int pointsInBuffer;
+  private final int maxPointOnHeap;
   // File name we are reading
   final String name;
 
-  public OfflinePointReader(Directory tempDir, String tempFileName, int packedBytesLength, long start, long length,
-                     boolean longOrds, boolean singleValuePerDoc) throws IOException {
-    this.singleValuePerDoc = singleValuePerDoc;
-    int bytesPerDoc = packedBytesLength + Integer.BYTES;
-    if (singleValuePerDoc == false) {
-      if (longOrds) {
-        bytesPerDoc += Long.BYTES;
-      } else {
-        bytesPerDoc += Integer.BYTES;
-      }
-    }
-    this.bytesPerDoc = bytesPerDoc;
+  public OfflinePointReader(Directory tempDir, String tempFileName, int packedBytesLength, long start, long length, byte[] reusableBuffer) throws IOException {
+    this.bytesPerDoc = packedBytesLength + Integer.BYTES;
+    this.packedValueLength = packedBytesLength;
 
     if ((start + length) * bytesPerDoc + CodecUtil.footerLength() > tempDir.fileLength(tempFileName)) {
       throw new IllegalArgumentException("requested slice is beyond the length of this file: start=" + start + " length=" + length + " bytesPerDoc=" + bytesPerDoc + " fileLength=" + tempDir.fileLength(tempFileName) + " tempFileName=" + tempFileName);
     }
+    if (reusableBuffer == null) {
+      throw new IllegalArgumentException("[reusableBuffer] cannot be null");
+    }
+    if (reusableBuffer.length < bytesPerDoc) {
+      throw new IllegalArgumentException("Length of [reusableBuffer] must be bigger than " + bytesPerDoc);
+    }
 
+    this.maxPointOnHeap =  reusableBuffer.length / bytesPerDoc;
     // Best-effort checksumming:
     if (start == 0 && length*bytesPerDoc == tempDir.fileLength(tempFileName) - CodecUtil.footerLength()) {
       // If we are going to read the entire file, e.g. because BKDWriter is now
@@ -74,55 +72,63 @@ public final class OfflinePointReader extends PointReader {
       // at another level of the BKDWriter recursion
       in = tempDir.openInput(tempFileName, IOContext.READONCE);
     }
+
     name = tempFileName;
 
     long seekFP = start * bytesPerDoc;
     in.seek(seekFP);
     countLeft = length;
-    packedValue = new byte[packedBytesLength];
-    this.longOrds = longOrds;
+    this.onHeapBuffer = reusableBuffer;
   }
 
   @Override
   public boolean next() throws IOException {
-    if (countLeft >= 0) {
-      if (countLeft == 0) {
-        return false;
+    if (this.pointsInBuffer == 0) {
+      if (countLeft >= 0) {
+        if (countLeft == 0) {
+          return false;
+        }
       }
-      countLeft--;
-    }
-    try {
-      in.readBytes(packedValue, 0, packedValue.length);
-    } catch (EOFException eofe) {
-      assert countLeft == -1;
-      return false;
-    }
-    docID = in.readInt();
-    if (singleValuePerDoc == false) {
-      if (longOrds) {
-        ord = in.readLong();
-      } else {
-        ord = in.readInt();
+      try {
+        if (countLeft > maxPointOnHeap) {
+          in.readBytes(onHeapBuffer, 0, maxPointOnHeap * bytesPerDoc);
+          pointsInBuffer = maxPointOnHeap - 1;
+          countLeft -= maxPointOnHeap;
+        } else {
+          in.readBytes(onHeapBuffer, 0, (int) countLeft * bytesPerDoc);
+          pointsInBuffer = Math.toIntExact(countLeft - 1);
+          countLeft = 0;
+        }
+        this.offset = 0;
+      } catch (EOFException eofe) {
+        assert countLeft == -1;
+        return false;
       }
     } else {
-      ord = docID;
+      this.pointsInBuffer--;
+      this.offset += bytesPerDoc;
     }
     return true;
   }
 
   @Override
-  public byte[] packedValue() {
-    return packedValue;
+  public void packedValue(BytesRef bytesRef) {
+    bytesRef.bytes = onHeapBuffer;
+    bytesRef.offset = offset;
+    bytesRef.length = packedValueLength;
   }
 
-  @Override
-  public long ord() {
-    return ord;
+  protected void packedValueWithDocId(BytesRef bytesRef) {
+    bytesRef.bytes = onHeapBuffer;
+    bytesRef.offset = offset;
+    bytesRef.length = bytesPerDoc;
   }
 
   @Override
   public int docID() {
-    return docID;
+    int position = this.offset + packedValueLength;
+    return ((onHeapBuffer[position++] & 0xFF) << 24) | ((onHeapBuffer[position++] & 0xFF) << 16)
+        | ((onHeapBuffer[position++] & 0xFF) <<  8) |  (onHeapBuffer[position++] & 0xFF);
   }
 
   @Override
@@ -137,112 +143,5 @@ public final class OfflinePointReader extends PointReader {
       in.close();
     }
   }
-
-  @Override
-  public void markOrds(long count, LongBitSet ordBitSet) throws IOException {
-    if (countLeft < count) {
-      throw new IllegalStateException("only " + countLeft + " points remain, but " + count + " were requested");
-    }
-    long fp = in.getFilePointer() + packedValue.length;
-    if (singleValuePerDoc == false) {
-      fp += Integer.BYTES;
-    }
-    for(long i=0;i<count;i++) {
-      in.seek(fp);
-      long ord;
-      if (longOrds) {
-        ord = in.readLong();
-      } else {
-        ord = in.readInt();
-      }
-      assert ordBitSet.get(ord) == false: "ord=" + ord + " i=" + i + " was seen twice from " + this;
-      ordBitSet.set(ord);
-      fp += bytesPerDoc;
-    }
-  }
-
-  @Override
-  public long split(long count, LongBitSet rightTree, PointWriter left, PointWriter right, boolean doClearBits) throws IOException {
-
-    if (left instanceof OfflinePointWriter == false ||
-        right instanceof OfflinePointWriter == false) {
-      return super.split(count, rightTree, left, right, doClearBits);
-    }
-
-    // We specialize the offline -> offline split since the default impl
-    // is somewhat wasteful otherwise (e.g. decoding docID when we don't
-    // need to)
-
-    int packedBytesLength = packedValue.length;
-
-    int bytesPerDoc = packedBytesLength + Integer.BYTES;
-    if (singleValuePerDoc == false) {
-      if (longOrds) {
-        bytesPerDoc += Long.BYTES;
-      } else {
-        bytesPerDoc += Integer.BYTES;
-      }
-    }
-
-    long rightCount = 0;
-
-    IndexOutput rightOut = ((OfflinePointWriter) right).out;
-    IndexOutput leftOut = ((OfflinePointWriter) left).out;
-
-    assert count <= countLeft: "count=" + count + " countLeft=" + countLeft;
-
-    countLeft -= count;
-
-    long countStart = count;
-
-    byte[] buffer = new byte[bytesPerDoc];
-    while (count > 0) {
-      in.readBytes(buffer, 0, buffer.length);
-
-      long ord;
-      if (longOrds) {
-        // A long ord, after the docID:
-        ord = readLong(buffer, packedBytesLength+Integer.BYTES);
-      } else if (singleValuePerDoc) {
-        // docID is the ord:
-        ord = readInt(buffer, packedBytesLength);
-      } else {
-        // An int ord, after the docID:
-        ord = readInt(buffer, packedBytesLength+Integer.BYTES);
-      }
-
-      if (rightTree.get(ord)) {
-        rightOut.writeBytes(buffer, 0, bytesPerDoc);
-        if (doClearBits) {
-          rightTree.clear(ord);
-        }
-        rightCount++;
-      } else {
-        leftOut.writeBytes(buffer, 0, bytesPerDoc);
-      }
-
-      count--;
-    }
-
-    ((OfflinePointWriter) right).count = rightCount;
-    ((OfflinePointWriter) left).count = countStart-rightCount;
-
-    return rightCount;
-  }
-
-  // Poached from ByteArrayDataInput:
-  private static long readLong(byte[] bytes, int pos) {
-    final int i1 = ((bytes[pos++] & 0xff) << 24) | ((bytes[pos++] & 0xff) << 16) |
-      ((bytes[pos++] & 0xff) << 8) | (bytes[pos++] & 0xff);
-    final int i2 = ((bytes[pos++] & 0xff) << 24) | ((bytes[pos++] & 0xff) << 16) |
-      ((bytes[pos++] & 0xff) << 8) | (bytes[pos++] & 0xff);
-    return (((long)i1) << 32) | (i2 & 0xFFFFFFFFL);
-  }
-
-  // Poached from ByteArrayDataInput:
-  private static int readInt(byte[] bytes, int pos) {
-    return ((bytes[pos++] & 0xFF) << 24) | ((bytes[pos++] & 0xFF) << 16)
-      | ((bytes[pos++] & 0xFF) <<  8) |  (bytes[pos++] & 0xFF);
-  }
 }
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java
index 7e615a6..5479b53 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/OfflinePointWriter.java
@@ -16,104 +16,79 @@
  */
 package org.apache.lucene.util.bkd;
 
-import java.io.Closeable;
 import java.io.IOException;
-import java.util.List;
 
 import org.apache.lucene.codecs.CodecUtil;
-import org.apache.lucene.store.ChecksumIndexInput;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.store.IOContext;
 import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.util.BytesRef;
 
-/** Writes points to disk in a fixed-with format.
+/**
+ * Writes points to disk in a fixed-with format.
  *
- * @lucene.internal */
+ * @lucene.internal
+ * */
 public final class OfflinePointWriter implements PointWriter {
 
   final Directory tempDir;
   public final IndexOutput out;
   public final String name;
   final int packedBytesLength;
-  final boolean singleValuePerDoc;
   long count;
   private boolean closed;
-  // true if ords are written as long (8 bytes), else 4 bytes
-  private boolean longOrds;
-  private OfflinePointReader sharedReader;
-  private long nextSharedRead;
   final long expectedCount;
 
   /** Create a new writer with an unknown number of incoming points */
   public OfflinePointWriter(Directory tempDir, String tempFileNamePrefix, int packedBytesLength,
-                            boolean longOrds, String desc, long expectedCount, boolean singleValuePerDoc) throws IOException {
+                            String desc, long expectedCount) throws IOException {
     this.out = tempDir.createTempOutput(tempFileNamePrefix, "bkd_" + desc, IOContext.DEFAULT);
     this.name = out.getName();
     this.tempDir = tempDir;
     this.packedBytesLength = packedBytesLength;
-    this.longOrds = longOrds;
-    this.singleValuePerDoc = singleValuePerDoc;
-    this.expectedCount = expectedCount;
-  }
 
-  /** Initializes on an already written/closed file, just so consumers can use {@link #getReader} to read the file. */
-  public OfflinePointWriter(Directory tempDir, String name, int packedBytesLength, long count, boolean longOrds, boolean singleValuePerDoc) {
-    this.out = null;
-    this.name = name;
-    this.tempDir = tempDir;
-    this.packedBytesLength = packedBytesLength;
-    this.count = count;
-    closed = true;
-    this.longOrds = longOrds;
-    this.singleValuePerDoc = singleValuePerDoc;
-    this.expectedCount = 0;
+    this.expectedCount = expectedCount;
   }
     
   @Override
-  public void append(byte[] packedValue, long ord, int docID) throws IOException {
+  public void append(byte[] packedValue, int docID) throws IOException {
     assert packedValue.length == packedBytesLength;
     out.writeBytes(packedValue, 0, packedValue.length);
     out.writeInt(docID);
-    if (singleValuePerDoc == false) {
-      if (longOrds) {
-        out.writeLong(ord);
-      } else {
-        assert ord <= Integer.MAX_VALUE;
-        out.writeInt((int) ord);
-      }
-    }
+    count++;
+    assert expectedCount == 0 || count <= expectedCount;
+  }
+
+  @Override
+  public void append(BytesRef packedValue, int docID) throws IOException {
+    assert packedValue.length == packedBytesLength;
+    out.writeBytes(packedValue.bytes, packedValue.offset, packedValue.length);
+    out.writeInt(docID);
     count++;
     assert expectedCount == 0 || count <= expectedCount;
   }
 
   @Override
   public PointReader getReader(long start, long length) throws IOException {
+    byte[] buffer  = new byte[packedBytesLength + Integer.BYTES];
+    return getReader(start, length,  buffer);
+  }
+
+  protected OfflinePointReader getReader(long start, long length, byte[] reusableBuffer) throws IOException {
     assert closed;
     assert start + length <= count: "start=" + start + " length=" + length + " count=" + count;
     assert expectedCount == 0 || count == expectedCount;
-    return new OfflinePointReader(tempDir, name, packedBytesLength, start, length, longOrds, singleValuePerDoc);
+    return new OfflinePointReader(tempDir, name, packedBytesLength, start, length, reusableBuffer);
   }
 
   @Override
-  public PointReader getSharedReader(long start, long length, List<Closeable> toCloseHeroically) throws IOException {
-    if (sharedReader == null) {
-      assert start == 0;
-      assert length <= count;
-      sharedReader = new OfflinePointReader(tempDir, name, packedBytesLength, 0, count, longOrds, singleValuePerDoc);
-      toCloseHeroically.add(sharedReader);
-      // Make sure the OfflinePointReader intends to verify its checksum:
-      assert sharedReader.in instanceof ChecksumIndexInput;
-    } else {
-      assert start == nextSharedRead: "start=" + start + " length=" + length + " nextSharedRead=" + nextSharedRead;
-    }
-    nextSharedRead += length;
-    return sharedReader;
+  public long count() {
+    return count;
   }
 
   @Override
   public void close() throws IOException {
     if (closed == false) {
-      assert sharedReader == null;
       try {
         CodecUtil.writeFooter(out);
       } finally {
@@ -125,12 +100,6 @@ public final class OfflinePointWriter implements PointWriter {
 
   @Override
   public void destroy() throws IOException {
-    if (sharedReader != null) {
-      // At this point, the shared reader should have done a full sweep of the file:
-      assert nextSharedRead == count;
-      sharedReader.close();
-      sharedReader = null;
-    }
     tempDir.deleteFile(name);
   }
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java b/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java
index 0c31275..c0eaff8 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/PointReader.java
@@ -20,62 +20,24 @@ package org.apache.lucene.util.bkd;
 import java.io.Closeable;
 import java.io.IOException;
 
-import org.apache.lucene.util.LongBitSet;
+import org.apache.lucene.util.BytesRef;
 
 /** One pass iterator through all points previously written with a
- *  {@link PointWriter}, abstracting away whether points a read
+ *  {@link PointWriter}, abstracting away whether points are read
  *  from (offline) disk or simple arrays in heap.
  *
- * @lucene.internal */
+ * @lucene.internal
+ * */
 public abstract class PointReader implements Closeable {
 
   /** Returns false once iteration is done, else true. */
   public abstract boolean next() throws IOException;
 
-  /** Returns the packed byte[] value */
-  public abstract byte[] packedValue();
-
-  /** Point ordinal */
-  public abstract long ord();
+  /** Sets the packed value in the provided ByteRef */
+  public abstract void packedValue(BytesRef bytesRef);
 
   /** DocID for this point */
   public abstract int docID();
 
-  /** Iterates through the next {@code count} ords, marking them in the provided {@code ordBitSet}. */
-  public void markOrds(long count, LongBitSet ordBitSet) throws IOException {
-    for(int i=0;i<count;i++) {
-      boolean result = next();
-      if (result == false) {
-        throw new IllegalStateException("did not see enough points from reader=" + this);
-      }
-      assert ordBitSet.get(ord()) == false: "ord=" + ord() + " was seen twice from " + this;
-      ordBitSet.set(ord());
-    }
-  }
-
-  /** Splits this reader into left and right partitions */
-  public long split(long count, LongBitSet rightTree, PointWriter left, PointWriter right, boolean doClearBits) throws IOException {
-
-    // Partition this source according to how the splitDim split the values:
-    long rightCount = 0;
-    for (long i=0;i<count;i++) {
-      boolean result = next();
-      assert result;
-      byte[] packedValue = packedValue();
-      long ord = ord();
-      int docID = docID();
-      if (rightTree.get(ord)) {
-        right.append(packedValue, ord, docID);
-        rightCount++;
-        if (doClearBits) {
-          rightTree.clear(ord);
-        }
-      } else {
-        left.append(packedValue, ord, docID);
-      }
-    }
-
-    return rightCount;
-  }
 }
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java
index 0222d0e..194c826 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/PointWriter.java
@@ -19,24 +19,30 @@ package org.apache.lucene.util.bkd;
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.List;
+
+import org.apache.lucene.util.BytesRef;
 
 /** Appends many points, and then at the end provides a {@link PointReader} to iterate
  *  those points.  This abstracts away whether we write to disk, or use simple arrays
  *  in heap.
  *
- *  @lucene.internal */
+ *  @lucene.internal
+ *  */
 public interface PointWriter extends Closeable {
-  /** Add a new point */
-  void append(byte[] packedValue, long ord, int docID) throws IOException;
+  /** Add a new point from byte array*/
+  void append(byte[] packedValue, int docID) throws IOException;
+
+  /** Add a new point from byteRef */
+  void append(BytesRef packedValue, int docID) throws IOException;
 
   /** Returns a {@link PointReader} iterator to step through all previously added points */
   PointReader getReader(long startPoint, long length) throws IOException;
 
-  /** Returns the single shared reader, used at multiple times during the recursion, to read previously added points */
-  PointReader getSharedReader(long startPoint, long length, List<Closeable> toCloseHeroically) throws IOException;
+  /** Return the number of points in this writer */
+  long count();
 
   /** Removes any temp files behind this writer */
   void destroy() throws IOException;
+
 }
 
diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java b/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java
index 0d57bf8..570f95a 100644
--- a/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java
+++ b/lucene/core/src/test/org/apache/lucene/util/bkd/Test2BBKDPoints.java
@@ -42,7 +42,7 @@ public class Test2BBKDPoints extends LuceneTestCase {
     final int numDocs = (Integer.MAX_VALUE / 26) + 100;
 
     BKDWriter w = new BKDWriter(numDocs, dir, "_0", 1, 1, Long.BYTES,
-                                BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs, false);
+                                BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs);
     int counter = 0;
     byte[] packedBytes = new byte[Long.BYTES];
     for (int docID = 0; docID < numDocs; docID++) {
@@ -79,7 +79,7 @@ public class Test2BBKDPoints extends LuceneTestCase {
     final int numDocs = (Integer.MAX_VALUE / 26) + 100;
 
     BKDWriter w = new BKDWriter(numDocs, dir, "_0", 2, 2, Long.BYTES,
-                                BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs, false);
+                                BKDWriter.DEFAULT_MAX_POINTS_IN_LEAF_NODE, BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, 26L * numDocs);
     int counter = 0;
     byte[] packedBytes = new byte[2*Long.BYTES];
     for (int docID = 0; docID < numDocs; docID++) {
diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
index a01c927..01d05a0 100644
--- a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
+++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKD.java
@@ -48,7 +48,7 @@ public class TestBKD extends LuceneTestCase {
 
   public void testBasicInts1D() throws Exception {
     try (Directory dir = getDirectory(100)) {
-      BKDWriter w = new BKDWriter(100, dir, "tmp", 1, 1, 4, 2, 1.0f, 100, true);
+      BKDWriter w = new BKDWriter(100, dir, "tmp", 1, 1, 4, 2, 1.0f, 100);
       byte[] scratch = new byte[4];
       for(int docID=0;docID<100;docID++) {
         NumericUtils.intToSortableBytes(docID, scratch, 0);
@@ -124,7 +124,7 @@ public class TestBKD extends LuceneTestCase {
       int numIndexDims = TestUtil.nextInt(random(), 1, numDims);
       int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
       float maxMB = (float) 3.0 + (3*random().nextFloat());
-      BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numIndexDims, 4, maxPointsInLeafNode, maxMB, numDocs, true);
+      BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numIndexDims, 4, maxPointsInLeafNode, maxMB, numDocs);
 
       if (VERBOSE) {
         System.out.println("TEST: numDims=" + numDims + " numIndexDims=" + numIndexDims + " numDocs=" + numDocs);
@@ -265,7 +265,7 @@ public class TestBKD extends LuceneTestCase {
       int numDims = TestUtil.nextInt(random(), 1, 5);
       int maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 100);
       float maxMB = (float) 3.0 + (3*random().nextFloat());
-      BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numDims, numBytesPerDim, maxPointsInLeafNode, maxMB, numDocs, true);
+      BKDWriter w = new BKDWriter(numDocs, dir, "tmp", numDims, numDims, numBytesPerDim, maxPointsInLeafNode, maxMB, numDocs);
       BigInteger[][] docs = new BigInteger[numDocs][];
 
       byte[] scratch = new byte[numBytesPerDim*numDims];
@@ -441,7 +441,7 @@ public class TestBKD extends LuceneTestCase {
   public void testTooLittleHeap() throws Exception { 
     try (Directory dir = getDirectory(0)) {
       IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> {
-        new BKDWriter(1, dir, "bkd", 1, 1, 16, 1000000, 0.001, 0, true);
+        new BKDWriter(1, dir, "bkd", 1, 1, 16, 1000000, 0.001, 0);
       });
       assertTrue(expected.getMessage().contains("either increase maxMBSortInHeap or decrease maxPointsInLeafNode"));
     }
@@ -668,7 +668,7 @@ public class TestBKD extends LuceneTestCase {
     List<MergeState.DocMap> docMaps = null;
     int seg = 0;
 
-    BKDWriter w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length, false);
+    BKDWriter w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length);
     IndexOutput out = dir.createOutput("bkd", IOContext.DEFAULT);
     IndexInput in = null;
 
@@ -728,7 +728,7 @@ public class TestBKD extends LuceneTestCase {
           seg++;
           maxPointsInLeafNode = TestUtil.nextInt(random(), 50, 1000);
           maxMB = (float) 3.0 + (3*random().nextDouble());
-          w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length, false);
+          w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length);
           lastDocIDBase = docID;
         }
       }
@@ -749,7 +749,7 @@ public class TestBKD extends LuceneTestCase {
         out.close();
         in = dir.openInput("bkd", IOContext.DEFAULT);
         seg++;
-        w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length, false);
+        w = new BKDWriter(numValues, dir, "_" + seg, numDataDims, numIndexDims, numBytesPerDim, maxPointsInLeafNode, maxMB, docValues.length);
         List<BKDReader> readers = new ArrayList<>();
         for(long fp : toMerge) {
           in.seek(fp);
@@ -924,7 +924,7 @@ public class TestBKD extends LuceneTestCase {
         @Override
         public IndexOutput createTempOutput(String prefix, String suffix, IOContext context) throws IOException {
           IndexOutput out = in.createTempOutput(prefix, suffix, context);
-          if (corrupted == false && prefix.equals("_0_bkd1") && suffix.equals("sort")) {
+          if (corrupted == false && prefix.equals("_0") && suffix.equals("bkd_left0")) {
             corrupted = true;
             return new CorruptingIndexOutput(dir0, 22, out);
           } else {
@@ -1008,7 +1008,7 @@ public class TestBKD extends LuceneTestCase {
   public void testTieBreakOrder() throws Exception {
     try (Directory dir = newDirectory()) {
       int numDocs = 10000;
-      BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", 1, 1, Integer.BYTES, 2, 0.01f, numDocs, true);
+      BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", 1, 1, Integer.BYTES, 2, 0.01f, numDocs);
       for(int i=0;i<numDocs;i++) {
         w.add(new byte[Integer.BYTES], i);
       }
@@ -1046,11 +1046,7 @@ public class TestBKD extends LuceneTestCase {
   public void test2DLongOrdsOffline() throws Exception {
     try (Directory dir = newDirectory()) {
       int numDocs = 100000;
-      boolean singleValuePerDoc = false;
-      boolean longOrds = true;
-      int offlineSorterMaxTempFiles = TestUtil.nextInt(random(), 2, 20);
-      BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", 2, 2, Integer.BYTES, 2, 0.01f, numDocs,
-                                  singleValuePerDoc, longOrds, 1, offlineSorterMaxTempFiles);
+      BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", 2, 2, Integer.BYTES, 2, 0.01f, numDocs);
       byte[] buffer = new byte[2*Integer.BYTES];
       for(int i=0;i<numDocs;i++) {
         random().nextBytes(buffer);
@@ -1101,7 +1097,7 @@ public class TestBKD extends LuceneTestCase {
 
     Directory dir = newFSDirectory(createTempDir());
     int numDocs = 100000;
-    BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", numDims, numIndexDims, bytesPerDim, 32, 1f, numDocs, true);
+    BKDWriter w = new BKDWriter(numDocs+1, dir, "tmp", numDims, numIndexDims, bytesPerDim, 32, 1f, numDocs);
     byte[] tmp = new byte[bytesUsed];
     byte[] buffer = new byte[numDims * bytesPerDim];
     for(int i=0;i<numDocs;i++) {
@@ -1159,7 +1155,7 @@ public class TestBKD extends LuceneTestCase {
     random().nextBytes(uniquePointValue);
 
     BKDWriter w = new BKDWriter(numValues, dir, "_temp", 1, 1, numBytesPerDim, maxPointsInLeafNode,
-        BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, numValues, true);
+        BKDWriter.DEFAULT_MAX_MB_SORT_IN_HEAP, numValues);
     for (int i = 0; i < numValues; ++i) {
       if (i == numValues / 2) {
         w.add(uniquePointValue, i);
diff --git a/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java
new file mode 100644
index 0000000..ca61b02
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/bkd/TestBKDRadixSelector.java
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.util.bkd;
+
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.FutureArrays;
+import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.NumericUtils;
+import org.apache.lucene.util.TestUtil;
+
+public class TestBKDRadixSelector extends LuceneTestCase {
+
+  public void testBasic() throws IOException {
+    int values = 4;
+    Directory dir = getDirectory(values);
+    int middle = 2;
+    int dimensions =1;
+    int bytesPerDimensions = Integer.BYTES;
+    int packedLength = dimensions * bytesPerDimensions;
+    PointWriter points = getRandomPointWriter(dir, values, packedLength);
+    byte[] bytes = new byte[Integer.BYTES];
+    NumericUtils.intToSortableBytes(1, bytes, 0);
+    points.append(bytes, 0);
+    NumericUtils.intToSortableBytes(2, bytes, 0);
+    points.append(bytes, 1);
+    NumericUtils.intToSortableBytes(3, bytes, 0);
+    points.append(bytes, 2);
+    NumericUtils.intToSortableBytes(4, bytes, 0);
+    points.append(bytes, 3);
+    points.close();
+    verify(dir, points, dimensions, 0, values, middle, packedLength, bytesPerDimensions, 0);
+    dir.close();
+  }
+
+  public void testRandomBinaryTiny() throws Exception {
+    doTestRandomBinary(10);
+  }
+
+  public void testRandomBinaryMedium() throws Exception {
+    doTestRandomBinary(25000);
+  }
+
+  @Nightly
+  public void testRandomBinaryBig() throws Exception {
+    doTestRandomBinary(500000);
+  }
+
+  private void doTestRandomBinary(int count) throws IOException {
+    int values = TestUtil.nextInt(random(), count, count*2);
+    Directory dir = getDirectory(values);
+    int start;
+    int end;
+    if (random().nextBoolean()) {
+      start = 0;
+      end = values;
+    } else  {
+      start = TestUtil.nextInt(random(), 0, values -3);
+      end = TestUtil.nextInt(random(), start  + 2, values);
+    }
+    int partitionPoint = TestUtil.nextInt(random(), start + 1, end - 1);
+    int sortedOnHeap = random().nextInt(5000);
+    int dimensions =  TestUtil.nextInt(random(), 1, 8);
+    int bytesPerDimensions = TestUtil.nextInt(random(), 2, 30);
+    int packedLength = dimensions * bytesPerDimensions;
+    PointWriter points = getRandomPointWriter(dir, values, packedLength);
+    byte[] value = new byte[packedLength];
+    for (int i =0; i < values; i++) {
+      random().nextBytes(value);
+      points.append(value, i);
+    }
+    points.close();
+    verify(dir, points, dimensions, start, end, partitionPoint, packedLength, bytesPerDimensions, sortedOnHeap);
+    dir.close();
+  }
+
+  public void testRandomAllDimensionsEquals() throws IOException {
+    int values =  TestUtil.nextInt(random(), 15000, 20000);
+    Directory dir = getDirectory(values);
+    int partitionPoint = random().nextInt(values);
+    int sortedOnHeap = random().nextInt(5000);
+    int dimensions =  TestUtil.nextInt(random(), 1, 8);
+    int bytesPerDimensions = TestUtil.nextInt(random(), 2, 30);
+    int packedLength = dimensions * bytesPerDimensions;
+    PointWriter points = getRandomPointWriter(dir, values, packedLength);
+    byte[] value = new byte[packedLength];
+    random().nextBytes(value);
+    for (int i =0; i < values; i++) {
+      if (random().nextBoolean()) {
+        points.append(value, i);
+      } else {
+        points.append(value, random().nextInt(values));
+      }
+    }
+    points.close();
+    verify(dir, points, dimensions, 0, values, partitionPoint, packedLength, bytesPerDimensions, sortedOnHeap);
+    dir.close();
+  }
+
+  public void testRandomLastByteTwoValues() throws IOException {
+    int values = random().nextInt(15000) + 1;
+    Directory dir = getDirectory(values);
+    int partitionPoint = random().nextInt(values);
+    int sortedOnHeap = random().nextInt(5000);
+    int dimensions =  TestUtil.nextInt(random(), 1, 8);
+    int bytesPerDimensions = TestUtil.nextInt(random(), 2, 30);
+    int packedLength = dimensions * bytesPerDimensions;
+    PointWriter points = getRandomPointWriter(dir, values, packedLength);
+    byte[] value = new byte[packedLength];
+    random().nextBytes(value);
+    for (int i =0; i < values; i++) {
+      if (random().nextBoolean()) {
+        points.append(value, 1);
+      } else {
+        points.append(value, 2);
+      }
+    }
+    points.close();
+    verify(dir, points, dimensions, 0, values, partitionPoint, packedLength, bytesPerDimensions, sortedOnHeap);
+    dir.close();
+  }
+
+  public void testRandomAllDocsEquals() throws IOException {
+    int values = random().nextInt(15000) + 1;
+    Directory dir = getDirectory(values);
+    int partitionPoint = random().nextInt(values);
+    int sortedOnHeap = random().nextInt(5000);
+    int dimensions =  TestUtil.nextInt(random(), 1, 8);
+    int bytesPerDimensions = TestUtil.nextInt(random(), 2, 30);
+    int packedLength = dimensions * bytesPerDimensions;
+    PointWriter points = getRandomPointWriter(dir, values, packedLength);
+    byte[] value = new byte[packedLength];
+    random().nextBytes(value);
+    for (int i =0; i < values; i++) {
+      points.append(value, 0);
+    }
+    points.close();
+    verify(dir, points, dimensions, 0, values, partitionPoint, packedLength, bytesPerDimensions, sortedOnHeap);
+    dir.close();
+  }
+
+  public void testRandomFewDifferentValues() throws IOException {
+    int values = atLeast(15000);
+    Directory dir = getDirectory(values);
+    int partitionPoint = random().nextInt(values);
+    int sortedOnHeap = random().nextInt(5000);
+    int dimensions =  TestUtil.nextInt(random(), 1, 8);
+    int bytesPerDimensions = TestUtil.nextInt(random(), 2, 30);
+    int packedLength = dimensions * bytesPerDimensions;
+    PointWriter points = getRandomPointWriter(dir, values, packedLength);
+    int numberValues = random().nextInt(8) + 2;
+    byte[][] differentValues = new byte[numberValues][packedLength];
+    for (int i =0; i < numberValues; i++) {
+      random().nextBytes(differentValues[i]);
+    }
+    for (int i =0; i < values; i++) {
+      points.append(differentValues[random().nextInt(numberValues)], i);
+    }
+    points.close();
+    verify(dir, points, dimensions, 0, values, partitionPoint, packedLength, bytesPerDimensions, sortedOnHeap);
+    dir.close();
+  }
+
+  private void verify(Directory dir, PointWriter points, int dimensions, long start, long end, long middle, int packedLength, int bytesPerDimensions, int sortedOnHeap) throws IOException{
+    for (int splitDim =0; splitDim < dimensions; splitDim++) {
+      PointWriter copy = copyPoints(dir, points, packedLength);
+      PointWriter leftPointWriter = getRandomPointWriter(dir, middle - start, packedLength);
+      PointWriter rightPointWriter = getRandomPointWriter(dir, end - middle, packedLength);
+      BKDRadixSelector radixSelector = new BKDRadixSelector(dimensions, bytesPerDimensions, sortedOnHeap, dir, "test");
+      byte[] partitionPoint = radixSelector.select(copy, leftPointWriter, rightPointWriter, start, end, middle, splitDim);
+      leftPointWriter.close();
+      rightPointWriter.close();
+      byte[] max = getMax(leftPointWriter, middle - start, bytesPerDimensions, splitDim);
+      byte[] min = getMin(rightPointWriter, end - middle, bytesPerDimensions, splitDim);
+      int cmp = FutureArrays.compareUnsigned(max, 0, bytesPerDimensions, min, 0, bytesPerDimensions);
+      assertTrue(cmp <= 0);
+      if (cmp == 0) {
+        int maxDocID = getMaxDocId(leftPointWriter, middle - start, bytesPerDimensions, splitDim, partitionPoint);
+        int minDocId = getMinDocId(rightPointWriter, end - middle, bytesPerDimensions, splitDim, partitionPoint);
+        assertTrue(minDocId >= maxDocID);
+      }
+      assertTrue(Arrays.equals(partitionPoint, min));
+      leftPointWriter.destroy();
+      rightPointWriter.destroy();
+    }
+    points.destroy();
+  }
+
+  private PointWriter copyPoints(Directory dir, PointWriter points, int packedLength) throws IOException {
+    BytesRef bytesRef = new BytesRef();
+
+    try (PointWriter copy  = getRandomPointWriter(dir, points.count(), packedLength);
+         PointReader reader = points.getReader(0, points.count())) {
+      while (reader.next()) {
+        reader.packedValue(bytesRef);
+        copy.append(bytesRef, reader.docID());
+      }
+      return copy;
+    }
+  }
+
+  private PointWriter getRandomPointWriter(Directory dir, long numPoints, int packedBytesLength) throws IOException {
+    if (numPoints < 4096 && random().nextBoolean()) {
+      return new HeapPointWriter(Math.toIntExact(numPoints), Math.toIntExact(numPoints), packedBytesLength);
+    } else {
+      return new OfflinePointWriter(dir, "test", packedBytesLength, "data", numPoints);
+    }
+  }
+
+  private Directory getDirectory(int numPoints) {
+    Directory dir;
+    if (numPoints > 100000) {
+      dir = newFSDirectory(createTempDir("TestBKDTRadixSelector"));
+    } else {
+      dir = newDirectory();
+    }
+    return dir;
+  }
+
+  private byte[] getMin(PointWriter p, long size, int bytesPerDimension, int dimension) throws  IOException {
+    byte[] min = new byte[bytesPerDimension];
+    Arrays.fill(min, (byte) 0xff);
+    try (PointReader reader = p.getReader(0, size)) {
+      byte[] value = new byte[bytesPerDimension];
+      BytesRef packedValue = new BytesRef();
+      while (reader.next()) {
+        reader.packedValue(packedValue);
+        System.arraycopy(packedValue.bytes, packedValue.offset + dimension * bytesPerDimension, value, 0, bytesPerDimension);
+        if (FutureArrays.compareUnsigned(min, 0, bytesPerDimension, value, 0, bytesPerDimension) > 0) {
+          System.arraycopy(value, 0, min, 0, bytesPerDimension);
+        }
+      }
+    }
+    return min;
+  }
+
+  private int getMinDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws  IOException {
+   int docID = Integer.MAX_VALUE;
+    try (PointReader reader = p.getReader(0, size)) {
+      BytesRef packedValue = new BytesRef();
+      while (reader.next()) {
+        reader.packedValue(packedValue);
+        int offset = dimension * bytesPerDimension;
+        if (FutureArrays.compareUnsigned(packedValue.bytes, packedValue.offset + offset, packedValue.offset + offset + bytesPerDimension, partitionPoint, 0, bytesPerDimension) == 0) {
+          int newDocID = reader.docID();
+          if (newDocID < docID) {
+            docID = newDocID;
+          }
+        }
+      }
+    }
+    return docID;
+  }
+
+  private byte[] getMax(PointWriter p, long size, int bytesPerDimension, int dimension) throws  IOException {
+    byte[] max = new byte[bytesPerDimension];
+    Arrays.fill(max, (byte) 0);
+    try (PointReader reader = p.getReader(0, size)) {
+      byte[] value = new byte[bytesPerDimension];
+      BytesRef packedValue = new BytesRef();
+      while (reader.next()) {
+        reader.packedValue(packedValue);
+        System.arraycopy(packedValue.bytes, packedValue.offset + dimension * bytesPerDimension, value, 0, bytesPerDimension);
+        if (FutureArrays.compareUnsigned(max, 0, bytesPerDimension, value, 0, bytesPerDimension) < 0) {
+          System.arraycopy(value, 0, max, 0, bytesPerDimension);
+        }
+      }
+    }
+    return max;
+  }
+
+  private int getMaxDocId(PointWriter p, long size, int bytesPerDimension, int dimension, byte[] partitionPoint) throws  IOException {
+    int docID = Integer.MIN_VALUE;
+    try (PointReader reader = p.getReader(0, size)) {
+      BytesRef packedValue = new BytesRef();
+      while (reader.next()) {
+        reader.packedValue(packedValue);
+        int offset = dimension * bytesPerDimension;
+        if (FutureArrays.compareUnsigned(packedValue.bytes, packedValue.offset + offset, packedValue.offset + offset + bytesPerDimension, partitionPoint, 0, bytesPerDimension) == 0) {
+          int newDocID = reader.docID();
+          if (newDocID > docID) {
+            docID = newDocID;
+          }
+        }
+      }
+    }
+    return docID;
+  }
+
+}
diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java b/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java
index b5a728b..ec3c323 100644
--- a/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java
+++ b/lucene/test-framework/src/java/org/apache/lucene/index/RandomCodec.java
@@ -106,7 +106,6 @@ public class RandomCodec extends AssertingCodec {
           public void writeField(FieldInfo fieldInfo, PointsReader reader) throws IOException {
 
             PointValues values = reader.getValues(fieldInfo.name);
-            boolean singleValuePerDoc = values.size() == values.getDocCount();
 
             try (BKDWriter writer = new RandomlySplittingBKDWriter(writeState.segmentInfo.maxDoc(),
                                                                    writeState.directory,
@@ -117,7 +116,6 @@ public class RandomCodec extends AssertingCodec {
                                                                    maxPointsInLeafNode,
                                                                    maxMBSortInHeap,
                                                                    values.size(),
-                                                                   singleValuePerDoc,
                                                                    bkdSplitRandomSeed ^ fieldInfo.name.hashCode())) {
                 values.intersect(new IntersectVisitor() {
                     @Override
@@ -262,12 +260,8 @@ public class RandomCodec extends AssertingCodec {
 
     public RandomlySplittingBKDWriter(int maxDoc, Directory tempDir, String tempFileNamePrefix, int numDataDims, int numIndexDims,
                                       int bytesPerDim, int maxPointsInLeafNode, double maxMBSortInHeap,
-                                      long totalPointCount, boolean singleValuePerDoc, int randomSeed) throws IOException {
-      super(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount,
-            getRandomSingleValuePerDoc(singleValuePerDoc, randomSeed),
-            getRandomLongOrds(totalPointCount, singleValuePerDoc, randomSeed),
-            getRandomOfflineSorterBufferMB(randomSeed),
-            getRandomOfflineSorterMaxTempFiles(randomSeed));
+                                      long totalPointCount, int randomSeed) throws IOException {
+      super(maxDoc, tempDir, tempFileNamePrefix, numDataDims, numIndexDims, bytesPerDim, maxPointsInLeafNode, maxMBSortInHeap, totalPointCount);
       this.random = new Random(randomSeed);
     }