You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by br...@apache.org on 2021/11/01 10:02:36 UTC

[lucene] branch main updated: LUCENE-10196: Improve IntroSorter with 3-ways partitioning.

This is an automated email from the ASF dual-hosted git repository.

broustant pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/lucene.git


The following commit(s) were added to refs/heads/main by this push:
     new 63b9e60  LUCENE-10196: Improve IntroSorter with 3-ways partitioning.
63b9e60 is described below

commit 63b9e603e6f53dae40ece03814a9aa613f6cc189
Author: Bruno Roustant <br...@gmail.com>
AuthorDate: Thu Oct 21 16:18:32 2021 +0200

    LUCENE-10196: Improve IntroSorter with 3-ways partitioning.
---
 lucene/CHANGES.txt                                 |   2 +
 .../java/org/apache/lucene/util/IntroSorter.java   | 132 +++++++++++++++------
 .../src/java/org/apache/lucene/util/Sorter.java    |  24 +++-
 .../org/apache/lucene/util/BaseSortTestCase.java   |  53 +++++----
 .../org/apache/lucene/util/SorterBenchmark.java    | 112 +++++++++++++++++
 5 files changed, 263 insertions(+), 60 deletions(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 886fe28..3b02df9 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -474,6 +474,8 @@ Optimizations
   postings in memory, and reduced a bit of RAM overhead in
   IndexWriter's internal postings book-keeping (mashudong)
 
+* LUCENE-10196: Improve IntroSorter with 3-ways partitioning. (Bruno Roustant)
+
 Bug Fixes
 ---------------------
 
diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java b/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
index 630c208..99f1f6d 100644
--- a/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
@@ -20,66 +20,124 @@ package org.apache.lucene.util;
  * {@link Sorter} implementation based on a variant of the quicksort algorithm called <a
  * href="http://en.wikipedia.org/wiki/Introsort">introsort</a>: when the recursion level exceeds the
  * log of the length of the array to sort, it falls back to heapsort. This prevents quicksort from
- * running into its worst-case quadratic runtime. Small arrays are sorted with insertion sort.
+ * running into its worst-case quadratic runtime. Small ranges are sorted with insertion sort.
+ *
+ * <p>This sort algorithm is fast on most data shapes, especially with low cardinality. If the data
+ * to sort is known to be strictly ascending or descending, prefer {@link TimSorter}.
  *
  * @lucene.internal
  */
 public abstract class IntroSorter extends Sorter {
 
+  /** Below this size threshold, the partition selection is simplified to a single median. */
+  private static final int SINGLE_MEDIAN_THRESHOLD = 40;
+
   /** Create a new {@link IntroSorter}. */
   public IntroSorter() {}
 
   @Override
   public final void sort(int from, int to) {
     checkRange(from, to);
-    quicksort(from, to, 2 * MathUtil.log(to - from, 2));
+    sort(from, to, 2 * MathUtil.log(to - from, 2));
   }
 
-  void quicksort(int from, int to, int maxDepth) {
-    if (to - from < BINARY_SORT_THRESHOLD) {
-      binarySort(from, to);
-      return;
-    } else if (--maxDepth < 0) {
-      heapSort(from, to);
-      return;
-    }
-
-    final int mid = (from + to) >>> 1;
-
-    if (compare(from, mid) > 0) {
-      swap(from, mid);
-    }
-
-    if (compare(mid, to - 1) > 0) {
-      swap(mid, to - 1);
-      if (compare(from, mid) > 0) {
-        swap(from, mid);
+  /**
+   * Sorts between from (inclusive) and to (exclusive) with intro sort.
+   *
+   * <p>Sorts small ranges with insertion sort. Fallbacks to heap sort to avoid quadratic worst
+   * case. Selects the pivot with medians and partitions with the Bentley-McIlroy fast 3-ways
+   * algorithm (Engineering a Sort Function, Bentley-McIlroy).
+   */
+  void sort(int from, int to, int maxDepth) {
+    int size;
+
+    // Sort small ranges with insertion sort.
+    while ((size = to - from) > INSERTION_SORT_THRESHOLD) {
+
+      if (--maxDepth < 0) {
+        // Max recursion depth reached: fallback to heap sort.
+        heapSort(from, to);
+        return;
       }
-    }
-
-    int left = from + 1;
-    int right = to - 2;
 
-    setPivot(mid);
-    for (; ; ) {
-      while (comparePivot(right) < 0) {
-        --right;
+      // Pivot selection based on medians.
+      int last = to - 1;
+      int mid = (from + last) >>> 1;
+      int pivot;
+      if (size <= SINGLE_MEDIAN_THRESHOLD) {
+        // Select the pivot with a single median around the middle element.
+        // Do not take the median between [from, mid, last] because it hurts performance
+        // if the order is descending.
+        int range = size >> 2;
+        pivot = median(mid - range, mid, mid + range);
+      } else {
+        // Select the pivot with the median of medians.
+        int range = size >> 3;
+        int doubleRange = range << 1;
+        int medianFirst = median(from, from + range, from + doubleRange);
+        int medianMiddle = median(mid - range, mid, mid + range);
+        int medianLast = median(last - doubleRange, last - range, last);
+        pivot = median(medianFirst, medianMiddle, medianLast);
       }
 
-      while (left < right && comparePivot(left) >= 0) {
-        ++left;
+      // Bentley-McIlroy 3-way partitioning.
+      setPivot(pivot);
+      swap(from, pivot);
+      int i = from;
+      int j = to;
+      int p = from + 1;
+      int q = last;
+      while (true) {
+        int leftCmp, rightCmp;
+        while ((leftCmp = comparePivot(++i)) > 0) {}
+        while ((rightCmp = comparePivot(--j)) < 0) {}
+        if (i >= j) {
+          if (i == j && rightCmp == 0) {
+            swap(i, p);
+          }
+          break;
+        }
+        swap(i, j);
+        if (rightCmp == 0) {
+          swap(i, p++);
+        }
+        if (leftCmp == 0) {
+          swap(j, q--);
+        }
+      }
+      i = j + 1;
+      for (int k = from; k < p; ) {
+        swap(k++, j--);
+      }
+      for (int k = last; k > q; ) {
+        swap(k--, i++);
       }
 
-      if (left < right) {
-        swap(left, right);
-        --right;
+      // Recursion on the smallest partition. Replace the tail recursion by a loop.
+      if (j - from < last - i) {
+        sort(from, j + 1, maxDepth);
+        from = i;
       } else {
-        break;
+        sort(i, to, maxDepth);
+        to = j + 1;
       }
     }
 
-    quicksort(from, left + 1, maxDepth);
-    quicksort(left + 1, to, maxDepth);
+    insertionSort(from, to);
+  }
+
+  /** Returns the index of the median element among three elements at provided indices. */
+  private int median(int i, int j, int k) {
+    if (compare(i, j) < 0) {
+      if (compare(j, k) <= 0) {
+        return j;
+      }
+      return compare(i, k) < 0 ? k : i;
+    }
+    if (compare(j, k) >= 0) {
+      return j;
+    }
+    return compare(i, k) < 0 ? i : k;
   }
 
   // Don't rely on the slow default impl of setPivot/comparePivot since
diff --git a/lucene/core/src/java/org/apache/lucene/util/Sorter.java b/lucene/core/src/java/org/apache/lucene/util/Sorter.java
index 9f0f8ac..f0d2fc9 100644
--- a/lucene/core/src/java/org/apache/lucene/util/Sorter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/Sorter.java
@@ -27,6 +27,9 @@ public abstract class Sorter {
 
   static final int BINARY_SORT_THRESHOLD = 20;
 
+  /** Below this size threshold, the sub-range is sorted using Insertion sort. */
+  static final int INSERTION_SORT_THRESHOLD = 16;
+
   /** Sole constructor, used for inheritance. */
   protected Sorter() {}
 
@@ -190,7 +193,7 @@ public abstract class Sorter {
   /**
    * A binary sort implementation. This performs {@code O(n*log(n))} comparisons and {@code O(n^2)}
    * swaps. It is typically used by more sophisticated implementations as a fall-back when the
-   * numbers of items to sort has become less than {@value #BINARY_SORT_THRESHOLD}.
+   * number of items to sort has become less than {@value #BINARY_SORT_THRESHOLD}.
    */
   void binarySort(int from, int to) {
     binarySort(from, to, from + 1);
@@ -217,6 +220,25 @@ public abstract class Sorter {
   }
 
   /**
+   * Sorts between from (inclusive) and to (exclusive) with insertion sort. Runs in {@code O(n^2)}.
+   * It is typically used by more sophisticated implementations as a fall-back when the number of
+   * items to sort becomes less than {@value #INSERTION_SORT_THRESHOLD}.
+   */
+  void insertionSort(int from, int to) {
+    for (int i = from + 1; i < to; ) {
+      int current = i++;
+      int previous;
+      while (compare((previous = current - 1), current) > 0) {
+        swap(previous, current);
+        if (previous == from) {
+          break;
+        }
+        current = previous;
+      }
+    }
+  }
+
+  /**
    * Use heap sort to sort items between {@code from} inclusive and {@code to} exclusive. This runs
    * in {@code O(n*log(n))} and is used as a fall-back by {@link IntroSorter}.
    */
diff --git a/lucene/core/src/test/org/apache/lucene/util/BaseSortTestCase.java b/lucene/core/src/test/org/apache/lucene/util/BaseSortTestCase.java
index 93d18b7..0e05299 100644
--- a/lucene/core/src/test/org/apache/lucene/util/BaseSortTestCase.java
+++ b/lucene/core/src/test/org/apache/lucene/util/BaseSortTestCase.java
@@ -17,6 +17,7 @@
 package org.apache.lucene.util;
 
 import java.util.Arrays;
+import java.util.Random;
 
 public abstract class BaseSortTestCase extends LuceneTestCase {
 
@@ -32,7 +33,7 @@ public abstract class BaseSortTestCase extends LuceneTestCase {
 
     @Override
     public int compareTo(Entry other) {
-      return value < other.value ? -1 : value == other.value ? 0 : 1;
+      return Integer.compare(value, other.value);
     }
   }
 
@@ -68,70 +69,78 @@ public abstract class BaseSortTestCase extends LuceneTestCase {
   enum Strategy {
     RANDOM {
       @Override
-      public void set(Entry[] arr, int i) {
-        arr[i] = new Entry(random().nextInt(), i);
+      public void set(Entry[] arr, int i, Random random) {
+        arr[i] = new Entry(random.nextInt(), i);
       }
     },
     RANDOM_LOW_CARDINALITY {
       @Override
-      public void set(Entry[] arr, int i) {
-        arr[i] = new Entry(random().nextInt(6), i);
+      public void set(Entry[] arr, int i, Random random) {
+        arr[i] = new Entry(random.nextInt(6), i);
+      }
+    },
+    RANDOM_MEDIUM_CARDINALITY {
+      @Override
+      public void set(Entry[] arr, int i, Random random) {
+        arr[i] = new Entry(random.nextInt(arr.length / 2), i);
       }
     },
     ASCENDING {
       @Override
-      public void set(Entry[] arr, int i) {
+      public void set(Entry[] arr, int i, Random random) {
         arr[i] =
             i == 0
-                ? new Entry(random().nextInt(6), 0)
-                : new Entry(arr[i - 1].value + random().nextInt(6), i);
+                ? new Entry(random.nextInt(6), 0)
+                : new Entry(arr[i - 1].value + random.nextInt(6), i);
       }
     },
     DESCENDING {
       @Override
-      public void set(Entry[] arr, int i) {
+      public void set(Entry[] arr, int i, Random random) {
         arr[i] =
             i == 0
-                ? new Entry(random().nextInt(6), 0)
-                : new Entry(arr[i - 1].value - random().nextInt(6), i);
+                ? new Entry(random.nextInt(6), 0)
+                : new Entry(arr[i - 1].value - random.nextInt(6), i);
       }
     },
     STRICTLY_DESCENDING {
       @Override
-      public void set(Entry[] arr, int i) {
+      public void set(Entry[] arr, int i, Random random) {
         arr[i] =
             i == 0
-                ? new Entry(random().nextInt(6), 0)
-                : new Entry(arr[i - 1].value - TestUtil.nextInt(random(), 1, 5), i);
+                ? new Entry(random.nextInt(6), 0)
+                : new Entry(arr[i - 1].value - TestUtil.nextInt(random, 1, 5), i);
       }
     },
     ASCENDING_SEQUENCES {
       @Override
-      public void set(Entry[] arr, int i) {
+      public void set(Entry[] arr, int i, Random random) {
         arr[i] =
             i == 0
-                ? new Entry(random().nextInt(6), 0)
+                ? new Entry(random.nextInt(6), 0)
                 : new Entry(
-                    rarely() ? random().nextInt(1000) : arr[i - 1].value + random().nextInt(6), i);
+                    rarely(random) ? random.nextInt(1000) : arr[i - 1].value + random.nextInt(6),
+                    i);
       }
     },
     MOSTLY_ASCENDING {
       @Override
-      public void set(Entry[] arr, int i) {
+      public void set(Entry[] arr, int i, Random random) {
         arr[i] =
             i == 0
-                ? new Entry(random().nextInt(6), 0)
-                : new Entry(arr[i - 1].value + TestUtil.nextInt(random(), -8, 10), i);
+                ? new Entry(random.nextInt(6), 0)
+                : new Entry(arr[i - 1].value + TestUtil.nextInt(random, -8, 10), i);
       }
     };
 
-    public abstract void set(Entry[] arr, int i);
+    public abstract void set(Entry[] arr, int i, Random random);
   }
 
   public void test(Strategy strategy, int length) {
+    Random random = random();
     final Entry[] arr = new Entry[length];
     for (int i = 0; i < arr.length; ++i) {
-      strategy.set(arr, i);
+      strategy.set(arr, i, random);
     }
     test(arr);
   }
diff --git a/lucene/core/src/test/org/apache/lucene/util/SorterBenchmark.java b/lucene/core/src/test/org/apache/lucene/util/SorterBenchmark.java
new file mode 100644
index 0000000..9b62cb0
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/SorterBenchmark.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+import java.util.Locale;
+import java.util.Random;
+import org.apache.lucene.util.BaseSortTestCase.Entry;
+import org.apache.lucene.util.BaseSortTestCase.Strategy;
+
+/**
+ * Benchmark for {@link Sorter} implementations.
+ *
+ * <p>Run the static {@link #main(String[])} method to start the benchmark.
+ */
+public class SorterBenchmark {
+
+  private static final int ARRAY_LENGTH = 20000;
+  private static final int RUNS = 10;
+  private static final int LOOPS = 100;
+
+  private enum SorterFactory {
+    INTRO_SORTER(
+        "IntroSorter",
+        (arr, s) -> {
+          return new ArrayIntroSorter<>(arr, Entry::compareTo);
+        }),
+    TIM_SORTER(
+        "TimSorter",
+        (arr, s) -> {
+          return new ArrayTimSorter<>(arr, Entry::compareTo, arr.length / 64);
+        }),
+    MERGE_SORTER(
+        "MergeSorter",
+        (arr, s) -> {
+          return new ArrayInPlaceMergeSorter<>(arr, Entry::compareTo);
+        }),
+    ;
+    final String name;
+    final Builder builder;
+
+    SorterFactory(String name, Builder builder) {
+      this.name = name;
+      this.builder = builder;
+    }
+
+    interface Builder {
+      Sorter build(Entry[] arr, Strategy strategy);
+    }
+  }
+
+  public static void main(String[] args) throws Exception {
+    assert false : "Disable assertions to run the benchmark";
+    Random random = new Random(System.currentTimeMillis());
+    long seed = random.nextLong();
+
+    System.out.println("WARMUP");
+    benchmarkSorters(Strategy.RANDOM, random, seed);
+    System.out.println();
+
+    for (Strategy strategy : Strategy.values()) {
+      System.out.println(strategy);
+      benchmarkSorters(strategy, random, seed);
+    }
+  }
+
+  private static void benchmarkSorters(Strategy strategy, Random random, long seed) {
+    for (SorterFactory sorterFactory : SorterFactory.values()) {
+      System.out.printf(Locale.ROOT, "  %-12s...", sorterFactory.name);
+      random.setSeed(seed);
+      benchmarkSorter(strategy, sorterFactory, random);
+      System.out.println();
+    }
+  }
+
+  private static void benchmarkSorter(
+      Strategy strategy, SorterFactory sorterFactory, Random random) {
+    for (int i = 0; i < RUNS; i++) {
+      Entry[] original = createArray(strategy, random);
+      Entry[] clone = original.clone();
+      Sorter sorter = sorterFactory.builder.build(clone, strategy);
+      long startTimeNs = System.nanoTime();
+      for (int j = 0; j < LOOPS; j++) {
+        System.arraycopy(original, 0, clone, 0, original.length);
+        sorter.sort(0, clone.length);
+      }
+      long timeMs = (System.nanoTime() - startTimeNs) / 1000000;
+      System.out.printf(Locale.ROOT, "%5d", timeMs);
+    }
+  }
+
+  private static Entry[] createArray(Strategy strategy, Random random) {
+    Entry[] arr = new Entry[ARRAY_LENGTH];
+    for (int i = 0; i < arr.length; ++i) {
+      strategy.set(arr, i, random);
+    }
+    return arr;
+  }
+}