You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2018/09/12 18:25:30 UTC
[2/7] spark git commit: [SPARK-24882][SQL] Revert [] improve data
source v2 API from branch 2.4
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala
deleted file mode 100644
index 5884380..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala
+++ /dev/null
@@ -1,151 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.sources
-
-import java.io.ByteArrayOutputStream
-
-import org.apache.spark.sql.execution.streaming.MemoryStream
-import org.apache.spark.sql.streaming.{StreamTest, Trigger}
-
-class ConsoleWriteSupportSuite extends StreamTest {
- import testImplicits._
-
- test("microbatch - default") {
- val input = MemoryStream[Int]
-
- val captured = new ByteArrayOutputStream()
- Console.withOut(captured) {
- val query = input.toDF().writeStream.format("console").start()
- try {
- input.addData(1, 2, 3)
- query.processAllAvailable()
- input.addData(4, 5, 6)
- query.processAllAvailable()
- input.addData()
- query.processAllAvailable()
- } finally {
- query.stop()
- }
- }
-
- assert(captured.toString() ==
- """-------------------------------------------
- |Batch: 0
- |-------------------------------------------
- |+-----+
- ||value|
- |+-----+
- || 1|
- || 2|
- || 3|
- |+-----+
- |
- |-------------------------------------------
- |Batch: 1
- |-------------------------------------------
- |+-----+
- ||value|
- |+-----+
- || 4|
- || 5|
- || 6|
- |+-----+
- |
- |-------------------------------------------
- |Batch: 2
- |-------------------------------------------
- |+-----+
- ||value|
- |+-----+
- |+-----+
- |
- |""".stripMargin)
- }
-
- test("microbatch - with numRows") {
- val input = MemoryStream[Int]
-
- val captured = new ByteArrayOutputStream()
- Console.withOut(captured) {
- val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start()
- try {
- input.addData(1, 2, 3)
- query.processAllAvailable()
- } finally {
- query.stop()
- }
- }
-
- assert(captured.toString() ==
- """-------------------------------------------
- |Batch: 0
- |-------------------------------------------
- |+-----+
- ||value|
- |+-----+
- || 1|
- || 2|
- |+-----+
- |only showing top 2 rows
- |
- |""".stripMargin)
- }
-
- test("microbatch - truncation") {
- val input = MemoryStream[String]
-
- val captured = new ByteArrayOutputStream()
- Console.withOut(captured) {
- val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start()
- try {
- input.addData("123456789012345678901234567890")
- query.processAllAvailable()
- } finally {
- query.stop()
- }
- }
-
- assert(captured.toString() ==
- """-------------------------------------------
- |Batch: 0
- |-------------------------------------------
- |+--------------------+
- || value|
- |+--------------------+
- ||12345678901234567...|
- |+--------------------+
- |
- |""".stripMargin)
- }
-
- test("continuous - default") {
- val captured = new ByteArrayOutputStream()
- Console.withOut(captured) {
- val input = spark.readStream
- .format("rate")
- .option("numPartitions", "1")
- .option("rowsPerSecond", "5")
- .load()
- .select('value)
-
- val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start()
- assert(query.isActive)
- query.stop()
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala
new file mode 100644
index 0000000..55acf2b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io.ByteArrayOutputStream
+
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.{StreamTest, Trigger}
+
+class ConsoleWriterSuite extends StreamTest {
+ import testImplicits._
+
+ test("microbatch - default") {
+ val input = MemoryStream[Int]
+
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val query = input.toDF().writeStream.format("console").start()
+ try {
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ input.addData(4, 5, 6)
+ query.processAllAvailable()
+ input.addData()
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+
+ assert(captured.toString() ==
+ """-------------------------------------------
+ |Batch: 0
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ || 1|
+ || 2|
+ || 3|
+ |+-----+
+ |
+ |-------------------------------------------
+ |Batch: 1
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ || 4|
+ || 5|
+ || 6|
+ |+-----+
+ |
+ |-------------------------------------------
+ |Batch: 2
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ |+-----+
+ |
+ |""".stripMargin)
+ }
+
+ test("microbatch - with numRows") {
+ val input = MemoryStream[Int]
+
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start()
+ try {
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+
+ assert(captured.toString() ==
+ """-------------------------------------------
+ |Batch: 0
+ |-------------------------------------------
+ |+-----+
+ ||value|
+ |+-----+
+ || 1|
+ || 2|
+ |+-----+
+ |only showing top 2 rows
+ |
+ |""".stripMargin)
+ }
+
+ test("microbatch - truncation") {
+ val input = MemoryStream[String]
+
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start()
+ try {
+ input.addData("123456789012345678901234567890")
+ query.processAllAvailable()
+ } finally {
+ query.stop()
+ }
+ }
+
+ assert(captured.toString() ==
+ """-------------------------------------------
+ |Batch: 0
+ |-------------------------------------------
+ |+--------------------+
+ || value|
+ |+--------------------+
+ ||12345678901234567...|
+ |+--------------------+
+ |
+ |""".stripMargin)
+ }
+
+ test("continuous - default") {
+ val captured = new ByteArrayOutputStream()
+ Console.withOut(captured) {
+ val input = spark.readStream
+ .format("rate")
+ .option("numPartitions", "1")
+ .option("rowsPerSecond", "5")
+ .load()
+ .select('value)
+
+ val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start()
+ assert(query.isActive)
+ query.stop()
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
index dd74af8..5ca13b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -17,18 +17,20 @@
package org.apache.spark.sql.execution.streaming.sources
+import java.util.Optional
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider}
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.Offset
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.util.ManualClock
@@ -41,7 +43,7 @@ class RateSourceSuite extends StreamTest {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
assert(query.nonEmpty)
val rateSource = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source
+ case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
}.head
rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
@@ -54,10 +56,10 @@ class RateSourceSuite extends StreamTest {
test("microbatch in registry") {
withTempDir { temp =>
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupportProvider =>
- val readSupport = ds.createMicroBatchReadSupport(
- temp.getCanonicalPath, DataSourceOptions.empty())
- assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport])
+ case ds: MicroBatchReadSupport =>
+ val reader = ds.createMicroBatchReader(
+ Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty())
+ assert(reader.isInstanceOf[RateStreamMicroBatchReader])
case _ =>
throw new IllegalStateException("Could not find read support for rate")
}
@@ -67,7 +69,7 @@ class RateSourceSuite extends StreamTest {
test("compatible with old path in registry") {
DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider",
spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupportProvider =>
+ case ds: MicroBatchReadSupport =>
assert(ds.isInstanceOf[RateStreamProvider])
case _ =>
throw new IllegalStateException("Could not find read support for rate")
@@ -139,19 +141,30 @@ class RateSourceSuite extends StreamTest {
)
}
+ test("microbatch - set offset") {
+ withTempDir { temp =>
+ val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath)
+ val startOffset = LongOffset(0L)
+ val endOffset = LongOffset(1L)
+ reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+ assert(reader.getStartOffset() == startOffset)
+ assert(reader.getEndOffset() == endOffset)
+ }
+ }
+
test("microbatch - infer offsets") {
withTempDir { temp =>
- val readSupport = new RateStreamMicroBatchReadSupport(
+ val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(
Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava),
temp.getCanonicalPath)
- readSupport.clock.asInstanceOf[ManualClock].advance(100000)
- val startOffset = readSupport.initialOffset()
- startOffset match {
+ reader.clock.asInstanceOf[ManualClock].advance(100000)
+ reader.setOffsetRange(Optional.empty(), Optional.empty())
+ reader.getStartOffset() match {
case r: LongOffset => assert(r.offset === 0L)
case _ => throw new IllegalStateException("unexpected offset type")
}
- readSupport.latestOffset() match {
+ reader.getEndOffset() match {
case r: LongOffset => assert(r.offset >= 100)
case _ => throw new IllegalStateException("unexpected offset type")
}
@@ -160,16 +173,15 @@ class RateSourceSuite extends StreamTest {
test("microbatch - predetermined batch size") {
withTempDir { temp =>
- val readSupport = new RateStreamMicroBatchReadSupport(
+ val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava),
temp.getCanonicalPath)
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
- val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build()
- val tasks = readSupport.planInputPartitions(config)
- val readerFactory = readSupport.createReaderFactory(config)
+ reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 1)
- val dataReader = readerFactory.createReader(tasks(0))
+ val dataReader = tasks.get(0).createPartitionReader()
val data = ArrayBuffer[InternalRow]()
while (dataReader.next()) {
data.append(dataReader.get())
@@ -180,25 +192,24 @@ class RateSourceSuite extends StreamTest {
test("microbatch - data read") {
withTempDir { temp =>
- val readSupport = new RateStreamMicroBatchReadSupport(
+ val reader = new RateStreamMicroBatchReader(
new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava),
temp.getCanonicalPath)
val startOffset = LongOffset(0L)
val endOffset = LongOffset(1L)
- val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build()
- val tasks = readSupport.planInputPartitions(config)
- val readerFactory = readSupport.createReaderFactory(config)
+ reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 11)
- val readData = tasks
- .map(readerFactory.createReader)
+ val readData = tasks.asScala
+ .map(_.createPartitionReader())
.flatMap { reader =>
val buf = scala.collection.mutable.ListBuffer[InternalRow]()
while (reader.next()) buf.append(reader.get())
buf
}
- assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray)
+ assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
}
}
@@ -309,44 +320,41 @@ class RateSourceSuite extends StreamTest {
}
test("user-specified schema given") {
- val exception = intercept[UnsupportedOperationException] {
+ val exception = intercept[AnalysisException] {
spark.readStream
.format("rate")
.schema(spark.range(1).schema)
.load()
}
assert(exception.getMessage.contains(
- "rate source does not support user-specified schema"))
+ "rate source does not support a user-specified schema"))
}
test("continuous in registry") {
DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
- case ds: ContinuousReadSupportProvider =>
- val readSupport = ds.createContinuousReadSupport(
- "", DataSourceOptions.empty())
- assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport])
+ case ds: ContinuousReadSupport =>
+ val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
+ assert(reader.isInstanceOf[RateStreamContinuousReader])
case _ =>
throw new IllegalStateException("Could not find read support for continuous rate")
}
}
test("continuous data") {
- val readSupport = new RateStreamContinuousReadSupport(
+ val reader = new RateStreamContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
- val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build()
- val tasks = readSupport.planInputPartitions(config)
- val readerFactory = readSupport.createContinuousReaderFactory(config)
+ reader.setStartOffset(Optional.empty())
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 2)
val data = scala.collection.mutable.ListBuffer[InternalRow]()
- tasks.foreach {
+ tasks.asScala.foreach {
case t: RateStreamContinuousInputPartition =>
- val startTimeMs = readSupport.initialOffset()
+ val startTimeMs = reader.getStartOffset()
.asInstanceOf[RateStreamOffset]
.partitionToValueAndRunTimeMs(t.partitionIndex)
.runTimeMs
- val r = readerFactory.createReader(t)
- .asInstanceOf[RateStreamContinuousPartitionReader]
+ val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader]
for (rowIndex <- 0 to 9) {
r.next()
data.append(r.get())
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
index 409156e..48e5cf7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
@@ -21,6 +21,7 @@ import java.net.{InetSocketAddress, SocketException}
import java.nio.ByteBuffer
import java.nio.channels.ServerSocketChannel
import java.sql.Timestamp
+import java.util.Optional
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.JavaConverters._
@@ -33,8 +34,8 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider}
-import org.apache.spark.sql.sources.v2.reader.streaming.Offset
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -48,9 +49,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread.join()
serverThread = null
}
+ if (batchReader != null) {
+ batchReader.stop()
+ batchReader = null
+ }
}
private var serverThread: ServerThread = null
+ private var batchReader: MicroBatchReader = null
case class AddSocketData(data: String*) extends AddData {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
@@ -59,7 +65,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
"Cannot add data when there is no query for finding the active socket source")
val sources = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source
+ case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source
}
if (sources.isEmpty) {
throw new Exception(
@@ -85,7 +91,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
test("backward compatibility with old path") {
DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider",
spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupportProvider =>
+ case ds: MicroBatchReadSupport =>
assert(ds.isInstanceOf[TextSocketSourceProvider])
case _ =>
throw new IllegalStateException("Could not find socket source")
@@ -175,16 +181,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
test("params not given") {
val provider = new TextSocketSourceProvider
intercept[AnalysisException] {
- provider.createMicroBatchReadSupport(
- "", new DataSourceOptions(Map.empty[String, String].asJava))
+ provider.createMicroBatchReader(Optional.empty(), "",
+ new DataSourceOptions(Map.empty[String, String].asJava))
}
intercept[AnalysisException] {
- provider.createMicroBatchReadSupport(
- "", new DataSourceOptions(Map("host" -> "localhost").asJava))
+ provider.createMicroBatchReader(Optional.empty(), "",
+ new DataSourceOptions(Map("host" -> "localhost").asJava))
}
intercept[AnalysisException] {
- provider.createMicroBatchReadSupport(
- "", new DataSourceOptions(Map("port" -> "1234").asJava))
+ provider.createMicroBatchReader(Optional.empty(), "",
+ new DataSourceOptions(Map("port" -> "1234").asJava))
}
}
@@ -193,7 +199,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle")
intercept[AnalysisException] {
val a = new DataSourceOptions(params.asJava)
- provider.createMicroBatchReadSupport("", a)
+ provider.createMicroBatchReader(Optional.empty(), "", a)
}
}
@@ -203,12 +209,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
StructField("name", StringType) ::
StructField("area", StringType) :: Nil)
val params = Map("host" -> "localhost", "port" -> "1234")
- val exception = intercept[UnsupportedOperationException] {
- provider.createMicroBatchReadSupport(
- userSpecifiedSchema, "", new DataSourceOptions(params.asJava))
+ val exception = intercept[AnalysisException] {
+ provider.createMicroBatchReader(
+ Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava))
}
assert(exception.getMessage.contains(
- "socket source does not support user-specified schema"))
+ "socket source does not support a user-specified schema"))
}
test("input row metrics") {
@@ -299,27 +305,25 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread = new ServerThread()
serverThread.start()
- val readSupport = new TextSocketContinuousReadSupport(
+ val reader = new TextSocketContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
"port" -> serverThread.port.toString).asJava))
-
- val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build()
- val tasks = readSupport.planInputPartitions(scanConfig)
+ reader.setStartOffset(Optional.empty())
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 2)
val numRecords = 10
val data = scala.collection.mutable.ListBuffer[Int]()
val offsets = scala.collection.mutable.ListBuffer[Int]()
- val readerFactory = readSupport.createContinuousReaderFactory(scanConfig)
import org.scalatest.time.SpanSugar._
failAfter(5 seconds) {
// inject rows, read and check the data and offsets
for (i <- 0 until numRecords) {
serverThread.enqueue(i.toString)
}
- tasks.foreach {
+ tasks.asScala.foreach {
case t: TextSocketContinuousInputPartition =>
- val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader]
+ val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader]
for (i <- 0 until numRecords / 2) {
r.next()
offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset)
@@ -335,15 +339,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
data.clear()
case _ => throw new IllegalStateException("Unexpected task type")
}
- assert(readSupport.startOffset.offsets == List(3, 3))
- readSupport.commit(TextSocketOffset(List(5, 5)))
- assert(readSupport.startOffset.offsets == List(5, 5))
+ assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3))
+ reader.commit(TextSocketOffset(List(5, 5)))
+ assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5))
}
def commitOffset(partition: Int, offset: Int): Unit = {
- val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset)
- readSupport.commit(TextSocketOffset(offsetsToCommit))
- assert(readSupport.startOffset.offsets == offsetsToCommit)
+ val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset]
+ .offsets.updated(partition, offset)
+ reader.commit(TextSocketOffset(offsetsToCommit))
+ assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit)
}
}
@@ -351,13 +356,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread = new ServerThread()
serverThread.start()
- val readSupport = new TextSocketContinuousReadSupport(
+ val reader = new TextSocketContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
"port" -> serverThread.port.toString).asJava))
-
- readSupport.startOffset = TextSocketOffset(List(5, 5))
+ reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5))))
+ // ok to commit same offset
+ reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5))))
assertThrows[IllegalStateException] {
- readSupport.commit(TextSocketOffset(List(6, 6)))
+ reader.commit(TextSocketOffset(List(6, 6)))
}
}
@@ -365,12 +371,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
serverThread = new ServerThread()
serverThread.start()
- val readSupport = new TextSocketContinuousReadSupport(
+ val reader = new TextSocketContinuousReader(
new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost",
"includeTimestamp" -> "true",
"port" -> serverThread.port.toString).asJava))
- val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build()
- val tasks = readSupport.planInputPartitions(scanConfig)
+ reader.setStartOffset(Optional.empty())
+ val tasks = reader.planInputPartitions()
assert(tasks.size == 2)
val numRecords = 4
@@ -378,10 +384,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
for (i <- 0 until numRecords) {
serverThread.enqueue(i.toString)
}
- val readerFactory = readSupport.createContinuousReaderFactory(scanConfig)
- tasks.foreach {
+ tasks.asScala.foreach {
case t: TextSocketContinuousInputPartition =>
- val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader]
+ val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader]
for (i <- 0 until numRecords / 2) {
r.next()
assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP)
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index f6c3e0c..12beca2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.sources.v2
+import java.util.{ArrayList, List => JList}
+
import test.org.apache.spark.sql.sources.v2._
import org.apache.spark.SparkException
@@ -36,21 +38,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
class DataSourceV2Suite extends QueryTest with SharedSQLContext {
import testImplicits._
- private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = {
- query.queryExecution.executedPlan.collect {
- case d: DataSourceV2ScanExec =>
- d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder]
- }.head
- }
-
- private def getJavaScanConfig(
- query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = {
- query.queryExecution.executedPlan.collect {
- case d: DataSourceV2ScanExec =>
- d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder]
- }.head
- }
-
test("simplest implementation") {
Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
@@ -63,6 +50,18 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
test("advanced implementation") {
+ def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
+ query.queryExecution.executedPlan.collect {
+ case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
+ }.head
+ }
+
+ def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = {
+ query.queryExecution.executedPlan.collect {
+ case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader]
+ }.head
+ }
+
Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
@@ -71,58 +70,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val q1 = df.select('j)
checkAnswer(q1, (0 until 10).map(i => Row(-i)))
if (cls == classOf[AdvancedDataSourceV2]) {
- val config = getScanConfig(q1)
- assert(config.filters.isEmpty)
- assert(config.requiredSchema.fieldNames === Seq("j"))
+ val reader = getReader(q1)
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
} else {
- val config = getJavaScanConfig(q1)
- assert(config.filters.isEmpty)
- assert(config.requiredSchema.fieldNames === Seq("j"))
+ val reader = getJavaReader(q1)
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
}
val q2 = df.filter('i > 3)
checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
if (cls == classOf[AdvancedDataSourceV2]) {
- val config = getScanConfig(q2)
- assert(config.filters.flatMap(_.references).toSet == Set("i"))
- assert(config.requiredSchema.fieldNames === Seq("i", "j"))
+ val reader = getReader(q2)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
} else {
- val config = getJavaScanConfig(q2)
- assert(config.filters.flatMap(_.references).toSet == Set("i"))
- assert(config.requiredSchema.fieldNames === Seq("i", "j"))
+ val reader = getJavaReader(q2)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i", "j"))
}
val q3 = df.select('i).filter('i > 6)
checkAnswer(q3, (7 until 10).map(i => Row(i)))
if (cls == classOf[AdvancedDataSourceV2]) {
- val config = getScanConfig(q3)
- assert(config.filters.flatMap(_.references).toSet == Set("i"))
- assert(config.requiredSchema.fieldNames === Seq("i"))
+ val reader = getReader(q3)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i"))
} else {
- val config = getJavaScanConfig(q3)
- assert(config.filters.flatMap(_.references).toSet == Set("i"))
- assert(config.requiredSchema.fieldNames === Seq("i"))
+ val reader = getJavaReader(q3)
+ assert(reader.filters.flatMap(_.references).toSet == Set("i"))
+ assert(reader.requiredSchema.fieldNames === Seq("i"))
}
val q4 = df.select('j).filter('j < -10)
checkAnswer(q4, Nil)
if (cls == classOf[AdvancedDataSourceV2]) {
- val config = getScanConfig(q4)
+ val reader = getReader(q4)
// 'j < 10 is not supported by the testing data source.
- assert(config.filters.isEmpty)
- assert(config.requiredSchema.fieldNames === Seq("j"))
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
} else {
- val config = getJavaScanConfig(q4)
+ val reader = getJavaReader(q4)
// 'j < 10 is not supported by the testing data source.
- assert(config.filters.isEmpty)
- assert(config.requiredSchema.fieldNames === Seq("j"))
+ assert(reader.filters.isEmpty)
+ assert(reader.requiredSchema.fieldNames === Seq("j"))
}
}
}
}
test("columnar batch scan implementation") {
- Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls =>
+ Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
checkAnswer(df, (0 until 90).map(i => Row(i, -i)))
@@ -154,25 +153,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val df = spark.read.format(cls.getName).load()
checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2)))
- val groupByColA = df.groupBy('i).agg(sum('j))
+ val groupByColA = df.groupBy('a).agg(sum('b))
checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
assert(groupByColA.queryExecution.executedPlan.collectFirst {
case e: ShuffleExchangeExec => e
}.isEmpty)
- val groupByColAB = df.groupBy('i, 'j).agg(count("*"))
+ val groupByColAB = df.groupBy('a, 'b).agg(count("*"))
checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
assert(groupByColAB.queryExecution.executedPlan.collectFirst {
case e: ShuffleExchangeExec => e
}.isEmpty)
- val groupByColB = df.groupBy('j).agg(sum('i))
+ val groupByColB = df.groupBy('b).agg(sum('a))
checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
assert(groupByColB.queryExecution.executedPlan.collectFirst {
case e: ShuffleExchangeExec => e
}.isDefined)
- val groupByAPlusB = df.groupBy('i + 'j).agg(count("*"))
+ val groupByAPlusB = df.groupBy('a + 'b).agg(count("*"))
checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
assert(groupByAPlusB.queryExecution.executedPlan.collectFirst {
case e: ShuffleExchangeExec => e
@@ -273,30 +272,36 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
test("SPARK-23301: column pruning with arbitrary expressions") {
+ def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = {
+ query.queryExecution.executedPlan.collect {
+ case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader]
+ }.head
+ }
+
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
val q1 = df.select('i + 1)
checkAnswer(q1, (1 until 11).map(i => Row(i)))
- val config1 = getScanConfig(q1)
- assert(config1.requiredSchema.fieldNames === Seq("i"))
+ val reader1 = getReader(q1)
+ assert(reader1.requiredSchema.fieldNames === Seq("i"))
val q2 = df.select(lit(1))
checkAnswer(q2, (0 until 10).map(i => Row(1)))
- val config2 = getScanConfig(q2)
- assert(config2.requiredSchema.isEmpty)
+ val reader2 = getReader(q2)
+ assert(reader2.requiredSchema.isEmpty)
// 'j === 1 can't be pushed down, but we should still be able do column pruning
val q3 = df.filter('j === -1).select('j * 2)
checkAnswer(q3, Row(-2))
- val config3 = getScanConfig(q3)
- assert(config3.filters.isEmpty)
- assert(config3.requiredSchema.fieldNames === Seq("j"))
+ val reader3 = getReader(q3)
+ assert(reader3.filters.isEmpty)
+ assert(reader3.requiredSchema.fieldNames === Seq("j"))
// column pruning should work with other operators.
val q4 = df.sort('i).limit(1).select('i + 1)
checkAnswer(q4, Row(1))
- val config4 = getScanConfig(q4)
- assert(config4.requiredSchema.fieldNames === Seq("i"))
+ val reader4 = getReader(q4)
+ assert(reader4.requiredSchema.fieldNames === Seq("i"))
}
test("SPARK-23315: get output from canonicalized data source v2 related plans") {
@@ -319,291 +324,240 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
+class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {
-case class RangeInputPartition(start: Int, end: Int) extends InputPartition
-
-case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig {
- override def build(): ScanConfig = this
-}
-
-object SimpleReaderFactory extends PartitionReaderFactory {
- override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
- val RangeInputPartition(start, end) = partition
- new PartitionReader[InternalRow] {
- private var current = start - 1
-
- override def next(): Boolean = {
- current += 1
- current < end
- }
-
- override def get(): InternalRow = InternalRow(current, -current)
+ class Reader extends DataSourceReader {
+ override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def close(): Unit = {}
+ override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
+ java.util.Arrays.asList(new SimpleInputPartition(0, 5))
}
}
+
+ override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-abstract class SimpleReadSupport extends BatchReadSupport {
- override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int")
+// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
+// tests still pass.
+class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
- override def newScanConfigBuilder(): ScanConfigBuilder = {
- NoopScanConfigBuilder(fullSchema())
- }
+ class Reader extends DataSourceReader {
+ override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- SimpleReaderFactory
+ override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
+ java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10))
+ }
}
+
+ override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
+class SimpleInputPartition(start: Int, end: Int)
+ extends InputPartition[InternalRow]
+ with InputPartitionReader[InternalRow] {
+ private var current = start - 1
-class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider {
+ override def createPartitionReader(): InputPartitionReader[InternalRow] =
+ new SimpleInputPartition(start, end)
- class ReadSupport extends SimpleReadSupport {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- Array(RangeInputPartition(0, 5))
- }
+ override def next(): Boolean = {
+ current += 1
+ current < end
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def get(): InternalRow = InternalRow(current, -current)
+
+ override def close(): Unit = {}
}
-// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
-// tests still pass.
-class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {
+class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport {
- class ReadSupport extends SimpleReadSupport {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
- }
- }
+ class Reader extends DataSourceReader
+ with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
-}
+ var requiredSchema = new StructType().add("i", "int").add("j", "int")
+ var filters = Array.empty[Filter]
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ this.requiredSchema = requiredSchema
+ }
-class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {
+ override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ val (supported, unsupported) = filters.partition {
+ case GreaterThan("i", _: Int) => true
+ case _ => false
+ }
+ this.filters = supported
+ unsupported
+ }
- class ReadSupport extends SimpleReadSupport {
- override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder()
+ override def pushedFilters(): Array[Filter] = filters
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters
+ override def readSchema(): StructType = {
+ requiredSchema
+ }
+ override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
val lowerBound = filters.collectFirst {
case GreaterThan("i", v: Int) => v
}
- val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition]
+ val res = new ArrayList[InputPartition[InternalRow]]
if (lowerBound.isEmpty) {
- res.append(RangeInputPartition(0, 5))
- res.append(RangeInputPartition(5, 10))
+ res.add(new AdvancedInputPartition(0, 5, requiredSchema))
+ res.add(new AdvancedInputPartition(5, 10, requiredSchema))
} else if (lowerBound.get < 4) {
- res.append(RangeInputPartition(lowerBound.get + 1, 5))
- res.append(RangeInputPartition(5, 10))
+ res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema))
+ res.add(new AdvancedInputPartition(5, 10, requiredSchema))
} else if (lowerBound.get < 9) {
- res.append(RangeInputPartition(lowerBound.get + 1, 10))
+ res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema))
}
- res.toArray
- }
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema
- new AdvancedReaderFactory(requiredSchema)
+ res
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig
- with SupportsPushDownRequiredColumns with SupportsPushDownFilters {
+class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType)
+ extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] {
- var requiredSchema = new StructType().add("i", "int").add("j", "int")
- var filters = Array.empty[Filter]
+ private var current = start - 1
- override def pruneColumns(requiredSchema: StructType): Unit = {
- this.requiredSchema = requiredSchema
+ override def createPartitionReader(): InputPartitionReader[InternalRow] = {
+ new AdvancedInputPartition(start, end, requiredSchema)
}
- override def readSchema(): StructType = requiredSchema
+ override def close(): Unit = {}
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
- val (supported, unsupported) = filters.partition {
- case GreaterThan("i", _: Int) => true
- case _ => false
- }
- this.filters = supported
- unsupported
+ override def next(): Boolean = {
+ current += 1
+ current < end
}
- override def pushedFilters(): Array[Filter] = filters
-
- override def build(): ScanConfig = this
-}
-
-class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory {
- override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
- val RangeInputPartition(start, end) = partition
- new PartitionReader[InternalRow] {
- private var current = start - 1
-
- override def next(): Boolean = {
- current += 1
- current < end
- }
-
- override def get(): InternalRow = {
- val values = requiredSchema.map(_.name).map {
- case "i" => current
- case "j" => -current
- }
- InternalRow.fromSeq(values)
- }
-
- override def close(): Unit = {}
+ override def get(): InternalRow = {
+ val values = requiredSchema.map(_.name).map {
+ case "i" => current
+ case "j" => -current
}
+ InternalRow.fromSeq(values)
}
}
-class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider {
+class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport {
- class ReadSupport(val schema: StructType) extends SimpleReadSupport {
- override def fullSchema(): StructType = schema
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] =
- Array.empty
+ class Reader(val readSchema: StructType) extends DataSourceReader {
+ override def planInputPartitions(): JList[InputPartition[InternalRow]] =
+ java.util.Collections.emptyList()
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
+ override def createReader(options: DataSourceOptions): DataSourceReader = {
throw new IllegalArgumentException("requires a user-supplied schema")
}
- override def createBatchReadSupport(
- schema: StructType, options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport(schema)
+ override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = {
+ new Reader(schema)
}
}
-class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider {
+class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
- class ReadSupport extends SimpleReadSupport {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
- Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90))
- }
+ class Reader extends DataSourceReader with SupportsScanColumnarBatch {
+ override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- ColumnarReaderFactory
+ override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = {
+ java.util.Arrays.asList(
+ new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90))
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-object ColumnarReaderFactory extends PartitionReaderFactory {
- private final val BATCH_SIZE = 20
+class BatchInputPartitionReader(start: Int, end: Int)
+ extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] {
- override def supportColumnarReads(partition: InputPartition): Boolean = true
+ private final val BATCH_SIZE = 20
+ private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
+ private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
+ private lazy val batch = new ColumnarBatch(Array(i, j))
- override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
- throw new UnsupportedOperationException
- }
+ private var current = start
- override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = {
- val RangeInputPartition(start, end) = partition
- new PartitionReader[ColumnarBatch] {
- private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
- private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
- private lazy val batch = new ColumnarBatch(Array(i, j))
-
- private var current = start
-
- override def next(): Boolean = {
- i.reset()
- j.reset()
-
- var count = 0
- while (current < end && count < BATCH_SIZE) {
- i.putInt(count, current)
- j.putInt(count, -current)
- current += 1
- count += 1
- }
+ override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this
- if (count == 0) {
- false
- } else {
- batch.setNumRows(count)
- true
- }
- }
+ override def next(): Boolean = {
+ i.reset()
+ j.reset()
- override def get(): ColumnarBatch = batch
+ var count = 0
+ while (current < end && count < BATCH_SIZE) {
+ i.putInt(count, current)
+ j.putInt(count, -current)
+ current += 1
+ count += 1
+ }
- override def close(): Unit = batch.close()
+ if (count == 0) {
+ false
+ } else {
+ batch.setNumRows(count)
+ true
}
}
+
+ override def get(): ColumnarBatch = {
+ batch
+ }
+
+ override def close(): Unit = batch.close()
}
+class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
-class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider {
+ class Reader extends DataSourceReader with SupportsReportPartitioning {
+ override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int")
- class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning {
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
// Note that we don't have same value of column `a` across partitions.
- Array(
- SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)),
- SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2)))
- }
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- SpecificReaderFactory
+ java.util.Arrays.asList(
+ new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)),
+ new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2)))
}
- override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning
+ override def outputPartitioning(): Partitioning = new MyPartitioning
}
class MyPartitioning extends Partitioning {
override def numPartitions(): Int = 2
override def satisfy(distribution: Distribution): Boolean = distribution match {
- case c: ClusteredDistribution => c.clusteredColumns.contains("i")
+ case c: ClusteredDistribution => c.clusteredColumns.contains("a")
case _ => false
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
- new ReadSupport
- }
+ override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
}
-case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition
+class SpecificInputPartitionReader(i: Array[Int], j: Array[Int])
+ extends InputPartition[InternalRow]
+ with InputPartitionReader[InternalRow] {
+ assert(i.length == j.length)
-object SpecificReaderFactory extends PartitionReaderFactory {
- override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
- val p = partition.asInstanceOf[SpecificInputPartition]
- new PartitionReader[InternalRow] {
- private var current = -1
+ private var current = -1
- override def next(): Boolean = {
- current += 1
- current < p.i.length
- }
+ override def createPartitionReader(): InputPartitionReader[InternalRow] = this
- override def get(): InternalRow = InternalRow(p.i(current), p.j(current))
-
- override def close(): Unit = {}
- }
+ override def next(): Boolean = {
+ current += 1
+ current < i.length
}
+
+ override def get(): InternalRow = InternalRow(i(current), j(current))
+
+ override def close(): Unit = {}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
index 952241b..e1b8e9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
@@ -18,36 +18,34 @@
package org.apache.spark.sql.sources.v2
import java.io.{BufferedReader, InputStreamReader, IOException}
-import java.util.Optional
+import java.util.{Collections, List => JList, Optional}
import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader}
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration
/**
* A HDFS based transactional writable data source.
- * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`.
- * Each job moves files from `target/_temporary/queryId/` to `target`.
+ * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`.
+ * Each job moves files from `target/_temporary/jobId/` to `target`.
*/
-class SimpleWritableDataSource extends DataSourceV2
- with BatchReadSupportProvider with BatchWriteSupportProvider {
+class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport {
private val schema = new StructType().add("i", "long").add("j", "long")
- class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport {
+ class Reader(path: String, conf: Configuration) extends DataSourceReader {
+ override def readSchema(): StructType = schema
- override def fullSchema(): StructType = schema
-
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
val dataPath = new Path(path)
val fs = dataPath.getFileSystem(conf)
if (fs.exists(dataPath)) {
@@ -55,23 +53,21 @@ class SimpleWritableDataSource extends DataSourceV2
val name = status.getPath.getName
name.startsWith("_") || name.startsWith(".")
}.map { f =>
- CSVInputPartitionReader(f.getPath.toUri.toString)
- }.toArray
+ val serializableConf = new SerializableConfiguration(conf)
+ new SimpleCSVInputPartitionReader(
+ f.getPath.toUri.toString,
+ serializableConf): InputPartition[InternalRow]
+ }.toList.asJava
} else {
- Array.empty
+ Collections.emptyList()
}
}
-
- override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
- val serializableConf = new SerializableConfiguration(conf)
- new CSVReaderFactory(serializableConf)
- }
}
- class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport {
- override def createBatchWriterFactory(): DataWriterFactory = {
+ class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter {
+ override def createWriterFactory(): DataWriterFactory[InternalRow] = {
SimpleCounter.resetCounter
- new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf))
+ new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
}
override def onDataWriterCommit(message: WriterCommitMessage): Unit = {
@@ -80,7 +76,7 @@ class SimpleWritableDataSource extends DataSourceV2
override def commit(messages: Array[WriterCommitMessage]): Unit = {
val finalPath = new Path(path)
- val jobPath = new Path(new Path(finalPath, "_temporary"), queryId)
+ val jobPath = new Path(new Path(finalPath, "_temporary"), jobId)
val fs = jobPath.getFileSystem(conf)
try {
for (file <- fs.listStatus(jobPath).map(_.getPath)) {
@@ -95,23 +91,23 @@ class SimpleWritableDataSource extends DataSourceV2
}
override def abort(messages: Array[WriterCommitMessage]): Unit = {
- val jobPath = new Path(new Path(path, "_temporary"), queryId)
+ val jobPath = new Path(new Path(path, "_temporary"), jobId)
val fs = jobPath.getFileSystem(conf)
fs.delete(jobPath, true)
}
}
- override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = {
+ override def createReader(options: DataSourceOptions): DataSourceReader = {
val path = new Path(options.get("path").get())
val conf = SparkContext.getActive.get.hadoopConfiguration
- new ReadSupport(path.toUri.toString, conf)
+ new Reader(path.toUri.toString, conf)
}
- override def createBatchWriteSupport(
- queryId: String,
+ override def createWriter(
+ jobId: String,
schema: StructType,
mode: SaveMode,
- options: DataSourceOptions): Optional[BatchWriteSupport] = {
+ options: DataSourceOptions): Optional[DataSourceWriter] = {
assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable))
assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false))
@@ -134,42 +130,39 @@ class SimpleWritableDataSource extends DataSourceV2
}
val pathStr = path.toUri.toString
- Optional.of(new WritSupport(queryId, pathStr, conf))
+ Optional.of(new Writer(jobId, pathStr, conf))
}
}
-case class CSVInputPartitionReader(path: String) extends InputPartition
+class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration)
+ extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] {
-class CSVReaderFactory(conf: SerializableConfiguration)
- extends PartitionReaderFactory {
+ @transient private var lines: Iterator[String] = _
+ @transient private var currentLine: String = _
+ @transient private var inputStream: FSDataInputStream = _
- override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
- val path = partition.asInstanceOf[CSVInputPartitionReader].path
+ override def createPartitionReader(): InputPartitionReader[InternalRow] = {
val filePath = new Path(path)
val fs = filePath.getFileSystem(conf.value)
+ inputStream = fs.open(filePath)
+ lines = new BufferedReader(new InputStreamReader(inputStream))
+ .lines().iterator().asScala
+ this
+ }
- new PartitionReader[InternalRow] {
- private val inputStream = fs.open(filePath)
- private val lines = new BufferedReader(new InputStreamReader(inputStream))
- .lines().iterator().asScala
-
- private var currentLine: String = _
-
- override def next(): Boolean = {
- if (lines.hasNext) {
- currentLine = lines.next()
- true
- } else {
- false
- }
- }
+ override def next(): Boolean = {
+ if (lines.hasNext) {
+ currentLine = lines.next()
+ true
+ } else {
+ false
+ }
+ }
- override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*)
+ override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*)
- override def close(): Unit = {
- inputStream.close()
- }
- }
+ override def close(): Unit = {
+ inputStream.close()
}
}
@@ -190,11 +183,12 @@ private[v2] object SimpleCounter {
}
class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
- extends DataWriterFactory {
+ extends DataWriterFactory[InternalRow] {
- override def createWriter(
+ override def createDataWriter(
partitionId: Int,
- taskId: Long): DataWriter[InternalRow] = {
+ taskId: Long,
+ epochId: Long): DataWriter[InternalRow] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
val fs = filePath.getFileSystem(conf.value)
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 491dc34..35644c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -686,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
plan
.collect {
case r: StreamingExecutionRelation => r.source
- case r: StreamingDataSourceV2Relation => r.readSupport
+ case r: StreamingDataSourceV2Relation => r.reader
}
.zipWithIndex
.find(_._1 == source)
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index fe77a1b..0f15cd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
try {
val input = new MemoryStream[Int](0, sqlContext) {
@volatile var numTriggers = 0
- override def latestOffset(): OffsetV2 = {
+ override def getEndOffset: OffsetV2 = {
numTriggers += 1
- super.latestOffset()
+ super.getEndOffset
}
}
val clock = new StreamManualClock()
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 1dd8175..0278e2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.streaming
+import java.{util => ju}
+import java.util.Optional
import java.util.concurrent.CountDownLatch
import scala.collection.mutable
@@ -30,12 +32,13 @@ import org.scalatest.mockito.MockitoSugar
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig}
+import org.apache.spark.sql.sources.v2.reader.InputPartition
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.StructType
@@ -212,17 +215,25 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
private def dataAdded: Boolean = currentOffset.offset != -1
- // latestOffset should take 50 ms the first time it is called after data is added
- override def latestOffset(): OffsetV2 = synchronized {
- if (dataAdded) clock.waitTillTime(1050)
- super.latestOffset()
+ // setOffsetRange should take 50 ms the first time it is called after data is added
+ override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
+ synchronized {
+ if (dataAdded) clock.waitTillTime(1050)
+ super.setOffsetRange(start, end)
+ }
+ }
+
+ // getEndOffset should take 100 ms the first time it is called after data is added
+ override def getEndOffset(): OffsetV2 = synchronized {
+ if (dataAdded) clock.waitTillTime(1150)
+ super.getEndOffset()
}
// getBatch should take 100 ms the first time it is called
- override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
+ override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
synchronized {
- clock.waitTillTime(1150)
- super.planInputPartitions(config)
+ clock.waitTillTime(1350)
+ super.planInputPartitions()
}
}
}
@@ -263,26 +274,34 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
AssertOnQuery(_.status.message === "Waiting for next trigger"),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
- // Test status and progress when `latestOffset` is being called
+ // Test status and progress when setOffsetRange is being called
AddData(inputData, 1, 2),
- AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset`
+ AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange
AssertStreamExecThreadIsWaitingForTime(1050),
AssertOnQuery(_.status.isDataAvailable === false),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message.startsWith("Getting offsets from")),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
- AdvanceManualClock(50), // time = 1050 to unblock `latestOffset`
+ AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange
AssertClockTime(1050),
- // will block on `planInputPartitions` that needs 1350
- AssertStreamExecThreadIsWaitingForTime(1150),
+ AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150
+ AssertOnQuery(_.status.isDataAvailable === false),
+ AssertOnQuery(_.status.isTriggerActive === true),
+ AssertOnQuery(_.status.message.startsWith("Getting offsets from")),
+ AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
+
+ AdvanceManualClock(100), // time = 1150 to unblock getEndOffset
+ AssertClockTime(1150),
+ // will block on planInputPartitions that needs 1350
+ AssertStreamExecThreadIsWaitingForTime(1350),
AssertOnQuery(_.status.isDataAvailable === true),
AssertOnQuery(_.status.isTriggerActive === true),
AssertOnQuery(_.status.message === "Processing new data"),
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
- AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions`
- AssertClockTime(1150),
+ AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions
+ AssertClockTime(1350),
AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500
AssertOnQuery(_.status.isDataAvailable === true),
AssertOnQuery(_.status.isTriggerActive === true),
@@ -290,7 +309,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
// Test status and progress while batch processing has completed
- AdvanceManualClock(350), // time = 1500 to unblock map task
+ AdvanceManualClock(150), // time = 1500 to unblock map task
AssertClockTime(1500),
CheckAnswer(2),
AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger
@@ -310,10 +329,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
assert(progress.numInputRows === 2)
assert(progress.processedRowsPerSecond === 4.0)
- assert(progress.durationMs.get("latestOffset") === 50)
- assert(progress.durationMs.get("queryPlanning") === 100)
+ assert(progress.durationMs.get("setOffsetRange") === 50)
+ assert(progress.durationMs.get("getEndOffset") === 100)
+ assert(progress.durationMs.get("queryPlanning") === 200)
assert(progress.durationMs.get("walCommit") === 0)
- assert(progress.durationMs.get("addBatch") === 350)
+ assert(progress.durationMs.get("addBatch") === 150)
assert(progress.durationMs.get("triggerExecution") === 500)
assert(progress.sources.length === 1)
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
index d6819ea..4f19881 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -22,15 +22,16 @@ import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset}
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.reader.InputPartition
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.StreamTest
-import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType}
class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
case class LongPartitionOffset(offset: Long) extends PartitionOffset
@@ -43,8 +44,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
override def beforeEach(): Unit = {
super.beforeEach()
epochEndpoint = EpochCoordinatorRef.create(
- mock[StreamingWriteSupport],
- mock[ContinuousReadSupport],
+ mock[StreamWriter],
+ mock[ContinuousReader],
mock[ContinuousExecution],
coordinatorId,
startEpoch,
@@ -72,26 +73,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
*/
private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = {
val queue = new ArrayBlockingQueue[UnsafeRow](1024)
- val partitionReader = new ContinuousPartitionReader[InternalRow] {
- var index = -1
- var curr: UnsafeRow = _
-
- override def next() = {
- curr = queue.take()
- index += 1
- true
- }
+ val factory = new InputPartition[InternalRow] {
+ override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] {
+ var index = -1
+ var curr: UnsafeRow = _
+
+ override def next() = {
+ curr = queue.take()
+ index += 1
+ true
+ }
- override def get = curr
+ override def get = curr
- override def getOffset = LongPartitionOffset(index)
+ override def getOffset = LongPartitionOffset(index)
- override def close() = {}
+ override def close() = {}
+ }
}
val reader = new ContinuousQueuedDataReader(
- 0,
- partitionReader,
- new StructType().add("i", "int"),
+ new ContinuousDataSourceRDDPartition(0, factory),
mockContext,
dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize,
epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs)
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
index 3d21bc6..4980b0c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -41,7 +41,7 @@ class ContinuousSuiteBase extends StreamTest {
case s: ContinuousExecution =>
assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized")
val reader = s.lastExecution.executedPlan.collectFirst {
- case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReadSupport, _) => r
+ case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r
}.get
val deltaMs = numTriggers * 1000 + 300
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
index 3c973d8..82836dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
@@ -27,9 +27,9 @@ import org.apache.spark._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.test.TestSparkSession
class EpochCoordinatorSuite
@@ -40,20 +40,20 @@ class EpochCoordinatorSuite
private var epochCoordinator: RpcEndpointRef = _
- private var writeSupport: StreamingWriteSupport = _
+ private var writer: StreamWriter = _
private var query: ContinuousExecution = _
private var orderVerifier: InOrder = _
override def beforeEach(): Unit = {
- val reader = mock[ContinuousReadSupport]
- writeSupport = mock[StreamingWriteSupport]
+ val reader = mock[ContinuousReader]
+ writer = mock[StreamWriter]
query = mock[ContinuousExecution]
- orderVerifier = inOrder(writeSupport, query)
+ orderVerifier = inOrder(writer, query)
spark = new TestSparkSession()
epochCoordinator
- = EpochCoordinatorRef.create(writeSupport, reader, query, "test", 1, spark, SparkEnv.get)
+ = EpochCoordinatorRef.create(writer, reader, query, "test", 1, spark, SparkEnv.get)
}
test("single epoch") {
@@ -209,12 +209,12 @@ class EpochCoordinatorSuite
}
private def verifyCommit(epoch: Long): Unit = {
- orderVerifier.verify(writeSupport).commit(eqTo(epoch), any())
+ orderVerifier.verify(writer).commit(eqTo(epoch), any())
orderVerifier.verify(query).commit(epoch)
}
private def verifyNoCommitFor(epoch: Long): Unit = {
- verify(writeSupport, never()).commit(eqTo(epoch), any())
+ verify(writer, never()).commit(eqTo(epoch), any())
verify(query, never()).commit(epoch)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org