You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/27 06:27:31 UTC

[spark] branch master updated: [SPARK-42528][CORE] Optimize PercentileHeap

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b8234db247 [SPARK-42528][CORE] Optimize PercentileHeap
0b8234db247 is described below

commit 0b8234db247dc3807ea70448ed069483c8b00687
Author: Alkis Evlogimenos <al...@databricks.com>
AuthorDate: Mon Feb 27 14:27:06 2023 +0800

    [SPARK-42528][CORE] Optimize PercentileHeap
    
    ### What changes were proposed in this pull request?
    
    Reimplement `PercentileHeap` such that:
    - the percentile value is always in the `topHeap`, this speeds up `percentile` access
    - rebalance the heaps more efficiently by checking which heap should grow due to the new insertion and doing a rebalance based on target heap sizes
    - the heaps are java PriorityQueue's *without* comparators. Comparator call overhead slows down `poll`/`offer` by more than 2x. Instead implement a max-heap by `poll`/`offer` on the negated domain of numbers.
    
    ### Why are the changes needed?
    
    `PercentileHeap` is heavy weight enough to cause scheduling delays if inserted inside the scheduler loop.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added more extensive unittests.
    
    Closes #40121 from alkis/faster-percentile-heap.
    
    Lead-authored-by: Alkis Evlogimenos <al...@databricks.com>
    Co-authored-by: Alkis Evlogimenos <al...@evlogimenos.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/util/collection/PercentileHeap.scala     | 126 ++++++++-------------
 .../util/collection/PercentileHeapSuite.scala      |  98 ++++++++--------
 2 files changed, 96 insertions(+), 128 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala
index d95bbd04031..06e5042c92f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PercentileHeap.scala
@@ -17,100 +17,66 @@
 
 package org.apache.spark.util.collection
 
-import scala.collection.mutable.PriorityQueue
+import java.util.PriorityQueue
 
 /**
- * PercentileHeap is designed to be used to quickly track the percentile of a group of numbers
- * that may contain duplicates. Inserting a new number has O(log n) time complexity and
- * determining the percentile has O(1) time complexity.
- * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf
- * stores the smaller half of all numbers while the largerHalf stores the larger half.
- * The sizes of two heaps need to match the percentage each time when a new number is inserted so
- * that the ratio of their sizes is percentage to (1 - percentage). Therefore each time when
- * percentile() is called we check if the sizes of two heaps match the percentage. If they do,
- * we should return the average of the two top values of heaps. Otherwise we return the top of the
- * heap which exceeds its percentage.
+ * PercentileHeap tracks the percentile of a collection of numbers.
+ *
+ * Insertion is O(log n), Lookup is O(1).
+ *
+ * The implementation keeps two heaps: a small heap (`smallHeap`) and a large heap (`largeHeap`).
+ * The small heap stores all the numbers below the percentile and the large heap stores the ones
+ * above the percentile. During insertion the relative sizes of the heaps are adjusted to match
+ * the target percentile.
  */
-private[spark] class PercentileHeap(percentage: Double = 0.5)(implicit val ord: Ordering[Double]) {
-  assert(percentage >= 0 && percentage <= 1)
+private[spark] class PercentileHeap(percentage: Double = 0.5) {
+  assert(percentage > 0 && percentage < 1)
 
-  /**
-   * Stores all the numbers less than the current percentile in a smallerHalf,
-   * i.e percentile is the maximum, at the root.
-   */
-  private[this] val smallerHalf = PriorityQueue.empty[Double](ord)
+  // This is a min-heap so it works out of the box.
+  private[this] val largeHeap = new PriorityQueue[Double]
+  // This is a max-heap. If we pass a comparator things get slower because of function call
+  // overhead (>2x slower on insert). Instead we negate values when we offer/poll/peek.
+  private[this] val smallHeap = new PriorityQueue[Double]
 
-  /**
-   * Stores all the numbers greater than the current percentile in a largerHalf,
-   * i.e percentile is the minimum, at the root.
-   */
-  private[this] val largerHalf = PriorityQueue.empty[Double](ord.reverse)
+  def isEmpty(): Boolean = smallHeap.isEmpty && largeHeap.isEmpty
 
-  def isEmpty(): Boolean = {
-    smallerHalf.isEmpty && largerHalf.isEmpty
-  }
+  def size(): Int = smallHeap.size + largeHeap.size
 
-  def size(): Int = {
-    smallerHalf.size + largerHalf.size
+  /**
+   * Returns percentile of the inserted elements as if the inserted elements were sorted and we
+   * returned `sorted(p)` where `p = (sorted.length * percentage).toInt` if number of elements
+   * is odd, otherwise `(sorted(p-1) + sorted(p)) / 2` if number of elements is even.
+   */
+  def percentile(): Double = {
+    if (isEmpty) throw new NoSuchElementException("empty")
+    if (size % 2 == 1 || smallHeap.isEmpty) {
+      largeHeap.peek
+    } else {
+      (largeHeap.peek + -smallHeap.peek) / 2.0
+    }
   }
 
-  // Exposed for testing.
-  def smallerSize(): Int = smallerHalf.size
-
   def insert(x: Double): Unit = {
-    // If both heaps are empty, we insert it to the heap that has larger percentage.
     if (isEmpty) {
-      if (percentage < 0.5) smallerHalf.enqueue(x) else largerHalf.enqueue(x)
+      largeHeap.offer(x)
     } else {
-      // If the number is larger than current percentile, it should be inserted into largerHalf,
-      // otherwise smallerHalf.
-      if (x > percentile) {
-        largerHalf.enqueue(x)
+      val p = largeHeap.peek
+      val growBot = ((size + 1) * percentage).toInt > smallHeap.size
+      if (growBot) {
+        if (x < p) {
+          smallHeap.offer(-x)
+        } else {
+          largeHeap.offer(x)
+          smallHeap.offer(-largeHeap.poll)
+        }
       } else {
-        smallerHalf.enqueue(x)
-      }
-    }
-    rebalance()
-  }
-
-  // Calculate the deviation between the ratio of smaller heap size to larger heap size and the
-  // expected ratio, which is percentage : (1 - percentage). Negative result means the smaller
-  // heap has too less elements, positive result means the smaller heap has too many elements.
-  private def calculateDeviation(smallerSize: Int, largerSize: Int): Double = {
-    smallerSize * (1 - percentage) - largerSize * percentage
-  }
-
-  private[this] def rebalance(): Unit = {
-    // If moving one value from heap to the other heap can fit the percentage better, then
-    // move it.
-    val currentDev = calculateDeviation(smallerHalf.size, largerHalf.size)
-    if (currentDev > 0) {
-      val newDev = calculateDeviation(smallerHalf.size - 1, largerHalf.size + 1)
-      if (math.abs(newDev) < currentDev) {
-        largerHalf.enqueue(smallerHalf.dequeue)
-      }
-    }
-    if (currentDev < 0) {
-      val newDev = calculateDeviation(smallerHalf.size + 1, largerHalf.size - 1)
-      if (math.abs(newDev) < -currentDev) {
-        smallerHalf.enqueue(largerHalf.dequeue())
+        if (x < p) {
+          smallHeap.offer(-x)
+          largeHeap.offer(-smallHeap.poll)
+        } else {
+          largeHeap.offer(x)
+        }
       }
     }
   }
-
-  def percentile: Double = {
-    if (isEmpty) {
-      throw new NoSuchElementException("PercentileHeap is empty.")
-    }
-    val dev = calculateDeviation(smallerHalf.size, largerHalf.size)
-    // If the deviation is very small, we take the average of the top elements from the two heaps
-    // as the percentile.
-    if (smallerHalf.nonEmpty && largerHalf.nonEmpty && math.abs(dev / size) < 0.01) {
-      (largerHalf.head + smallerHalf.head) / 2.0
-    } else if (dev < 0) {
-      largerHalf.head
-    } else {
-      smallerHalf.head
-    }
-  }
 }
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala
index 03d72a11552..d67aca6d4e7 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PercentileHeapSuite.scala
@@ -17,71 +17,73 @@
 
 package org.apache.spark.util.collection
 
-import java.util.NoSuchElementException
+import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 
 class PercentileHeapSuite extends SparkFunSuite {
 
-  test("If no numbers in PercentileHeap, NoSuchElementException is thrown.") {
-    val medianHeap = new PercentileHeap()
+  test("When PercentileHeap is empty, NoSuchElementException is thrown.") {
+    val medianHeap = new PercentileHeap(0.5)
     intercept[NoSuchElementException] {
       medianHeap.percentile
     }
   }
 
-  test("Median should be correct when size of PercentileHeap is even") {
-    val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
-    val medianHeap = new PercentileHeap()
-    array.foreach(medianHeap.insert(_))
-    assert(medianHeap.size() === 10)
-    assert(medianHeap.smallerSize() === 5)
-    assert(medianHeap.percentile === 4.5)
+  private def percentile(nums: Seq[Int], percentage: Double): Double = {
+    val p = (nums.length * percentage).toInt
+    val sorted = nums.sorted.toIndexedSeq
+    if (nums.length % 2 == 1 || p == 0) {
+      sorted(p)
+    } else {
+      (sorted(p - 1) + sorted(p)) / 2.0
+    }
   }
 
-  test("Median should be correct when size of PercentileHeap is odd") {
-    val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
-    val medianHeap = new PercentileHeap()
-    array.foreach(medianHeap.insert(_))
-    assert(medianHeap.size() === 9)
-    assert(medianHeap.smallerSize() === 4)
-    assert(medianHeap.percentile === 4)
+  private def testPercentileFor(nums: Seq[Int], percentage: Double) = {
+    val h = new PercentileHeap(percentage)
+    Random.shuffle(nums).foreach(h.insert(_))
+    assert(h.size == nums.length)
+    assert(h.percentile == percentile(nums, percentage))
   }
 
-  test("Median should be correct though there are duplicated numbers inside.") {
-    val array = Array(0, 0, 1, 1, 2, 3, 4)
-    val medianHeap = new PercentileHeap()
-    array.foreach(medianHeap.insert(_))
-    assert(medianHeap.size === 7)
-    assert(medianHeap.smallerSize() === 3)
-    assert(medianHeap.percentile === 1)
-  }
+  private val tests = Seq(
+    0 until 1,
+    0 until 2,
+    0 until 11,
+    0 until 42,
+    0 until 100
+  )
 
-  test("Median should be correct when input data is skewed.") {
-    val medianHeap = new PercentileHeap()
-    (0 until 10).foreach(_ => medianHeap.insert(5))
-    assert(medianHeap.percentile === 5)
-    (0 until 100).foreach(_ => medianHeap.insert(10))
-    assert(medianHeap.percentile === 10)
-    (0 until 1000).foreach(_ => medianHeap.insert(0))
-    assert(medianHeap.percentile === 0)
+  for (t <- tests) {
+    for (p <- Seq(1, 50, 99)) {
+      test(s"$p% of ${t.mkString(",")}") {
+        testPercentileFor(t, p / 100d)
+      }
+    }
   }
 
-  test("Percentile should be correct when size of PercentileHeap is even") {
-    val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
-    val percentileMap = new PercentileHeap(0.7)
-    array.foreach(percentileMap.insert(_))
-    assert(percentileMap.size() === 10)
-    assert(percentileMap.smallerSize() == 7)
-    assert(percentileMap.percentile === 6.5)
-  }
+  ignore("benchmark") {
+    val input: Seq[Int] = 0 until 1000
+    val numRuns = 1000
+
+    def kernel(): Long = {
+      val shuffled = Random.shuffle(input).toArray
+      val start = System.nanoTime()
+      val h = new PercentileHeap(0.95)
+      shuffled.foreach { x =>
+        h.insert(x)
+        for (_ <- 0 until h.size) h.percentile
+      }
+      System.nanoTime() - start
+    }
+    for (_ <- 0 until numRuns) kernel()  // warmup
 
-  test("Percentile should be correct when size of PercentileHeap is odd") {
-    val array = Array(0, 1, 2, 3, 4, 5, 6, 7, 8)
-    val percentileMap = new PercentileHeap(0.7)
-    array.foreach(percentileMap.insert(_))
-    assert(percentileMap.size() === 9)
-    assert(percentileMap.smallerSize() == 6)
-    assert(percentileMap.percentile === 6)
+    var elapsed: Long = 0
+    for (_ <- 0 until numRuns) elapsed += kernel()
+    val perOp = elapsed / (numRuns * input.length)
+    // scalastyle:off println
+    println(s"$perOp ns per op on heaps of size ${input.length}")
+    // scalastyle:on println
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org