You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by br...@apache.org on 2021/11/17 10:51:40 UTC
[lucene] branch branch_9x updated: LUCENE-10225: Improve IntroSelector with 3-way partitioning.
This is an automated email from the ASF dual-hosted git repository.
broustant pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/lucene.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new 02a63f6 LUCENE-10225: Improve IntroSelector with 3-way partitioning.
02a63f6 is described below
commit 02a63f688c88f0c7549844c34b940ad9ff83058b
Author: Bruno Roustant <33...@users.noreply.github.com>
AuthorDate: Wed Nov 17 10:38:27 2021 +0100
LUCENE-10225: Improve IntroSelector with 3-way partitioning.
---
lucene/CHANGES.txt | 3 +-
.../java/org/apache/lucene/util/IntroSelector.java | 288 ++++++++++++---------
.../java/org/apache/lucene/util/IntroSorter.java | 15 +-
.../src/java/org/apache/lucene/util/MathUtil.java | 5 +-
.../org/apache/lucene/util/SelectorBenchmark.java | 126 +++++++++
.../org/apache/lucene/util/TestIntroSelector.java | 34 ++-
6 files changed, 317 insertions(+), 154 deletions(-)
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index bf2f255..5336f7d 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -25,7 +25,8 @@ Improvements
Optimizations
---------------------
-(No changes)
+
+* LUCENE-10225: Improve IntroSelector with 3-ways partitioning. (Bruno Roustant, Adrien Grand)
Bug Fixes
---------------------
diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
index 6d00fbc..2ade7ab 100644
--- a/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
+++ b/lucene/core/src/java/org/apache/lucene/util/IntroSelector.java
@@ -17,173 +17,200 @@
package org.apache.lucene.util;
import java.util.Comparator;
+import java.util.SplittableRandom;
/**
- * Implementation of the quick select algorithm.
+ * Adaptive selection algorithm based on the introspective quick select algorithm. The quick select
+ * algorithm uses an interpolation variant of Tukey's ninther median-of-medians for pivot, and
+ * Bentley-McIlroy 3-way partitioning. For the introspective protection, it shuffles the sub-range
+ * if the max recursive depth is exceeded.
*
- * <p>It uses the median of the first, middle and last values as a pivot and falls back to a median
- * of medians when the number of recursion levels exceeds {@code 2 lg(n)}, as a consequence it runs
- * in linear time on average.
+ * <p>This selection algorithm is fast on most data shapes, especially on nearly sorted data, or
+ * when k is close to the boundaries. It runs in linear time on average.
*
* @lucene.internal
*/
public abstract class IntroSelector extends Selector {
+ // This selector is used repeatedly by the radix selector for sub-ranges of less than
+ // 100 entries. This means this selector is also optimized to be fast on small ranges.
+ // It uses the variant of medians-of-medians and 3-way partitioning, and finishes the
+ // last tiny range (3 entries or less) with a very specialized sort.
+
+ private SplittableRandom random;
+
@Override
public final void select(int from, int to, int k) {
checkArgs(from, to, k);
- final int maxDepth = 2 * MathUtil.log(to - from, 2);
- quickSelect(from, to, k, maxDepth);
+ select(from, to, k, 2 * MathUtil.log(to - from, 2));
}
- int slowSelect(int from, int to, int k) {
- return medianOfMediansSelect(from, to - 1, k);
- }
+ // Visible for testing.
+ void select(int from, int to, int k, int maxDepth) {
+ // This code is inspired from IntroSorter#sort, adapted to loop on a single partition.
+
+ // For efficiency, we must enter the loop with at least 4 entries to be able to skip
+ // some boundary tests during the 3-way partitioning.
+ int size;
+ while ((size = to - from) > 3) {
- int medianOfMediansSelect(int left, int right, int k) {
- do {
- // Defensive check, this is also checked in the calling
- // method. Including here so this method can be used
- // as a self contained quickSelect implementation.
- if (left == right) {
- return left;
+ if (--maxDepth == -1) {
+ // Max recursion depth exceeded: shuffle (only once) and continue.
+ shuffle(from, to);
}
- int pivotIndex = pivot(left, right);
- pivotIndex = partition(left, right, k, pivotIndex);
- if (k == pivotIndex) {
- return k;
- } else if (k < pivotIndex) {
- right = pivotIndex - 1;
+
+ // Pivot selection based on medians.
+ int last = to - 1;
+ int mid = (from + last) >>> 1;
+ int pivot;
+ if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) {
+ // Select the pivot with a single median around the middle element.
+ // Do not take the median between [from, mid, last] because it hurts performance
+ // if the order is descending in conjunction with the 3-way partitioning.
+ int range = size >> 2;
+ pivot = median(mid - range, mid, mid + range);
} else {
- left = pivotIndex + 1;
+ // Select the pivot with a variant of the Tukey's ninther median of medians.
+ // If k is close to the boundaries, select either the lowest or highest median (this variant
+ // is inspired from the interpolation search).
+ int range = size >> 3;
+ int doubleRange = range << 1;
+ int medianFirst = median(from, from + range, from + doubleRange);
+ int medianMiddle = median(mid - range, mid, mid + range);
+ int medianLast = median(last - doubleRange, last - range, last);
+ if (k - from < range) {
+ // k is close to 'from': select the lowest median.
+ pivot = min(medianFirst, medianMiddle, medianLast);
+ } else if (to - k <= range) {
+ // k is close to 'to': select the highest median.
+ pivot = max(medianFirst, medianMiddle, medianLast);
+ } else {
+ // Otherwise select the median of medians.
+ pivot = median(medianFirst, medianMiddle, medianLast);
+ }
}
- } while (left != right);
- return left;
- }
- private int partition(int left, int right, int k, int pivotIndex) {
- setPivot(pivotIndex);
- swap(pivotIndex, right);
- int storeIndex = left;
- for (int i = left; i < right; i++) {
- if (comparePivot(i) > 0) {
- swap(storeIndex, i);
- storeIndex++;
+ // Bentley-McIlroy 3-way partitioning.
+ setPivot(pivot);
+ swap(from, pivot);
+ int i = from;
+ int j = to;
+ int p = from + 1;
+ int q = last;
+ while (true) {
+ int leftCmp, rightCmp;
+ while ((leftCmp = comparePivot(++i)) > 0) {}
+ while ((rightCmp = comparePivot(--j)) < 0) {}
+ if (i >= j) {
+ if (i == j && rightCmp == 0) {
+ swap(i, p);
+ }
+ break;
+ }
+ swap(i, j);
+ if (rightCmp == 0) {
+ swap(i, p++);
+ }
+ if (leftCmp == 0) {
+ swap(j, q--);
+ }
}
- }
- int storeIndexEq = storeIndex;
- for (int i = storeIndex; i < right; i++) {
- if (comparePivot(i) == 0) {
- swap(storeIndexEq, i);
- storeIndexEq++;
+ i = j + 1;
+ for (int l = from; l < p; ) {
+ swap(l++, j--);
+ }
+ for (int l = last; l > q; ) {
+ swap(l--, i++);
}
- }
- swap(right, storeIndexEq);
- if (k < storeIndex) {
- return storeIndex;
- } else if (k <= storeIndexEq) {
- return k;
- }
- return storeIndexEq;
- }
- private int pivot(int left, int right) {
- if (right - left < 5) {
- int pivotIndex = partition5(left, right);
- return pivotIndex;
+ // Select the partition containing the k-th element.
+ if (k <= j) {
+ to = j + 1;
+ } else if (k >= i) {
+ from = i;
+ } else {
+ return;
+ }
}
- for (int i = left; i <= right; i = i + 5) {
- int subRight = i + 4;
- if (subRight > right) {
- subRight = right;
- }
- int median5 = partition5(i, subRight);
- swap(median5, left + ((i - left) / 5));
+ // Sort the final tiny range (3 entries or less) with a very specialized sort.
+ switch (size) {
+ case 2:
+ if (compare(from, from + 1) > 0) {
+ swap(from, from + 1);
+ }
+ break;
+ case 3:
+ sort3(from);
+ break;
}
- int mid = ((right - left) / 10) + left + 1;
- int to = left + ((right - left) / 5);
- return medianOfMediansSelect(left, to, mid);
}
- // selects the median of a group of at most five elements,
- // implemented using insertion sort. Efficient due to
- // bounded nature of data set.
- private int partition5(int left, int right) {
- int i = left + 1;
- while (i <= right) {
- int j = i;
- while (j > left && compare(j - 1, j) > 0) {
- swap(j - 1, j);
- j--;
- }
- i++;
+ /** Returns the index of the min element among three elements at provided indices. */
+ private int min(int i, int j, int k) {
+ if (compare(i, j) <= 0) {
+ return compare(i, k) <= 0 ? i : k;
}
- return (left + right) >>> 1;
+ return compare(j, k) <= 0 ? j : k;
}
- private void quickSelect(int from, int to, int k, int maxDepth) {
- assert from <= k;
- assert k < to;
- if (to - from == 1) {
- return;
- }
- if (--maxDepth < 0) {
- slowSelect(from, to, k);
- return;
+ /** Returns the index of the max element among three elements at provided indices. */
+ private int max(int i, int j, int k) {
+ if (compare(i, j) <= 0) {
+ return compare(j, k) < 0 ? k : j;
}
+ return compare(i, k) < 0 ? k : i;
+ }
- final int mid = (from + to) >>> 1;
- // heuristic: we use the median of the values at from, to-1 and mid as a pivot
- if (compare(from, to - 1) > 0) {
- swap(from, to - 1);
- }
- if (compare(to - 1, mid) > 0) {
- swap(to - 1, mid);
- if (compare(from, to - 1) > 0) {
- swap(from, to - 1);
+ /** Copy of {@code IntroSorter#median}. */
+ private int median(int i, int j, int k) {
+ if (compare(i, j) < 0) {
+ if (compare(j, k) <= 0) {
+ return j;
}
+ return compare(i, k) < 0 ? k : i;
}
-
- setPivot(to - 1);
-
- int left = from + 1;
- int right = to - 2;
-
- for (; ; ) {
- while (comparePivot(left) > 0) {
- ++left;
- }
-
- while (left < right && comparePivot(right) <= 0) {
- --right;
- }
-
- if (left < right) {
- swap(left, right);
- --right;
- } else {
- break;
- }
+ if (compare(j, k) >= 0) {
+ return j;
}
- swap(left, to - 1);
+ return compare(i, k) < 0 ? i : k;
+ }
- if (left == k) {
- return;
- } else if (left < k) {
- quickSelect(left + 1, to, k, maxDepth);
+ /**
+ * Sorts 3 entries starting at from (inclusive). This specialized method is more efficient than
+ * calling {@link Sorter#insertionSort(int, int)}.
+ */
+ private void sort3(int from) {
+ final int mid = from + 1;
+ final int last = from + 2;
+ if (compare(from, mid) <= 0) {
+ if (compare(mid, last) > 0) {
+ swap(mid, last);
+ if (compare(from, mid) > 0) {
+ swap(from, mid);
+ }
+ }
+ } else if (compare(mid, last) >= 0) {
+ swap(from, last);
} else {
- quickSelect(from, left, k, maxDepth);
+ swap(from, mid);
+ if (compare(mid, last) > 0) {
+ swap(mid, last);
+ }
}
}
/**
- * Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
- * value is the same as {@link Comparator#compare(Object, Object)}.
+ * Shuffles the entries between from (inclusive) and to (exclusive) with Durstenfeld's algorithm.
*/
- protected int compare(int i, int j) {
- setPivot(i);
- return comparePivot(j);
+ private void shuffle(int from, int to) {
+ if (this.random == null) {
+ this.random = new SplittableRandom();
+ }
+ SplittableRandom random = this.random;
+ for (int i = to - 1; i > from; i--) {
+ swap(i, random.nextInt(from, i + 1));
+ }
}
/**
@@ -197,4 +224,13 @@ public abstract class IntroSelector extends Selector {
* compare(i, j)}.
*/
protected abstract int comparePivot(int j);
+
+ /**
+ * Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
+ * value is the same as {@link Comparator#compare(Object, Object)}.
+ */
+ protected int compare(int i, int j) {
+ setPivot(i);
+ return comparePivot(j);
+ }
}
diff --git a/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java b/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
index 99f1f6d..dbd003f 100644
--- a/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
+++ b/lucene/core/src/java/org/apache/lucene/util/IntroSorter.java
@@ -20,7 +20,9 @@ package org.apache.lucene.util;
* {@link Sorter} implementation based on a variant of the quicksort algorithm called <a
* href="http://en.wikipedia.org/wiki/Introsort">introsort</a>: when the recursion level exceeds the
* log of the length of the array to sort, it falls back to heapsort. This prevents quicksort from
- * running into its worst-case quadratic runtime. Small ranges are sorted with insertion sort.
+ * running into its worst-case quadratic runtime. Selects the pivot using Tukey's ninther
+ * median-of-medians, and partitions using Bentley-McIlroy 3-way partitioning. Small ranges are
+ * sorted with insertion sort.
*
* <p>This sort algorithm is fast on most data shapes, especially with low cardinality. If the data
* to sort is known to be strictly ascending or descending, prefer {@link TimSorter}.
@@ -30,7 +32,7 @@ package org.apache.lucene.util;
public abstract class IntroSorter extends Sorter {
/** Below this size threshold, the partition selection is simplified to a single median. */
- private static final int SINGLE_MEDIAN_THRESHOLD = 40;
+ static final int SINGLE_MEDIAN_THRESHOLD = 40;
/** Create a new {@link IntroSorter}. */
public IntroSorter() {}
@@ -49,13 +51,12 @@ public abstract class IntroSorter extends Sorter {
* algorithm (Engineering a Sort Function, Bentley-McIlroy).
*/
void sort(int from, int to, int maxDepth) {
- int size;
-
// Sort small ranges with insertion sort.
+ int size;
while ((size = to - from) > INSERTION_SORT_THRESHOLD) {
if (--maxDepth < 0) {
- // Max recursion depth reached: fallback to heap sort.
+ // Max recursion depth exceeded: fallback to heap sort.
heapSort(from, to);
return;
}
@@ -67,11 +68,11 @@ public abstract class IntroSorter extends Sorter {
if (size <= SINGLE_MEDIAN_THRESHOLD) {
// Select the pivot with a single median around the middle element.
// Do not take the median between [from, mid, last] because it hurts performance
- // if the order is descending.
+ // if the order is descending in conjunction with the 3-way partitioning.
int range = size >> 2;
pivot = median(mid - range, mid, mid + range);
} else {
- // Select the pivot with the median of medians.
+ // Select the pivot with the Tukey's ninther median of medians.
int range = size >> 3;
int doubleRange = range << 1;
int medianFirst = median(from, from + range, from + doubleRange);
diff --git a/lucene/core/src/java/org/apache/lucene/util/MathUtil.java b/lucene/core/src/java/org/apache/lucene/util/MathUtil.java
index 78c0676..9733835 100644
--- a/lucene/core/src/java/org/apache/lucene/util/MathUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/util/MathUtil.java
@@ -30,7 +30,10 @@ public final class MathUtil {
* @param base must be {@code > 1}
*/
public static int log(long x, int base) {
- if (base <= 1) {
+ if (base == 2) {
+ // This specialized method is 30x faster.
+ return x <= 0 ? 0 : 63 - Long.numberOfLeadingZeros(x);
+ } else if (base <= 1) {
throw new IllegalArgumentException("base must be > 1");
}
int ret = 0;
diff --git a/lucene/core/src/test/org/apache/lucene/util/SelectorBenchmark.java b/lucene/core/src/test/org/apache/lucene/util/SelectorBenchmark.java
new file mode 100644
index 0000000..ec9f6a1
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/util/SelectorBenchmark.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.util;
+
+import java.util.Locale;
+import java.util.Random;
+import org.apache.lucene.util.BaseSortTestCase.Entry;
+import org.apache.lucene.util.BaseSortTestCase.Strategy;
+
+/**
+ * Benchmark for {@link Selector} implementations.
+ *
+ * <p>Run the static {@link #main(String[])} method to start the benchmark.
+ */
+public class SelectorBenchmark {
+
+ private static final int ARRAY_LENGTH = 20000;
+ private static final int RUNS = 10;
+ private static final int LOOPS = 800;
+
+ private enum SelectorFactory {
+ INTRO_SELECTOR(
+ "IntroSelector",
+ (arr, s) -> {
+ return new IntroSelector() {
+
+ Entry pivot;
+
+ @Override
+ protected void swap(int i, int j) {
+ ArrayUtil.swap(arr, i, j);
+ }
+
+ @Override
+ protected void setPivot(int i) {
+ pivot = arr[i];
+ }
+
+ @Override
+ protected int comparePivot(int j) {
+ return pivot.compareTo(arr[j]);
+ }
+ };
+ }),
+ ;
+ final String name;
+ final Builder builder;
+
+ SelectorFactory(String name, Builder builder) {
+ this.name = name;
+ this.builder = builder;
+ }
+
+ interface Builder {
+ Selector build(Entry[] arr, Strategy strategy);
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ assert false : "Disable assertions to run the benchmark";
+ Random random = new Random(System.currentTimeMillis());
+ long seed = random.nextLong();
+
+ System.out.println("WARMUP");
+ benchmarkSelectors(Strategy.RANDOM, random, seed);
+ System.out.println();
+
+ for (Strategy strategy : Strategy.values()) {
+ System.out.println(strategy);
+ benchmarkSelectors(strategy, random, seed);
+ }
+ }
+
+ private static void benchmarkSelectors(Strategy strategy, Random random, long seed) {
+ for (SelectorFactory selectorFactory : SelectorFactory.values()) {
+ System.out.printf(Locale.ROOT, " %-15s...", selectorFactory.name);
+ random.setSeed(seed);
+ benchmarkSelector(strategy, selectorFactory, random);
+ System.out.println();
+ }
+ }
+
+ private static void benchmarkSelector(
+ Strategy strategy, SelectorFactory selectorFactory, Random random) {
+ for (int i = 0; i < RUNS; i++) {
+ Entry[] original = createArray(strategy, random);
+ Entry[] clone = original.clone();
+ Selector selector = selectorFactory.builder.build(clone, strategy);
+ long startTimeNs = System.nanoTime();
+ int k = random.nextInt(clone.length);
+ int kIncrement = random.nextInt(clone.length / 14) * 2 + 1;
+ for (int j = 0; j < LOOPS; j++) {
+ System.arraycopy(original, 0, clone, 0, original.length);
+ selector.select(0, clone.length, k);
+ k += kIncrement;
+ if (k >= clone.length) {
+ k -= clone.length;
+ }
+ }
+ long timeMs = (System.nanoTime() - startTimeNs) / 1000000;
+ System.out.printf(Locale.ROOT, "%5d", timeMs);
+ }
+ }
+
+ private static Entry[] createArray(Strategy strategy, Random random) {
+ Entry[] arr = new Entry[ARRAY_LENGTH];
+ for (int i = 0; i < arr.length; ++i) {
+ strategy.set(arr, i, random);
+ }
+ return arr;
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java b/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java
index f595362..ef8b409 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestIntroSelector.java
@@ -17,30 +17,26 @@
package org.apache.lucene.util;
import java.util.Arrays;
+import java.util.Random;
public class TestIntroSelector extends LuceneTestCase {
public void testSelect() {
+ Random random = random();
for (int iter = 0; iter < 100; ++iter) {
- doTestSelect(false);
+ doTestSelect(random);
}
}
- public void testSlowSelect() {
- for (int iter = 0; iter < 100; ++iter) {
- doTestSelect(true);
- }
- }
-
- private void doTestSelect(boolean slow) {
- final int from = random().nextInt(5);
- final int to = from + TestUtil.nextInt(random(), 1, 10000);
- final int max = random().nextBoolean() ? random().nextInt(100) : random().nextInt(100000);
- Integer[] arr = new Integer[to + random().nextInt(5)];
+ private void doTestSelect(Random random) {
+ final int from = random.nextInt(5);
+ final int to = from + TestUtil.nextInt(random, 1, 10000);
+ final int max = random.nextBoolean() ? random.nextInt(100) : random.nextInt(100000);
+ Integer[] arr = new Integer[to + random.nextInt(5)];
for (int i = 0; i < arr.length; ++i) {
- arr[i] = TestUtil.nextInt(random(), 0, max);
+ arr[i] = TestUtil.nextInt(random, 0, max);
}
- final int k = TestUtil.nextInt(random(), from, to - 1);
+ final int k = TestUtil.nextInt(random, from, to - 1);
Integer[] expected = arr.clone();
Arrays.sort(expected, from, to);
@@ -66,10 +62,10 @@ public class TestIntroSelector extends LuceneTestCase {
return pivot.compareTo(actual[j]);
}
};
- if (slow) {
- selector.slowSelect(from, to, k);
- } else {
+ if (random.nextBoolean()) {
selector.select(from, to, k);
+ } else {
+ selector.select(from, to, k, random.nextInt(3));
}
assertEquals(expected[k], actual[k]);
@@ -77,9 +73,9 @@ public class TestIntroSelector extends LuceneTestCase {
if (i < from || i >= to) {
assertSame(arr[i], actual[i]);
} else if (i <= k) {
- assertTrue(actual[i].intValue() <= actual[k].intValue());
+ assertTrue(actual[i] <= actual[k]);
} else {
- assertTrue(actual[i].intValue() >= actual[k].intValue());
+ assertTrue(actual[i] >= actual[k]);
}
}
}