You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/13 05:10:07 UTC

[spark] branch master updated: [SPARK-42034] QueryExecutionListener and Observation API do not work with `foreach` / `reduce` / `foreachPartition` action

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a1649ad2429 [SPARK-42034] QueryExecutionListener and Observation API do not work with `foreach` / `reduce` / `foreachPartition` action
a1649ad2429 is described below

commit a1649ad24298d988267acb8588d19848c7fb16c4
Author: 佘志铭 <sh...@corp.netease.com>
AuthorDate: Mon Feb 13 14:09:54 2023 +0900

    [SPARK-42034] QueryExecutionListener and Observation API do not work with `foreach` / `reduce` / `foreachPartition` action
    
    ### What changes were proposed in this pull request?
    
    Add the name parameter for 'foreach'/'reduce'/'foreachPartition' operators in `DataSet#withNewRDDExecutionId`. Because the QueryExecutionListener and Observation API is triggered only when the operators have the name parameter.
    
    https://github.com/apache/spark/blob/84ddd409c11e4da769c5b1f496f2b61c3d928c07/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala#L181
    
    ### Why are the changes needed?
    
    The QueryExecutionListener and Observation API is triggered only when the operators have the name parameter.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    add two unit test.
    
    Closes #39976 from zzzzming95/SPARK-42034.
    
    Authored-by: 佘志铭 <sh...@corp.netease.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 10 ++++----
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 13 ++++++++++
 .../spark/sql/util/DataFrameCallbackSuite.scala    | 28 ++++++++++++++++++++++
 3 files changed, 46 insertions(+), 5 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 28177b90c7e..edcfad0c798 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1858,7 +1858,7 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def reduce(func: (T, T) => T): T = withNewRDDExecutionId {
+  def reduce(func: (T, T) => T): T = withNewRDDExecutionId("reduce") {
     rdd.reduce(func)
   }
 
@@ -3336,7 +3336,7 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def foreach(f: T => Unit): Unit = withNewRDDExecutionId {
+  def foreach(f: T => Unit): Unit = withNewRDDExecutionId("foreach") {
     rdd.foreach(f)
   }
 
@@ -3355,7 +3355,7 @@ class Dataset[T] private[sql](
    * @group action
    * @since 1.6.0
    */
-  def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId {
+  def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") {
     rdd.foreachPartition(f)
   }
 
@@ -4148,8 +4148,8 @@ class Dataset[T] private[sql](
    * them with an execution. Before performing the action, the metrics of the executed plan will be
    * reset.
    */
-  private def withNewRDDExecutionId[U](body: => U): U = {
-    SQLExecution.withNewExecutionId(rddQueryExecution) {
+  private def withNewRDDExecutionId[U](name: String)(body: => U): U = {
+    SQLExecution.withNewExecutionId(rddQueryExecution, Some(name)) {
       rddQueryExecution.executedPlan.resetMetrics()
       body
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 86e640a4fa8..263e361413c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -960,6 +960,19 @@ class DatasetSuite extends QueryTest
     observe(spark.range(1, 10, 1, 11), Map("percentile_approx_val" -> 5))
   }
 
+  test("observation on datasets when a DataSet trigger foreach action") {
+    def f(): Unit = {}
+
+    val namedObservation = Observation("named")
+    val observed_df = spark.range(100).observe(
+      namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val"))
+
+    observed_df.foreach(r => f)
+    val expected = Map("percentile_approx_val" -> 49)
+
+    assert(namedObservation.get === expected)
+  }
+
   test("sample with replacement") {
     val n = 100
     val data = sparkContext.parallelize(1 to n, 2).toDS()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 2fc1f10d3ea..f046daacb91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -96,6 +96,34 @@ class DataFrameCallbackSuite extends QueryTest
     spark.listenerManager.unregister(listener)
   }
 
+  test("execute callback functions when a DataSet trigger foreach action finished") {
+    val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]
+    val listener = new QueryExecutionListener {
+      // Only test successful case here, so no need to implement `onFailure`
+      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+
+      override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+        metrics += ((funcName, qe, duration))
+      }
+    }
+    spark.listenerManager.register(listener)
+
+    def f(): Unit = {}
+
+    val df = Seq(1).toDF("i")
+
+    df.foreach(r => f)
+    df.reduce((x, y) => x)
+
+    sparkContext.listenerBus.waitUntilEmpty()
+    assert(metrics.length == 2)
+
+    assert(metrics(0)._1 == "foreach")
+    assert(metrics(1)._1 == "reduce")
+
+    spark.listenerManager.unregister(listener)
+  }
+
   test("get numRows metrics by callback") {
     val metrics = ArrayBuffer.empty[Long]
     val listener = new QueryExecutionListener {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org