You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/30 13:54:39 UTC
spark git commit: Roll forward "[SPARK-23096][SS] Migrate rate source
to V2"
Repository: spark
Updated Branches:
refs/heads/master b02e76cbf -> 5b5a36ed6
Roll forward "[SPARK-23096][SS] Migrate rate source to V2"
## What changes were proposed in this pull request?
Roll forward c68ec4e (#20688).
There are two minor test changes required:
* An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException.
* The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils.
## How was this patch tested?
existing tests
Author: Jose Torres <to...@gmail.com>
Author: jerryshao <ss...@hortonworks.com>
Closes #20922 from jose-torres/ratefix.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b5a36ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b5a36ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b5a36ed
Branch: refs/heads/master
Commit: 5b5a36ed6d2bb0971edfeccddf0f280936d2275f
Parents: b02e76c
Author: Jose Torres <to...@gmail.com>
Authored: Fri Mar 30 21:54:26 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Mar 30 21:54:26 2018 +0800
----------------------------------------------------------------------
....apache.spark.sql.sources.DataSourceRegister | 3 +-
.../sql/execution/datasources/DataSource.scala | 6 +-
.../streaming/RateSourceProvider.scala | 262 ---------------
.../continuous/ContinuousRateStreamSource.scala | 25 +-
.../sources/RateStreamMicroBatchReader.scala | 222 ++++++++++++
.../streaming/sources/RateStreamProvider.scala | 125 +++++++
.../streaming/sources/RateStreamSourceV2.scala | 187 -----------
.../execution/streaming/RateSourceSuite.scala | 194 -----------
.../execution/streaming/RateSourceV2Suite.scala | 191 -----------
.../sources/RateStreamProviderSuite.scala | 334 +++++++++++++++++++
10 files changed, 705 insertions(+), 844 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 1fe9c09..1b37905 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
org.apache.spark.sql.execution.datasources.text.TextFileFormat
org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
-org.apache.spark.sql.execution.streaming.RateSourceProvider
+org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
-org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 31fa89b..b84ea76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
+import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
@@ -566,6 +566,7 @@ object DataSource extends Logging {
val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
val nativeOrc = classOf[OrcFileFormat].getCanonicalName
val socket = classOf[TextSocketSourceProvider].getCanonicalName
+ val rate = classOf[RateStreamProvider].getCanonicalName
Map(
"org.apache.spark.sql.jdbc" -> jdbc,
@@ -587,7 +588,8 @@ object DataSource extends Logging {
"org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
"org.apache.spark.ml.source.libsvm" -> libsvm,
"com.databricks.spark.csv" -> csv,
- "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
+ "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
+ "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
deleted file mode 100644
index 649fbbf..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
+++ /dev/null
@@ -1,262 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.io._
-import java.nio.charset.StandardCharsets
-import java.util.Optional
-import java.util.concurrent.TimeUnit
-
-import org.apache.commons.io.IOUtils
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
-import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
-import org.apache.spark.sql.types._
-import org.apache.spark.util.{ManualClock, SystemClock}
-
-/**
- * A source that generates increment long values with timestamps. Each generated row has two
- * columns: a timestamp column for the generated time and an auto increment long column starting
- * with 0L.
- *
- * This source supports the following options:
- * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
- * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
- * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
- * seconds.
- * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
- * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
- * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
- */
-class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
- with DataSourceV2 with ContinuousReadSupport {
-
- override def sourceSchema(
- sqlContext: SQLContext,
- schema: Option[StructType],
- providerName: String,
- parameters: Map[String, String]): (String, StructType) = {
- if (schema.nonEmpty) {
- throw new AnalysisException("The rate source does not support a user-specified schema.")
- }
-
- (shortName(), RateSourceProvider.SCHEMA)
- }
-
- override def createSource(
- sqlContext: SQLContext,
- metadataPath: String,
- schema: Option[StructType],
- providerName: String,
- parameters: Map[String, String]): Source = {
- val params = CaseInsensitiveMap(parameters)
-
- val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
- if (rowsPerSecond <= 0) {
- throw new IllegalArgumentException(
- s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
- "must be positive")
- }
-
- val rampUpTimeSeconds =
- params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
- if (rampUpTimeSeconds < 0) {
- throw new IllegalArgumentException(
- s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
- "must not be negative")
- }
-
- val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
- sqlContext.sparkContext.defaultParallelism)
- if (numPartitions <= 0) {
- throw new IllegalArgumentException(
- s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
- "must be positive")
- }
-
- new RateStreamSource(
- sqlContext,
- metadataPath,
- rowsPerSecond,
- rampUpTimeSeconds,
- numPartitions,
- params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
- )
- }
-
- override def createContinuousReader(
- schema: Optional[StructType],
- checkpointLocation: String,
- options: DataSourceOptions): ContinuousReader = {
- new RateStreamContinuousReader(options)
- }
-
- override def shortName(): String = "rate"
-}
-
-object RateSourceProvider {
- val SCHEMA =
- StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
-
- val VERSION = 1
-}
-
-class RateStreamSource(
- sqlContext: SQLContext,
- metadataPath: String,
- rowsPerSecond: Long,
- rampUpTimeSeconds: Long,
- numPartitions: Int,
- useManualClock: Boolean) extends Source with Logging {
-
- import RateSourceProvider._
- import RateStreamSource._
-
- val clock = if (useManualClock) new ManualClock else new SystemClock
-
- private val maxSeconds = Long.MaxValue / rowsPerSecond
-
- if (rampUpTimeSeconds > maxSeconds) {
- throw new ArithmeticException(
- s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
- s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
- }
-
- private val startTimeMs = {
- val metadataLog =
- new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
- override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
- val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
- writer.write("v" + VERSION + "\n")
- writer.write(metadata.json)
- writer.flush
- }
-
- override def deserialize(in: InputStream): LongOffset = {
- val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
- // HDFSMetadataLog guarantees that it never creates a partial file.
- assert(content.length != 0)
- if (content(0) == 'v') {
- val indexOfNewLine = content.indexOf("\n")
- if (indexOfNewLine > 0) {
- val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
- LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
- } else {
- throw new IllegalStateException(
- s"Log file was malformed: failed to detect the log file version line.")
- }
- } else {
- throw new IllegalStateException(
- s"Log file was malformed: failed to detect the log file version line.")
- }
- }
- }
-
- metadataLog.get(0).getOrElse {
- val offset = LongOffset(clock.getTimeMillis())
- metadataLog.add(0, offset)
- logInfo(s"Start time: $offset")
- offset
- }.offset
- }
-
- /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
- @volatile private var lastTimeMs = startTimeMs
-
- override def schema: StructType = RateSourceProvider.SCHEMA
-
- override def getOffset: Option[Offset] = {
- val now = clock.getTimeMillis()
- if (lastTimeMs < now) {
- lastTimeMs = now
- }
- Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
- }
-
- override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
- val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
- val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
- assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
- if (endSeconds > maxSeconds) {
- throw new ArithmeticException("Integer overflow. Max offset with " +
- s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
- }
- // Fix "lastTimeMs" for recovery
- if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
- lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
- }
- val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
- val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
- logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
- s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
-
- if (rangeStart == rangeEnd) {
- return sqlContext.internalCreateDataFrame(
- sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
- }
-
- val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
- val relativeMsPerValue =
- TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
-
- val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
- val relative = math.round((v - rangeStart) * relativeMsPerValue)
- InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
- }
- sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
- }
-
- override def stop(): Unit = {}
-
- override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
- s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
-}
-
-object RateStreamSource {
-
- /** Calculate the end value we will emit at the time `seconds`. */
- def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
- // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
- // Then speedDeltaPerSecond = 2
- //
- // seconds = 0 1 2 3 4 5 6
- // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
- // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
- val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
- if (seconds <= rampUpTimeSeconds) {
- // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
- // avoid overflow
- if (seconds % 2 == 1) {
- (seconds + 1) / 2 * speedDeltaPerSecond * seconds
- } else {
- seconds / 2 * speedDeltaPerSecond * (seconds + 1)
- }
- } else {
- // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
- val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
- rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index 20d9006..2f0de26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
+import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
+import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
@@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions)
val creationTime = System.currentTimeMillis()
- val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
- val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
+ val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt
+ val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong
val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
@@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions)
RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
}
- override def readSchema(): StructType = RateSourceProvider.SCHEMA
+ override def readSchema(): StructType = RateStreamProvider.SCHEMA
private var offset: Offset = _
override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
- this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
+ this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
}
override def getStartOffset(): Offset = offset
@@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions)
override def commit(end: Offset): Unit = {}
override def stop(): Unit = {}
+ private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
+ RateStreamOffset(
+ Range(0, numPartitions).map { i =>
+ // Note that the starting offset is exclusive, so we have to decrement the starting value
+ // by the increment that will later be applied. The first row output in each
+ // partition will have a value equal to the partition index.
+ (i,
+ ValueRunTimeMsPair(
+ (i - numPartitions).toLong,
+ creationTimeMs))
+ }.toMap)
+ }
+
}
case class RateStreamContinuousDataReaderFactory(
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
new file mode 100644
index 0000000..6cf8520
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io._
+import java.nio.charset.StandardCharsets
+import java.util.Optional
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.IOUtils
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{ManualClock, SystemClock}
+
+class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String)
+ extends MicroBatchReader with Logging {
+ import RateStreamProvider._
+
+ private[sources] val clock = {
+ // The option to use a manual clock is provided only for unit testing purposes.
+ if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock
+ }
+
+ private val rowsPerSecond =
+ options.get(ROWS_PER_SECOND).orElse("1").toLong
+
+ private val rampUpTimeSeconds =
+ Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String]))
+ .map(JavaUtils.timeStringAsSec(_))
+ .getOrElse(0L)
+
+ private val maxSeconds = Long.MaxValue / rowsPerSecond
+
+ if (rampUpTimeSeconds > maxSeconds) {
+ throw new ArithmeticException(
+ s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
+ s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
+ }
+
+ private[sources] val creationTimeMs = {
+ val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession)
+ require(session.isDefined)
+
+ val metadataLog =
+ new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) {
+ override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
+ val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+ writer.write("v" + VERSION + "\n")
+ writer.write(metadata.json)
+ writer.flush
+ }
+
+ override def deserialize(in: InputStream): LongOffset = {
+ val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
+ // HDFSMetadataLog guarantees that it never creates a partial file.
+ assert(content.length != 0)
+ if (content(0) == 'v') {
+ val indexOfNewLine = content.indexOf("\n")
+ if (indexOfNewLine > 0) {
+ parseVersion(content.substring(0, indexOfNewLine), VERSION)
+ LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ } else {
+ throw new IllegalStateException(
+ s"Log file was malformed: failed to detect the log file version line.")
+ }
+ }
+ }
+
+ metadataLog.get(0).getOrElse {
+ val offset = LongOffset(clock.getTimeMillis())
+ metadataLog.add(0, offset)
+ logInfo(s"Start time: $offset")
+ offset
+ }.offset
+ }
+
+ @volatile private var lastTimeMs: Long = creationTimeMs
+
+ private var start: LongOffset = _
+ private var end: LongOffset = _
+
+ override def readSchema(): StructType = SCHEMA
+
+ override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {
+ this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset]
+ this.end = end.orElse {
+ val now = clock.getTimeMillis()
+ if (lastTimeMs < now) {
+ lastTimeMs = now
+ }
+ LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs))
+ }.asInstanceOf[LongOffset]
+ }
+
+ override def getStartOffset(): Offset = {
+ if (start == null) throw new IllegalStateException("start offset not set")
+ start
+ }
+ override def getEndOffset(): Offset = {
+ if (end == null) throw new IllegalStateException("end offset not set")
+ end
+ }
+
+ override def deserializeOffset(json: String): Offset = {
+ LongOffset(json.toLong)
+ }
+
+ override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+ val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
+ val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
+ assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
+ if (endSeconds > maxSeconds) {
+ throw new ArithmeticException("Integer overflow. Max offset with " +
+ s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
+ }
+ // Fix "lastTimeMs" for recovery
+ if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) {
+ lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs
+ }
+ val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
+ val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
+ logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
+ s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
+
+ if (rangeStart == rangeEnd) {
+ return List.empty.asJava
+ }
+
+ val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
+ val relativeMsPerValue =
+ TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
+ val numPartitions = {
+ val activeSession = SparkSession.getActiveSession
+ require(activeSession.isDefined)
+ Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
+ .map(_.toInt)
+ .getOrElse(activeSession.get.sparkContext.defaultParallelism)
+ }
+
+ (0 until numPartitions).map { p =>
+ new RateStreamMicroBatchDataReaderFactory(
+ p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
+ : DataReaderFactory[Row]
+ }.toList.asJava
+ }
+
+ override def commit(end: Offset): Unit = {}
+
+ override def stop(): Unit = {}
+
+ override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " +
+ s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
+ s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
+}
+
+class RateStreamMicroBatchDataReaderFactory(
+ partitionId: Int,
+ numPartitions: Int,
+ rangeStart: Long,
+ rangeEnd: Long,
+ localStartTimeMs: Long,
+ relativeMsPerValue: Double) extends DataReaderFactory[Row] {
+
+ override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
+ partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
+}
+
+class RateStreamMicroBatchDataReader(
+ partitionId: Int,
+ numPartitions: Int,
+ rangeStart: Long,
+ rangeEnd: Long,
+ localStartTimeMs: Long,
+ relativeMsPerValue: Double) extends DataReader[Row] {
+ private var count = 0
+
+ override def next(): Boolean = {
+ rangeStart + partitionId + numPartitions * count < rangeEnd
+ }
+
+ override def get(): Row = {
+ val currValue = rangeStart + partitionId + numPartitions * count
+ count += 1
+ val relative = math.round((currValue - rangeStart) * relativeMsPerValue)
+ Row(
+ DateTimeUtils.toJavaTimestamp(
+ DateTimeUtils.fromMillis(relative + localStartTimeMs)),
+ currValue
+ )
+ }
+
+ override def close(): Unit = {}
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
new file mode 100644
index 0000000..6bdd492
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util.Optional
+
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader}
+import org.apache.spark.sql.types._
+
+/**
+ * A source that generates increment long values with timestamps. Each generated row has two
+ * columns: a timestamp column for the generated time and an auto increment long column starting
+ * with 0L.
+ *
+ * This source supports the following options:
+ * - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
+ * - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
+ * becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
+ * seconds.
+ * - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
+ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
+ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
+ */
+class RateStreamProvider extends DataSourceV2
+ with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister {
+ import RateStreamProvider._
+
+ override def createMicroBatchReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceOptions): MicroBatchReader = {
+ if (options.get(ROWS_PER_SECOND).isPresent) {
+ val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong
+ if (rowsPerSecond <= 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive")
+ }
+ }
+
+ if (options.get(RAMP_UP_TIME).isPresent) {
+ val rampUpTimeSeconds =
+ JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get())
+ if (rampUpTimeSeconds < 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative")
+ }
+ }
+
+ if (options.get(NUM_PARTITIONS).isPresent) {
+ val numPartitions = options.get(NUM_PARTITIONS).get().toInt
+ if (numPartitions <= 0) {
+ throw new IllegalArgumentException(
+ s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
+ }
+ }
+
+ if (schema.isPresent) {
+ throw new AnalysisException("The rate source does not support a user-specified schema.")
+ }
+
+ new RateStreamMicroBatchReader(options, checkpointLocation)
+ }
+
+ override def createContinuousReader(
+ schema: Optional[StructType],
+ checkpointLocation: String,
+ options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options)
+
+ override def shortName(): String = "rate"
+}
+
+object RateStreamProvider {
+ val SCHEMA =
+ StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
+
+ val VERSION = 1
+
+ val NUM_PARTITIONS = "numPartitions"
+ val ROWS_PER_SECOND = "rowsPerSecond"
+ val RAMP_UP_TIME = "rampUpTime"
+
+ /** Calculate the end value we will emit at the time `seconds`. */
+ def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
+ // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
+ // Then speedDeltaPerSecond = 2
+ //
+ // seconds = 0 1 2 3 4 5 6
+ // speed = 0 2 4 6 8 10 10 (speedDeltaPerSecond * seconds)
+ // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
+ val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
+ if (seconds <= rampUpTimeSeconds) {
+ // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
+ // avoid overflow
+ if (seconds % 2 == 1) {
+ (seconds + 1) / 2 * speedDeltaPerSecond * seconds
+ } else {
+ seconds / 2 * speedDeltaPerSecond * (seconds + 1)
+ }
+ } else {
+ // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
+ val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
+ rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
deleted file mode 100644
index 4e2459b..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.sources
-
-import java.util.Optional
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-import org.json4s.DefaultFormats
-import org.json4s.jackson.Serialization
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
-import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
-import org.apache.spark.util.{ManualClock, SystemClock}
-
-/**
- * This is a temporary register as we build out v2 migration. Microbatch read support should
- * be implemented in the same register as v1.
- */
-class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister {
- override def createMicroBatchReader(
- schema: Optional[StructType],
- checkpointLocation: String,
- options: DataSourceOptions): MicroBatchReader = {
- new RateStreamMicroBatchReader(options)
- }
-
- override def shortName(): String = "ratev2"
-}
-
-class RateStreamMicroBatchReader(options: DataSourceOptions)
- extends MicroBatchReader {
- implicit val defaultFormats: DefaultFormats = DefaultFormats
-
- val clock = {
- // The option to use a manual clock is provided only for unit testing purposes.
- if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock
- else new SystemClock
- }
-
- private val numPartitions =
- options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
- private val rowsPerSecond =
- options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
-
- // The interval (in milliseconds) between rows in each partition.
- // e.g. if there are 4 global rows per second, and 2 partitions, each partition
- // should output rows every (1000 * 2 / 4) = 500 ms.
- private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond
-
- override def readSchema(): StructType = {
- StructType(
- StructField("timestamp", TimestampType, false) ::
- StructField("value", LongType, false) :: Nil)
- }
-
- val creationTimeMs = clock.getTimeMillis()
-
- private var start: RateStreamOffset = _
- private var end: RateStreamOffset = _
-
- override def setOffsetRange(
- start: Optional[Offset],
- end: Optional[Offset]): Unit = {
- this.start = start.orElse(
- RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs))
- .asInstanceOf[RateStreamOffset]
-
- this.end = end.orElse {
- val currentTime = clock.getTimeMillis()
- RateStreamOffset(
- this.start.partitionToValueAndRunTimeMs.map {
- case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
- // Calculate the number of rows we should advance in this partition (based on the
- // current time), and output a corresponding offset.
- val readInterval = currentTime - currentReadTime
- val numNewRows = readInterval / msPerPartitionBetweenRows
- if (numNewRows <= 0) {
- startOffset
- } else {
- (part, ValueRunTimeMsPair(
- currentVal + (numNewRows * numPartitions),
- currentReadTime + (numNewRows * msPerPartitionBetweenRows)))
- }
- }
- )
- }.asInstanceOf[RateStreamOffset]
- }
-
- override def getStartOffset(): Offset = {
- if (start == null) throw new IllegalStateException("start offset not set")
- start
- }
- override def getEndOffset(): Offset = {
- if (end == null) throw new IllegalStateException("end offset not set")
- end
- }
-
- override def deserializeOffset(json: String): Offset = {
- RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
- }
-
- override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
- val startMap = start.partitionToValueAndRunTimeMs
- val endMap = end.partitionToValueAndRunTimeMs
- endMap.keys.toSeq.map { part =>
- val ValueRunTimeMsPair(endVal, _) = endMap(part)
- val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part)
-
- val packedRows = mutable.ListBuffer[(Long, Long)]()
- var outVal = startVal + numPartitions
- var outTimeMs = startTimeMs
- while (outVal <= endVal) {
- packedRows.append((outTimeMs, outVal))
- outVal += numPartitions
- outTimeMs += msPerPartitionBetweenRows
- }
-
- RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]]
- }.toList.asJava
- }
-
- override def commit(end: Offset): Unit = {}
- override def stop(): Unit = {}
-}
-
-case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] {
- override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
-}
-
-class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
- private var currentIndex = -1
-
- override def next(): Boolean = {
- // Return true as long as the new index is in the seq.
- currentIndex += 1
- currentIndex < vals.size
- }
-
- override def get(): Row = {
- Row(
- DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)),
- vals(currentIndex)._2)
- }
-
- override def close(): Unit = {}
-}
-
-object RateStreamSourceV2 {
- val NUM_PARTITIONS = "numPartitions"
- val ROWS_PER_SECOND = "rowsPerSecond"
-
- private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
- RateStreamOffset(
- Range(0, numPartitions).map { i =>
- // Note that the starting offset is exclusive, so we have to decrement the starting value
- // by the increment that will later be applied. The first row output in each
- // partition will have a value equal to the partition index.
- (i,
- ValueRunTimeMsPair(
- (i - numPartitions).toLong,
- creationTimeMs))
- }.toMap)
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
deleted file mode 100644
index 03d0f63..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.util.concurrent.TimeUnit
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
-import org.apache.spark.util.ManualClock
-
-class RateSourceSuite extends StreamTest {
-
- import testImplicits._
-
- case class AdvanceRateManualClock(seconds: Long) extends AddData {
- override def addData(query: Option[StreamExecution]): (Source, Offset) = {
- assert(query.nonEmpty)
- val rateSource = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] =>
- source.asInstanceOf[RateStreamSource]
- }.head
- rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
- (rateSource, rateSource.getOffset.get)
- }
- }
-
- test("basic") {
- val input = spark.readStream
- .format("rate")
- .option("rowsPerSecond", "10")
- .option("useManualClock", "true")
- .load()
- testStream(input)(
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
- StopStream,
- StartStream(),
- // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
- AdvanceRateManualClock(seconds = 2),
- CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
- )
- }
-
- test("uniform distribution of event timestamps") {
- val input = spark.readStream
- .format("rate")
- .option("rowsPerSecond", "1500")
- .option("useManualClock", "true")
- .load()
- .as[(java.sql.Timestamp, Long)]
- .map(v => (v._1.getTime, v._2))
- val expectedAnswer = (0 until 1500).map { v =>
- (math.round(v * (1000.0 / 1500)), v)
- }
- testStream(input)(
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch(expectedAnswer: _*)
- )
- }
-
- test("valueAtSecond") {
- import RateStreamSource._
-
- assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
- assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
-
- assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0)
- assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1)
- assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3)
- assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8)
-
- assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0)
- assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2)
- assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6)
- assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12)
- assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20)
- assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30)
- }
-
- test("rampUpTime") {
- val input = spark.readStream
- .format("rate")
- .option("rowsPerSecond", "10")
- .option("rampUpTime", "4s")
- .option("useManualClock", "true")
- .load()
- .as[(java.sql.Timestamp, Long)]
- .map(v => (v._1.getTime, v._2))
- testStream(input)(
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch({
- Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
- }: _*), // speed = 6
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
- AdvanceRateManualClock(seconds = 1),
- // Now we should reach full speed
- CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10
- )
- }
-
- test("numPartitions") {
- val input = spark.readStream
- .format("rate")
- .option("rowsPerSecond", "10")
- .option("numPartitions", "6")
- .option("useManualClock", "true")
- .load()
- .select(spark_partition_id())
- .distinct()
- testStream(input)(
- AdvanceRateManualClock(1),
- CheckLastBatch((0 until 6): _*)
- )
- }
-
- testQuietly("overflow") {
- val input = spark.readStream
- .format("rate")
- .option("rowsPerSecond", Long.MaxValue.toString)
- .option("useManualClock", "true")
- .load()
- .select(spark_partition_id())
- .distinct()
- testStream(input)(
- AdvanceRateManualClock(2),
- ExpectFailure[ArithmeticException](t => {
- Seq("overflow", "rowsPerSecond").foreach { msg =>
- assert(t.getMessage.contains(msg))
- }
- })
- )
- }
-
- testQuietly("illegal option values") {
- def testIllegalOptionValue(
- option: String,
- value: String,
- expectedMessages: Seq[String]): Unit = {
- val e = intercept[StreamingQueryException] {
- spark.readStream
- .format("rate")
- .option(option, value)
- .load()
- .writeStream
- .format("console")
- .start()
- .awaitTermination()
- }
- assert(e.getCause.isInstanceOf[IllegalArgumentException])
- for (msg <- expectedMessages) {
- assert(e.getCause.getMessage.contains(msg))
- }
- }
-
- testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive"))
- testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive"))
- }
-
- test("user-specified schema given") {
- val exception = intercept[AnalysisException] {
- spark.readStream
- .format("rate")
- .schema(spark.range(1).schema)
- .load()
- }
- assert(exception.getMessage.contains(
- "rate source does not support a user-specified schema"))
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
deleted file mode 100644
index 983ba16..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.util.Optional
-import java.util.concurrent.TimeUnit
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2}
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport}
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.streaming.StreamTest
-import org.apache.spark.util.ManualClock
-
-class RateSourceV2Suite extends StreamTest {
- import testImplicits._
-
- case class AdvanceRateManualClock(seconds: Long) extends AddData {
- override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
- assert(query.nonEmpty)
- val rateSource = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
- }.head
- rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
- rateSource.setOffsetRange(Optional.empty(), Optional.empty())
- (rateSource, rateSource.getEndOffset())
- }
- }
-
- test("microbatch in registry") {
- DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match {
- case ds: MicroBatchReadSupport =>
- val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty())
- assert(reader.isInstanceOf[RateStreamMicroBatchReader])
- case _ =>
- throw new IllegalStateException("Could not find v2 read support for rate")
- }
- }
-
- test("basic microbatch execution") {
- val input = spark.readStream
- .format("rateV2")
- .option("numPartitions", "1")
- .option("rowsPerSecond", "10")
- .option("useManualClock", "true")
- .load()
- testStream(input, useV2Sink = true)(
- AdvanceRateManualClock(seconds = 1),
- CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
- StopStream,
- StartStream(),
- // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
- AdvanceRateManualClock(seconds = 2),
- CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
- )
- }
-
- test("microbatch - numPartitions propagated") {
- val reader = new RateStreamMicroBatchReader(
- new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
- reader.setOffsetRange(Optional.empty(), Optional.empty())
- val tasks = reader.createDataReaderFactories()
- assert(tasks.size == 11)
- }
-
- test("microbatch - set offset") {
- val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty())
- val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
- val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000))))
- reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
- assert(reader.getStartOffset() == startOffset)
- assert(reader.getEndOffset() == endOffset)
- }
-
- test("microbatch - infer offsets") {
- val reader = new RateStreamMicroBatchReader(
- new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava))
- reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100)
- reader.setOffsetRange(Optional.empty(), Optional.empty())
- reader.getStartOffset() match {
- case r: RateStreamOffset =>
- assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs)
- case _ => throw new IllegalStateException("unexpected offset type")
- }
- reader.getEndOffset() match {
- case r: RateStreamOffset =>
- // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted
- // longer than 100ms. It should never be early.
- assert(r.partitionToValueAndRunTimeMs(0).value >= 9)
- assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100)
-
- case _ => throw new IllegalStateException("unexpected offset type")
- }
- }
-
- test("microbatch - predetermined batch size") {
- val reader = new RateStreamMicroBatchReader(
- new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava))
- val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
- val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000))))
- reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
- val tasks = reader.createDataReaderFactories()
- assert(tasks.size == 1)
- assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20)
- }
-
- test("microbatch - data read") {
- val reader = new RateStreamMicroBatchReader(
- new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
- val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs)
- val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map {
- case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
- (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000))
- }.toMap)
-
- reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
- val tasks = reader.createDataReaderFactories()
- assert(tasks.size == 11)
-
- val readData = tasks.asScala
- .map(_.createDataReader())
- .flatMap { reader =>
- val buf = scala.collection.mutable.ListBuffer[Row]()
- while (reader.next()) buf.append(reader.get())
- buf
- }
-
- assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
- }
-
- test("continuous in registry") {
- DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
- case ds: ContinuousReadSupport =>
- val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
- assert(reader.isInstanceOf[RateStreamContinuousReader])
- case _ =>
- throw new IllegalStateException("Could not find v2 read support for rate")
- }
- }
-
- test("continuous data") {
- val reader = new RateStreamContinuousReader(
- new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
- reader.setStartOffset(Optional.empty())
- val tasks = reader.createDataReaderFactories()
- assert(tasks.size == 2)
-
- val data = scala.collection.mutable.ListBuffer[Row]()
- tasks.asScala.foreach {
- case t: RateStreamContinuousDataReaderFactory =>
- val startTimeMs = reader.getStartOffset()
- .asInstanceOf[RateStreamOffset]
- .partitionToValueAndRunTimeMs(t.partitionIndex)
- .runTimeMs
- val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
- for (rowIndex <- 0 to 9) {
- r.next()
- data.append(r.get())
- assert(r.getOffset() ==
- RateStreamPartitionOffset(
- t.partitionIndex,
- t.partitionIndex + rowIndex * 2,
- startTimeMs + (rowIndex + 1) * 100))
- }
- assert(System.currentTimeMillis() >= startTimeMs + 1000)
-
- case _ => throw new IllegalStateException("Unexpected task type")
- }
-
- assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
new file mode 100644
index 0000000..ff14ec3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.nio.file.Files
+import java.util.Optional
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.Offset
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.util.ManualClock
+
+class RateSourceSuite extends StreamTest {
+
+ import testImplicits._
+
+ case class AdvanceRateManualClock(seconds: Long) extends AddData {
+ override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
+ assert(query.nonEmpty)
+ val rateSource = query.get.logicalPlan.collect {
+ case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
+ }.head
+
+ rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
+ val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds(
+ rateSource.clock.getTimeMillis() - rateSource.creationTimeMs))
+ (rateSource, offset)
+ }
+ }
+
+ test("microbatch in registry") {
+ DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
+ case ds: MicroBatchReadSupport =>
+ val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty())
+ assert(reader.isInstanceOf[RateStreamMicroBatchReader])
+ case _ =>
+ throw new IllegalStateException("Could not find read support for rate")
+ }
+ }
+
+ test("compatible with old path in registry") {
+ DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider",
+ spark.sqlContext.conf).newInstance() match {
+ case ds: MicroBatchReadSupport =>
+ assert(ds.isInstanceOf[RateStreamProvider])
+ case _ =>
+ throw new IllegalStateException("Could not find read support for rate")
+ }
+ }
+
+ test("microbatch - basic") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("useManualClock", "true")
+ .load()
+ testStream(input)(
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
+ StopStream,
+ StartStream(),
+ // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
+ AdvanceRateManualClock(seconds = 2),
+ CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
+ )
+ }
+
+ test("microbatch - uniform distribution of event timestamps") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "1500")
+ .option("useManualClock", "true")
+ .load()
+ .as[(java.sql.Timestamp, Long)]
+ .map(v => (v._1.getTime, v._2))
+ val expectedAnswer = (0 until 1500).map { v =>
+ (math.round(v * (1000.0 / 1500)), v)
+ }
+ testStream(input)(
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch(expectedAnswer: _*)
+ )
+ }
+
+ test("microbatch - set offset") {
+ val temp = Files.createTempDirectory("dummy").toString
+ val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp)
+ val startOffset = LongOffset(0L)
+ val endOffset = LongOffset(1L)
+ reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+ assert(reader.getStartOffset() == startOffset)
+ assert(reader.getEndOffset() == endOffset)
+ }
+
+ test("microbatch - infer offsets") {
+ val tempFolder = Files.createTempDirectory("dummy").toString
+ val reader = new RateStreamMicroBatchReader(
+ new DataSourceOptions(
+ Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava),
+ tempFolder)
+ reader.clock.asInstanceOf[ManualClock].advance(100000)
+ reader.setOffsetRange(Optional.empty(), Optional.empty())
+ reader.getStartOffset() match {
+ case r: LongOffset => assert(r.offset === 0L)
+ case _ => throw new IllegalStateException("unexpected offset type")
+ }
+ reader.getEndOffset() match {
+ case r: LongOffset => assert(r.offset >= 100)
+ case _ => throw new IllegalStateException("unexpected offset type")
+ }
+ }
+
+ test("microbatch - predetermined batch size") {
+ val temp = Files.createTempDirectory("dummy").toString
+ val reader = new RateStreamMicroBatchReader(
+ new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp)
+ val startOffset = LongOffset(0L)
+ val endOffset = LongOffset(1L)
+ reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+ val tasks = reader.createDataReaderFactories()
+ assert(tasks.size == 1)
+ val dataReader = tasks.get(0).createDataReader()
+ val data = ArrayBuffer[Row]()
+ while (dataReader.next()) {
+ data.append(dataReader.get())
+ }
+ assert(data.size === 20)
+ }
+
+ test("microbatch - data read") {
+ val temp = Files.createTempDirectory("dummy").toString
+ val reader = new RateStreamMicroBatchReader(
+ new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp)
+ val startOffset = LongOffset(0L)
+ val endOffset = LongOffset(1L)
+ reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+ val tasks = reader.createDataReaderFactories()
+ assert(tasks.size == 11)
+
+ val readData = tasks.asScala
+ .map(_.createDataReader())
+ .flatMap { reader =>
+ val buf = scala.collection.mutable.ListBuffer[Row]()
+ while (reader.next()) buf.append(reader.get())
+ buf
+ }
+
+ assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
+ }
+
+ test("valueAtSecond") {
+ import RateStreamProvider._
+
+ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
+ assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
+
+ assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0)
+ assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1)
+ assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3)
+ assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8)
+
+ assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0)
+ assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2)
+ assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6)
+ assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12)
+ assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20)
+ assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30)
+ }
+
+ test("rampUpTime") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("rampUpTime", "4s")
+ .option("useManualClock", "true")
+ .load()
+ .as[(java.sql.Timestamp, Long)]
+ .map(v => (v._1.getTime, v._2))
+ testStream(input)(
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch({
+ Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
+ }: _*), // speed = 6
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
+ AdvanceRateManualClock(seconds = 1),
+ // Now we should reach full speed
+ CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10
+ AdvanceRateManualClock(seconds = 1),
+ CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10
+ )
+ }
+
+ test("numPartitions") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", "10")
+ .option("numPartitions", "6")
+ .option("useManualClock", "true")
+ .load()
+ .select(spark_partition_id())
+ .distinct()
+ testStream(input)(
+ AdvanceRateManualClock(1),
+ CheckLastBatch((0 until 6): _*)
+ )
+ }
+
+ testQuietly("overflow") {
+ val input = spark.readStream
+ .format("rate")
+ .option("rowsPerSecond", Long.MaxValue.toString)
+ .option("useManualClock", "true")
+ .load()
+ .select(spark_partition_id())
+ .distinct()
+ testStream(input)(
+ AdvanceRateManualClock(2),
+ ExpectFailure[ArithmeticException](t => {
+ Seq("overflow", "rowsPerSecond").foreach { msg =>
+ assert(t.getMessage.contains(msg))
+ }
+ })
+ )
+ }
+
+ testQuietly("illegal option values") {
+ def testIllegalOptionValue(
+ option: String,
+ value: String,
+ expectedMessages: Seq[String]): Unit = {
+ val e = intercept[IllegalArgumentException] {
+ spark.readStream
+ .format("rate")
+ .option(option, value)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ .awaitTermination()
+ }
+ for (msg <- expectedMessages) {
+ assert(e.getMessage.contains(msg))
+ }
+ }
+
+ testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive"))
+ testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive"))
+ }
+
+ test("user-specified schema given") {
+ val exception = intercept[AnalysisException] {
+ spark.readStream
+ .format("rate")
+ .schema(spark.range(1).schema)
+ .load()
+ }
+ assert(exception.getMessage.contains(
+ "rate source does not support a user-specified schema"))
+ }
+
+ test("continuous in registry") {
+ DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
+ case ds: ContinuousReadSupport =>
+ val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
+ assert(reader.isInstanceOf[RateStreamContinuousReader])
+ case _ =>
+ throw new IllegalStateException("Could not find read support for continuous rate")
+ }
+ }
+
+ test("continuous data") {
+ val reader = new RateStreamContinuousReader(
+ new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
+ reader.setStartOffset(Optional.empty())
+ val tasks = reader.createDataReaderFactories()
+ assert(tasks.size == 2)
+
+ val data = scala.collection.mutable.ListBuffer[Row]()
+ tasks.asScala.foreach {
+ case t: RateStreamContinuousDataReaderFactory =>
+ val startTimeMs = reader.getStartOffset()
+ .asInstanceOf[RateStreamOffset]
+ .partitionToValueAndRunTimeMs(t.partitionIndex)
+ .runTimeMs
+ val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
+ for (rowIndex <- 0 to 9) {
+ r.next()
+ data.append(r.get())
+ assert(r.getOffset() ==
+ RateStreamPartitionOffset(
+ t.partitionIndex,
+ t.partitionIndex + rowIndex * 2,
+ startTimeMs + (rowIndex + 1) * 100))
+ }
+ assert(System.currentTimeMillis() >= startTimeMs + 1000)
+
+ case _ => throw new IllegalStateException("Unexpected task type")
+ }
+
+ assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org