You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/27 06:27:31 UTC
[spark] branch master updated: [SPARK-42528][CORE] Optimize PercentileHeap
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0b8234db247 [SPARK-42528][CORE] Optimize PercentileHeap
0b8234db247 is described below
commit 0b8234db247dc3807ea70448ed069483c8b00687
Author: Alkis Evlogimenos <al...@databricks.com>
AuthorDate: Mon Feb 27 14:27:06 2023 +0800
[SPARK-42528][CORE] Optimize PercentileHeap
### What changes were proposed in this pull request?
Reimplement `PercentileHeap` such that:
- the percentile value is always in the `topHeap`, this speeds up `percentile` access
- rebalance the heaps more efficiently by checking which heap should grow due to the new insertion and doing a rebalance based on target heap sizes
- the heaps are java PriorityQueue's *without* comparators. Comparator call overhead slows down `poll`/`offer` by more than 2x. Instead implement a max-heap by `poll`/`offer` on the negated domain of numbers.
### Why are the changes needed?
`PercentileHeap` is heavy weight enough to cause scheduling delays if inserted inside the scheduler loop.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added more extensive unittests.
Closes #40121 from alkis/faster-percentile-heap.
Lead-authored-by: Alkis Evlogimenos <al...@databricks.com>
Co-authored-by: Alkis Evlogimenos <al...@evlogimenos.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/util/collection/PercentileHeap.scala | 126 ++++++++-------------
.../util/collection/PercentileHeapSuite.scala | 98 ++++++++--------
2 files changed, 96 insertions(+), 128 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala
index d95bbd04031..06e5042c92f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala
@@ -17,100 +17,66 @@
package org.apache.spark.util.collection
-import scala.collection.mutable.PriorityQueue
+import java.util.PriorityQueue
/**
- * PercentileHeap is designed to be used to quickly track the percentile of a group of numbers
- * that may contain duplicates. Inserting a new number has O(log n) time complexity and
- * determining the percentile has O(1) time complexity.
- * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf
- * stores the smaller half of all numbers while the largerHalf stores the larger half.
- * The sizes of two heaps need to match the percentage each time when a new number is inserted so
- * that the ratio of their sizes is percentage to (1 - percentage). Therefore each time when
- * percentile() is called we check if the sizes of two heaps match the percentage. If they do,
- * we should return the average of the two top values of heaps. Otherwise we return the top of the
- * heap which exceeds its percentage.
+ * PercentileHeap tracks the percentile of a collection of numbers.
+ *
+ * Insertion is O(log n), Lookup is O(1).
+ *
+ * The implementation keeps two heaps: a small heap (`smallHeap`) and a large heap (`largeHeap`).
+ * The small heap stores all the numbers below the percentile and the large heap stores the ones
+ * above the percentile. During insertion the relative sizes of the heaps are adjusted to match
+ * the target percentile.
*/
-private[spark] class PercentileHeap(percentage: Double = 0.5)(implicit val ord: Ordering[Double]) {
- assert(percentage >= 0 && percentage <= 1)
+private[spark] class PercentileHeap(percentage: Double = 0.5) {
+ assert(percentage > 0 && percentage < 1)
- /**
- * Stores all the numbers less than the current percentile in a smallerHalf,
- * i.e percentile is the maximum, at the root.
- */
- private[this] val smallerHalf = PriorityQueue.empty[Double](ord)
+ // This is a min-heap so it works out of the box.
+ private[this] val largeHeap = new PriorityQueue[Double]
+ // This is a max-heap. If we pass a comparator things get slower because of function call
+ // overhead (>2x slower on insert). Instead we negate values when we offer/poll/peek.
+ private[this] val smallHeap = new PriorityQueue[Double]
- /**
- * Stores all the numbers greater than the current percentile in a largerHalf,
- * i.e percentile is the minimum, at the root.
- */
- private[this] val largerHalf = PriorityQueue.empty[Double](ord.reverse)
+ def isEmpty(): Boolean = smallHeap.isEmpty && largeHeap.isEmpty
- def isEmpty(): Boolean = {
- smallerHalf.isEmpty && largerHalf.isEmpty
- }
+ def size(): Int = smallHeap.size + largeHeap.size
- def size(): Int = {
- smallerHalf.size + largerHalf.size
+ /**
+ * Returns percentile of the inserted elements as if the inserted elements were sorted and we
+ * returned `sorted(p)` where `p = (sorted.length * percentage).toInt` if number of elements
+ * is odd, otherwise `(sorted(p-1) + sorted(p)) / 2` if number of elements is even.
+ */
+ def percentile(): Double = {
+ if (isEmpty) throw new NoSuchElementException("empty")
+ if (size % 2 == 1 || smallHeap.isEmpty) {
+ largeHeap.peek
+ } else {
+ (largeHeap.peek + -smallHeap.peek) / 2.0
+ }
}
- // Exposed for testing.
- def smallerSize(): Int = smallerHalf.size
-
def insert(x: Double): Unit = {
- // If both heaps are empty, we insert it to the heap that has larger percentage.
if (isEmpty) {
- if (percentage < 0.5) smallerHalf.enqueue(x) else largerHalf.enqueue(x)
+ largeHeap.offer(x)
} else {
- // If the number is larger than current percentile, it should be inserted into largerHalf,
- // otherwise smallerHalf.
- if (x > percentile) {
- largerHalf.enqueue(x)
+ val p = largeHeap.peek
+ val growBot = ((size + 1) * percentage).toInt > smallHeap.size
+ if (growBot) {
+ if (x < p) {
+ smallHeap.offer(-x)
+ } else {
+ largeHeap.offer(x)
+ smallHeap.offer(-largeHeap.poll)
+ }
} else {
- smallerHalf.enqueue(x)
- }
- }
- rebalance()
- }
-
- // Calculate the deviation between the ratio of smaller heap size to larger heap size and the
- // expected ratio, which is percentage : (1 - percentage). Negative result means the smaller
- // heap has too less elements, positive result means the smaller heap has too many elements.
- private def calculateDeviation(smallerSize: Int, largerSize: Int): Double = {
- smallerSize * (1 - percentage) - largerSize * percentage
- }
-
- private[this] def rebalance(): Unit = {
- // If moving one value from heap to the other heap can fit the percentage better, then
- // move it.
- val currentDev = calculateDeviation(smallerHalf.size, largerHalf.size)
- if (currentDev > 0) {
- val newDev = calculateDeviation(smallerHalf.size - 1, largerHalf.size + 1)
- if (math.abs(newDev) < currentDev) {
- largerHalf.enqueue(smallerHalf.dequeue)
- }
- }
- if (currentDev < 0) {
- val newDev = calculateDeviation(smallerHalf.size + 1, largerHalf.size - 1)
- if (math.abs(newDev) < -currentDev) {
- smallerHalf.enqueue(largerHalf.dequeue())
+ if (x < p) {
+ smallHeap.offer(-x)
+ largeHeap.offer(-smallHeap.poll)
+ } else {
+ largeHeap.offer(x)
+ }
}
}
}
-
- def percentile: Double = {
- if (isEmpty) {
- throw new NoSuchElementException("PercentileHeap is empty.")
- }
- val dev = calculateDeviation(smallerHalf.size, largerHalf.size)
- // If the deviation is very small, we take the average of the top elements from the two heaps
- // as the percentile.
- if (smallerHalf.nonEmpty && largerHalf.nonEmpty && math.abs(dev / size) < 0.01) {
- (largerHalf.head + smallerHalf.head) / 2.0
- } else if (dev < 0) {
- largerHalf.head
- } else {
- smallerHalf.head
- }
- }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala
index 03d72a11552..d67aca6d4e7 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala
@@ -17,71 +17,73 @@
package org.apache.spark.util.collection
-import java.util.NoSuchElementException
+import scala.util.Random
import org.apache.spark.SparkFunSuite
class PercentileHeapSuite extends SparkFunSuite {
- test("If no numbers in PercentileHeap, NoSuchElementException is thrown.") {
- val medianHeap = new PercentileHeap()
+ test("When PercentileHeap is empty, NoSuchElementException is thrown.") {
+ val medianHeap = new PercentileHeap(0.5)
intercept[NoSuchElementException] {
medianHeap.percentile
}
}
- test("Median should be correct when size of PercentileHeap is even") {
- val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
- val medianHeap = new PercentileHeap()
- array.foreach(medianHeap.insert(_))
- assert(medianHeap.size() === 10)
- assert(medianHeap.smallerSize() === 5)
- assert(medianHeap.percentile === 4.5)
+ private def percentile(nums: Seq[Int], percentage: Double): Double = {
+ val p = (nums.length * percentage).toInt
+ val sorted = nums.sorted.toIndexedSeq
+ if (nums.length % 2 == 1 || p == 0) {
+ sorted(p)
+ } else {
+ (sorted(p - 1) + sorted(p)) / 2.0
+ }
}
- test("Median should be correct when size of PercentileHeap is odd") {
- val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
- val medianHeap = new PercentileHeap()
- array.foreach(medianHeap.insert(_))
- assert(medianHeap.size() === 9)
- assert(medianHeap.smallerSize() === 4)
- assert(medianHeap.percentile === 4)
+ private def testPercentileFor(nums: Seq[Int], percentage: Double) = {
+ val h = new PercentileHeap(percentage)
+ Random.shuffle(nums).foreach(h.insert(_))
+ assert(h.size == nums.length)
+ assert(h.percentile == percentile(nums, percentage))
}
- test("Median should be correct though there are duplicated numbers inside.") {
- val array = Array(0, 0, 1, 1, 2, 3, 4)
- val medianHeap = new PercentileHeap()
- array.foreach(medianHeap.insert(_))
- assert(medianHeap.size === 7)
- assert(medianHeap.smallerSize() === 3)
- assert(medianHeap.percentile === 1)
- }
+ private val tests = Seq(
+ 0 until 1,
+ 0 until 2,
+ 0 until 11,
+ 0 until 42,
+ 0 until 100
+ )
- test("Median should be correct when input data is skewed.") {
- val medianHeap = new PercentileHeap()
- (0 until 10).foreach(_ => medianHeap.insert(5))
- assert(medianHeap.percentile === 5)
- (0 until 100).foreach(_ => medianHeap.insert(10))
- assert(medianHeap.percentile === 10)
- (0 until 1000).foreach(_ => medianHeap.insert(0))
- assert(medianHeap.percentile === 0)
+ for (t <- tests) {
+ for (p <- Seq(1, 50, 99)) {
+ test(s"$p% of ${t.mkString(",")}") {
+ testPercentileFor(t, p / 100d)
+ }
+ }
}
- test("Percentile should be correct when size of PercentileHeap is even") {
- val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
- val percentileMap = new PercentileHeap(0.7)
- array.foreach(percentileMap.insert(_))
- assert(percentileMap.size() === 10)
- assert(percentileMap.smallerSize() == 7)
- assert(percentileMap.percentile === 6.5)
- }
+ ignore("benchmark") {
+ val input: Seq[Int] = 0 until 1000
+ val numRuns = 1000
+
+ def kernel(): Long = {
+ val shuffled = Random.shuffle(input).toArray
+ val start = System.nanoTime()
+ val h = new PercentileHeap(0.95)
+ shuffled.foreach { x =>
+ h.insert(x)
+ for (_ <- 0 until h.size) h.percentile
+ }
+ System.nanoTime() - start
+ }
+ for (_ <- 0 until numRuns) kernel() // warmup
- test("Percentile should be correct when size of PercentileHeap is odd") {
- val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
- val percentileMap = new PercentileHeap(0.7)
- array.foreach(percentileMap.insert(_))
- assert(percentileMap.size() === 9)
- assert(percentileMap.smallerSize() == 6)
- assert(percentileMap.percentile === 6)
+ var elapsed: Long = 0
+ for (_ <- 0 until numRuns) elapsed += kernel()
+ val perOp = elapsed / (numRuns * input.length)
+ // scalastyle:off println
+ println(s"$perOp ns per op on heaps of size ${input.length}")
+ // scalastyle:on println
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org