You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ch...@apache.org on 2023/02/23 09:04:40 UTC

[kyuubi] branch master updated: [KYUUBI #4402] [ARROW] Make arrow-based query metrics trackable in SQL UI

This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new 3016f431b [KYUUBI #4402] [ARROW] Make arrow-based query metrics trackable in SQL UI
3016f431b is described below

commit 3016f431bf0d4b881d8a1979aa13b23432a58468
Author: Fu Chen <cf...@gmail.com>
AuthorDate: Thu Feb 23 17:04:29 2023 +0800

    [KYUUBI #4402] [ARROW] Make arrow-based query metrics trackable in SQL UI
    
    ### _Why are the changes needed?_
    
    Currently, the SQL metrics are missing from the SQL UI tab, this is because we mistakenly bound QueryExecution in [PR-4392](https://github.com/apache/kyuubi/pull/4392), before this PR, it was `resultDF.queryExecution` that was bound to `SQLExecution.withNewExecutionId()`, But the executed Dataset is `resultDF.select(cols: _*)`, this PR passed the correct QueryExecution `resultDF.select(cols: _*).queryExecution` to solve this problem.
    
    ```sql
    set kyuubi.operation.result.format=arrow;
    select 1;
    ```
    
    Before this PR:
    
    ![截屏2023-02-23 下午1 47 34](https://user-images.githubusercontent.com/8537877/220832155-4277ccf7-1cfe-40db-a6e5-e1ed4a6d2e29.png)
    
    After this PR:
    
    ![截屏2023-02-23 下午2 07 23](https://user-images.githubusercontent.com/8537877/220832184-b9871b4b-f408-42ac-91ca-30e5cd503b24.png)
    
    ### _How was this patch tested?_
    - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #4402 from cfmcgrady/arrow-metrics.
    
    Closes #4402
    
    e0cde3b1 [Fu Chen] fix style
    b35cbfdc [Fu Chen] fix
    542414ef [Fu Chen] make arrow-based query metrics trackable in SQL UI
    
    Authored-by: Fu Chen <cf...@gmail.com>
    Signed-off-by: Cheng Pan <ch...@apache.org>
---
 .../engine/spark/operation/ExecuteStatement.scala  | 79 ++++++++++++----------
 .../operation/SparkArrowbasedOperationSuite.scala  | 51 ++++++++++++--
 2 files changed, 88 insertions(+), 42 deletions(-)

diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
index 6ebcce377..fa517a8b1 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.RejectedExecutionException
 
 import scala.collection.JavaConverters._
 
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.kyuubi.SparkDatasetHelper
@@ -75,29 +76,6 @@ class ExecuteStatement(
     resultDF.take(maxRows)
   }
 
-  protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
-    val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
-      .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
-    if (incrementalCollect) {
-      if (resultMaxRows > 0) {
-        warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
-      }
-      info("Execute in incremental collect mode")
-      new IterableFetchIterator[Any](new Iterable[Any] {
-        override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
-      })
-    } else {
-      val internalArray = if (resultMaxRows <= 0) {
-        info("Execute in full collect mode")
-        fullCollectResult(resultDF)
-      } else {
-        info(s"Execute with max result rows[$resultMaxRows]")
-        takeResult(resultDF, resultMaxRows)
-      }
-      new ArrayFetchIterator(internalArray)
-    }
-  }
-
   protected def executeStatement(): Unit = withLocalProperties {
     try {
       setState(OperationState.RUNNING)
@@ -163,14 +141,33 @@ class ExecuteStatement(
     }
   }
 
-  def convertComplexType(df: DataFrame): DataFrame = {
-    SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
-  }
-
   override def getResultSetMetadataHints(): Seq[String] =
     Seq(
       s"__kyuubi_operation_result_format__=$resultFormat",
       s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
+
+  private def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
+    val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
+      .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
+    if (incrementalCollect) {
+      if (resultMaxRows > 0) {
+        warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
+      }
+      info("Execute in incremental collect mode")
+      new IterableFetchIterator[Any](new Iterable[Any] {
+        override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
+      })
+    } else {
+      val internalArray = if (resultMaxRows <= 0) {
+        info("Execute in full collect mode")
+        fullCollectResult(resultDF)
+      } else {
+        info(s"Execute with max result rows[$resultMaxRows]")
+        takeResult(resultDF, resultMaxRows)
+      }
+      new ArrayFetchIterator(internalArray)
+    }
+  }
 }
 
 class ArrowBasedExecuteStatement(
@@ -182,30 +179,42 @@ class ArrowBasedExecuteStatement(
   extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
 
   override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
-    SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
+    collectAsArrow(convertComplexType(resultDF)) { rdd =>
+      rdd.toLocalIterator
+    }
   }
 
   override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
-    SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
+    collectAsArrow(convertComplexType(resultDF)) { rdd =>
+      rdd.collect()
+    }
   }
 
   override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
     // this will introduce shuffle and hurt performance
     val limitedResult = resultDF.limit(maxRows)
-    SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
+    collectAsArrow(convertComplexType(limitedResult)) { rdd =>
+      rdd.collect()
+    }
   }
 
   /**
-   * assign a new execution id for arrow-based operation.
+   * refer to org.apache.spark.sql.Dataset#withAction(), assign a new execution id for arrow-based
+   * operation, so that we can track the arrow-based queries on the UI tab.
    */
-  override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
-    SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
-      resultDF.queryExecution.executedPlan.resetMetrics()
-      super.collectAsIterator(resultDF)
+  private def collectAsArrow[T](df: DataFrame)(action: RDD[Array[Byte]] => T): T = {
+    SQLExecution.withNewExecutionId(df.queryExecution, Some("collectAsArrow")) {
+      df.queryExecution.executedPlan.resetMetrics()
+      action(SparkDatasetHelper.toArrowBatchRdd(df))
     }
   }
 
   override protected def isArrowBasedOperation: Boolean = true
 
   override val resultFormat = "arrow"
+
+  private def convertComplexType(df: DataFrame): DataFrame = {
+    SparkDatasetHelper.convertTopLevelComplexTypeToHiveString(df, timestampAsString)
+  }
+
 }
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
index 30cdeca5a..ae6237bb5 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
@@ -103,22 +103,41 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
     withJdbcStatement() { statement =>
       // since all the new sessions have their owner listener bus, we should register the listener
       // in the current session.
-      SparkSQLEngine.currentEngine.get
-        .backendService
-        .sessionManager
-        .allSessions()
-        .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
+      registerListener(listener)
 
       val result = statement.executeQuery("select 1 as c1")
       assert(result.next())
       assert(result.getInt("c1") == 1)
     }
-
     KyuubiSparkContextHelper.waitListenerBus(spark)
-    spark.listenerManager.unregister(listener)
+    unregisterListener(listener)
     assert(plan.isInstanceOf[Project])
   }
 
+  test("arrow-based query metrics") {
+    var queryExecution: QueryExecution = null
+
+    val listener = new QueryExecutionListener {
+      override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+        queryExecution = qe
+      }
+      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+    }
+    withJdbcStatement() { statement =>
+      registerListener(listener)
+      val result = statement.executeQuery("select 1 as c1")
+      assert(result.next())
+      assert(result.getInt("c1") == 1)
+    }
+
+    KyuubiSparkContextHelper.waitListenerBus(spark)
+    unregisterListener(listener)
+
+    val metrics = queryExecution.executedPlan.collectLeaves().head.metrics
+    assert(metrics.contains("numOutputRows"))
+    assert(metrics("numOutputRows").value === 1)
+  }
+
   private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
     val query =
       s"""
@@ -140,4 +159,22 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
     assert(resultSet.next())
     assert(resultSet.getString("col") === expect)
   }
+
+  private def registerListener(listener: QueryExecutionListener): Unit = {
+    // since all the new sessions have their owner listener bus, we should register the listener
+    // in the current session.
+    SparkSQLEngine.currentEngine.get
+      .backendService
+      .sessionManager
+      .allSessions()
+      .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
+  }
+
+  private def unregisterListener(listener: QueryExecutionListener): Unit = {
+    SparkSQLEngine.currentEngine.get
+      .backendService
+      .sessionManager
+      .allSessions()
+      .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
+  }
 }