You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by br...@apache.org on 2021/11/17 10:51:40 UTC

[lucene] branch branch_9x updated: LUCENE-10225: Improve IntroSelector with 3-way partitioning.

This is an automated email from the ASF dual-hosted git repository.

broustant pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/branch_9x by this push:
     new 02a63f6  LUCENE-10225: Improve IntroSelector with 3-way partitioning.
02a63f6 is described below

commit 02a63f688c88f0c7549844c34b940ad9ff83058b
Author: Bruno Roustant <33...@users.noreply.github.com>
AuthorDate: Wed Nov 17 10:38:27 2021 +0100

    LUCENE-10225: Improve IntroSelector with 3-way partitioning.
---
 lucene/CHANGES.txt                                 |   3 +-
 .../java/org/apache/lucene/util/IntroSelector.java | 288 ++++++++++++---------
 .../java/org/apache/lucene/util/IntroSorter.java   |  15 +-
 .../src/java/org/apache/lucene/util/MathUtil.java  |   5 +-
 .../org/apache/lucene/util/SelectorBenchmark.java  | 126 +++++++++
 .../org/apache/lucene/util/TestIntroSelector.java  |  34 ++-
 6 files changed, 317 insertions(+), 154 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index bf2f255..5336f7d 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -25,7 +25,8 @@ Improvements
 
 Optimizations
 ---------------------
-(No changes)
+
+* LUCENE-10225: Improve IntroSelector with 3-ways partitioning. (Bruno Roustant, Adrien Grand)
 
 Bug Fixes
 ---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
index 6d00fbc..2ade7ab 100644
--- a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
+++ b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
@@ -17,173 +17,200 @@
 package org.apache.lucene.util;
 
 import java.util.Comparator;
+import java.util.SplittableRandom;
 
 /**
- * Implementation of the quick select algorithm.
+ * Adaptive selection algorithm based on the introspective quick select algorithm. The quick select
+ * algorithm uses an interpolation variant of Tukey's ninther median-of-medians for pivot, and
+ * Bentley-McIlroy 3-way partitioning. For the introspective protection, it shuffles the sub-range
+ * if the max recursive depth is exceeded.
  *
- * <p>It uses the median of the first, middle and last values as a pivot and falls back to a median
- * of medians when the number of recursion levels exceeds {@code 2 lg(n)}, as a consequence it runs
- * in linear time on average.
+ * <p>This selection algorithm is fast on most data shapes, especially on nearly sorted data, or
+ * when k is close to the boundaries. It runs in linear time on average.
  *
  * @lucene.internal
  */
 public abstract class IntroSelector extends Selector {
 
+  // This selector is used repeatedly by the radix selector for sub-ranges of less than
+  // 100 entries. This means this selector is also optimized to be fast on small ranges.
+  // It uses the variant of medians-of-medians and 3-way partitioning, and finishes the
+  // last tiny range (3 entries or less) with a very specialized sort.
+
+  private SplittableRandom random;
+
   @Override
   public final void select(int from, int to, int k) {
     checkArgs(from, to, k);
-    final int maxDepth = 2 * MathUtil.log(to - from, 2);
-    quickSelect(from, to, k, maxDepth);
+    select(from, to, k, 2 * MathUtil.log(to - from, 2));
   }
 
-  int slowSelect(int from, int to, int k) {
-    return medianOfMediansSelect(from, to - 1, k);
-  }
+  // Visible for testing.
+  void select(int from, int to, int k, int maxDepth) {
+    // This code is inspired from IntroSorter#sort, adapted to loop on a single partition.
+
+    // For efficiency, we must enter the loop with at least 4 entries to be able to skip
+    // some boundary tests during the 3-way partitioning.
+    int size;
+    while ((size = to - from) > 3) {
 
-  int medianOfMediansSelect(int left, int right, int k) {
-    do {
-      // Defensive check, this is also checked in the calling
-      // method. Including here so this method can be used
-      // as a self contained quickSelect implementation.
-      if (left == right) {
-        return left;
+      if (--maxDepth == -1) {
+        // Max recursion depth exceeded: shuffle (only once) and continue.
+        shuffle(from, to);
       }
-      int pivotIndex = pivot(left, right);
-      pivotIndex = partition(left, right, k, pivotIndex);
-      if (k == pivotIndex) {
-        return k;
-      } else if (k < pivotIndex) {
-        right = pivotIndex - 1;
+
+      // Pivot selection based on medians.
+      int last = to - 1;
+      int mid = (from + last) >>> 1;
+      int pivot;
+      if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) {
+        // Select the pivot with a single median around the middle element.
+        // Do not take the median between [from, mid, last] because it hurts performance
+        // if the order is descending in conjunction with the 3-way partitioning.
+        int range = size >> 2;
+        pivot = median(mid - range, mid, mid + range);
       } else {
-        left = pivotIndex + 1;
+        // Select the pivot with a variant of the Tukey's ninther median of medians.
+        // If k is close to the boundaries, select either the lowest or highest median (this variant
+        // is inspired from the interpolation search).
+        int range = size >> 3;
+        int doubleRange = range << 1;
+        int medianFirst = median(from, from + range, from + doubleRange);
+        int medianMiddle = median(mid - range, mid, mid + range);
+        int medianLast = median(last - doubleRange, last - range, last);
+        if (k - from < range) {
+          // k is close to 'from': select the lowest median.
+          pivot = min(medianFirst, medianMiddle, medianLast);
+        } else if (to - k <= range) {
+          // k is close to 'to': select the highest median.
+          pivot = max(medianFirst, medianMiddle, medianLast);
+        } else {
+          // Otherwise select the median of medians.
+          pivot = median(medianFirst, medianMiddle, medianLast);
+        }
       }
-    } while (left != right);
-    return left;
-  }
 
-  private int partition(int left, int right, int k, int pivotIndex) {
-    setPivot(pivotIndex);
-    swap(pivotIndex, right);
-    int storeIndex = left;
-    for (int i = left; i < right; i++) {
-      if (comparePivot(i) > 0) {
-        swap(storeIndex, i);
-        storeIndex++;
+      // Bentley-McIlroy 3-way partitioning.
+      setPivot(pivot);
+      swap(from, pivot);
+      int i = from;
+      int j = to;
+      int p = from + 1;
+      int q = last;
+      while (true) {
+        int leftCmp, rightCmp;
+        while ((leftCmp = comparePivot(++i)) > 0) {}
+        while ((rightCmp = comparePivot(--j)) < 0) {}
+        if (i >= j) {
+          if (i == j && rightCmp == 0) {
+            swap(i, p);
+          }
+          break;
+        }
+        swap(i, j);
+        if (rightCmp == 0) {
+          swap(i, p++);
+        }
+        if (leftCmp == 0) {
+          swap(j, q--);
+        }
       }
-    }
-    int storeIndexEq = storeIndex;
-    for (int i = storeIndex; i < right; i++) {
-      if (comparePivot(i) == 0) {
-        swap(storeIndexEq, i);
-        storeIndexEq++;
+      i = j + 1;
+      for (int l = from; l < p; ) {
+        swap(l++, j--);
+      }
+      for (int l = last; l > q; ) {
+        swap(l--, i++);
       }
-    }
-    swap(right, storeIndexEq);
-    if (k < storeIndex) {
-      return storeIndex;
-    } else if (k <= storeIndexEq) {
-      return k;
-    }
-    return storeIndexEq;
-  }
 
-  private int pivot(int left, int right) {
-    if (right - left < 5) {
-      int pivotIndex = partition5(left, right);
-      return pivotIndex;
+      // Select the partition containing the k-th element.
+      if (k <= j) {
+        to = j + 1;
+      } else if (k >= i) {
+        from = i;
+      } else {
+        return;
+      }
     }
 
-    for (int i = left; i <= right; i = i + 5) {
-      int subRight = i + 4;
-      if (subRight > right) {
-        subRight = right;
-      }
-      int median5 = partition5(i, subRight);
-      swap(median5, left + ((i - left) / 5));
+    // Sort the final tiny range (3 entries or less) with a very specialized sort.
+    switch (size) {
+      case 2:
+        if (compare(from, from + 1) > 0) {
+          swap(from, from + 1);
+        }
+        break;
+      case 3:
+        sort3(from);
+        break;
     }
-    int mid = ((right - left) / 10) + left + 1;
-    int to = left + ((right - left) / 5);
-    return medianOfMediansSelect(left, to, mid);
   }
 
-  // selects the median of a group of at most five elements,
-  // implemented using insertion sort. Efficient due to
-  // bounded nature of data set.
-  private int partition5(int left, int right) {
-    int i = left + 1;
-    while (i <= right) {
-      int j = i;
-      while (j > left && compare(j - 1, j) > 0) {
-        swap(j - 1, j);
-        j--;
-      }
-      i++;
+  /** Returns the index of the min element among three elements at provided indices. */
+  private int min(int i, int j, int k) {
+    if (compare(i, j) <= 0) {
+      return compare(i, k) <= 0 ? i : k;
     }
-    return (left + right) >>> 1;
+    return compare(j, k) <= 0 ? j : k;
   }
 
-  private void quickSelect(int from, int to, int k, int maxDepth) {
-    assert from <= k;
-    assert k < to;
-    if (to - from == 1) {
-      return;
-    }
-    if (--maxDepth < 0) {
-      slowSelect(from, to, k);
-      return;
+  /** Returns the index of the max element among three elements at provided indices. */
+  private int max(int i, int j, int k) {
+    if (compare(i, j) <= 0) {
+      return compare(j, k) < 0 ? k : j;
     }
+    return compare(i, k) < 0 ? k : i;
+  }
 
-    final int mid = (from + to) >>> 1;
-    // heuristic: we use the median of the values at from, to-1 and mid as a pivot
-    if (compare(from, to - 1) > 0) {
-      swap(from, to - 1);
-    }
-    if (compare(to - 1, mid) > 0) {
-      swap(to - 1, mid);
-      if (compare(from, to - 1) > 0) {
-        swap(from, to - 1);
+  /** Copy of {@code IntroSorter#median}. */
+  private int median(int i, int j, int k) {
+    if (compare(i, j) < 0) {
+      if (compare(j, k) <= 0) {
+        return j;
       }
+      return compare(i, k) < 0 ? k : i;
     }
-
-    setPivot(to - 1);
-
-    int left = from + 1;
-    int right = to - 2;
-
-    for (; ; ) {
-      while (comparePivot(left) > 0) {
-        ++left;
-      }
-
-      while (left < right && comparePivot(right) <= 0) {
-        --right;
-      }
-
-      if (left < right) {
-        swap(left, right);
-        --right;
-      } else {
-        break;
-      }
+    if (compare(j, k) >= 0) {
+      return j;
     }
-    swap(left, to - 1);
+    return compare(i, k) < 0 ? i : k;
+  }
 
-    if (left == k) {
-      return;
-    } else if (left < k) {
-      quickSelect(left + 1, to, k, maxDepth);
+  /**
+   * Sorts 3 entries starting at from (inclusive). This specialized method is more efficient than
+   * calling {@link Sorter#insertionSort(int, int)}.
+   */
+  private void sort3(int from) {
+    final int mid = from + 1;
+    final int last = from + 2;
+    if (compare(from, mid) <= 0) {
+      if (compare(mid, last) > 0) {
+        swap(mid, last);
+        if (compare(from, mid) > 0) {
+          swap(from, mid);
+        }
+      }
+    } else if (compare(mid, last) >= 0) {
+      swap(from, last);
     } else {
-      quickSelect(from, left, k, maxDepth);
+      swap(from, mid);
+      if (compare(mid, last) > 0) {
+        swap(mid, last);
+      }
     }
   }
 
   /**
-   * Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
-   * value is the same as {@link Comparator#compare(Object, Object)}.
+   * Shuffles the entries between from (inclusive) and to (exclusive) with Durstenfeld's algorithm.
    */
-  protected int compare(int i, int j) {
-    setPivot(i);
-    return comparePivot(j);
+  private void shuffle(int from, int to) {
+    if (this.random == null) {
+      this.random = new SplittableRandom();
+    }
+    SplittableRandom random = this.random;
+    for (int i = to - 1; i > from; i--) {
+      swap(i, random.nextInt(from, i + 1));
+    }
   }
 
   /**
@@ -197,4 +224,13 @@ public abstract class IntroSelector extends Selector {
    * compare(i, j)}.
    */
   protected abstract int comparePivot(int j);
+
+  /**
+   * Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
+   * value is the same as {@link Comparator#compare(Object, Object)}.
+   */
+  protected int compare(int i, int j) {
+    setPivot(i);
+    return comparePivot(j);
+  }
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java b/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
index 99f1f6d..dbd003f 100644
--- a/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
@@ -20,7 +20,9 @@ package org.apache.lucene.util;
  * {@link Sorter} implementation based on a variant of the quicksort algorithm called <a
  * href="http://en.wikipedia.org/wiki/Introsort">introsort</a>: when the recursion level exceeds the
  * log of the length of the array to sort, it falls back to heapsort. This prevents quicksort from
- * running into its worst-case quadratic runtime. Small ranges are sorted with insertion sort.
+ * running into its worst-case quadratic runtime. Selects the pivot using Tukey's ninther
+ * median-of-medians, and partitions using Bentley-McIlroy 3-way partitioning. Small ranges are
+ * sorted with insertion sort.
  *
  * <p>This sort algorithm is fast on most data shapes, especially with low cardinality. If the data
  * to sort is known to be strictly ascending or descending, prefer {@link TimSorter}.
@@ -30,7 +32,7 @@ package org.apache.lucene.util;
 public abstract class IntroSorter extends Sorter {
 
   /** Below this size threshold, the partition selection is simplified to a single median. */
-  private static final int SINGLE_MEDIAN_THRESHOLD = 40;
+  static final int SINGLE_MEDIAN_THRESHOLD = 40;
 
   /** Create a new {@link IntroSorter}. */
   public IntroSorter() {}
@@ -49,13 +51,12 @@ public abstract class IntroSorter extends Sorter {
    * algorithm (Engineering a Sort Function, Bentley-McIlroy).
    */
   void sort(int from, int to, int maxDepth) {
-    int size;
-
     // Sort small ranges with insertion sort.
+    int size;
     while ((size = to - from) > INSERTION_SORT_THRESHOLD) {
 
       if (--maxDepth < 0) {
-        // Max recursion depth reached: fallback to heap sort.
+        // Max recursion depth exceeded: fallback to heap sort.
         heapSort(from, to);
         return;
       }
@@ -67,11 +68,11 @@ public abstract class IntroSorter extends Sorter {
       if (size <= SINGLE_MEDIAN_THRESHOLD) {
         // Select the pivot with a single median around the middle element.
         // Do not take the median between [from, mid, last] because it hurts performance
-        // if the order is descending.
+        // if the order is descending in conjunction with the 3-way partitioning.
         int range = size >> 2;
         pivot = median(mid - range, mid, mid + range);
       } else {
-        // Select the pivot with the median of medians.
+        // Select the pivot with the Tukey's ninther median of medians.
         int range = size >> 3;
         int doubleRange = range << 1;
         int medianFirst = median(from, from + range, from + doubleRange);
diff --git a/lucene/core/src/java/org/apache/lucene/util/MathUtil.java b/lucene/core/src/java/org/apache/lucene/util/MathUtil.java
index 78c0676..9733835 100644
--- a/lucene/core/src/java/org/apache/lucene/util/MathUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/MathUtil.java
@@ -30,7 +30,10 @@ public final class MathUtil {
    * @param base must be {@code > 1}
    */
   public static int log(long x, int base) {
-    if (base <= 1) {
+    if (base == 2) {
+      // This specialized method is 30x faster.
+      return x <= 0 ? 0 : 63 - Long.numberOfLeadingZeros(x);
+    } else if (base <= 1) {
       throw new IllegalArgumentException("base must be > 1");
     }
     int ret = 0;
diff --git a/lucene/core/src/test/org/apache/lucene/util/SelectorBenchmark.java b/lucene/core/src/test/org/apache/lucene/util/SelectorBenchmark.java
new file mode 100644
index 0000000..ec9f6a1
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/SelectorBenchmark.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+import java.util.Locale;
+import java.util.Random;
+import org.apache.lucene.util.BaseSortTestCase.Entry;
+import org.apache.lucene.util.BaseSortTestCase.Strategy;
+
+/**
+ * Benchmark for {@link Selector} implementations.
+ *
+ * <p>Run the static {@link #main(String[])} method to start the benchmark.
+ */
+public class SelectorBenchmark {
+
+  private static final int ARRAY_LENGTH = 20000;
+  private static final int RUNS = 10;
+  private static final int LOOPS = 800;
+
+  private enum SelectorFactory {
+    INTRO_SELECTOR(
+        "IntroSelector",
+        (arr, s) -> {
+          return new IntroSelector() {
+
+            Entry pivot;
+
+            @Override
+            protected void swap(int i, int j) {
+              ArrayUtil.swap(arr, i, j);
+            }
+
+            @Override
+            protected void setPivot(int i) {
+              pivot = arr[i];
+            }
+
+            @Override
+            protected int comparePivot(int j) {
+              return pivot.compareTo(arr[j]);
+            }
+          };
+        }),
+    ;
+    final String name;
+    final Builder builder;
+
+    SelectorFactory(String name, Builder builder) {
+      this.name = name;
+      this.builder = builder;
+    }
+
+    interface Builder {
+      Selector build(Entry[] arr, Strategy strategy);
+    }
+  }
+
+  public static void main(String[] args) throws Exception {
+    assert false : "Disable assertions to run the benchmark";
+    Random random = new Random(System.currentTimeMillis());
+    long seed = random.nextLong();
+
+    System.out.println("WARMUP");
+    benchmarkSelectors(Strategy.RANDOM, random, seed);
+    System.out.println();
+
+    for (Strategy strategy : Strategy.values()) {
+      System.out.println(strategy);
+      benchmarkSelectors(strategy, random, seed);
+    }
+  }
+
+  private static void benchmarkSelectors(Strategy strategy, Random random, long seed) {
+    for (SelectorFactory selectorFactory : SelectorFactory.values()) {
+      System.out.printf(Locale.ROOT, "  %-15s...", selectorFactory.name);
+      random.setSeed(seed);
+      benchmarkSelector(strategy, selectorFactory, random);
+      System.out.println();
+    }
+  }
+
+  private static void benchmarkSelector(
+      Strategy strategy, SelectorFactory selectorFactory, Random random) {
+    for (int i = 0; i < RUNS; i++) {
+      Entry[] original = createArray(strategy, random);
+      Entry[] clone = original.clone();
+      Selector selector = selectorFactory.builder.build(clone, strategy);
+      long startTimeNs = System.nanoTime();
+      int k = random.nextInt(clone.length);
+      int kIncrement = random.nextInt(clone.length / 14) * 2 + 1;
+      for (int j = 0; j < LOOPS; j++) {
+        System.arraycopy(original, 0, clone, 0, original.length);
+        selector.select(0, clone.length, k);
+        k += kIncrement;
+        if (k >= clone.length) {
+          k -= clone.length;
+        }
+      }
+      long timeMs = (System.nanoTime() - startTimeNs) / 1000000;
+      System.out.printf(Locale.ROOT, "%5d", timeMs);
+    }
+  }
+
+  private static Entry[] createArray(Strategy strategy, Random random) {
+    Entry[] arr = new Entry[ARRAY_LENGTH];
+    for (int i = 0; i < arr.length; ++i) {
+      strategy.set(arr, i, random);
+    }
+    return arr;
+  }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java b/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java
index f595362..ef8b409 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java
@@ -17,30 +17,26 @@
 package org.apache.lucene.util;
 
 import java.util.Arrays;
+import java.util.Random;
 
 public class TestIntroSelector extends LuceneTestCase {
 
   public void testSelect() {
+    Random random = random();
     for (int iter = 0; iter < 100; ++iter) {
-      doTestSelect(false);
+      doTestSelect(random);
     }
   }
 
-  public void testSlowSelect() {
-    for (int iter = 0; iter < 100; ++iter) {
-      doTestSelect(true);
-    }
-  }
-
-  private void doTestSelect(boolean slow) {
-    final int from = random().nextInt(5);
-    final int to = from + TestUtil.nextInt(random(), 1, 10000);
-    final int max = random().nextBoolean() ? random().nextInt(100) : random().nextInt(100000);
-    Integer[] arr = new Integer[to + random().nextInt(5)];
+  private void doTestSelect(Random random) {
+    final int from = random.nextInt(5);
+    final int to = from + TestUtil.nextInt(random, 1, 10000);
+    final int max = random.nextBoolean() ? random.nextInt(100) : random.nextInt(100000);
+    Integer[] arr = new Integer[to + random.nextInt(5)];
     for (int i = 0; i < arr.length; ++i) {
-      arr[i] = TestUtil.nextInt(random(), 0, max);
+      arr[i] = TestUtil.nextInt(random, 0, max);
     }
-    final int k = TestUtil.nextInt(random(), from, to - 1);
+    final int k = TestUtil.nextInt(random, from, to - 1);
 
     Integer[] expected = arr.clone();
     Arrays.sort(expected, from, to);
@@ -66,10 +62,10 @@ public class TestIntroSelector extends LuceneTestCase {
             return pivot.compareTo(actual[j]);
           }
         };
-    if (slow) {
-      selector.slowSelect(from, to, k);
-    } else {
+    if (random.nextBoolean()) {
       selector.select(from, to, k);
+    } else {
+      selector.select(from, to, k, random.nextInt(3));
     }
 
     assertEquals(expected[k], actual[k]);
@@ -77,9 +73,9 @@ public class TestIntroSelector extends LuceneTestCase {
       if (i < from || i >= to) {
         assertSame(arr[i], actual[i]);
       } else if (i <= k) {
-        assertTrue(actual[i].intValue() <= actual[k].intValue());
+        assertTrue(actual[i] <= actual[k]);
       } else {
-        assertTrue(actual[i].intValue() >= actual[k].intValue());
+        assertTrue(actual[i] >= actual[k]);
       }
     }
   }