You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2020/02/13 23:01:41 UTC

[spark] branch branch-3.0 updated: [SPARK-30798][SQL] Scope Session.active in QueryExecution

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 4db64ed  [SPARK-30798][SQL] Scope Session.active in QueryExecution
4db64ed is described below

commit 4db64ed37c601eb62aa3939d13f4f0e15bc1e4a9
Author: Ali Afroozeh <al...@databricks.com>
AuthorDate: Thu Feb 13 23:58:55 2020 +0100

    [SPARK-30798][SQL] Scope Session.active in QueryExecution
    
    ### What changes were proposed in this pull request?
    
    This PR scopes `SparkSession.active` to prevent problems with processing queries with possibly different spark sessions (and different configs). A new method, `withActive` is introduced on `SparkSession` that restores the previous spark session after the block of code is executed.
    
    ### Why are the changes needed?
    `SparkSession.active` is a thread local variable that points to the current thread's spark session. It is important to note that the `SQLConf.get` method depends on `SparkSession.active`. In the current implementation it is possible that `SparkSession.active` points to a different session which causes various problems. Most of these problems arise because part of the query processing is done using the configurations of a different session. For example, when creating a data frame using [...]
    
    ### Does this PR introduce any user-facing change?
    The `withActive` method is introduced on `SparkSession`.
    
    ### How was this patch tested?
    Unit tests (to be added)
    
    Closes #27387 from dbaliafroozeh/UseWithActiveSessionInQueryExecution.
    
    Authored-by: Ali Afroozeh <al...@databricks.com>
    Signed-off-by: herman <he...@databricks.com>
    (cherry picked from commit e2d3983de78f5c80fac066b7ee8bedd0987110dd)
    Signed-off-by: herman <he...@databricks.com>
---
 .../org/apache/spark/sql/DataFrameWriter.scala     |  2 +-
 .../org/apache/spark/sql/DataFrameWriterV2.scala   |  2 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 36 ++++++++++++----------
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  5 +--
 .../scala/org/apache/spark/sql/SparkSession.scala  | 30 ++++++++++++------
 .../spark/sql/execution/QueryExecution.scala       | 16 +++++-----
 .../apache/spark/sql/execution/SQLExecution.scala  |  4 +--
 .../execution/streaming/MicroBatchExecution.scala  |  4 +--
 .../streaming/continuous/ContinuousExecution.scala |  2 +-
 .../apache/spark/sql/internal/CatalogImpl.scala    |  2 +-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 10 ++++++
 .../org/apache/spark/sql/SQLQueryTestSuite.scala   |  2 +-
 .../execution/ui/SQLAppStatusListenerSuite.scala   |  2 +-
 .../SparkExecuteStatementOperation.scala           |  2 +-
 .../sql/hive/thriftserver/SparkSQLDriver.scala     |  2 +-
 .../sql/hive/execution/HiveComparisonTest.scala    |  3 +-
 .../org/apache/spark/sql/hive/test/TestHive.scala  |  2 +-
 17 files changed, 74 insertions(+), 52 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 4557219..fff1f4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -896,7 +896,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
   private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = {
     val qe = session.sessionState.executePlan(command)
     // call `QueryExecution.toRDD` to trigger the execution of commands.
-    SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
+    SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd)
   }
 
   private def lookupV2Provider(): Option[TableProvider] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
index f5dd761..cf6bde5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -226,7 +226,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
   private def runCommand(name: String)(command: LogicalPlan): Unit = {
     val qe = sparkSession.sessionState.executePlan(command)
     // call `QueryExecution.toRDD` to trigger the execution of commands.
-    SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd)
+    SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd)
   }
 
   private def internalReplace(orCreate: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a1c33f9..42f3535 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -82,18 +82,19 @@ private[sql] object Dataset {
     dataset
   }
 
-  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
-    val qe = sparkSession.sessionState.executePlan(logicalPlan)
-    qe.assertAnalyzed()
-    new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
+  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
+    sparkSession.withActive {
+      val qe = sparkSession.sessionState.executePlan(logicalPlan)
+      qe.assertAnalyzed()
+      new Dataset[Row](qe, RowEncoder(qe.analyzed.schema))
   }
 
   /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
   def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker)
-    : DataFrame = {
+    : DataFrame = sparkSession.withActive {
     val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
     qe.assertAnalyzed()
-    new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
+    new Dataset[Row](qe, RowEncoder(qe.analyzed.schema))
   }
 }
 
@@ -185,13 +186,12 @@ private[sql] object Dataset {
  */
 @Stable
 class Dataset[T] private[sql](
-    @transient private val _sparkSession: SparkSession,
     @DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
     @DeveloperApi @Unstable @transient val encoder: Encoder[T])
   extends Serializable {
 
   @transient lazy val sparkSession: SparkSession = {
-    if (_sparkSession == null) {
+    if (queryExecution == null || queryExecution.sparkSession == null) {
       throw new SparkException(
       "Dataset transformations and actions can only be invoked by the driver, not inside of" +
         " other Dataset transformations; for example, dataset1.map(x => dataset2.values.count()" +
@@ -199,7 +199,7 @@ class Dataset[T] private[sql](
         "performed inside of the dataset1.map transformation. For more information," +
         " see SPARK-28702.")
     }
-    _sparkSession
+    queryExecution.sparkSession
   }
 
   // A globally unique id of this Dataset.
@@ -211,7 +211,7 @@ class Dataset[T] private[sql](
   // you wrap it with `withNewExecutionId` if this actions doesn't call other action.
 
   def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
-    this(sparkSession, sparkSession.sessionState.executePlan(logicalPlan), encoder)
+    this(sparkSession.sessionState.executePlan(logicalPlan), encoder)
   }
 
   def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
@@ -445,7 +445,7 @@ class Dataset[T] private[sql](
    */
   // This is declared with parentheses to prevent the Scala compiler from treating
   // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
-  def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema))
+  def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema))
 
   /**
    * Returns a new Dataset where each record has been mapped on to the specified type. The
@@ -503,7 +503,9 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 1.6.0
    */
-  def schema: StructType = queryExecution.analyzed.schema
+  def schema: StructType = sparkSession.withActive {
+    queryExecution.analyzed.schema
+  }
 
   /**
    * Prints the schema to the console in a nice tree format.
@@ -539,7 +541,7 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 3.0.0
    */
-  def explain(mode: String): Unit = {
+  def explain(mode: String): Unit = sparkSession.withActive {
     // Because temporary views are resolved during analysis when we create a Dataset, and
     // `ExplainCommand` analyzes input query plan and resolves temporary views again. Using
     // `ExplainCommand` here will probably output different query plans, compared to the results
@@ -1502,7 +1504,7 @@ class Dataset[T] private[sql](
     val namedColumns =
       columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
     val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
-    new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
+    new Dataset(execution, ExpressionEncoder.tuple(encoders))
   }
 
   /**
@@ -3472,7 +3474,7 @@ class Dataset[T] private[sql](
    * an execution.
    */
   private def withNewExecutionId[U](body: => U): U = {
-    SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body)
+    SQLExecution.withNewExecutionId(queryExecution)(body)
   }
 
   /**
@@ -3481,7 +3483,7 @@ class Dataset[T] private[sql](
    * reset.
    */
   private def withNewRDDExecutionId[U](body: => U): U = {
-    SQLExecution.withNewExecutionId(sparkSession, rddQueryExecution) {
+    SQLExecution.withNewExecutionId(rddQueryExecution) {
       rddQueryExecution.executedPlan.resetMetrics()
       body
     }
@@ -3492,7 +3494,7 @@ class Dataset[T] private[sql](
    * user-registered callback functions.
    */
   private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
-    SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) {
+    SQLExecution.withNewExecutionId(qe, Some(name)) {
       qe.executedPlan.resetMetrics()
       action(qe.executedPlan)
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 89cc973..76ee297 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -449,10 +449,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
     val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
     val execution = new QueryExecution(sparkSession, aggregate)
 
-    new Dataset(
-      sparkSession,
-      execution,
-      ExpressionEncoder.tuple(kExprEnc +: encoders))
+    new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders))
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index abefb34..1fb97fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -293,8 +293,7 @@ class SparkSession private(
    *
    * @since 2.0.0
    */
-  def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
-    SparkSession.setActiveSession(this)
+  def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = withActive {
     val encoder = Encoders.product[A]
     Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder))
   }
@@ -304,8 +303,7 @@ class SparkSession private(
    *
    * @since 2.0.0
    */
-  def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
-    SparkSession.setActiveSession(this)
+  def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = withActive {
     val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
     val attributeSeq = schema.toAttributes
     Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
@@ -343,7 +341,7 @@ class SparkSession private(
    * @since 2.0.0
    */
   @DeveloperApi
-  def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+  def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive {
     // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
     // schema differs from the existing schema on any field data type.
     val encoder = RowEncoder(schema)
@@ -373,7 +371,7 @@ class SparkSession private(
    * @since 2.0.0
    */
   @DeveloperApi
-  def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
+  def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive {
     Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
   }
 
@@ -385,7 +383,7 @@ class SparkSession private(
    *
    * @since 2.0.0
    */
-  def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+  def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = withActive {
     val attributeSeq: Seq[AttributeReference] = getSchema(beanClass)
     val className = beanClass.getName
     val rowRdd = rdd.mapPartitions { iter =>
@@ -414,7 +412,7 @@ class SparkSession private(
    *          SELECT * queries will return the columns in an undefined order.
    * @since 1.6.0
    */
-  def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
+  def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = withActive {
     val attrSeq = getSchema(beanClass)
     val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
     Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
@@ -599,7 +597,7 @@ class SparkSession private(
    *
    * @since 2.0.0
    */
-  def sql(sqlText: String): DataFrame = {
+  def sql(sqlText: String): DataFrame = withActive {
     val tracker = new QueryPlanningTracker
     val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
       sessionState.sqlParser.parsePlan(sqlText)
@@ -751,6 +749,20 @@ class SparkSession private(
     }
   }
 
+  /**
+   * Execute a block of code with the this session set as the active session, and restore the
+   * previous session on completion.
+   */
+  private[sql] def withActive[T](block: => T): T = {
+    // Use the active session thread local directly to make sure we get the session that is actually
+    // set and not the default session. This to prevent that we promote the default session to the
+    // active session once we are done.
+    val old = SparkSession.activeThreadSession.get()
+    SparkSession.setActiveSession(this)
+    try block finally {
+      SparkSession.setActiveSession(old)
+    }
+  }
 }
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 38ef666..53b6b5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -63,13 +63,12 @@ class QueryExecution(
     }
   }
 
-  lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) {
-    SparkSession.setActiveSession(sparkSession)
+  lazy val analyzed: LogicalPlan = executePhase(QueryPlanningTracker.ANALYSIS) {
     // We can't clone `logical` here, which will reset the `_analyzed` flag.
     sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
   }
 
-  lazy val withCachedData: LogicalPlan = {
+  lazy val withCachedData: LogicalPlan = sparkSession.withActive {
     assertAnalyzed()
     assertSupported()
     // clone the plan to avoid sharing the plan instance between different stages like analyzing,
@@ -77,20 +76,20 @@ class QueryExecution(
     sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone())
   }
 
-  lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) {
+  lazy val optimizedPlan: LogicalPlan = executePhase(QueryPlanningTracker.OPTIMIZATION) {
     // clone the plan to avoid sharing the plan instance between different stages like analyzing,
     // optimizing and planning.
     sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker)
   }
 
-  lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) {
+  lazy val sparkPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) {
     // Clone the logical plan here, in case the planner rules change the states of the logical plan.
     QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone())
   }
 
   // executedPlan should not be used to initialize any SparkPlan. It should be
   // only used for execution.
-  lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) {
+  lazy val executedPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) {
     // clone the plan to avoid sharing the plan instance between different stages like analyzing,
     // optimizing and planning.
     QueryExecution.prepareForExecution(preparations, sparkPlan.clone())
@@ -116,6 +115,10 @@ class QueryExecution(
     QueryExecution.preparations(sparkSession)
   }
 
+  private def executePhase[T](phase: String)(block: => T): T = sparkSession.withActive {
+    tracker.measurePhase(phase)(block)
+  }
+
   def simpleString: String = simpleString(false)
 
   def simpleString(formatted: Boolean): String = withRedaction {
@@ -305,7 +308,6 @@ object QueryExecution {
       sparkSession: SparkSession,
       planner: SparkPlanner,
       plan: LogicalPlan): SparkPlan = {
-    SparkSession.setActiveSession(sparkSession)
     // TODO: We use next(), i.e. take the first plan returned by the planner, here for now,
     //       but we will implement to choose the best plan.
     planner.plan(ReturnAnswer(plan)).next()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 9f17781..59c503e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -62,9 +62,9 @@ object SQLExecution {
    * we can connect them with an execution.
    */
   def withNewExecutionId[T](
-      sparkSession: SparkSession,
       queryExecution: QueryExecution,
-      name: Option[String] = None)(body: => T): T = {
+      name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive {
+    val sparkSession = queryExecution.sparkSession
     val sc = sparkSession.sparkContext
     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
     val executionId = SQLExecution.nextExecutionId
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 83bc347..45a2ce1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -563,11 +563,11 @@ class MicroBatchExecution(
     }
 
     val nextBatch =
-      new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))
+      new Dataset(lastExecution, RowEncoder(lastExecution.analyzed.schema))
 
     val batchSinkProgress: Option[StreamWriterCommitProgress] =
       reportTimeTaken("addBatch") {
-      SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
+      SQLExecution.withNewExecutionId(lastExecution) {
         sink match {
           case s: Sink => s.addBatch(currentBatchId, nextBatch)
           case _: SupportsWrite =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index a9b724a..a109c21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -252,7 +252,7 @@ class ContinuousExecution(
 
       updateStatusMessage("Running")
       reportTimeTaken("runContinuous") {
-        SQLExecution.withNewExecutionId(sparkSessionForQuery, lastExecution) {
+        SQLExecution.withNewExecutionId(lastExecution) {
           lastExecution.executedPlan.execute()
         }
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index 3740b56..d3ef03e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -520,7 +520,7 @@ private[sql] object CatalogImpl {
     val encoded = data.map(d => enc.toRow(d).copy())
     val plan = new LocalRelation(enc.schema.toAttributes, encoded)
     val queryExecution = sparkSession.sessionState.executePlan(plan)
-    new Dataset[T](sparkSession, queryExecution, enc)
+    new Dataset[T](queryExecution, enc)
   }
 
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 233d678..b0bd612 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1899,6 +1899,16 @@ class DatasetSuite extends QueryTest
     val e = intercept[AnalysisException](spark.range(1).tail(-1))
     e.getMessage.contains("tail expression must be equal to or greater than 0")
   }
+
+  test("SparkSession.active should be the same instance after dataset operations") {
+    val active = SparkSession.getActiveSession.get
+    val clone = active.cloneSession()
+    val ds = new Dataset(clone, spark.range(10).queryExecution.logical, Encoders.INT)
+
+    ds.queryExecution.analyzed
+
+    assert(active eq SparkSession.getActiveSession.get)
+  }
 }
 
 object AssertExecutionId {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index da4727f..8328591 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -511,7 +511,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession {
     val df = session.sql(sql)
     val schema = df.schema.catalogString
     // Get answer, but also get rid of the #1234 expression ids that show up in explain plans
-    val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) {
+    val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) {
       hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
     }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
index 55b551d..9f4a335 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala
@@ -506,7 +506,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils
       override lazy val executedPlan = physicalPlan
     }
 
-    SQLExecution.withNewExecutionId(spark, dummyQueryExecution) {
+    SQLExecution.withNewExecutionId(dummyQueryExecution) {
       physicalPlan.execute().collect()
     }
 
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index 76d0784..cf0e5eb 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -295,7 +295,7 @@ private[hive] class SparkExecuteStatementOperation(
           resultList.get.iterator
         }
       }
-      dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray
+      dataTypes = result.schema.fields.map(_.dataType)
     } catch {
       // Actually do need to catch Throwable as some failures don't inherit from Exception and
       // HiveServer will silently swallow them.
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
index 362ac362..12fba0e 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala
@@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont
     try {
       context.sparkContext.setJobDescription(command)
       val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
-      hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) {
+      hiveResponse = SQLExecution.withNewExecutionId(execution) {
         hiveResultString(execution.executedPlan)
       }
       tableSchema = getResultSetSchema(execution)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 28e1db9..8b1f4c9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -346,8 +346,7 @@ abstract class HiveComparisonTest
         val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
           val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath))
           def getResult(): Seq[String] = {
-            SQLExecution.withNewExecutionId(
-              query.sparkSession, query)(hiveResultString(query.executedPlan))
+            SQLExecution.withNewExecutionId(query)(hiveResultString(query.executedPlan))
           }
           try { (query, prepareAnswer(query, getResult())) } catch {
             case e: Throwable =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index cc4592a..222244a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -501,7 +501,7 @@ private[hive] class TestHiveSparkSession(
       // has already set the execution id.
       if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) {
         // We don't actually have a `QueryExecution` here, use a fake one instead.
-        SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation())) {
+        SQLExecution.withNewExecutionId(new QueryExecution(this, OneRowRelation())) {
           createCmds.foreach(_())
         }
       } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org