You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by iv...@apache.org on 2020/05/11 17:16:19 UTC

[lucene-solr] branch branch_8x updated: LUCENE-9358: remove unnecessary tree rotation for the one dimensional case (#1481)

This is an automated email from the ASF dual-hosted git repository.

ivera pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/branch_8x by this push:
     new b77d693  LUCENE-9358: remove unnecessary tree rotation for the one dimensional case (#1481)
b77d693 is described below

commit b77d693c807e6afb9063bbf67dafc63abf28f121
Author: Ignacio Vera <iv...@apache.org>
AuthorDate: Mon May 11 18:39:15 2020 +0200

    LUCENE-9358: remove unnecessary tree rotation for the one dimensional case (#1481)
    
    Change the way the multi-dimensional BKD tree builder generates the intermediate tree representation to be equal to the one dimensional case to avoid unnecessary tree and leaves rotation
---
 lucene/CHANGES.txt                                 |   3 +
 .../java/org/apache/lucene/util/bkd/BKDWriter.java | 310 ++++++++++++---------
 2 files changed, 177 insertions(+), 136 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index e74b468..63fcd87 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -143,6 +143,9 @@ Other
 * LUCENE-9338: Refactors SimpleBindings to improve type safety and cycle detection (Alan Woodward,
   Adrien Grand)
 
+* LUCENE-9358: Change the way the multi-dimensional BKD tree builder generates the intermediate tree representation to be
+  equal to the one dimensional case to avoid unnecessary tree and leaves rotation. (Ignacio Vera)
+
 Build
 
 * Upgrade forbiddenapis to version 3.0.  (Uwe Schindler)
diff --git a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
index 8118837..16ba5db 100644
--- a/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/bkd/BKDWriter.java
@@ -71,7 +71,7 @@ import org.apache.lucene.util.PriorityQueue;
  *  {@code maxMBSortInHeap} heap space for writing.
  *
  *  <p>
- *  <b>NOTE</b>: This can write at most Integer.MAX_VALUE * <code>maxPointsInLeafNode</code> / (1+bytesPerDim)
+ *  <b>NOTE</b>: This can write at most Integer.MAX_VALUE * <code>maxPointsInLeafNode</code> / bytesPerDim
  *  total points.
  *
  * @lucene.experimental */
@@ -375,6 +375,22 @@ public class BKDWriter implements Closeable {
     }
   }
 
+  /** flat representation of a kd-tree */
+  private interface BKDTreeLeafNodes {
+    /** number of leaf nodes */
+    int numLeaves();
+    /** pointer to the leaf node previously written. Leaves are order from
+     * left to right, so leaf at {@code index} 0 is the leftmost leaf and
+     * the the leaf at {@code numleaves()} -1 is the rightmost leaf */
+    long getLeafLP(int index);
+    /** split value between two leaves. The split value at position n corresponds to the
+     *  leaves at (n -1) and n. */
+    BytesRef getSplitValue(int index);
+    /** split dimension between two leaves. The split dimension at position n corresponds to the
+     *  leaves at (n -1) and n.*/
+    int getSplitDimension(int index);
+  }
+
   /** Write a field from a {@link MutablePointValues}. This way of writing
    *  points is faster than regular writes with {@link BKDWriter#add} since
    *  there is opportunity for reordering points before writing them to
@@ -427,10 +443,12 @@ public class BKDWriter implements Closeable {
     pointCount = values.size();
 
     final int numLeaves = Math.toIntExact((pointCount + maxPointsInLeafNode - 1) / maxPointsInLeafNode);
+    final int numSplits = numLeaves - 1;
 
     checkMaxLeafNodeCount(numLeaves);
 
-    final byte[] splitPackedValues = new byte[numLeaves * (bytesPerDim + 1)];
+    final byte[] splitPackedValues = new byte[numSplits * bytesPerDim];
+    final byte[] splitDimensionValues = new byte[numSplits];
     final long[] leafBlockFPs = new long[numLeaves];
 
     // compute the min/max for this slice
@@ -440,14 +458,40 @@ public class BKDWriter implements Closeable {
     }
 
     final int[] parentSplits = new int[numIndexDims];
-    build(1, 0, numLeaves, values, 0, Math.toIntExact(pointCount), out,
+    build(0, numLeaves, values, 0, Math.toIntExact(pointCount), out,
             minPackedValue.clone(), maxPackedValue.clone(), parentSplits,
-            splitPackedValues, leafBlockFPs,
+            splitPackedValues, splitDimensionValues, leafBlockFPs,
             new int[maxPointsInLeafNode]);
     assert Arrays.equals(parentSplits, new int[numIndexDims]);
 
+    scratchBytesRef1.length = bytesPerDim;
+    scratchBytesRef1.bytes = splitPackedValues;
+
+    BKDTreeLeafNodes leafNodes  = new BKDTreeLeafNodes() {
+      @Override
+      public long getLeafLP(int index) {
+        return leafBlockFPs[index];
+      }
+
+      @Override
+      public BytesRef getSplitValue(int index) {
+        scratchBytesRef1.offset = index * bytesPerDim;
+        return scratchBytesRef1;
+      }
+
+      @Override
+      public int getSplitDimension(int index) {
+        return splitDimensionValues[index] & 0xff;
+      }
+
+      @Override
+      public int numLeaves() {
+        return leafBlockFPs.length;
+      }
+    };
+
     long indexFP = out.getFilePointer();
-    writeIndex(out, maxPointsInLeafNode, leafBlockFPs, splitPackedValues);
+    writeIndex(out, maxPointsInLeafNode, leafNodes);
     return indexFP;
   }
 
@@ -601,17 +645,32 @@ public class BKDWriter implements Closeable {
 
       long indexFP = out.getFilePointer();
 
-      int numInnerNodes = leafBlockStartValues.size();
+      scratchBytesRef1.length = bytesPerDim;
+      scratchBytesRef1.offset = 0;
+      assert leafBlockStartValues.size() + 1 == leafBlockFPs.size();
+      BKDTreeLeafNodes leafNodes = new BKDTreeLeafNodes() {
+        @Override
+        public long getLeafLP(int index) {
+          return leafBlockFPs.get(index);
+        }
 
-      //System.out.println("BKDW: now rotate numInnerNodes=" + numInnerNodes + " leafBlockStarts=" + leafBlockStartValues.size());
+        @Override
+        public BytesRef getSplitValue(int index) {
+          scratchBytesRef1.bytes = leafBlockStartValues.get(index);
+          return scratchBytesRef1;
+        }
 
-      byte[] index = new byte[(1+numInnerNodes) * (1+bytesPerDim)];
-      rotateToTree(1, 0, numInnerNodes, index, leafBlockStartValues);
-      long[] arr = new long[leafBlockFPs.size()];
-      for(int i=0;i<leafBlockFPs.size();i++) {
-        arr[i] = leafBlockFPs.get(i);
-      }
-      writeIndex(out, maxPointsInLeafNode, arr, index);
+        @Override
+        public int getSplitDimension(int index) {
+          return 0;
+        }
+
+        @Override
+        public int numLeaves() {
+          return leafBlockFPs.size();
+        }
+      };
+      writeIndex(out, maxPointsInLeafNode, leafNodes);
       return indexFP;
     }
 
@@ -663,28 +722,6 @@ public class BKDWriter implements Closeable {
     }
   }
 
-  private void rotateToTree(int nodeID, int offset, int numNodes, byte[] index, List<byte[]> leafBlockStartValues) {
-    if (numNodes == 1) {
-      // Leaf index node
-      System.arraycopy(leafBlockStartValues.get(offset), 0, index, nodeID*(1+bytesPerDim)+1, bytesPerDim);
-    } else if (numNodes > 1) {
-      // Internal index node
-      // numNodes + 1 is the number of leaves
-      // -1 because there is one less inner node
-      int leftHalf = getNumLeftLeafNodes(numNodes + 1) - 1;
-      int rootOffset = offset + leftHalf;
-
-      System.arraycopy(leafBlockStartValues.get(rootOffset), 0, index, nodeID*(1+bytesPerDim)+1, bytesPerDim);
-
-      // Recurse left
-      rotateToTree(2*nodeID, offset, leftHalf, index, leafBlockStartValues);
-      // Recurse right
-      rotateToTree(2*nodeID+1, rootOffset+1, numNodes-leftHalf-1, index, leafBlockStartValues);
-    } else {
-      assert numNodes == 0;
-    }
-  }
-
   private int getNumLeftLeafNodes(int numLeaves) {
     assert numLeaves > 1: "getNumLeftLeaveNodes() called with " + numLeaves;
     // return the level that can be filled with this number of leaves
@@ -724,7 +761,7 @@ public class BKDWriter implements Closeable {
   */
 
   private void checkMaxLeafNodeCount(int numLeaves) {
-    if ((1+bytesPerDim) * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) {
+    if (bytesPerDim * (long) numLeaves > ArrayUtil.MAX_ARRAY_LENGTH) {
       throw new IllegalStateException("too many nodes; increase maxPointsInLeafNode (currently " + maxPointsInLeafNode + ") and reindex");
     }
   }
@@ -755,6 +792,7 @@ public class BKDWriter implements Closeable {
 
 
     final int numLeaves = Math.toIntExact((pointCount + maxPointsInLeafNode - 1) / maxPointsInLeafNode);
+    final int numSplits = numLeaves - 1;
 
     checkMaxLeafNodeCount(numLeaves);
 
@@ -762,7 +800,8 @@ public class BKDWriter implements Closeable {
     // step of the recursion to recompute the split dim:
 
     // Indexed by nodeID, but first (root) nodeID is 1.  We do 1+ because the lead byte at each recursion says which dim we split on.
-    byte[] splitPackedValues = new byte[Math.toIntExact(numLeaves*(1+bytesPerDim))];
+    byte[] splitPackedValues = new byte[Math.toIntExact(numSplits*bytesPerDim)];
+    byte[] splitDimensionValues = new byte[numSplits];
 
     // +1 because leaf count is power of 2 (e.g. 8), and innerNodeCount is power of 2 minus 1 (e.g. 7)
     long[] leafBlockFPs = new long[numLeaves];
@@ -777,11 +816,12 @@ public class BKDWriter implements Closeable {
     try {
 
       final int[] parentSplits = new int[numIndexDims];
-      build(1, 0, numLeaves, points,
+      build(0, numLeaves, points,
               out, radixSelector,
               minPackedValue.clone(), maxPackedValue.clone(),
               parentSplits,
               splitPackedValues,
+              splitDimensionValues,
               leafBlockFPs,
               new int[maxPointsInLeafNode]);
       assert Arrays.equals(parentSplits, new int[numIndexDims]);
@@ -798,33 +838,39 @@ public class BKDWriter implements Closeable {
       }
     }
 
-    //System.out.println("Total nodes: " + innerNodeCount);
+    scratchBytesRef1.bytes = splitPackedValues;
+    scratchBytesRef1.length = bytesPerDim;
+    BKDTreeLeafNodes leafNodes  = new BKDTreeLeafNodes() {
+      @Override
+      public long getLeafLP(int index) {
+        return leafBlockFPs[index];
+      }
+
+      @Override
+      public BytesRef getSplitValue(int index) {
+        scratchBytesRef1.offset = index * bytesPerDim;
+        return scratchBytesRef1;
+      }
+
+      @Override
+      public int getSplitDimension(int index) {
+        return splitDimensionValues[index] & 0xff;
+      }
+
+      @Override
+      public int numLeaves() {
+        return leafBlockFPs.length;
+      }
+    };
 
     // Write index:
     long indexFP = out.getFilePointer();
-    writeIndex(out, maxPointsInLeafNode, leafBlockFPs, splitPackedValues);
+    writeIndex(out, maxPointsInLeafNode, leafNodes);
     return indexFP;
   }
 
   /** Packs the two arrays, representing a semi-balanced binary tree, into a compact byte[] structure. */
-  private byte[] packIndex(long[] leafBlockFPs, byte[] splitPackedValues) throws IOException {
-
-    int numLeaves = leafBlockFPs.length;
-    // Possibly rotate the leaf block FPs, if the index not fully balanced binary tree.
-    // In this case the leaf nodes may straddle the two bottom
-    // levels of the binary tree:
-    int lastFullLevel = 31 - Integer.numberOfLeadingZeros(numLeaves);
-    int leavesFullLevel = 1 << lastFullLevel;
-    int leavesPartialLevel = 2 * (numLeaves - leavesFullLevel);
-    if (leavesPartialLevel != 0) {
-      // Last level is partially filled, so we must rotate the leaf FPs to match.  We do this here, after loading
-      // at read-time, so that we can still delta code them on disk at write:
-      long[] newLeafBlockFPs = new long[numLeaves];
-      System.arraycopy(leafBlockFPs, leavesPartialLevel, newLeafBlockFPs, 0, numLeaves - leavesPartialLevel);
-      System.arraycopy(leafBlockFPs, 0, newLeafBlockFPs, numLeaves - leavesPartialLevel, leavesPartialLevel);
-      leafBlockFPs = newLeafBlockFPs;
-    }
-
+  private byte[] packIndex(BKDTreeLeafNodes leafNodes) throws IOException {
     /** Reused while packing the index */
     RAMOutputStream writeBuffer = new RAMOutputStream();
 
@@ -832,7 +878,8 @@ public class BKDWriter implements Closeable {
     List<byte[]> blocks = new ArrayList<>();
     byte[] lastSplitValues = new byte[bytesPerDim * numIndexDims];
     //System.out.println("\npack index");
-    int totalSize = recursePackIndex(writeBuffer, leafBlockFPs, splitPackedValues, 0l, blocks, 1, lastSplitValues, new boolean[numIndexDims], false);
+    int totalSize = recursePackIndex(writeBuffer, leafNodes, 0l, blocks, lastSplitValues, new boolean[numIndexDims], false,
+            0, leafNodes.numLeaves());
 
     // Compact the byte[] blocks into single byte index:
     byte[] index = new byte[totalSize];
@@ -859,45 +906,43 @@ public class BKDWriter implements Closeable {
   /**
    * lastSplitValues is per-dimension split value previously seen; we use this to prefix-code the split byte[] on each inner node
    */
-  private int recursePackIndex(RAMOutputStream writeBuffer, long[] leafBlockFPs, byte[] splitPackedValues, long minBlockFP, List<byte[]> blocks,
-                               int nodeID, byte[] lastSplitValues, boolean[] negativeDeltas, boolean isLeft) throws IOException {
-    if (nodeID >= leafBlockFPs.length) {
-      int leafID = nodeID - leafBlockFPs.length;
-      //System.out.println("recursePack leaf nodeID=" + nodeID);
-
-      // In the unbalanced case it's possible the left most node only has one child:
-      if (leafID < leafBlockFPs.length) {
-        long delta = leafBlockFPs[leafID] - minBlockFP;
-        if (isLeft) {
-          assert delta == 0;
-          return 0;
-        } else {
-          assert nodeID == 1 || delta > 0: "nodeID=" + nodeID;
-          writeBuffer.writeVLong(delta);
-          return appendBlock(writeBuffer, blocks);
-        }
-      } else {
+  private int recursePackIndex(RAMOutputStream writeBuffer, BKDTreeLeafNodes leafNodes, long minBlockFP, List<byte[]> blocks,
+                               byte[] lastSplitValues, boolean[] negativeDeltas, boolean isLeft, int leavesOffset, int numLeaves) throws IOException {
+    if (numLeaves == 1) {
+      if (isLeft) {
+        assert leafNodes.getLeafLP(leavesOffset) - minBlockFP == 0;
         return 0;
+      } else {
+        long delta = leafNodes.getLeafLP(leavesOffset) - minBlockFP;
+        assert leafNodes.numLeaves() == numLeaves || delta > 0 : "expected delta > 0; got numLeaves =" + numLeaves + " and delta=" + delta;
+        writeBuffer.writeVLong(delta);
+        return appendBlock(writeBuffer, blocks);
       }
     } else {
       long leftBlockFP;
-      if (isLeft == false) {
-        leftBlockFP = getLeftMostLeafBlockFP(leafBlockFPs, nodeID);
-        long delta = leftBlockFP - minBlockFP;
-        assert nodeID == 1 || delta > 0 : "expected nodeID=1 or delta > 0; got nodeID=" + nodeID + " and delta=" + delta;
-        writeBuffer.writeVLong(delta);
-      } else {
+      if (isLeft) {
         // The left tree's left most leaf block FP is always the minimal FP:
+        assert leafNodes.getLeafLP(leavesOffset) == minBlockFP;
         leftBlockFP = minBlockFP;
+      } else {
+        leftBlockFP = leafNodes.getLeafLP(leavesOffset);
+        long delta = leftBlockFP - minBlockFP;
+        assert leafNodes.numLeaves() == numLeaves || delta > 0 : "expected delta > 0; got numLeaves =" + numLeaves + " and delta=" + delta;
+        writeBuffer.writeVLong(delta);
       }
 
-      int address = nodeID * (1+bytesPerDim);
-      int splitDim = splitPackedValues[address++] & 0xff;
+      int numLeftLeafNodes = getNumLeftLeafNodes(numLeaves);
+      final int rightOffset = leavesOffset + numLeftLeafNodes;
+      final int splitOffset = rightOffset - 1;
+
+      int splitDim = leafNodes.getSplitDimension(splitOffset);
+      BytesRef splitValue = leafNodes.getSplitValue(splitOffset);
+      int address = splitValue.offset;
 
       //System.out.println("recursePack inner nodeID=" + nodeID + " splitDim=" + splitDim + " splitValue=" + new BytesRef(splitPackedValues, address, bytesPerDim));
 
       // find common prefix with last split value in this dim:
-      int prefix = FutureArrays.mismatch(splitPackedValues, address, address + bytesPerDim, lastSplitValues,
+      int prefix = FutureArrays.mismatch(splitValue.bytes, address, address + bytesPerDim, lastSplitValues,
               splitDim * bytesPerDim, splitDim * bytesPerDim + bytesPerDim);
       if (prefix == -1) {
         prefix = bytesPerDim;
@@ -908,7 +953,7 @@ public class BKDWriter implements Closeable {
       int firstDiffByteDelta;
       if (prefix < bytesPerDim) {
         //System.out.println("  delta byte cur=" + Integer.toHexString(splitPackedValues[address+prefix]&0xFF) + " prev=" + Integer.toHexString(lastSplitValues[splitDim * bytesPerDim + prefix]&0xFF) + " negated?=" + negativeDeltas[splitDim]);
-        firstDiffByteDelta = (splitPackedValues[address+prefix]&0xFF) - (lastSplitValues[splitDim * bytesPerDim + prefix]&0xFF);
+        firstDiffByteDelta = (splitValue.bytes[address+prefix]&0xFF) - (lastSplitValues[splitDim * bytesPerDim + prefix]&0xFF);
         if (negativeDeltas[splitDim]) {
           firstDiffByteDelta = -firstDiffByteDelta;
         }
@@ -930,7 +975,7 @@ public class BKDWriter implements Closeable {
       int suffix = bytesPerDim - prefix;
       byte[] savSplitValue = new byte[suffix];
       if (suffix > 1) {
-        writeBuffer.writeBytes(splitPackedValues, address+prefix+1, suffix-1);
+        writeBuffer.writeBytes(splitValue.bytes, address+prefix+1, suffix-1);
       }
 
       byte[] cmp = lastSplitValues.clone();
@@ -938,7 +983,7 @@ public class BKDWriter implements Closeable {
       System.arraycopy(lastSplitValues, splitDim * bytesPerDim + prefix, savSplitValue, 0, suffix);
 
       // copy our split value into lastSplitValues for our children to prefix-code against
-      System.arraycopy(splitPackedValues, address+prefix, lastSplitValues, splitDim * bytesPerDim + prefix, suffix);
+      System.arraycopy(splitValue.bytes, address+prefix, lastSplitValues, splitDim * bytesPerDim + prefix, suffix);
 
       int numBytes = appendBlock(writeBuffer, blocks);
 
@@ -950,13 +995,16 @@ public class BKDWriter implements Closeable {
       boolean savNegativeDelta = negativeDeltas[splitDim];
       negativeDeltas[splitDim] = true;
 
-      int leftNumBytes = recursePackIndex(writeBuffer, leafBlockFPs, splitPackedValues, leftBlockFP, blocks, 2*nodeID, lastSplitValues, negativeDeltas, true);
 
-      if (nodeID * 2 < leafBlockFPs.length) {
+      int leftNumBytes = recursePackIndex(writeBuffer, leafNodes, leftBlockFP, blocks, lastSplitValues, negativeDeltas, true,
+              leavesOffset, numLeftLeafNodes);
+
+      if (numLeftLeafNodes != 1) {
         writeBuffer.writeVInt(leftNumBytes);
       } else {
         assert leftNumBytes == 0: "leftNumBytes=" + leftNumBytes;
       }
+
       int numBytes2 = Math.toIntExact(writeBuffer.getFilePointer());
       byte[] bytes2 = new byte[numBytes2];
       writeBuffer.writeTo(bytes2, 0);
@@ -965,7 +1013,8 @@ public class BKDWriter implements Closeable {
       blocks.set(idxSav, bytes2);
 
       negativeDeltas[splitDim] = false;
-      int rightNumBytes = recursePackIndex(writeBuffer, leafBlockFPs, splitPackedValues, leftBlockFP, blocks, 2*nodeID+1, lastSplitValues, negativeDeltas, false);
+      int rightNumBytes = recursePackIndex(writeBuffer,  leafNodes, leftBlockFP, blocks, lastSplitValues, negativeDeltas, false,
+              rightOffset, numLeaves - numLeftLeafNodes);
 
       negativeDeltas[splitDim] = savNegativeDelta;
 
@@ -974,31 +1023,13 @@ public class BKDWriter implements Closeable {
 
       assert Arrays.equals(lastSplitValues, cmp);
 
-      return numBytes + numBytes2 + leftNumBytes + rightNumBytes;
-    }
-  }
-
-  private long getLeftMostLeafBlockFP(long[] leafBlockFPs, int nodeID) {
-    // TODO: can we do this cheaper, e.g. a closed form solution instead of while loop?  Or
-    // change the recursion while packing the index to return this left-most leaf block FP
-    // from each recursion instead?
-    //
-    // Still, the overall cost here is minor: this method's cost is O(log(N)), and while writing
-    // we call it O(N) times (N = number of leaf blocks)
-    while (nodeID < leafBlockFPs.length) {
-      nodeID *= 2;
+      return numBytes + bytes2.length + leftNumBytes + rightNumBytes;
     }
-    int leafID = nodeID - leafBlockFPs.length;
-    long result = leafBlockFPs[leafID];
-    if (result < 0) {
-      throw new AssertionError(result + " for leaf " + leafID);
-    }
-    return result;
   }
 
-  private void writeIndex(IndexOutput out, int countPerLeaf, long[] leafBlockFPs, byte[] splitPackedValues) throws IOException {
-    byte[] packedIndex = packIndex(leafBlockFPs, splitPackedValues);
-    writeIndex(out, countPerLeaf, leafBlockFPs.length, packedIndex);
+  private void writeIndex(IndexOutput out, int countPerLeaf, BKDTreeLeafNodes leafNodes) throws IOException {
+    byte[] packedIndex = packIndex(leafNodes);
+    writeIndex(out, countPerLeaf, leafNodes.numLeaves(), packedIndex);
   }
 
   private void writeIndex(IndexOutput out, int countPerLeaf, int numLeaves, byte[] packedIndex) throws IOException {
@@ -1292,12 +1323,13 @@ public class BKDWriter implements Closeable {
 
   /* Recursively reorders the provided reader and writes the bkd-tree on the fly; this method is used
    * when we are writing a new segment directly from IndexWriter's indexing buffer (MutablePointsReader). */
-  private void build(int nodeID, int leavesOffset, int numLeaves,
+  private void build(int leavesOffset, int numLeaves,
                      MutablePointValues reader, int from, int to,
                      IndexOutput out,
                      byte[] minPackedValue, byte[] maxPackedValue,
                      int[] parentSplits,
                      byte[] splitPackedValues,
+                     byte[] splitDimensionValues,
                      long[] leafBlockFPs,
                      int[] spareDocIds) throws IOException {
 
@@ -1418,7 +1450,7 @@ public class BKDWriter implements Closeable {
         // for dimensions > 2 we recompute the bounds for the current inner node to help the algorithm choose best
         // split dimensions. Because it is an expensive operation, the frequency we recompute the bounds is given
         // by SPLITS_BEFORE_EXACT_BOUNDS.
-        if (nodeID > 1 && numIndexDims > 2 && Arrays.stream(parentSplits).sum() % SPLITS_BEFORE_EXACT_BOUNDS == 0) {
+        if (numLeaves != leafBlockFPs.length && numIndexDims > 2 && Arrays.stream(parentSplits).sum() % SPLITS_BEFORE_EXACT_BOUNDS == 0) {
           computePackedValueBounds(reader, from, to, minPackedValue, maxPackedValue, scratchBytesRef1);
         }
         splitDim = split(minPackedValue, maxPackedValue, parentSplits);
@@ -1439,11 +1471,13 @@ public class BKDWriter implements Closeable {
       MutablePointsReaderUtils.partition(numDataDims, numIndexDims, maxDoc, splitDim, bytesPerDim, commonPrefixLen,
               reader, from, to, mid, scratchBytesRef1, scratchBytesRef2);
 
+      final int rightOffset = leavesOffset + numLeftLeafNodes;
+      final int splitOffset = rightOffset - 1;
       // set the split value
-      final int address = nodeID * (1+bytesPerDim);
-      splitPackedValues[address] = (byte) splitDim;
+      final int address = splitOffset * bytesPerDim;
+      splitDimensionValues[splitOffset] = (byte) splitDim;
       reader.getValue(mid, scratchBytesRef1);
-      System.arraycopy(scratchBytesRef1.bytes, scratchBytesRef1.offset + splitDim * bytesPerDim, splitPackedValues, address + 1, bytesPerDim);
+      System.arraycopy(scratchBytesRef1.bytes, scratchBytesRef1.offset + splitDim * bytesPerDim, splitPackedValues, address, bytesPerDim);
 
       byte[] minSplitPackedValue = ArrayUtil.copyOfSubArray(minPackedValue, 0, packedIndexBytesLength);
       byte[] maxSplitPackedValue = ArrayUtil.copyOfSubArray(maxPackedValue, 0, packedIndexBytesLength);
@@ -1454,12 +1488,12 @@ public class BKDWriter implements Closeable {
 
       // recurse
       parentSplits[splitDim]++;
-      build(nodeID * 2, leavesOffset, numLeftLeafNodes, reader, from, mid, out,
+      build(leavesOffset, numLeftLeafNodes, reader, from, mid, out,
               minPackedValue, maxSplitPackedValue, parentSplits,
-              splitPackedValues, leafBlockFPs, spareDocIds);
-      build(nodeID * 2 + 1, leavesOffset + numLeftLeafNodes, numLeaves - numLeftLeafNodes, reader, mid, to, out,
+              splitPackedValues, splitDimensionValues, leafBlockFPs, spareDocIds);
+      build(rightOffset, numLeaves - numLeftLeafNodes, reader, mid, to, out,
               minSplitPackedValue, maxPackedValue, parentSplits,
-              splitPackedValues, leafBlockFPs, spareDocIds);
+              splitPackedValues, splitDimensionValues, leafBlockFPs, spareDocIds);
       parentSplits[splitDim]--;
     }
   }
@@ -1489,13 +1523,14 @@ public class BKDWriter implements Closeable {
 
   /** The point writer contains the data that is going to be splitted using radix selection.
    /*  This method is used when we are merging previously written segments, in the numDims > 1 case. */
-  private void build(int nodeID, int leavesOffset, int numLeaves,
+  private void build(int leavesOffset, int numLeaves,
                      BKDRadixSelector.PathSlice points,
                      IndexOutput out,
                      BKDRadixSelector radixSelector,
                      byte[] minPackedValue, byte[] maxPackedValue,
                      int[] parentSplits,
                      byte[] splitPackedValues,
+                     byte[] splitDimensionValues,
                      long[] leafBlockFPs,
                      int[] spareDocIds) throws IOException {
 
@@ -1556,7 +1591,7 @@ public class BKDWriter implements Closeable {
       // Write docIDs first, as their own chunk, so that at intersect time we can add all docIDs w/o
       // loading the values:
       int count = to - from;
-      assert count > 0: "nodeID=" + nodeID + " leavesOffset=" + leavesOffset;
+      assert count > 0: "numLeaves=" + numLeaves + " leavesOffset=" + leavesOffset;
       assert count <= spareDocIds.length : "count=" + count + " > length=" + spareDocIds.length;
       // Write doc IDs
       int[] docIDs = spareDocIds;
@@ -1599,13 +1634,13 @@ public class BKDWriter implements Closeable {
         // for dimensions > 2 we recompute the bounds for the current inner node to help the algorithm choose best
         // split dimensions. Because it is an expensive operation, the frequency we recompute the bounds is given
         // by SPLITS_BEFORE_EXACT_BOUNDS.
-        if (nodeID > 1 && numIndexDims > 2 && Arrays.stream(parentSplits).sum() % SPLITS_BEFORE_EXACT_BOUNDS == 0) {
+        if (numLeaves != leafBlockFPs.length && numIndexDims > 2 && Arrays.stream(parentSplits).sum() % SPLITS_BEFORE_EXACT_BOUNDS == 0) {
           computePackedValueBounds(points, minPackedValue, maxPackedValue);
         }
         splitDim = split(minPackedValue, maxPackedValue, parentSplits);
       }
 
-      assert nodeID < splitPackedValues.length : "nodeID=" + nodeID + " splitValues.length=" + splitPackedValues.length;
+      assert numLeaves <= leafBlockFPs.length : "numLeaves=" + numLeaves + " leafBlockFPs.length=" + leafBlockFPs.length;
 
       // How many leaves will be in the left tree:
       int numLeftLeafNodes = getNumLeftLeafNodes(numLeaves);
@@ -1623,9 +1658,12 @@ public class BKDWriter implements Closeable {
 
       byte[] splitValue = radixSelector.select(points, slices, points.start, points.start + points.count,  points.start + leftCount, splitDim, commonPrefixLen);
 
-      int address = nodeID * (1 + bytesPerDim);
-      splitPackedValues[address] = (byte) splitDim;
-      System.arraycopy(splitValue, 0, splitPackedValues, address + 1, bytesPerDim);
+      final int rightOffset = leavesOffset + numLeftLeafNodes;
+      final int splitValueOffset = rightOffset - 1;
+
+      splitDimensionValues[splitValueOffset] = (byte) splitDim;
+      int address = splitValueOffset * bytesPerDim;
+      System.arraycopy(splitValue, 0, splitPackedValues, address, bytesPerDim);
 
       byte[] minSplitPackedValue = new byte[packedIndexBytesLength];
       System.arraycopy(minPackedValue, 0, minSplitPackedValue, 0, packedIndexBytesLength);
@@ -1638,14 +1676,14 @@ public class BKDWriter implements Closeable {
 
       parentSplits[splitDim]++;
       // Recurse on left tree:
-      build(2 * nodeID, leavesOffset, numLeftLeafNodes, slices[0],
+      build(leavesOffset, numLeftLeafNodes, slices[0],
               out, radixSelector, minPackedValue, maxSplitPackedValue,
-              parentSplits, splitPackedValues, leafBlockFPs, spareDocIds);
+              parentSplits, splitPackedValues, splitDimensionValues, leafBlockFPs, spareDocIds);
 
       // Recurse on right tree:
-      build(2 * nodeID + 1, leavesOffset + numLeftLeafNodes, numLeaves - numLeftLeafNodes, slices[1],
-              out, radixSelector, minSplitPackedValue, maxPackedValue
-              , parentSplits, splitPackedValues, leafBlockFPs, spareDocIds);
+      build(rightOffset, numLeaves - numLeftLeafNodes, slices[1],
+              out, radixSelector, minSplitPackedValue, maxPackedValue,
+              parentSplits, splitPackedValues, splitDimensionValues, leafBlockFPs, spareDocIds);
 
       parentSplits[splitDim]--;
     }