You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2019/04/29 16:44:54 UTC

[spark] branch master updated: [SPARK-23014][SS] Fully remove V1 memory sink.

This is an automated email from the ASF dual-hosted git repository.

vanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fb6b19a  [SPARK-23014][SS] Fully remove V1 memory sink.
fb6b19a is described below

commit fb6b19ab7c38aa0e1b2e208da86897bf3c07ae00
Author: Gabor Somogyi <ga...@gmail.com>
AuthorDate: Mon Apr 29 09:44:23 2019 -0700

    [SPARK-23014][SS] Fully remove V1 memory sink.
    
    ## What changes were proposed in this pull request?
    
    There is a MemorySink v2 already so v1 can be removed. In this PR I've removed it completely.
    What this PR contains:
    * V1 memory sink removal
    * V2 memory sink renamed to become the only implementation
    * Since DSv2 sends exceptions in a chained format (linking them with cause field) I've made python side compliant
    * Adapted all the tests
    
    ## How was this patch tested?
    
    Existing unit tests.
    
    Closes #24403 from gaborgsomogyi/SPARK-23014.
    
    Authored-by: Gabor Somogyi <ga...@gmail.com>
    Signed-off-by: Marcelo Vanzin <va...@cloudera.com>
---
 .../main/scala/org/apache/spark/TestUtils.scala    | 14 ++++
 .../spark/sql/kafka010/KafkaContinuousTest.scala   |  1 -
 .../apache/spark/ml/recommendation/ALSSuite.scala  | 10 ++-
 .../scala/org/apache/spark/ml/util/MLTest.scala    | 10 +--
 python/pyspark/sql/streaming.py                    |  2 +-
 python/pyspark/sql/tests/test_streaming.py         | 12 ++-
 python/pyspark/sql/utils.py                        | 49 ++++++++----
 .../spark/sql/execution/SparkStrategies.scala      |  5 +-
 .../spark/sql/execution/streaming/memory.scala     | 92 +---------------------
 .../sources/{memoryV2.scala => memory.scala}       | 14 ++--
 .../spark/sql/streaming/DataStreamWriter.scala     | 12 +--
 .../sql/execution/streaming/MemorySinkSuite.scala  | 88 ++++++++++++++++-----
 .../execution/streaming/MemorySinkV2Suite.scala    | 66 ----------------
 .../sql/streaming/EventTimeWatermarkSuite.scala    |  1 +
 .../sql/streaming/FileStreamSourceSuite.scala      |  1 +
 .../apache/spark/sql/streaming/StreamSuite.scala   | 16 ++--
 .../apache/spark/sql/streaming/StreamTest.scala    | 16 ++--
 .../sql/streaming/StreamingAggregationSuite.scala  |  1 +
 .../streaming/StreamingQueryListenerSuite.scala    |  2 +-
 .../spark/sql/streaming/StreamingQuerySuite.scala  | 22 +++---
 .../ContinuousQueryStatusAndProgressSuite.scala    |  2 +-
 .../sql/streaming/continuous/ContinuousSuite.scala |  9 +--
 22 files changed, 181 insertions(+), 264 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index c2ebd38..c97b10e 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -193,6 +193,20 @@ private[spark] object TestUtils {
   }
 
   /**
+   * Asserts that exception message contains the message. Please note this checks all
+   * exceptions in the tree.
+   */
+  def assertExceptionMsg(exception: Throwable, msg: String): Unit = {
+    var e = exception
+    var contains = e.getMessage.contains(msg)
+    while (e.getCause != null && !contains) {
+      e = e.getCause
+      contains = e.getMessage.contains(msg)
+    }
+    assert(contains, s"Exception tree doesn't contain the expected message: $msg")
+  }
+
+  /**
    * Test if a command is available.
    */
   def testCommandAvailable(command: String): Boolean = {
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
index ad1c2c5..9ee8cbf 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
@@ -30,7 +30,6 @@ import org.apache.spark.sql.test.TestSparkSession
 // Trait to configure StreamTest for kafka continuous execution tests.
 trait KafkaContinuousTest extends KafkaSourceTest {
   override val defaultTrigger = Trigger.Continuous(1000)
-  override val defaultUseV2Sink = true
 
   // We need more than the default local[2] to be able to schedule all partitions simultaneously.
   override protected def createSparkSession = new TestSparkSession(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 6d0321c..5ba3928 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -695,12 +695,14 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
     withClue("transform should fail when ids exceed integer range. ") {
       val model = als.fit(df)
       def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = {
-        assert(intercept[SparkException] {
+        val e1 = intercept[SparkException] {
           model.transform(dataFrame).first
-        }.getMessage.contains(msg))
-        assert(intercept[StreamingQueryException] {
+        }
+        TestUtils.assertExceptionMsg(e1, msg)
+        val e2 = intercept[StreamingQueryException] {
           testTransformer[A](dataFrame, model, "prediction") { _ => }
-        }.getMessage.contains(msg))
+        }
+        TestUtils.assertExceptionMsg(e2, msg)
       }
       testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"),
         df("item")))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index c23b6d8..8a0a48f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -21,7 +21,7 @@ import java.io.File
 
 import org.scalatest.Suite
 
-import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext}
+import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext, TestUtils}
 import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
 import org.apache.spark.ml.{Model, PredictionModel, Transformer}
 import org.apache.spark.ml.linalg.Vector
@@ -129,21 +129,17 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
     expectedMessagePart : String,
     firstResultCol: String) {
 
-    def hasExpectedMessage(exception: Throwable): Boolean =
-      exception.getMessage.contains(expectedMessagePart) ||
-        (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart))
-
     withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") {
       val exceptionOnDf = intercept[Throwable] {
         testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit)
       }
-      assert(hasExpectedMessage(exceptionOnDf))
+      TestUtils.assertExceptionMsg(exceptionOnDf, expectedMessagePart)
     }
     withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") {
       val exceptionOnStreamData = intercept[Throwable] {
         testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit)
       }
-      assert(hasExpectedMessage(exceptionOnStreamData))
+      TestUtils.assertExceptionMsg(exceptionOnStreamData, expectedMessagePart)
     }
   }
 
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index fa25267..d15779b 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -186,7 +186,7 @@ class StreamingQuery(object):
             je = self._jsq.exception().get()
             msg = je.toString().split(': ', 1)[1]  # Drop the Java StreamingQueryException type info
             stackTrace = '\n\t at '.join(map(lambda x: x.toString(), je.getStackTrace()))
-            return StreamingQueryException(msg, stackTrace)
+            return StreamingQueryException(msg, stackTrace, je.getCause())
         else:
             return None
 
diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py
index ac4b691..bbd3ddb 100644
--- a/python/pyspark/sql/tests/test_streaming.py
+++ b/python/pyspark/sql/tests/test_streaming.py
@@ -225,11 +225,19 @@ class StreamingTests(ReusedSQLTestCase):
             self.fail("bad udf should fail the query")
         except StreamingQueryException as e:
             # This is expected
-            self.assertTrue("ZeroDivisionError" in e.desc)
+            self._assert_exception_tree_contains_msg(e, "ZeroDivisionError")
         finally:
             sq.stop()
         self.assertTrue(type(sq.exception()) is StreamingQueryException)
-        self.assertTrue("ZeroDivisionError" in sq.exception().desc)
+        self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError")
+
+    def _assert_exception_tree_contains_msg(self, exception, msg):
+        e = exception
+        contains = msg in e.desc
+        while e.cause is not None and not contains:
+            e = e.cause
+            contains = msg in e.desc
+        self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg)
 
     def test_query_manager_await_termination(self):
         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 709d3a0..1c96e33 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -19,9 +19,10 @@ import py4j
 
 
 class CapturedException(Exception):
-    def __init__(self, desc, stackTrace):
+    def __init__(self, desc, stackTrace, cause=None):
         self.desc = desc
         self.stackTrace = stackTrace
+        self.cause = convert_exception(cause) if cause is not None else None
 
     def __str__(self):
         return repr(self.desc)
@@ -57,27 +58,41 @@ class QueryExecutionException(CapturedException):
     """
 
 
+class UnknownException(CapturedException):
+    """
+    None of the above exceptions.
+    """
+
+
+def convert_exception(e):
+    s = e.toString()
+    stackTrace = '\n\t at '.join(map(lambda x: x.toString(), e.getStackTrace()))
+    c = e.getCause()
+    if s.startswith('org.apache.spark.sql.AnalysisException: '):
+        return AnalysisException(s.split(': ', 1)[1], stackTrace, c)
+    if s.startswith('org.apache.spark.sql.catalyst.analysis'):
+        return AnalysisException(s.split(': ', 1)[1], stackTrace, c)
+    if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
+        return ParseException(s.split(': ', 1)[1], stackTrace, c)
+    if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '):
+        return StreamingQueryException(s.split(': ', 1)[1], stackTrace, c)
+    if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '):
+        return QueryExecutionException(s.split(': ', 1)[1], stackTrace, c)
+    if s.startswith('java.lang.IllegalArgumentException: '):
+        return IllegalArgumentException(s.split(': ', 1)[1], stackTrace, c)
+    return UnknownException(s, stackTrace, c)
+
+
 def capture_sql_exception(f):
     def deco(*a, **kw):
         try:
             return f(*a, **kw)
         except py4j.protocol.Py4JJavaError as e:
-            s = e.java_exception.toString()
-            stackTrace = '\n\t at '.join(map(lambda x: x.toString(),
-                                             e.java_exception.getStackTrace()))
-            if s.startswith('org.apache.spark.sql.AnalysisException: '):
-                raise AnalysisException(s.split(': ', 1)[1], stackTrace)
-            if s.startswith('org.apache.spark.sql.catalyst.analysis'):
-                raise AnalysisException(s.split(': ', 1)[1], stackTrace)
-            if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
-                raise ParseException(s.split(': ', 1)[1], stackTrace)
-            if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '):
-                raise StreamingQueryException(s.split(': ', 1)[1], stackTrace)
-            if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '):
-                raise QueryExecutionException(s.split(': ', 1)[1], stackTrace)
-            if s.startswith('java.lang.IllegalArgumentException: '):
-                raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
-            raise
+            converted = convert_exception(e.java_exception)
+            if not isinstance(converted, UnknownException):
+                raise converted
+            else:
+                raise
     return deco
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c0a28fa..831fc73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
 import org.apache.spark.sql.execution.python._
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
+import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
 import org.apache.spark.sql.types.StructType
@@ -624,9 +624,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
 
       case MemoryPlan(sink, output) =>
-        val encoder = RowEncoder(sink.schema)
-        LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
-      case MemoryPlanV2(sink, output) =>
         val encoder = RowEncoder(StructType.fromAttributes(output))
         LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 0dcbdd3..6efde0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -22,23 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
-import scala.util.control.NonFatal
+import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.encoderFor
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2._
 import org.apache.spark.sql.sources.v2.reader._
 import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2}
-import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
@@ -276,85 +272,3 @@ trait MemorySinkBase extends BaseStreamingSink {
   def dataSinceBatch(sinceBatchId: Long): Seq[Row]
   def latestBatchId: Option[Long]
 }
-
-/**
- * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
- * tests and does not provide durability.
- */
-class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
-  with MemorySinkBase with Logging {
-
-  private case class AddedData(batchId: Long, data: Array[Row])
-
-  /** An order list of batches that have been written to this [[Sink]]. */
-  @GuardedBy("this")
-  private val batches = new ArrayBuffer[AddedData]()
-
-  /** Returns all rows that are stored in this [[Sink]]. */
-  def allData: Seq[Row] = synchronized {
-    batches.flatMap(_.data)
-  }
-
-  def latestBatchId: Option[Long] = synchronized {
-    batches.lastOption.map(_.batchId)
-  }
-
-  def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) }
-
-  def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
-    batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
-  }
-
-  def toDebugString: String = synchronized {
-    batches.map { case AddedData(batchId, data) =>
-      val dataStr = try data.mkString(" ") catch {
-        case NonFatal(e) => "[Error converting to string]"
-      }
-      s"$batchId: $dataStr"
-    }.mkString("\n")
-  }
-
-  override def addBatch(batchId: Long, data: DataFrame): Unit = {
-    val notCommitted = synchronized {
-      latestBatchId.isEmpty || batchId > latestBatchId.get
-    }
-    if (notCommitted) {
-      logDebug(s"Committing batch $batchId to $this")
-      outputMode match {
-        case Append | Update =>
-          val rows = AddedData(batchId, data.collect())
-          synchronized { batches += rows }
-
-        case Complete =>
-          val rows = AddedData(batchId, data.collect())
-          synchronized {
-            batches.clear()
-            batches += rows
-          }
-
-        case _ =>
-          throw new IllegalArgumentException(
-            s"Output mode $outputMode is not supported by MemorySink")
-      }
-    } else {
-      logDebug(s"Skipping already committed batch: $batchId")
-    }
-  }
-
-  def clear(): Unit = synchronized {
-    batches.clear()
-  }
-
-  override def toString(): String = "MemorySink"
-}
-
-/**
- * Used to query the data that has been written into a [[MemorySink]].
- */
-case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
-  def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
-
-  private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)
-
-  override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
similarity index 93%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
index 219e25c..9008c63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
@@ -43,9 +43,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
  * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
  * tests and does not provide durability.
  */
-class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Logging {
+class MemorySink extends Table with SupportsWrite with MemorySinkBase with Logging {
 
-  override def name(): String = "MemorySinkV2"
+  override def name(): String = "MemorySink"
 
   override def schema(): StructType = StructType(Nil)
 
@@ -69,7 +69,7 @@ class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Log
       }
 
       override def buildForStreaming(): StreamingWrite = {
-        new MemoryStreamingWrite(MemorySinkV2.this, inputSchema, needTruncate)
+        new MemoryStreamingWrite(MemorySink.this, inputSchema, needTruncate)
       }
     }
   }
@@ -130,14 +130,14 @@ class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Log
     batches.clear()
   }
 
-  override def toString(): String = "MemorySinkV2"
+  override def toString(): String = "MemorySink"
 }
 
 case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row])
   extends WriterCommitMessage {}
 
 class MemoryStreamingWrite(
-    val sink: MemorySinkV2, schema: StructType, needTruncate: Boolean)
+    val sink: MemorySink, schema: StructType, needTruncate: Boolean)
   extends StreamingWrite {
 
   override def createStreamingWriterFactory: MemoryWriterFactory = {
@@ -195,9 +195,9 @@ class MemoryDataWriter(partition: Int, schema: StructType)
 
 
 /**
- * Used to query the data that has been written into a [[MemorySinkV2]].
+ * Used to query the data that has been written into a [[MemorySink]].
  */
-case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
+case class MemoryPlan(sink: MemorySink, override val output: Seq[Attribute]) extends LeafNode {
   private val sizePerRow = EstimationUtils.getSizePerRow(output)
 
   override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index d2df3a5..2f12efe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -254,16 +254,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
       if (extraOptions.get("queryName").isEmpty) {
         throw new AnalysisException("queryName must be specified for memory sink")
       }
-      val (sink, resultDf) = trigger match {
-        case _: ContinuousTrigger =>
-          val s = new MemorySinkV2()
-          val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
-          (s, r)
-        case _ =>
-          val s = new MemorySink(df.schema, outputMode)
-          val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
-          (s, r)
-      }
+      val sink = new MemorySink()
+      val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes))
       val chkpointLoc = extraOptions.get("checkpointLocation")
       val recoverFromChkpoint = outputMode == OutputMode.Complete()
       val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
index 3bc36ce..3ead91f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
@@ -22,6 +22,8 @@ import scala.language.implicitConversions
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.streaming.sources._
 import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.apache.spark.util.Utils
@@ -36,7 +38,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("directly add data in Append output mode") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Append)
+    val sink = new MemorySink
+    val addBatch = addBatchFunc(sink, false) _
 
     // Before adding data, check output
     assert(sink.latestBatchId === None)
@@ -44,25 +47,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
     checkAnswer(sink.allData, Seq.empty)
 
     // Add batch 0 and check outputs
-    sink.addBatch(0, 1 to 3)
+    addBatch(0, 1 to 3)
     assert(sink.latestBatchId === Some(0))
     checkAnswer(sink.latestBatchData, 1 to 3)
     checkAnswer(sink.allData, 1 to 3)
 
     // Add batch 1 and check outputs
-    sink.addBatch(1, 4 to 6)
+    addBatch(1, 4 to 6)
     assert(sink.latestBatchId === Some(1))
     checkAnswer(sink.latestBatchData, 4 to 6)
     checkAnswer(sink.allData, 1 to 6)     // new data should get appended to old data
 
     // Re-add batch 1 with different data, should not be added and outputs should not be changed
-    sink.addBatch(1, 7 to 9)
+    addBatch(1, 7 to 9)
     assert(sink.latestBatchId === Some(1))
     checkAnswer(sink.latestBatchData, 4 to 6)
     checkAnswer(sink.allData, 1 to 6)
 
     // Add batch 2 and check outputs
-    sink.addBatch(2, 7 to 9)
+    addBatch(2, 7 to 9)
     assert(sink.latestBatchId === Some(2))
     checkAnswer(sink.latestBatchData, 7 to 9)
     checkAnswer(sink.allData, 1 to 9)
@@ -70,7 +73,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("directly add data in Update output mode") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Update)
+    val sink = new MemorySink
+    val addBatch = addBatchFunc(sink, false) _
 
     // Before adding data, check output
     assert(sink.latestBatchId === None)
@@ -78,25 +82,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
     checkAnswer(sink.allData, Seq.empty)
 
     // Add batch 0 and check outputs
-    sink.addBatch(0, 1 to 3)
+    addBatch(0, 1 to 3)
     assert(sink.latestBatchId === Some(0))
     checkAnswer(sink.latestBatchData, 1 to 3)
     checkAnswer(sink.allData, 1 to 3)
 
     // Add batch 1 and check outputs
-    sink.addBatch(1, 4 to 6)
+    addBatch(1, 4 to 6)
     assert(sink.latestBatchId === Some(1))
     checkAnswer(sink.latestBatchData, 4 to 6)
     checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data
 
     // Re-add batch 1 with different data, should not be added and outputs should not be changed
-    sink.addBatch(1, 7 to 9)
+    addBatch(1, 7 to 9)
     assert(sink.latestBatchId === Some(1))
     checkAnswer(sink.latestBatchData, 4 to 6)
     checkAnswer(sink.allData, 1 to 6)
 
     // Add batch 2 and check outputs
-    sink.addBatch(2, 7 to 9)
+    addBatch(2, 7 to 9)
     assert(sink.latestBatchId === Some(2))
     checkAnswer(sink.latestBatchData, 7 to 9)
     checkAnswer(sink.allData, 1 to 9)
@@ -104,7 +108,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("directly add data in Complete output mode") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Complete)
+    val sink = new MemorySink
+    val addBatch = addBatchFunc(sink, true) _
 
     // Before adding data, check output
     assert(sink.latestBatchId === None)
@@ -112,25 +117,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
     checkAnswer(sink.allData, Seq.empty)
 
     // Add batch 0 and check outputs
-    sink.addBatch(0, 1 to 3)
+    addBatch(0, 1 to 3)
     assert(sink.latestBatchId === Some(0))
     checkAnswer(sink.latestBatchData, 1 to 3)
     checkAnswer(sink.allData, 1 to 3)
 
     // Add batch 1 and check outputs
-    sink.addBatch(1, 4 to 6)
+    addBatch(1, 4 to 6)
     assert(sink.latestBatchId === Some(1))
     checkAnswer(sink.latestBatchData, 4 to 6)
     checkAnswer(sink.allData, 4 to 6)     // new data should replace old data
 
     // Re-add batch 1 with different data, should not be added and outputs should not be changed
-    sink.addBatch(1, 7 to 9)
+    addBatch(1, 7 to 9)
     assert(sink.latestBatchId === Some(1))
     checkAnswer(sink.latestBatchData, 4 to 6)
     checkAnswer(sink.allData, 4 to 6)
 
     // Add batch 2 and check outputs
-    sink.addBatch(2, 7 to 9)
+    addBatch(2, 7 to 9)
     assert(sink.latestBatchId === Some(2))
     checkAnswer(sink.latestBatchData, 7 to 9)
     checkAnswer(sink.allData, 7 to 9)
@@ -211,18 +216,19 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("MemoryPlan statistics") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Append)
-    val plan = new MemoryPlan(sink)
+    val sink = new MemorySink
+    val addBatch = addBatchFunc(sink, false) _
+    val plan = new MemoryPlan(sink, schema.toAttributes)
 
     // Before adding data, check output
     checkAnswer(sink.allData, Seq.empty)
     assert(plan.stats.sizeInBytes === 0)
 
-    sink.addBatch(0, 1 to 3)
+    addBatch(0, 1 to 3)
     plan.invalidateStatsCache()
     assert(plan.stats.sizeInBytes === 36)
 
-    sink.addBatch(1, 4 to 6)
+    addBatch(1, 4 to 6)
     plan.invalidateStatsCache()
     assert(plan.stats.sizeInBytes === 72)
   }
@@ -285,6 +291,50 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
     }
   }
 
+  test("data writer") {
+    val partition = 1234
+    val writer = new MemoryDataWriter(
+      partition, new StructType().add("i", "int"))
+    writer.write(InternalRow(1))
+    writer.write(InternalRow(2))
+    writer.write(InternalRow(44))
+    val msg = writer.commit()
+    assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
+    assert(msg.partition == partition)
+
+    // Buffer should be cleared, so repeated commits should give empty.
+    assert(writer.commit().data.isEmpty)
+  }
+
+  test("streaming writer") {
+    val sink = new MemorySink
+    val write = new MemoryStreamingWrite(
+      sink, new StructType().add("i", "int"), needTruncate = false)
+    write.commit(0,
+      Array(
+        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
+        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
+        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
+      ))
+    assert(sink.latestBatchId.contains(0))
+    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
+    write.commit(19,
+      Array(
+        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
+        MemoryWriterCommitMessage(0, Seq(Row(33)))
+      ))
+    assert(sink.latestBatchId.contains(19))
+    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
+
+    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
+  }
+
+  private def addBatchFunc(sink: MemorySink, needTruncate: Boolean)(
+      batchId: Long,
+      vals: Seq[Int]): Unit = {
+    sink.write(batchId, needTruncate, vals.map(Row(_)).toArray)
+  }
+
   private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = {
     checkAnswer(
       sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
deleted file mode 100644
index a90acf8..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.streaming.sources._
-import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
-import org.apache.spark.sql.types.StructType
-
-class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
-  test("data writer") {
-    val partition = 1234
-    val writer = new MemoryDataWriter(
-      partition, new StructType().add("i", "int"))
-    writer.write(InternalRow(1))
-    writer.write(InternalRow(2))
-    writer.write(InternalRow(44))
-    val msg = writer.commit()
-    assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
-    assert(msg.partition == partition)
-
-    // Buffer should be cleared, so repeated commits should give empty.
-    assert(writer.commit().data.isEmpty)
-  }
-
-  test("streaming writer") {
-    val sink = new MemorySinkV2
-    val write = new MemoryStreamingWrite(
-      sink, new StructType().add("i", "int"), needTruncate = false)
-    write.commit(0,
-      Array(
-        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
-        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
-        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
-      ))
-    assert(sink.latestBatchId.contains(0))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
-    write.commit(19,
-      Array(
-        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
-        MemoryWriterCommitMessage(0, Seq(Row(33)))
-      ))
-    assert(sink.latestBatchId.contains(19))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
-
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
-  }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index 1ff9dec..4bf49ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, Dataset}
 import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
 import org.apache.spark.sql.functions.{count, window}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 33b4c08..4b0bab1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap}
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._
 import org.apache.spark.sql.streaming.util.StreamManualClock
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 659deb8..f229b08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils
 import org.apache.hadoop.conf.Configuration
 import org.scalatest.time.SpanSugar._
 
-import org.apache.spark.{SparkConf, SparkContext, TaskContext}
+import org.apache.spark.{SparkConf, SparkContext, TaskContext, TestUtils}
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.plans.logical.Range
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
+import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink}
 import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -876,8 +876,8 @@ class StreamSuite extends StreamTest {
         query.awaitTermination()
       }
 
-      assert(e.getMessage.contains(providerClassName))
-      assert(e.getMessage.contains("instantiated"))
+      TestUtils.assertExceptionMsg(e, providerClassName)
+      TestUtils.assertExceptionMsg(e, "instantiated")
     }
   }
 
@@ -1083,15 +1083,15 @@ class StreamSuite extends StreamTest {
 
   test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " +
     " to Dataset - use v2 sink") {
-    testCurrentTimestampOnStreamingQuery(useV2Sink = true)
+    testCurrentTimestampOnStreamingQuery()
   }
 
   test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " +
     " to Dataset - use v1 sink") {
-    testCurrentTimestampOnStreamingQuery(useV2Sink = false)
+    testCurrentTimestampOnStreamingQuery()
   }
 
-  private def testCurrentTimestampOnStreamingQuery(useV2Sink: Boolean): Unit = {
+  private def testCurrentTimestampOnStreamingQuery(): Unit = {
     val input = MemoryStream[Int]
     val df = input.toDS().withColumn("cur_timestamp", lit(current_timestamp()))
 
@@ -1109,7 +1109,7 @@ class StreamSuite extends StreamTest {
 
     var lastTimestamp = System.currentTimeMillis()
     val currentDate = DateTimeUtils.millisToDays(lastTimestamp)
-    testStream(df, useV2Sink = useV2Sink) (
+    testStream(df) (
       AddData(input, 1),
       CheckLastBatch { rows: Seq[Row] =>
         lastTimestamp = assertBatchOutputAndUpdateLastTimestamp(rows, lastTimestamp, currentDate, 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 6ff3c94..900098a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.streaming
 
-import java.lang.Thread.UncaughtExceptionHandler
-
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.language.experimental.macros
@@ -42,7 +40,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
-import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -86,7 +84,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
   }
 
   protected val defaultTrigger = Trigger.ProcessingTime(0)
-  protected val defaultUseV2Sink = false
 
   /** How long to wait for an active stream to catch up when checking a result. */
   val streamingTimeout = 60.seconds
@@ -327,8 +324,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
    */
   def testStream(
       _stream: Dataset[_],
-      outputMode: OutputMode = OutputMode.Append,
-      useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
+      outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized {
     import org.apache.spark.sql.streaming.util.StreamManualClock
 
     // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
@@ -341,7 +337,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     var currentStream: StreamExecution = null
     var lastStream: StreamExecution = null
     val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for
-    val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode)
+    val sink = new MemorySink
     val resetConfValues = mutable.Map[String, Option[String]]()
     val defaultCheckpointLocation =
       Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -391,10 +387,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
       }
 
     def testState = {
-      val sinkDebugString = sink match {
-        case s: MemorySink => s.toDebugString
-        case s: MemorySinkV2 => s.toDebugString
-      }
+      val sinkDebugString = sink.toDebugString
+
       s"""
          |== Progress ==
          |$testActions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 81b22be..134e61e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.exchange.Exchange
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
 import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager
 import org.apache.spark.sql.expressions.scalalang.typed
 import org.apache.spark.sql.functions._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index 6b711f9..422223b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -179,7 +179,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
     val listeners = (1 to 5).map(_ => new EventCollector)
     try {
       listeners.foreach(listener => spark.streams.addListener(listener))
-      testStream(df, OutputMode.Append, useV2Sink = true)(
+      testStream(df, OutputMode.Append)(
         StartStream(Trigger.Continuous(1000)),
         StopStream,
         AssertOnQuery { query =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 97a6ba8..13976ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch
 
 import scala.collection.mutable
 
-import org.apache.commons.io.{FileUtils, IOUtils}
+import org.apache.commons.io.FileUtils
 import org.apache.commons.lang3.RandomStringUtils
 import org.apache.hadoop.fs.Path
 import org.scalactic.TolerantNumerics
@@ -30,13 +30,13 @@ import org.scalatest.BeforeAndAfter
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.mockito.MockitoSugar
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, TestUtils}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
 import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid}
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
+import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.v2.reader.InputPartition
@@ -498,7 +498,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
   test("input row calculation with same V2 source used twice in self-union") {
     val streamInput = MemoryStream[Int]
 
-    testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)(
+    testStream(streamInput.toDF().union(streamInput.toDF()))(
       AddData(streamInput, 1, 2, 3),
       CheckAnswer(1, 1, 2, 2, 3, 3),
       AssertOnQuery { q =>
@@ -519,7 +519,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
       // relation, which breaks exchange reuse, as the optimizer will remove Project from one side.
       // Here we manually add a useful Project, to trigger exchange reuse.
       val streamDF = memoryStream.toDF().select('value + 0 as "v")
-      testStream(streamDF.join(streamDF, "v"), useV2Sink = true)(
+      testStream(streamDF.join(streamDF, "v"))(
         AddData(memoryStream, 1, 2, 3),
         CheckAnswer(1, 2, 3),
         check
@@ -556,7 +556,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
     val streamInput1 = MemoryStream[Int]
     val streamInput2 = MemoryStream[Int]
 
-    testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)(
+    testStream(streamInput1.toDF().union(streamInput2.toDF()))(
       AddData(streamInput1, 1, 2, 3),
       CheckLastBatch(1, 2, 3),
       AssertOnQuery { q =>
@@ -587,7 +587,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
     val streamInput = MemoryStream[Int]
     val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue")
 
-    testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)(
+    testStream(streamInput.toDF().join(staticInputDF, "value"))(
       AddData(streamInput, 1, 2, 3),
       AssertOnQuery { q =>
         q.processAllAvailable()
@@ -609,7 +609,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
     val streamInput2 = MemoryStream[Int]
     val staticInputDF2 = staticInputDF.union(staticInputDF).cache()
 
-    testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)(
+    testStream(streamInput2.toDF().join(staticInputDF2, "value"))(
       AddData(streamInput2, 1, 2, 3),
       AssertOnQuery { q =>
         q.processAllAvailable()
@@ -717,8 +717,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
         q3.processAllAvailable()
       }
       assert(e.getCause.isInstanceOf[SparkException])
-      assert(e.getCause.getCause.isInstanceOf[IllegalStateException])
-      assert(e.getMessage.contains("StreamingQuery cannot be used in executors"))
+      assert(e.getCause.getCause.getCause.isInstanceOf[IllegalStateException])
+      TestUtils.assertExceptionMsg(e, "StreamingQuery cannot be used in executors")
     } finally {
       q1.stop()
       q2.stop()
@@ -912,7 +912,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
       AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation"))
     )
 
-    testStream(df, useV2Sink = true)(
+    testStream(df)(
       StartStream(trigger = Trigger.Continuous(100)),
       AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation"))
     )
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
index 10bea7f..59d6ac0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
@@ -34,7 +34,7 @@ class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase {
     }
 
     val trigger = Trigger.Continuous(100)
-    testStream(input.toDF(), useV2Sink = true)(
+    testStream(input.toDF())(
       StartStream(trigger),
       Execute(assertStatus),
       AddData(input, 0, 1, 2),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
index d2e489a..9840c7f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -57,7 +57,6 @@ class ContinuousSuiteBase extends StreamTest {
   protected val longContinuousTrigger = Trigger.Continuous("1 hour")
 
   override protected val defaultTrigger = Trigger.Continuous(100)
-  override protected val defaultUseV2Sink = true
 }
 
 class ContinuousSuite extends ContinuousSuiteBase {
@@ -239,7 +238,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
       .load()
       .select('value)
 
-    testStream(df, useV2Sink = true)(
+    testStream(df)(
       StartStream(longContinuousTrigger),
       AwaitEpoch(0),
       Execute(waitForRateSourceTriggers(_, 10)),
@@ -257,7 +256,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
       .load()
       .select('value)
 
-    testStream(df, useV2Sink = true)(
+    testStream(df)(
       StartStream(Trigger.Continuous(2012)),
       AwaitEpoch(0),
       Execute(waitForRateSourceTriggers(_, 10)),
@@ -274,7 +273,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
       .load()
       .select('value)
 
-    testStream(df, useV2Sink = true)(
+    testStream(df)(
       StartStream(Trigger.Continuous(1012)),
       AwaitEpoch(2),
       StopStream,
@@ -365,7 +364,7 @@ class ContinuousEpochBacklogSuite extends ContinuousSuiteBase {
         .load()
         .select('value)
 
-      testStream(df, useV2Sink = true)(
+      testStream(df)(
         StartStream(Trigger.Continuous(1)),
         ExpectFailure[IllegalStateException] { e =>
           e.getMessage.contains("queue has exceeded its maximum")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org