You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/01/18 06:36:35 UTC
spark git commit: [SPARK-23052][SS] Migrate ConsoleSink to data
source V2 api.
Repository: spark
Updated Branches:
refs/heads/master 39d244d92 -> 1c76a91e5
[SPARK-23052][SS] Migrate ConsoleSink to data source V2 api.
## What changes were proposed in this pull request?
Migrate ConsoleSink to data source V2 api.
Note that this includes a missing piece in DataStreamWriter required to specify a data source V2 writer.
Note also that I've removed the "Rerun batch" part of the sink, because as far as I can tell this would never have actually happened. A MicroBatchExecution object will only commit each batch once for its lifetime, and a new MicroBatchExecution object would have a new ConsoleSink object which doesn't know it's retrying a batch. So I think this represents an anti-feature rather than a weakness in the V2 API.
## How was this patch tested?
new unit test
Author: Jose Torres <jo...@databricks.com>
Closes #20243 from jose-torres/console-sink.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c76a91e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c76a91e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c76a91e
Branch: refs/heads/master
Commit: 1c76a91e5fae11dcb66c453889e587b48039fdc9
Parents: 39d244d
Author: Jose Torres <jo...@databricks.com>
Authored: Wed Jan 17 22:36:29 2018 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Wed Jan 17 22:36:29 2018 -0800
----------------------------------------------------------------------
.../streaming/MicroBatchExecution.scala | 7 +-
.../spark/sql/execution/streaming/console.scala | 62 ++---
.../continuous/ContinuousExecution.scala | 9 +-
.../streaming/sources/ConsoleWriter.scala | 64 +++++
.../sources/PackedRowWriterFactory.scala | 60 +++++
.../spark/sql/streaming/DataStreamWriter.scala | 16 +-
....apache.spark.sql.sources.DataSourceRegister | 8 +
.../streaming/sources/ConsoleWriterSuite.scala | 135 ++++++++++
.../sources/StreamingDataSourceV2Suite.scala | 249 +++++++++++++++++++
.../test/DataStreamReaderWriterSuite.scala | 25 --
10 files changed, 551 insertions(+), 84 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 70407f0..7c38045 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -91,11 +91,14 @@ class MicroBatchExecution(
nextSourceId += 1
StreamingExecutionRelation(reader, output)(sparkSession)
})
- case s @ StreamingRelationV2(_, _, _, output, v1Relation) =>
+ case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) =>
v2ToExecutionRelationMap.getOrElseUpdate(s, {
// Materialize source to avoid creating it in every batch
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- assert(v1Relation.isDefined, "v2 execution didn't match but v1 was unavailable")
+ if (v1Relation.isEmpty) {
+ throw new UnsupportedOperationException(
+ s"Data source $sourceName does not support microbatch processing.")
+ }
val source = v1Relation.get.dataSource.createSource(metadataPath)
nextSourceId += 1
StreamingExecutionRelation(source, output)(sparkSession)
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 71eaabe..9482037 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -17,58 +17,36 @@
package org.apache.spark.sql.execution.streaming
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
-import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, StreamSinkProvider}
-import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.StructType
-
-class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
- // Number of rows to display, by default 20 rows
- private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20)
-
- // Truncate the displayed data if it is too long, by default it is true
- private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true)
+import java.util.Optional
- // Track the batch id
- private var lastBatchId = -1L
-
- override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
- val batchIdStr = if (batchId <= lastBatchId) {
- s"Rerun batch: $batchId"
- } else {
- lastBatchId = batchId
- s"Batch: $batchId"
- }
-
- // scalastyle:off println
- println("-------------------------------------------")
- println(batchIdStr)
- println("-------------------------------------------")
- // scalastyle:off println
- data.sparkSession.createDataFrame(
- data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
- .show(numRowsToShow, isTruncated)
- }
+import scala.collection.JavaConverters._
- override def toString(): String = s"ConsoleSink[numRows=$numRowsToShow, truncate=$isTruncated]"
-}
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.streaming.MicroBatchWriteSupport
+import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
extends BaseRelation {
override def schema: StructType = data.schema
}
-class ConsoleSinkProvider extends StreamSinkProvider
+class ConsoleSinkProvider extends DataSourceV2
+ with MicroBatchWriteSupport
with DataSourceRegister
with CreatableRelationProvider {
- def createSink(
- sqlContext: SQLContext,
- parameters: Map[String, String],
- partitionColumns: Seq[String],
- outputMode: OutputMode): Sink = {
- new ConsoleSink(parameters)
+
+ override def createMicroBatchWriter(
+ queryId: String,
+ epochId: Long,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceV2Options): Optional[DataSourceV2Writer] = {
+ Optional.of(new ConsoleWriter(epochId, schema, options))
}
def createRelation(
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index c050722..462e7d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -54,16 +54,13 @@ class ContinuousExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
- @volatile protected var continuousSources: Seq[ContinuousReader] = _
+ @volatile protected var continuousSources: Seq[ContinuousReader] = Seq()
override protected def sources: Seq[BaseStreamingSource] = continuousSources
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _
- override lazy val logicalPlan: LogicalPlan = {
- assert(queryExecutionThread eq Thread.currentThread,
- "logicalPlan must be initialized in StreamExecutionThread " +
- s"but the current thread was ${Thread.currentThread}")
+ override val logicalPlan: LogicalPlan = {
val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]()
analyzedPlan.transform {
case r @ StreamingRelationV2(
@@ -72,7 +69,7 @@ class ContinuousExecution(
ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession)
})
case StreamingRelationV2(_, sourceName, _, _, _) =>
- throw new AnalysisException(
+ throw new UnsupportedOperationException(
s"Data source $sourceName does not support continuous processing.")
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
new file mode 100644
index 0000000..3619799
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.sources.v2.DataSourceV2Options
+import org.apache.spark.sql.sources.v2.writer.{DataSourceV2Writer, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[DataSourceV2Writer]] that collects results to the driver and prints them in the console.
+ * Generated by [[org.apache.spark.sql.execution.streaming.ConsoleSinkProvider]].
+ *
+ * This sink should not be used for production, as it requires sending all rows to the driver
+ * and does not support recovery.
+ */
+class ConsoleWriter(batchId: Long, schema: StructType, options: DataSourceV2Options)
+ extends DataSourceV2Writer with Logging {
+ // Number of rows to display, by default 20 rows
+ private val numRowsToShow = options.getInt("numRows", 20)
+
+ // Truncate the displayed data if it is too long, by default it is true
+ private val isTruncated = options.getBoolean("truncate", true)
+
+ assert(SparkSession.getActiveSession.isDefined)
+ private val spark = SparkSession.getActiveSession.get
+
+ override def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit = synchronized {
+ val batch = messages.collect {
+ case PackedRowCommitMessage(rows) => rows
+ }.flatten
+
+ // scalastyle:off println
+ println("-------------------------------------------")
+ println(s"Batch: $batchId")
+ println("-------------------------------------------")
+ // scalastyle:off println
+ spark.createDataFrame(
+ spark.sparkContext.parallelize(batch), schema)
+ .show(numRowsToShow, isTruncated)
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {}
+
+ override def toString(): String = s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
new file mode 100644
index 0000000..9282ba0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+
+/**
+ * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery
+ * to a [[org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer]] on the driver.
+ *
+ * Note that, because it sends all rows to the driver, this factory will generally be unsuitable
+ * for production-quality sinks. It's intended for use in tests.
+ */
+case object PackedRowWriterFactory extends DataWriterFactory[Row] {
+ def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = {
+ new PackedRowDataWriter()
+ }
+}
+
+/**
+ * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most
+ * recent interval.
+ */
+case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage
+
+/**
+ * A simple [[DataWriter]] that just sends all the rows it's received as a commit message.
+ */
+class PackedRowDataWriter() extends DataWriter[Row] with Logging {
+ private val data = mutable.Buffer[Row]()
+
+ override def write(row: Row): Unit = data.append(row)
+
+ override def commit(): PackedRowCommitMessage = {
+ val msg = PackedRowCommitMessage(data.toArray)
+ data.clear()
+ msg
+ }
+
+ override def abort(): Unit = data.clear()
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index b5b4a05..d24f0dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
-import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousWriteSupport, MicroBatchWriteSupport}
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -280,14 +280,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
useTempCheckpointLocation = true,
trigger = trigger)
} else {
- val sink = trigger match {
- case _: ContinuousTrigger =>
- val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
- ds.newInstance() match {
- case w: ContinuousWriteSupport => w
- case _ => throw new AnalysisException(
- s"Data source $source does not support continuous writing")
- }
+ val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
+ val sink = (ds.newInstance(), trigger) match {
+ case (w: ContinuousWriteSupport, _: ContinuousTrigger) => w
+ case (_, _: ContinuousTrigger) => throw new UnsupportedOperationException(
+ s"Data source $source does not support continuous writing")
+ case (w: MicroBatchWriteSupport, _) => w
case _ =>
val ds = DataSource(
df.sparkSession,
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index c6973bf..a0b25b4 100644
--- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -5,3 +5,11 @@ org.apache.spark.sql.sources.FakeSourceFour
org.apache.fakesource.FakeExternalSourceOne
org.apache.fakesource.FakeExternalSourceTwo
org.apache.fakesource.FakeExternalSourceThree
+org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly
+org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly
+org.apache.spark.sql.streaming.sources.FakeReadBothModes
+org.apache.spark.sql.streaming.sources.FakeReadNeitherMode
+org.apache.spark.sql.streaming.sources.FakeWriteMicroBatchOnly
+org.apache.spark.sql.streaming.sources.FakeWriteContinuousOnly
+org.apache.spark.sql.streaming.sources.FakeWriteBothModes
+org.apache.spark.sql.streaming.sources.FakeWriteNeitherMode
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala
new file mode 100644
index 0000000..60ffee9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io.ByteArrayOutputStream
+
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+
+class ConsoleWriterSuite extends StreamTest {
+ import testImplicits._
+
+ test("console") {
+ val input = MemoryStream[Int]
+
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val query = input.toDF().writeStream.format("console").start()
+ try {
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ input.addData(4, 5, 6)
+ query.processAllAvailable()
+ input.addData()
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+
+ assert(captured.toString() ==
+ """-------------------------------------------
+ |Batch: 0
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ || 1|
+ || 2|
+ || 3|
+ |+-----+
+ |
+ |-------------------------------------------
+ |Batch: 1
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ || 4|
+ || 5|
+ || 6|
+ |+-----+
+ |
+ |-------------------------------------------
+ |Batch: 2
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ |+-----+
+ |
+ |""".stripMargin)
+ }
+
+ test("console with numRows") {
+ val input = MemoryStream[Int]
+
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start()
+ try {
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+
+ assert(captured.toString() ==
+ """-------------------------------------------
+ |Batch: 0
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ || 1|
+ || 2|
+ |+-----+
+ |only showing top 2 rows
+ |
+ |""".stripMargin)
+ }
+
+ test("console with truncation") {
+ val input = MemoryStream[String]
+
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start()
+ try {
+ input.addData("123456789012345678901234567890")
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+
+ assert(captured.toString() ==
+ """-------------------------------------------
+ |Batch: 0
+ |-------------------------------------------
+ |+--------------------+
+ || value|
+ |+--------------------+
+ ||12345678901234567...|
+ |+--------------------+
+ |
+ |""".stripMargin)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
new file mode 100644
index 0000000..f152174
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
@@ -0,0 +1,249 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.sources
+
+import java.util.Optional
+
+import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.streaming.{LongOffset, RateStreamOffset}
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.reader.ReadTask
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport, MicroBatchReadSupport, MicroBatchWriteSupport}
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousReader, MicroBatchReader, Offset, PartitionOffset}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
+import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer
+import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+case class FakeReader() extends MicroBatchReader with ContinuousReader {
+ def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {}
+ def getStartOffset: Offset = RateStreamOffset(Map())
+ def getEndOffset: Offset = RateStreamOffset(Map())
+ def deserializeOffset(json: String): Offset = RateStreamOffset(Map())
+ def commit(end: Offset): Unit = {}
+ def readSchema(): StructType = StructType(Seq())
+ def stop(): Unit = {}
+ def mergeOffsets(offsets: Array[PartitionOffset]): Offset = RateStreamOffset(Map())
+ def setOffset(start: Optional[Offset]): Unit = {}
+
+ def createReadTasks(): java.util.ArrayList[ReadTask[Row]] = {
+ throw new IllegalStateException("fake source - cannot actually read")
+ }
+}
+
+trait FakeMicroBatchReadSupport extends MicroBatchReadSupport {
+ override def createMicroBatchReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceV2Options): MicroBatchReader = FakeReader()
+}
+
+trait FakeContinuousReadSupport extends ContinuousReadSupport {
+ override def createContinuousReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceV2Options): ContinuousReader = FakeReader()
+}
+
+trait FakeMicroBatchWriteSupport extends MicroBatchWriteSupport {
+ def createMicroBatchWriter(
+ queryId: String,
+ epochId: Long,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceV2Options): Optional[DataSourceV2Writer] = {
+ throw new IllegalStateException("fake sink - cannot actually write")
+ }
+}
+
+trait FakeContinuousWriteSupport extends ContinuousWriteSupport {
+ def createContinuousWriter(
+ queryId: String,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceV2Options): Optional[ContinuousWriter] = {
+ throw new IllegalStateException("fake sink - cannot actually write")
+ }
+}
+
+class FakeReadMicroBatchOnly extends DataSourceRegister with FakeMicroBatchReadSupport {
+ override def shortName(): String = "fake-read-microbatch-only"
+}
+
+class FakeReadContinuousOnly extends DataSourceRegister with FakeContinuousReadSupport {
+ override def shortName(): String = "fake-read-continuous-only"
+}
+
+class FakeReadBothModes extends DataSourceRegister
+ with FakeMicroBatchReadSupport with FakeContinuousReadSupport {
+ override def shortName(): String = "fake-read-microbatch-continuous"
+}
+
+class FakeReadNeitherMode extends DataSourceRegister {
+ override def shortName(): String = "fake-read-neither-mode"
+}
+
+class FakeWriteMicroBatchOnly extends DataSourceRegister with FakeMicroBatchWriteSupport {
+ override def shortName(): String = "fake-write-microbatch-only"
+}
+
+class FakeWriteContinuousOnly extends DataSourceRegister with FakeContinuousWriteSupport {
+ override def shortName(): String = "fake-write-continuous-only"
+}
+
+class FakeWriteBothModes extends DataSourceRegister
+ with FakeMicroBatchWriteSupport with FakeContinuousWriteSupport {
+ override def shortName(): String = "fake-write-microbatch-continuous"
+}
+
+class FakeWriteNeitherMode extends DataSourceRegister {
+ override def shortName(): String = "fake-write-neither-mode"
+}
+
+class StreamingDataSourceV2Suite extends StreamTest {
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ val fakeCheckpoint = Utils.createTempDir()
+ spark.conf.set("spark.sql.streaming.checkpointLocation", fakeCheckpoint.getCanonicalPath)
+ }
+
+ val readFormats = Seq(
+ "fake-read-microbatch-only",
+ "fake-read-continuous-only",
+ "fake-read-microbatch-continuous",
+ "fake-read-neither-mode")
+ val writeFormats = Seq(
+ "fake-write-microbatch-only",
+ "fake-write-continuous-only",
+ "fake-write-microbatch-continuous",
+ "fake-write-neither-mode")
+ val triggers = Seq(
+ Trigger.Once(),
+ Trigger.ProcessingTime(1000),
+ Trigger.Continuous(1000))
+
+ private def testPositiveCase(readFormat: String, writeFormat: String, trigger: Trigger) = {
+ val query = spark.readStream
+ .format(readFormat)
+ .load()
+ .writeStream
+ .format(writeFormat)
+ .trigger(trigger)
+ .start()
+ query.stop()
+ }
+
+ private def testNegativeCase(
+ readFormat: String,
+ writeFormat: String,
+ trigger: Trigger,
+ errorMsg: String) = {
+ val ex = intercept[UnsupportedOperationException] {
+ testPositiveCase(readFormat, writeFormat, trigger)
+ }
+ assert(ex.getMessage.contains(errorMsg))
+ }
+
+ private def testPostCreationNegativeCase(
+ readFormat: String,
+ writeFormat: String,
+ trigger: Trigger,
+ errorMsg: String) = {
+ val query = spark.readStream
+ .format(readFormat)
+ .load()
+ .writeStream
+ .format(writeFormat)
+ .trigger(trigger)
+ .start()
+
+ eventually(timeout(streamingTimeout)) {
+ assert(query.exception.isDefined)
+ assert(query.exception.get.cause != null)
+ assert(query.exception.get.cause.getMessage.contains(errorMsg))
+ }
+ }
+
+ // Get a list of (read, write, trigger) tuples for test cases.
+ val cases = readFormats.flatMap { read =>
+ writeFormats.flatMap { write =>
+ triggers.map(t => (write, t))
+ }.map {
+ case (write, t) => (read, write, t)
+ }
+ }
+
+ for ((read, write, trigger) <- cases) {
+ testQuietly(s"stream with read format $read, write format $write, trigger $trigger") {
+ val readSource = DataSource.lookupDataSource(read, spark.sqlContext.conf).newInstance()
+ val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).newInstance()
+ (readSource, writeSource, trigger) match {
+ // Valid microbatch queries.
+ case (_: MicroBatchReadSupport, _: MicroBatchWriteSupport, t)
+ if !t.isInstanceOf[ContinuousTrigger] =>
+ testPositiveCase(read, write, trigger)
+
+ // Valid continuous queries.
+ case (_: ContinuousReadSupport, _: ContinuousWriteSupport, _: ContinuousTrigger) =>
+ testPositiveCase(read, write, trigger)
+
+ // Invalid - can't read at all
+ case (r, _, _)
+ if !r.isInstanceOf[MicroBatchReadSupport]
+ && !r.isInstanceOf[ContinuousReadSupport] =>
+ testNegativeCase(read, write, trigger,
+ s"Data source $read does not support streamed reading")
+
+ // Invalid - trigger is continuous but writer is not
+ case (_, w, _: ContinuousTrigger) if !w.isInstanceOf[ContinuousWriteSupport] =>
+ testNegativeCase(read, write, trigger,
+ s"Data source $write does not support continuous writing")
+
+ // Invalid - can't write at all
+ case (_, w, _)
+ if !w.isInstanceOf[MicroBatchWriteSupport]
+ && !w.isInstanceOf[ContinuousWriteSupport] =>
+ testNegativeCase(read, write, trigger,
+ s"Data source $write does not support streamed writing")
+
+ // Invalid - trigger and writer are continuous but reader is not
+ case (r, _: ContinuousWriteSupport, _: ContinuousTrigger)
+ if !r.isInstanceOf[ContinuousReadSupport] =>
+ testNegativeCase(read, write, trigger,
+ s"Data source $read does not support continuous processing")
+
+ // Invalid - trigger is microbatch but writer is not
+ case (_, w, t)
+ if !w.isInstanceOf[MicroBatchWriteSupport] && !t.isInstanceOf[ContinuousTrigger] =>
+ testNegativeCase(read, write, trigger,
+ s"Data source $write does not support streamed writing")
+
+ // Invalid - trigger and writer are microbatch but reader is not
+ case (r, _, t)
+ if !r.isInstanceOf[MicroBatchReadSupport] && !t.isInstanceOf[ContinuousTrigger] =>
+ testPostCreationNegativeCase(read, write, trigger,
+ s"Data source $read does not support microbatch processing")
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/1c76a91e/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index aa163d2..8212fb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -422,21 +422,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
}
}
- test("ConsoleSink can be correctly loaded") {
- LastOptions.clear()
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
-
- val sq = df.writeStream
- .format("console")
- .option("checkpointLocation", newMetadataDir)
- .trigger(ProcessingTime(2.seconds))
- .start()
-
- sq.awaitTermination(2000L)
- }
-
test("prevent all column partitioning") {
withTempDir { dir =>
val path = dir.getCanonicalPath
@@ -450,16 +435,6 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
}
}
- test("ConsoleSink should not require checkpointLocation") {
- LastOptions.clear()
- val df = spark.readStream
- .format("org.apache.spark.sql.streaming.test")
- .load()
-
- val sq = df.writeStream.format("console").start()
- sq.stop()
- }
-
private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = {
import testImplicits._
val ms = new MemoryStream[Int](0, sqlContext)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org