You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kyuubi.apache.org by ch...@apache.org on 2023/02/22 15:00:41 UTC

[kyuubi] branch master updated: [KYUUBI #4392] [ARROW] Assign a new execution id for arrow-based result

This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/kyuubi.git


The following commit(s) were added to refs/heads/master by this push:
     new f0acff315 [KYUUBI #4392] [ARROW] Assign a new execution id for arrow-based result
f0acff315 is described below

commit f0acff315c68356f5138f99fdf0438e565ce5f06
Author: Fu Chen <cf...@gmail.com>
AuthorDate: Wed Feb 22 23:00:30 2023 +0800

    [KYUUBI #4392] [ARROW] Assign a new execution id for arrow-based result
    
    ### _Why are the changes needed?_
    
    assign a new execution id for arrow-based result, so that we can track the arrow-based queries on the UI tab.
    
    ```sql
    set kyuubi.operation.result.format=arrow;
    select 1;
    ```
    
    Before this PR:
    
    ![截屏2023-02-21 下午5 23 08](https://user-images.githubusercontent.com/8537877/220303920-fbaf978b-ead7-4708-9094-bcc84e8fb47c.png)
    
    ![截屏2023-02-21 下午5 23 19](https://user-images.githubusercontent.com/8537877/220303966-cb8dfeae-cd10-4c4f-add6-2650619fc5f9.png)
    
    After this PR:
    ![截屏2023-02-22 上午10 21 53](https://user-images.githubusercontent.com/8537877/220504608-f67a5f70-8c64-4e3b-89c2-c2ea54676217.png)
    
    ![截屏2023-02-21 下午5 20 50](https://user-images.githubusercontent.com/8537877/220304021-9b845f44-96c3-41f2-a48a-a428f8c4823f.png)
    
    ### _How was this patch tested?_
    - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible
    
    - [ ] Add screenshots for manual tests if appropriate
    
    - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request
    
    Closes #4392 from cfmcgrady/arrow-execution-id-2.
    
    Closes #4392
    
    481118a4 [Fu Chen] enable ut
    c90674ee [Fu Chen] address comment
    6cc7af44 [Fu Chen] address comment
    3f8a3ab8 [Fu Chen] fix ut
    223a2469 [Fu Chen] add KyuubiSparkContextHelper
    bb7b28f5 [Fu Chen] fix style
    879a1502 [Fu Chen] unnecessary changes
    a2b04f83 [Fu Chen] fix
    
    Authored-by: Fu Chen <cf...@gmail.com>
    Signed-off-by: Cheng Pan <ch...@apache.org>
---
 .../engine/spark/operation/ExecuteStatement.scala  | 114 ++++++++++++++-------
 .../engine/spark/operation/SparkOperation.scala    |  16 +--
 .../spark/operation/SparkSQLOperationManager.scala |  19 +++-
 .../apache/kyuubi/engine/spark/schema/RowSet.scala |  21 ++--
 .../operation/SparkArrowbasedOperationSuite.scala  |  36 ++++++-
 .../kyuubi/engine/spark/schema/RowSetSuite.scala   |   9 +-
 .../apache/spark/KyuubiSparkContextHelper.scala    |  30 ++++++
 .../apache/kyuubi/operation/SparkQueryTests.scala  |   4 +-
 8 files changed, 178 insertions(+), 71 deletions(-)

diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
index 2b90525c1..6ebcce377 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala
@@ -22,13 +22,14 @@ import java.util.concurrent.RejectedExecutionException
 import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.kyuubi.SparkDatasetHelper
 import org.apache.spark.sql.types._
 
 import org.apache.kyuubi.{KyuubiSQLException, Logging}
 import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
 import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
-import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState}
+import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState}
 import org.apache.kyuubi.operation.log.OperationLog
 import org.apache.kyuubi.session.Session
 
@@ -62,49 +63,49 @@ class ExecuteStatement(
     OperationLog.removeCurrentOperationLog()
   }
 
-  private def executeStatement(): Unit = withLocalProperties {
+  protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
+    resultDF.toLocalIterator().asScala
+  }
+
+  protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
+    resultDF.collect()
+  }
+
+  protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
+    resultDF.take(maxRows)
+  }
+
+  protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
+    val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
+      .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
+    if (incrementalCollect) {
+      if (resultMaxRows > 0) {
+        warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
+      }
+      info("Execute in incremental collect mode")
+      new IterableFetchIterator[Any](new Iterable[Any] {
+        override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
+      })
+    } else {
+      val internalArray = if (resultMaxRows <= 0) {
+        info("Execute in full collect mode")
+        fullCollectResult(resultDF)
+      } else {
+        info(s"Execute with max result rows[$resultMaxRows]")
+        takeResult(resultDF, resultMaxRows)
+      }
+      new ArrayFetchIterator(internalArray)
+    }
+  }
+
+  protected def executeStatement(): Unit = withLocalProperties {
     try {
       setState(OperationState.RUNNING)
       info(diagnostics)
       Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
       addOperationListener()
       result = spark.sql(statement)
-
-      val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
-        .getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
-      iter = if (incrementalCollect) {
-        if (resultMaxRows > 0) {
-          warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
-        }
-        info("Execute in incremental collect mode")
-        def internalIterator(): Iterator[Any] = if (arrowEnabled) {
-          SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).toLocalIterator
-        } else {
-          result.toLocalIterator().asScala
-        }
-        new IterableFetchIterator[Any](new Iterable[Any] {
-          override def iterator: Iterator[Any] = internalIterator()
-        })
-      } else {
-        val internalArray = if (resultMaxRows <= 0) {
-          info("Execute in full collect mode")
-          if (arrowEnabled) {
-            SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).collect()
-          } else {
-            result.collect()
-          }
-        } else {
-          info(s"Execute with max result rows[$resultMaxRows]")
-          if (arrowEnabled) {
-            // this will introduce shuffle and hurt performance
-            val limitedResult = result.limit(resultMaxRows)
-            SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
-          } else {
-            result.take(resultMaxRows)
-          }
-        }
-        new ArrayFetchIterator(internalArray)
-      }
+      iter = collectAsIterator(result)
       setCompiledStateIfNeeded()
       setState(OperationState.FINISHED)
     } catch {
@@ -171,3 +172,40 @@ class ExecuteStatement(
       s"__kyuubi_operation_result_format__=$resultFormat",
       s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
 }
+
+class ArrowBasedExecuteStatement(
+    session: Session,
+    override val statement: String,
+    override val shouldRunAsync: Boolean,
+    queryTimeout: Long,
+    incrementalCollect: Boolean)
+  extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
+
+  override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
+    SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
+  }
+
+  override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
+    SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
+  }
+
+  override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
+    // this will introduce shuffle and hurt performance
+    val limitedResult = resultDF.limit(maxRows)
+    SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
+  }
+
+  /**
+   * assign a new execution id for arrow-based operation.
+   */
+  override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
+    SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
+      resultDF.queryExecution.executedPlan.resetMetrics()
+      super.collectAsIterator(resultDF)
+    }
+  }
+
+  override protected def isArrowBasedOperation: Boolean = true
+
+  override val resultFormat = "arrow"
+}
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
index a6a7fc896..eb58407d4 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala
@@ -245,7 +245,7 @@ abstract class SparkOperation(session: Session)
           case FETCH_FIRST => iter.fetchAbsolute(0);
         }
         resultRowSet =
-          if (arrowEnabled) {
+          if (isArrowBasedOperation) {
             if (iter.hasNext) {
               val taken = iter.next().asInstanceOf[Array[Byte]]
               RowSet.toTRowSet(taken, getProtocolVersion)
@@ -257,8 +257,7 @@ abstract class SparkOperation(session: Session)
             RowSet.toTRowSet(
               taken.toSeq.asInstanceOf[Seq[Row]],
               resultSchema,
-              getProtocolVersion,
-              timeZone)
+              getProtocolVersion)
           }
         resultRowSet.setStartRowOffset(iter.getPosition)
       } catch onError(cancel = true)
@@ -268,16 +267,9 @@ abstract class SparkOperation(session: Session)
 
   override def shouldRunAsync: Boolean = false
 
-  protected def arrowEnabled: Boolean = {
-    resultFormat.equalsIgnoreCase("arrow") &&
-    // TODO: (fchen) make all operation support arrow
-    getClass.getCanonicalName == classOf[ExecuteStatement].getCanonicalName
-  }
+  protected def isArrowBasedOperation: Boolean = false
 
-  protected def resultFormat: String = {
-    // TODO: respect the config of the operation ExecuteStatement, if it was set.
-    spark.conf.get("kyuubi.operation.result.format", "thrift")
-  }
+  protected def resultFormat: String = "thrift"
 
   protected def timestampAsString: Boolean = {
     spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala
index 5c5ed0f98..4743f147c 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala
@@ -82,7 +82,24 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
             case NoneMode =>
               val incrementalCollect = spark.conf.getOption(OPERATION_INCREMENTAL_COLLECT.key)
                 .map(_.toBoolean).getOrElse(operationIncrementalCollectDefault)
-              new ExecuteStatement(session, statement, runAsync, queryTimeout, incrementalCollect)
+              // TODO: respect the config of the operation ExecuteStatement, if it was set.
+              val resultFormat = spark.conf.get("kyuubi.operation.result.format", "thrift")
+              resultFormat.toLowerCase match {
+                case "arrow" =>
+                  new ArrowBasedExecuteStatement(
+                    session,
+                    statement,
+                    runAsync,
+                    queryTimeout,
+                    incrementalCollect)
+                case _ =>
+                  new ExecuteStatement(
+                    session,
+                    statement,
+                    runAsync,
+                    queryTimeout,
+                    incrementalCollect)
+              }
             case mode =>
               new PlanOnlyStatement(session, statement, mode)
           }
diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
index 7be70403d..4f935ce49 100644
--- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
+++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/schema/RowSet.scala
@@ -18,7 +18,6 @@
 package org.apache.kyuubi.engine.spark.schema
 
 import java.nio.ByteBuffer
-import java.time.ZoneId
 
 import scala.collection.JavaConverters._
 
@@ -61,16 +60,15 @@ object RowSet {
   def toTRowSet(
       rows: Seq[Row],
       schema: StructType,
-      protocolVersion: TProtocolVersion,
-      timeZone: ZoneId): TRowSet = {
+      protocolVersion: TProtocolVersion): TRowSet = {
     if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
-      toRowBasedSet(rows, schema, timeZone)
+      toRowBasedSet(rows, schema)
     } else {
-      toColumnBasedSet(rows, schema, timeZone)
+      toColumnBasedSet(rows, schema)
     }
   }
 
-  def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
+  def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
     val rowSize = rows.length
     val tRows = new java.util.ArrayList[TRow](rowSize)
     var i = 0
@@ -80,7 +78,7 @@ object RowSet {
       var j = 0
       val columnSize = row.length
       while (j < columnSize) {
-        val columnValue = toTColumnValue(j, row, schema, timeZone)
+        val columnValue = toTColumnValue(j, row, schema)
         tRow.addToColVals(columnValue)
         j += 1
       }
@@ -90,21 +88,21 @@ object RowSet {
     new TRowSet(0, tRows)
   }
 
-  def toColumnBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
+  def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
     val rowSize = rows.length
     val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
     var i = 0
     val columnSize = schema.length
     while (i < columnSize) {
       val field = schema(i)
-      val tColumn = toTColumn(rows, i, field.dataType, timeZone)
+      val tColumn = toTColumn(rows, i, field.dataType)
       tRowSet.addToColumns(tColumn)
       i += 1
     }
     tRowSet
   }
 
-  private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType, timeZone: ZoneId): TColumn = {
+  private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
     val nulls = new java.util.BitSet()
     typ match {
       case BooleanType =>
@@ -186,8 +184,7 @@ object RowSet {
   private def toTColumnValue(
       ordinal: Int,
       row: Row,
-      types: StructType,
-      timeZone: ZoneId): TColumnValue = {
+      types: StructType): TColumnValue = {
     types(ordinal).dataType match {
       case BooleanType =>
         val boolValue = new TBoolValue
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
index 60cc52891..30cdeca5a 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala
@@ -19,8 +19,14 @@ package org.apache.kyuubi.engine.spark.operation
 
 import java.sql.Statement
 
+import org.apache.spark.KyuubiSparkContextHelper
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.QueryExecutionListener
+
 import org.apache.kyuubi.config.KyuubiConf
-import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
+import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
+import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
 import org.apache.kyuubi.operation.SparkDataTypeTests
 
 class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
@@ -85,6 +91,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
     }
   }
 
+  test("assign a new execution id for arrow-based result") {
+    var plan: LogicalPlan = null
+
+    val listener = new QueryExecutionListener {
+      override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+        plan = qe.analyzed
+      }
+      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+    }
+    withJdbcStatement() { statement =>
+      // since all the new sessions have their owner listener bus, we should register the listener
+      // in the current session.
+      SparkSQLEngine.currentEngine.get
+        .backendService
+        .sessionManager
+        .allSessions()
+        .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
+
+      val result = statement.executeQuery("select 1 as c1")
+      assert(result.next())
+      assert(result.getInt("c1") == 1)
+    }
+
+    KyuubiSparkContextHelper.waitListenerBus(spark)
+    spark.listenerManager.unregister(listener)
+    assert(plan.isInstanceOf[Project])
+  }
+
   private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
     val query =
       s"""
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
index a999563ea..5d2ba4a0d 100644
--- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/schema/RowSetSuite.scala
@@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.schema
 import java.nio.ByteBuffer
 import java.nio.charset.StandardCharsets
 import java.sql.{Date, Timestamp}
-import java.time.{Instant, LocalDate, ZoneId}
+import java.time.{Instant, LocalDate}
 
 import scala.collection.JavaConverters._
 
@@ -96,10 +96,9 @@ class RowSetSuite extends KyuubiFunSuite {
     .add("q", "timestamp")
 
   private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))
-  private val zoneId: ZoneId = ZoneId.systemDefault()
 
   test("column based set") {
-    val tRowSet = RowSet.toColumnBasedSet(rows, schema, zoneId)
+    val tRowSet = RowSet.toColumnBasedSet(rows, schema)
     assert(tRowSet.getColumns.size() === schema.size)
     assert(tRowSet.getRowsSize === 0)
 
@@ -204,7 +203,7 @@ class RowSetSuite extends KyuubiFunSuite {
   }
 
   test("row based set") {
-    val tRowSet = RowSet.toRowBasedSet(rows, schema, zoneId)
+    val tRowSet = RowSet.toRowBasedSet(rows, schema)
     assert(tRowSet.getColumnCount === 0)
     assert(tRowSet.getRowsSize === rows.size)
     val iter = tRowSet.getRowsIterator
@@ -250,7 +249,7 @@ class RowSetSuite extends KyuubiFunSuite {
 
   test("to row set") {
     TProtocolVersion.values().foreach { proto =>
-      val set = RowSet.toTRowSet(rows, schema, proto, zoneId)
+      val set = RowSet.toTRowSet(rows, schema, proto)
       if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
         assert(!set.isSetColumns, proto.toString)
         assert(set.isSetRows, proto.toString)
diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
new file mode 100644
index 000000000..8293123ea
--- /dev/null
+++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * A place to invoke non-public APIs of [[SparkContext]], for test only.
+ */
+object KyuubiSparkContextHelper {
+
+  def waitListenerBus(spark: SparkSession): Unit = {
+    spark.sparkContext.listenerBus.waitUntilEmpty()
+  }
+}
diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala
index e297e6281..a42b05473 100644
--- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala
+++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/SparkQueryTests.scala
@@ -433,13 +433,13 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper {
         expectedFormat = "thrift")
       checkStatusAndResultSetFormatHint(
         sql = "set kyuubi.operation.result.format=arrow",
-        expectedFormat = "arrow")
+        expectedFormat = "thrift")
       checkStatusAndResultSetFormatHint(
         sql = "SELECT 1",
         expectedFormat = "arrow")
       checkStatusAndResultSetFormatHint(
         sql = "set kyuubi.operation.result.format=thrift",
-        expectedFormat = "thrift")
+        expectedFormat = "arrow")
       checkStatusAndResultSetFormatHint(
         sql = "set kyuubi.operation.result.format",
         expectedFormat = "thrift")