You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2016/12/16 08:42:39 UTC

spark git commit: [SPARK-18850][SS] Make StreamExecution and progress classes serializable

Repository: spark
Updated Branches:
  refs/heads/master 78062b852 -> d7f3058e1


[SPARK-18850][SS] Make StreamExecution and progress classes serializable

## What changes were proposed in this pull request?

This PR adds StreamingQueryWrapper to make StreamExecution and progress classes serializable because it is too easy for it to get captured with normal usage. If StreamingQueryWrapper gets captured in a closure but no place calls its methods, it should not fail the Spark tasks. However if its methods are called, then this PR will throw a better message.

## How was this patch tested?

`test("StreamingQuery should be Serializable but cannot be used in executors")`
`test("progress classes should be Serializable")`

Author: Shixiong Zhu <sh...@databricks.com>

Closes #16272 from zsxwing/SPARK-18850.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7f3058e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7f3058e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7f3058e

Branch: refs/heads/master
Commit: d7f3058e17571d76a8b4c8932de6de81ce8d2e78
Parents: 78062b8
Author: Shixiong Zhu <sh...@databricks.com>
Authored: Fri Dec 16 00:42:39 2016 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri Dec 16 00:42:39 2016 -0800

----------------------------------------------------------------------
 .../execution/streaming/ProgressReporter.scala  |   4 +-
 .../streaming/StreamingQueryWrapper.scala       | 107 +++++++++++++++++++
 .../sql/streaming/StreamingQueryManager.scala   |   8 +-
 .../sql/streaming/StreamingQueryStatus.scala    |   6 +-
 .../apache/spark/sql/streaming/progress.scala   |   8 +-
 .../sql/streaming/FileStreamSourceSuite.scala   |   6 +-
 .../spark/sql/streaming/StreamSuite.scala       |   4 +-
 .../apache/spark/sql/streaming/StreamTest.scala |   3 +-
 .../streaming/StreamingQueryManagerSuite.scala  |   5 +-
 .../StreamingQueryStatusAndProgressSuite.scala  |  52 +++++++--
 .../sql/streaming/StreamingQuerySuite.scala     |  44 +++++++-
 .../test/DataStreamReaderWriterSuite.scala      |   4 +-
 12 files changed, 222 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index e40135f..2386f33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -159,8 +159,8 @@ trait ProgressReporter extends Logging {
       name = name,
       timestamp = formatTimestamp(currentTriggerStartTimestamp),
       batchId = currentBatchId,
-      durationMs = currentDurationsMs.toMap.mapValues(long2Long).asJava,
-      eventTime = executionStats.eventTimeStats.asJava,
+      durationMs = new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).asJava),
+      eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava),
       stateOperators = executionStats.stateOperators.toArray,
       sources = sourceProgress.toArray,
       sink = sinkProgress)

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala
new file mode 100644
index 0000000..020c9cb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.util.UUID
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus}
+
+/**
+ * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to
+ * get captured with normal usage. It's safe to capture the query but not use it in executors.
+ * However, if the user tries to call its methods, it will throw `IllegalStateException`.
+ */
+class StreamingQueryWrapper(@transient private val _streamingQuery: StreamExecution)
+  extends StreamingQuery with Serializable {
+
+  def streamingQuery: StreamExecution = {
+    /** Assert the codes run in the driver. */
+    if (_streamingQuery == null) {
+      throw new IllegalStateException("StreamingQuery cannot be used in executors")
+    }
+    _streamingQuery
+  }
+
+  override def name: String = {
+    streamingQuery.name
+  }
+
+  override def id: UUID = {
+    streamingQuery.id
+  }
+
+  override def runId: UUID = {
+    streamingQuery.runId
+  }
+
+  override def awaitTermination(): Unit = {
+    streamingQuery.awaitTermination()
+  }
+
+  override def awaitTermination(timeoutMs: Long): Boolean = {
+    streamingQuery.awaitTermination(timeoutMs)
+  }
+
+  override def stop(): Unit = {
+    streamingQuery.stop()
+  }
+
+  override def processAllAvailable(): Unit = {
+    streamingQuery.processAllAvailable()
+  }
+
+  override def isActive: Boolean = {
+    streamingQuery.isActive
+  }
+
+  override def lastProgress: StreamingQueryProgress = {
+    streamingQuery.lastProgress
+  }
+
+  override def explain(): Unit = {
+    streamingQuery.explain()
+  }
+
+  override def explain(extended: Boolean): Unit = {
+    streamingQuery.explain(extended)
+  }
+
+  /**
+   * This method is called in Python. Python cannot call "explain" directly as it outputs in the JVM
+   * process, which may not be visible in Python process.
+   */
+  def explainInternal(extended: Boolean): String = {
+    streamingQuery.explainInternal(extended)
+  }
+
+  override def sparkSession: SparkSession = {
+    streamingQuery.sparkSession
+  }
+
+  override def recentProgress: Array[StreamingQueryProgress] = {
+    streamingQuery.recentProgress
+  }
+
+  override def status: StreamingQueryStatus = {
+    streamingQuery.status
+  }
+
+  override def exception: Option[StreamingQueryException] = {
+    streamingQuery.exception
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 6ebd706..8c26ee2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -193,7 +193,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
       useTempCheckpointLocation: Boolean,
       recoverFromCheckpointLocation: Boolean,
       trigger: Trigger,
-      triggerClock: Clock): StreamExecution = {
+      triggerClock: Clock): StreamingQueryWrapper = {
     val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified =>
       new Path(userSpecified).toUri.toString
     }.orElse {
@@ -229,7 +229,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
       UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode)
     }
 
-    new StreamExecution(
+    new StreamingQueryWrapper(new StreamExecution(
       sparkSession,
       userSpecifiedName.orNull,
       checkpointLocation,
@@ -237,7 +237,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
       sink,
       trigger,
       triggerClock,
-      outputMode)
+      outputMode))
   }
 
   /**
@@ -301,7 +301,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) {
       // As it's provided by the user and can run arbitrary codes, we must not hold any lock here.
       // Otherwise, it's easy to cause dead-lock, or block too long if the user codes take a long
       // time to finish.
-      query.start()
+      query.streamingQuery.start()
     } catch {
       case e: Throwable =>
         activeQueriesLock.synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
index 44befa0..c2befa6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
@@ -22,7 +22,10 @@ import org.json4s.JsonAST.JValue
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
+import org.apache.spark.annotation.Experimental
+
 /**
+ * :: Experimental ::
  * Reports information about the instantaneous status of a streaming query.
  *
  * @param message A human readable description of what the stream is currently doing.
@@ -32,10 +35,11 @@ import org.json4s.jackson.JsonMethods._
  *
  * @since 2.1.0
  */
+@Experimental
 class StreamingQueryStatus protected[sql](
     val message: String,
     val isDataAvailable: Boolean,
-    val isTriggerActive: Boolean) {
+    val isTriggerActive: Boolean) extends Serializable {
 
   /** The compact JSON representation of this status. */
   def json: String = compact(render(jsonValue))

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index e219cfd..bea0b9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -38,7 +38,7 @@ import org.apache.spark.annotation.Experimental
 @Experimental
 class StateOperatorProgress private[sql](
     val numRowsTotal: Long,
-    val numRowsUpdated: Long) {
+    val numRowsUpdated: Long) extends Serializable {
 
   /** The compact JSON representation of this progress. */
   def json: String = compact(render(jsonValue))
@@ -90,7 +90,7 @@ class StreamingQueryProgress private[sql](
   val eventTime: ju.Map[String, String],
   val stateOperators: Array[StateOperatorProgress],
   val sources: Array[SourceProgress],
-  val sink: SinkProgress) {
+  val sink: SinkProgress) extends Serializable {
 
   /** The aggregate (across all sources) number of records processed in a trigger. */
   def numInputRows: Long = sources.map(_.numInputRows).sum
@@ -157,7 +157,7 @@ class SourceProgress protected[sql](
   val endOffset: String,
   val numInputRows: Long,
   val inputRowsPerSecond: Double,
-  val processedRowsPerSecond: Double) {
+  val processedRowsPerSecond: Double) extends Serializable {
 
   /** The compact JSON representation of this progress. */
   def json: String = compact(render(jsonValue))
@@ -197,7 +197,7 @@ class SourceProgress protected[sql](
  */
 @Experimental
 class SinkProgress protected[sql](
-    val description: String) {
+    val description: String) extends Serializable {
 
   /** The compact JSON representation of this progress. */
   def json: String = compact(render(jsonValue))

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index b96ccb4..cbcc983 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -746,7 +746,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
         .format("memory")
         .queryName("file_data")
         .start()
-        .asInstanceOf[StreamExecution]
+        .asInstanceOf[StreamingQueryWrapper]
+        .streamingQuery
       q.processAllAvailable()
       val memorySink = q.sink.asInstanceOf[MemorySink]
       val fileSource = q.logicalPlan.collect {
@@ -836,7 +837,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
       df.explain()
 
       val q = df.writeStream.queryName("file_explain").format("memory").start()
-        .asInstanceOf[StreamExecution]
+        .asInstanceOf[StreamingQueryWrapper]
+        .streamingQuery
       try {
         assert("No physical plan. Waiting for data." === q.explainInternal(false))
         assert("No physical plan. Waiting for data." === q.explainInternal(true))

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 6bdf479..4a64054 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.sources.StreamSourceProvider
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
-import org.apache.spark.util.ManualClock
 
 class StreamSuite extends StreamTest {
 
@@ -278,7 +277,8 @@ class StreamSuite extends StreamTest {
     // Test `explain` not throwing errors
     df.explain()
     val q = df.writeStream.queryName("memory_explain").format("memory").start()
-      .asInstanceOf[StreamExecution]
+      .asInstanceOf[StreamingQueryWrapper]
+      .streamingQuery
     try {
       assert("No physical plan. Waiting for data." === q.explainInternal(false))
       assert("No physical plan. Waiting for data." === q.explainInternal(true))

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 10f267e..6fbbbb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -355,7 +355,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
                   outputMode,
                   trigger = trigger,
                   triggerClock = triggerClock)
-                .asInstanceOf[StreamExecution]
+                .asInstanceOf[StreamingQueryWrapper]
+                .streamingQuery
             currentStream.microBatchThread.setUncaughtExceptionHandler(
               new UncaughtExceptionHandler {
                 override def uncaughtException(t: Thread, e: Throwable): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
index 1742a54..8e16fd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala
@@ -244,7 +244,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter {
     failAfter(streamingTimeout) {
       val queries = withClue("Error starting queries") {
         datasets.zipWithIndex.map { case (ds, i) =>
-          @volatile var query: StreamExecution = null
+          var query: StreamingQuery = null
           try {
             val df = ds.toDF
             val metadataRoot =
@@ -256,7 +256,6 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter {
                 .option("checkpointLocation", metadataRoot)
                 .outputMode("append")
                 .start()
-                .asInstanceOf[StreamExecution]
           } catch {
             case NonFatal(e) =>
               if (query != null) query.stop()
@@ -304,7 +303,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter {
       Thread.sleep(stopAfter.toMillis)
       if (withError) {
         logDebug(s"Terminating query ${queryToStop.name} with error")
-        queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
+        queryToStop.asInstanceOf[StreamingQueryWrapper].streamingQuery.logicalPlan.collect {
           case StreamingExecutionRelation(source, _) =>
             source.asInstanceOf[MemoryStream[Int]].addData(0)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
index c970743..34bf398 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
@@ -24,11 +24,12 @@ import scala.collection.JavaConverters._
 import org.json4s._
 import org.json4s.jackson.JsonMethods._
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._
 
 
-class StreamingQueryStatusAndProgressSuite extends SparkFunSuite {
+class StreamingQueryStatusAndProgressSuite extends StreamTest {
 
   test("StreamingQueryProgress - prettyJson") {
     val json1 = testProgress1.prettyJson
@@ -128,6 +129,42 @@ class StreamingQueryStatusAndProgressSuite extends SparkFunSuite {
   test("StreamingQueryStatus - toString") {
     assert(testStatus.toString === testStatus.prettyJson)
   }
+
+  test("progress classes should be Serializable") {
+    import testImplicits._
+
+    val inputData = MemoryStream[Int]
+
+    val query = inputData.toDS()
+      .groupBy($"value")
+      .agg(count("*"))
+      .writeStream
+      .queryName("progress_serializable_test")
+      .format("memory")
+      .outputMode("complete")
+      .start()
+    try {
+      inputData.addData(1, 2, 3)
+      query.processAllAvailable()
+
+      val progress = query.recentProgress
+
+      // Make sure it generates the progress objects we want to test
+      assert(progress.exists { p =>
+        p.sources.size >= 1 && p.stateOperators.size >= 1 && p.sink != null
+      })
+
+      val array = spark.sparkContext.parallelize(progress).collect()
+      assert(array.length === progress.length)
+      array.zip(progress).foreach { case (p1, p2) =>
+        // Make sure we did serialize and deserialize the object
+        assert(p1 ne p2)
+        assert(p1.json === p2.json)
+      }
+    } finally {
+      query.stop()
+    }
+  }
 }
 
 object StreamingQueryStatusAndProgressSuite {
@@ -137,12 +174,12 @@ object StreamingQueryStatusAndProgressSuite {
     name = "myName",
     timestamp = "2016-12-05T20:54:20.827Z",
     batchId = 2L,
-    durationMs = Map("total" -> 0L).mapValues(long2Long).asJava,
-    eventTime = Map(
+    durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava),
+    eventTime = new java.util.HashMap(Map(
       "max" -> "2016-12-05T20:54:20.827Z",
       "min" -> "2016-12-05T20:54:20.827Z",
       "avg" -> "2016-12-05T20:54:20.827Z",
-      "watermark" -> "2016-12-05T20:54:20.827Z").asJava,
+      "watermark" -> "2016-12-05T20:54:20.827Z").asJava),
     stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)),
     sources = Array(
       new SourceProgress(
@@ -163,8 +200,9 @@ object StreamingQueryStatusAndProgressSuite {
     name = null, // should not be present in the json
     timestamp = "2016-12-05T20:54:20.827Z",
     batchId = 2L,
-    durationMs = Map("total" -> 0L).mapValues(long2Long).asJava,
-    eventTime = Map.empty[String, String].asJava,  // empty maps should be handled correctly
+    durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava),
+    // empty maps should be handled correctly
+    eventTime = new java.util.HashMap(Map.empty[String, String].asJava),
     stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)),
     sources = Array(
       new SourceProgress(

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index b052bd9..6c4bb35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.SparkException
 import org.apache.spark.sql.execution.streaming._
@@ -439,6 +439,48 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
     }
   }
 
+  test("StreamingQuery should be Serializable but cannot be used in executors") {
+    def startQuery(ds: Dataset[Int], queryName: String): StreamingQuery = {
+      ds.writeStream
+        .queryName(queryName)
+        .format("memory")
+        .start()
+    }
+
+    val input = MemoryStream[Int]
+    val q1 = startQuery(input.toDS, "stream_serializable_test_1")
+    val q2 = startQuery(input.toDS.map { i =>
+      // Emulate that `StreamingQuery` get captured with normal usage unintentionally.
+      // It should not fail the query.
+      q1
+      i
+    }, "stream_serializable_test_2")
+    val q3 = startQuery(input.toDS.map { i =>
+      // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear
+      // error message.
+      q1.explain()
+      i
+    }, "stream_serializable_test_3")
+    try {
+      input.addData(1)
+
+      // q2 should not fail since it doesn't use `q1` in the closure
+      q2.processAllAvailable()
+
+      // The user calls `StreamingQuery` in the closure and it should fail
+      val e = intercept[StreamingQueryException] {
+        q3.processAllAvailable()
+      }
+      assert(e.getCause.isInstanceOf[SparkException])
+      assert(e.getCause.getCause.isInstanceOf[IllegalStateException])
+      assert(e.getMessage.contains("StreamingQuery cannot be used in executors"))
+    } finally {
+      q1.stop()
+      q2.stop()
+      q3.stop()
+    }
+  }
+
   /** Create a streaming DF that only execute one batch in which it returns the given static DF */
   private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
     require(!triggerDF.isStreaming)

http://git-wip-us.apache.org/repos/asf/spark/blob/d7f3058e/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index f4a6290..acac0bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -339,7 +339,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
       .start()
     q.stop()
 
-    assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000))
+    assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.trigger == ProcessingTime(10000))
 
     q = df.writeStream
       .format("org.apache.spark.sql.streaming.test")
@@ -348,7 +348,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
       .start()
     q.stop()
 
-    assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000))
+    assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.trigger == ProcessingTime(100000))
   }
 
   test("source metadataPath") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org