You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/30 13:54:39 UTC

spark git commit: Roll forward "[SPARK-23096][SS] Migrate rate source to V2"

Repository: spark
Updated Branches:
  refs/heads/master b02e76cbf -> 5b5a36ed6


Roll forward "[SPARK-23096][SS] Migrate rate source to V2"

## What changes were proposed in this pull request?

Roll forward c68ec4e (#20688).

There are two minor test changes required:

* An error which used to be TreeNodeException[ArithmeticException] is no longer wrapped and is now just ArithmeticException.
* The test framework simply does not set the active Spark session. (Or rather, it doesn't do so early enough - I think it only happens when a query is analyzed.) I've added the required logic to SQLTestUtils.

## How was this patch tested?

existing tests

Author: Jose Torres <to...@gmail.com>
Author: jerryshao <ss...@hortonworks.com>

Closes #20922 from jose-torres/ratefix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b5a36ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b5a36ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b5a36ed

Branch: refs/heads/master
Commit: 5b5a36ed6d2bb0971edfeccddf0f280936d2275f
Parents: b02e76c
Author: Jose Torres <to...@gmail.com>
Authored: Fri Mar 30 21:54:26 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Mar 30 21:54:26 2018 +0800

----------------------------------------------------------------------
 ....apache.spark.sql.sources.DataSourceRegister |   3 +-
 .../sql/execution/datasources/DataSource.scala  |   6 +-
 .../streaming/RateSourceProvider.scala          | 262 ---------------
 .../continuous/ContinuousRateStreamSource.scala |  25 +-
 .../sources/RateStreamMicroBatchReader.scala    | 222 ++++++++++++
 .../streaming/sources/RateStreamProvider.scala  | 125 +++++++
 .../streaming/sources/RateStreamSourceV2.scala  | 187 -----------
 .../execution/streaming/RateSourceSuite.scala   | 194 -----------
 .../execution/streaming/RateSourceV2Suite.scala | 191 -----------
 .../sources/RateStreamProviderSuite.scala       | 334 +++++++++++++++++++
 10 files changed, 705 insertions(+), 844 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 1fe9c09..1b37905 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -5,6 +5,5 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 org.apache.spark.sql.execution.datasources.text.TextFileFormat
 org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
-org.apache.spark.sql.execution.streaming.RateSourceProvider
+org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
 org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
-org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 31fa89b..b84ea76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
 import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
+import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.streaming.OutputMode
@@ -566,6 +566,7 @@ object DataSource extends Logging {
     val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
     val nativeOrc = classOf[OrcFileFormat].getCanonicalName
     val socket = classOf[TextSocketSourceProvider].getCanonicalName
+    val rate = classOf[RateStreamProvider].getCanonicalName
 
     Map(
       "org.apache.spark.sql.jdbc" -> jdbc,
@@ -587,7 +588,8 @@ object DataSource extends Logging {
       "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
       "org.apache.spark.ml.source.libsvm" -> libsvm,
       "com.databricks.spark.csv" -> csv,
-      "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
+      "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket,
+      "org.apache.spark.sql.execution.streaming.RateSourceProvider" -> rate
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
deleted file mode 100644
index 649fbbf..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
+++ /dev/null
@@ -1,262 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.io._
-import java.nio.charset.StandardCharsets
-import java.util.Optional
-import java.util.concurrent.TimeUnit
-
-import org.apache.commons.io.IOUtils
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
-import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
-import org.apache.spark.sql.types._
-import org.apache.spark.util.{ManualClock, SystemClock}
-
-/**
- *  A source that generates increment long values with timestamps. Each generated row has two
- *  columns: a timestamp column for the generated time and an auto increment long column starting
- *  with 0L.
- *
- *  This source supports the following options:
- *  - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
- *  - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
- *    becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
- *    seconds.
- *  - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
- *    generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
- *    be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
- */
-class RateSourceProvider extends StreamSourceProvider with DataSourceRegister
-  with DataSourceV2 with ContinuousReadSupport {
-
-  override def sourceSchema(
-      sqlContext: SQLContext,
-      schema: Option[StructType],
-      providerName: String,
-      parameters: Map[String, String]): (String, StructType) = {
-    if (schema.nonEmpty) {
-      throw new AnalysisException("The rate source does not support a user-specified schema.")
-    }
-
-    (shortName(), RateSourceProvider.SCHEMA)
-  }
-
-  override def createSource(
-      sqlContext: SQLContext,
-      metadataPath: String,
-      schema: Option[StructType],
-      providerName: String,
-      parameters: Map[String, String]): Source = {
-    val params = CaseInsensitiveMap(parameters)
-
-    val rowsPerSecond = params.get("rowsPerSecond").map(_.toLong).getOrElse(1L)
-    if (rowsPerSecond <= 0) {
-      throw new IllegalArgumentException(
-        s"Invalid value '${params("rowsPerSecond")}'. The option 'rowsPerSecond' " +
-          "must be positive")
-    }
-
-    val rampUpTimeSeconds =
-      params.get("rampUpTime").map(JavaUtils.timeStringAsSec(_)).getOrElse(0L)
-    if (rampUpTimeSeconds < 0) {
-      throw new IllegalArgumentException(
-        s"Invalid value '${params("rampUpTime")}'. The option 'rampUpTime' " +
-          "must not be negative")
-    }
-
-    val numPartitions = params.get("numPartitions").map(_.toInt).getOrElse(
-      sqlContext.sparkContext.defaultParallelism)
-    if (numPartitions <= 0) {
-      throw new IllegalArgumentException(
-        s"Invalid value '${params("numPartitions")}'. The option 'numPartitions' " +
-          "must be positive")
-    }
-
-    new RateStreamSource(
-      sqlContext,
-      metadataPath,
-      rowsPerSecond,
-      rampUpTimeSeconds,
-      numPartitions,
-      params.get("useManualClock").map(_.toBoolean).getOrElse(false) // Only for testing
-    )
-  }
-
-  override def createContinuousReader(
-      schema: Optional[StructType],
-      checkpointLocation: String,
-      options: DataSourceOptions): ContinuousReader = {
-    new RateStreamContinuousReader(options)
-  }
-
-  override def shortName(): String = "rate"
-}
-
-object RateSourceProvider {
-  val SCHEMA =
-    StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
-
-  val VERSION = 1
-}
-
-class RateStreamSource(
-    sqlContext: SQLContext,
-    metadataPath: String,
-    rowsPerSecond: Long,
-    rampUpTimeSeconds: Long,
-    numPartitions: Int,
-    useManualClock: Boolean) extends Source with Logging {
-
-  import RateSourceProvider._
-  import RateStreamSource._
-
-  val clock = if (useManualClock) new ManualClock else new SystemClock
-
-  private val maxSeconds = Long.MaxValue / rowsPerSecond
-
-  if (rampUpTimeSeconds > maxSeconds) {
-    throw new ArithmeticException(
-      s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
-        s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
-  }
-
-  private val startTimeMs = {
-    val metadataLog =
-      new HDFSMetadataLog[LongOffset](sqlContext.sparkSession, metadataPath) {
-        override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
-          val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
-          writer.write("v" + VERSION + "\n")
-          writer.write(metadata.json)
-          writer.flush
-        }
-
-        override def deserialize(in: InputStream): LongOffset = {
-          val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
-          // HDFSMetadataLog guarantees that it never creates a partial file.
-          assert(content.length != 0)
-          if (content(0) == 'v') {
-            val indexOfNewLine = content.indexOf("\n")
-            if (indexOfNewLine > 0) {
-              val version = parseVersion(content.substring(0, indexOfNewLine), VERSION)
-              LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
-            } else {
-              throw new IllegalStateException(
-                s"Log file was malformed: failed to detect the log file version line.")
-            }
-          } else {
-            throw new IllegalStateException(
-              s"Log file was malformed: failed to detect the log file version line.")
-          }
-        }
-      }
-
-    metadataLog.get(0).getOrElse {
-      val offset = LongOffset(clock.getTimeMillis())
-      metadataLog.add(0, offset)
-      logInfo(s"Start time: $offset")
-      offset
-    }.offset
-  }
-
-  /** When the system time runs backward, "lastTimeMs" will make sure we are still monotonic. */
-  @volatile private var lastTimeMs = startTimeMs
-
-  override def schema: StructType = RateSourceProvider.SCHEMA
-
-  override def getOffset: Option[Offset] = {
-    val now = clock.getTimeMillis()
-    if (lastTimeMs < now) {
-      lastTimeMs = now
-    }
-    Some(LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - startTimeMs)))
-  }
-
-  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
-    val startSeconds = start.flatMap(LongOffset.convert(_).map(_.offset)).getOrElse(0L)
-    val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
-    assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
-    if (endSeconds > maxSeconds) {
-      throw new ArithmeticException("Integer overflow. Max offset with " +
-        s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
-    }
-    // Fix "lastTimeMs" for recovery
-    if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs) {
-      lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + startTimeMs
-    }
-    val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
-    val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
-    logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
-      s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
-
-    if (rangeStart == rangeEnd) {
-      return sqlContext.internalCreateDataFrame(
-        sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
-    }
-
-    val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
-    val relativeMsPerValue =
-      TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
-
-    val rdd = sqlContext.sparkContext.range(rangeStart, rangeEnd, 1, numPartitions).map { v =>
-      val relative = math.round((v - rangeStart) * relativeMsPerValue)
-      InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
-    }
-    sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
-  }
-
-  override def stop(): Unit = {}
-
-  override def toString: String = s"RateSource[rowsPerSecond=$rowsPerSecond, " +
-    s"rampUpTimeSeconds=$rampUpTimeSeconds, numPartitions=$numPartitions]"
-}
-
-object RateStreamSource {
-
-  /** Calculate the end value we will emit at the time `seconds`. */
-  def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
-    // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
-    // Then speedDeltaPerSecond = 2
-    //
-    // seconds   = 0 1 2  3  4  5  6
-    // speed     = 0 2 4  6  8 10 10 (speedDeltaPerSecond * seconds)
-    // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
-    val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
-    if (seconds <= rampUpTimeSeconds) {
-      // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
-      // avoid overflow
-      if (seconds % 2 == 1) {
-        (seconds + 1) / 2 * speedDeltaPerSecond * seconds
-      } else {
-        seconds / 2 * speedDeltaPerSecond * (seconds + 1)
-      }
-    } else {
-      // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
-      val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
-      rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index 20d9006..2f0de26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -24,8 +24,8 @@ import org.json4s.jackson.Serialization
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateSourceProvider, RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.execution.streaming.sources.RateStreamSourceV2
+import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
+import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
 import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.sources.v2.reader._
 import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
@@ -40,8 +40,8 @@ class RateStreamContinuousReader(options: DataSourceOptions)
 
   val creationTime = System.currentTimeMillis()
 
-  val numPartitions = options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
-  val rowsPerSecond = options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
+  val numPartitions = options.get(RateStreamProvider.NUM_PARTITIONS).orElse("5").toInt
+  val rowsPerSecond = options.get(RateStreamProvider.ROWS_PER_SECOND).orElse("6").toLong
   val perPartitionRate = rowsPerSecond.toDouble / numPartitions.toDouble
 
   override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
@@ -57,12 +57,12 @@ class RateStreamContinuousReader(options: DataSourceOptions)
     RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
   }
 
-  override def readSchema(): StructType = RateSourceProvider.SCHEMA
+  override def readSchema(): StructType = RateStreamProvider.SCHEMA
 
   private var offset: Offset = _
 
   override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
-    this.offset = offset.orElse(RateStreamSourceV2.createInitialOffset(numPartitions, creationTime))
+    this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
   }
 
   override def getStartOffset(): Offset = offset
@@ -98,6 +98,19 @@ class RateStreamContinuousReader(options: DataSourceOptions)
   override def commit(end: Offset): Unit = {}
   override def stop(): Unit = {}
 
+  private def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
+    RateStreamOffset(
+      Range(0, numPartitions).map { i =>
+        // Note that the starting offset is exclusive, so we have to decrement the starting value
+        // by the increment that will later be applied. The first row output in each
+        // partition will have a value equal to the partition index.
+        (i,
+          ValueRunTimeMsPair(
+            (i - numPartitions).toLong,
+            creationTimeMs))
+      }.toMap)
+  }
+
 }
 
 case class RateStreamContinuousDataReaderFactory(

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
new file mode 100644
index 0000000..6cf8520
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io._
+import java.nio.charset.StandardCharsets
+import java.util.Optional
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.IOUtils
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{ManualClock, SystemClock}
+
+class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String)
+  extends MicroBatchReader with Logging {
+  import RateStreamProvider._
+
+  private[sources] val clock = {
+    // The option to use a manual clock is provided only for unit testing purposes.
+    if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock
+  }
+
+  private val rowsPerSecond =
+    options.get(ROWS_PER_SECOND).orElse("1").toLong
+
+  private val rampUpTimeSeconds =
+    Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String]))
+      .map(JavaUtils.timeStringAsSec(_))
+      .getOrElse(0L)
+
+  private val maxSeconds = Long.MaxValue / rowsPerSecond
+
+  if (rampUpTimeSeconds > maxSeconds) {
+    throw new ArithmeticException(
+      s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" +
+        s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.")
+  }
+
+  private[sources] val creationTimeMs = {
+    val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession)
+    require(session.isDefined)
+
+    val metadataLog =
+      new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) {
+        override def serialize(metadata: LongOffset, out: OutputStream): Unit = {
+          val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+          writer.write("v" + VERSION + "\n")
+          writer.write(metadata.json)
+          writer.flush
+        }
+
+        override def deserialize(in: InputStream): LongOffset = {
+          val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
+          // HDFSMetadataLog guarantees that it never creates a partial file.
+          assert(content.length != 0)
+          if (content(0) == 'v') {
+            val indexOfNewLine = content.indexOf("\n")
+            if (indexOfNewLine > 0) {
+              parseVersion(content.substring(0, indexOfNewLine), VERSION)
+              LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1)))
+            } else {
+              throw new IllegalStateException(
+                s"Log file was malformed: failed to detect the log file version line.")
+            }
+          } else {
+            throw new IllegalStateException(
+              s"Log file was malformed: failed to detect the log file version line.")
+          }
+        }
+      }
+
+    metadataLog.get(0).getOrElse {
+      val offset = LongOffset(clock.getTimeMillis())
+      metadataLog.add(0, offset)
+      logInfo(s"Start time: $offset")
+      offset
+    }.offset
+  }
+
+  @volatile private var lastTimeMs: Long = creationTimeMs
+
+  private var start: LongOffset = _
+  private var end: LongOffset = _
+
+  override def readSchema(): StructType = SCHEMA
+
+  override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = {
+    this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset]
+    this.end = end.orElse {
+      val now = clock.getTimeMillis()
+      if (lastTimeMs < now) {
+        lastTimeMs = now
+      }
+      LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs))
+    }.asInstanceOf[LongOffset]
+  }
+
+  override def getStartOffset(): Offset = {
+    if (start == null) throw new IllegalStateException("start offset not set")
+    start
+  }
+  override def getEndOffset(): Offset = {
+    if (end == null) throw new IllegalStateException("end offset not set")
+    end
+  }
+
+  override def deserializeOffset(json: String): Offset = {
+    LongOffset(json.toLong)
+  }
+
+  override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
+    val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L)
+    val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L)
+    assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)")
+    if (endSeconds > maxSeconds) {
+      throw new ArithmeticException("Integer overflow. Max offset with " +
+        s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.")
+    }
+    // Fix "lastTimeMs" for recovery
+    if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) {
+      lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs
+    }
+    val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds)
+    val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds)
+    logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " +
+      s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
+
+    if (rangeStart == rangeEnd) {
+      return List.empty.asJava
+    }
+
+    val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
+    val relativeMsPerValue =
+      TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart)
+    val numPartitions = {
+      val activeSession = SparkSession.getActiveSession
+      require(activeSession.isDefined)
+      Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String]))
+        .map(_.toInt)
+        .getOrElse(activeSession.get.sparkContext.defaultParallelism)
+    }
+
+    (0 until numPartitions).map { p =>
+      new RateStreamMicroBatchDataReaderFactory(
+        p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
+        : DataReaderFactory[Row]
+    }.toList.asJava
+  }
+
+  override def commit(end: Offset): Unit = {}
+
+  override def stop(): Unit = {}
+
+  override def toString: String = s"MicroBatchRateSource[rowsPerSecond=$rowsPerSecond, " +
+    s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
+    s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
+}
+
+class RateStreamMicroBatchDataReaderFactory(
+    partitionId: Int,
+    numPartitions: Int,
+    rangeStart: Long,
+    rangeEnd: Long,
+    localStartTimeMs: Long,
+    relativeMsPerValue: Double) extends DataReaderFactory[Row] {
+
+  override def createDataReader(): DataReader[Row] = new RateStreamMicroBatchDataReader(
+    partitionId, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
+}
+
+class RateStreamMicroBatchDataReader(
+    partitionId: Int,
+    numPartitions: Int,
+    rangeStart: Long,
+    rangeEnd: Long,
+    localStartTimeMs: Long,
+    relativeMsPerValue: Double) extends DataReader[Row] {
+  private var count = 0
+
+  override def next(): Boolean = {
+    rangeStart + partitionId + numPartitions * count < rangeEnd
+  }
+
+  override def get(): Row = {
+    val currValue = rangeStart + partitionId + numPartitions * count
+    count += 1
+    val relative = math.round((currValue - rangeStart) * relativeMsPerValue)
+    Row(
+      DateTimeUtils.toJavaTimestamp(
+        DateTimeUtils.fromMillis(relative + localStartTimeMs)),
+      currValue
+    )
+  }
+
+  override def close(): Unit = {}
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
new file mode 100644
index 0000000..6bdd492
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.util.Optional
+
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader}
+import org.apache.spark.sql.types._
+
+/**
+ *  A source that generates increment long values with timestamps. Each generated row has two
+ *  columns: a timestamp column for the generated time and an auto increment long column starting
+ *  with 0L.
+ *
+ *  This source supports the following options:
+ *  - `rowsPerSecond` (e.g. 100, default: 1): How many rows should be generated per second.
+ *  - `rampUpTime` (e.g. 5s, default: 0s): How long to ramp up before the generating speed
+ *    becomes `rowsPerSecond`. Using finer granularities than seconds will be truncated to integer
+ *    seconds.
+ *  - `numPartitions` (e.g. 10, default: Spark's default parallelism): The partition number for the
+ *    generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
+ *    be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
+ */
+class RateStreamProvider extends DataSourceV2
+  with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister {
+  import RateStreamProvider._
+
+  override def createMicroBatchReader(
+      schema: Optional[StructType],
+      checkpointLocation: String,
+      options: DataSourceOptions): MicroBatchReader = {
+      if (options.get(ROWS_PER_SECOND).isPresent) {
+      val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong
+      if (rowsPerSecond <= 0) {
+        throw new IllegalArgumentException(
+          s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive")
+      }
+    }
+
+    if (options.get(RAMP_UP_TIME).isPresent) {
+      val rampUpTimeSeconds =
+        JavaUtils.timeStringAsSec(options.get(RAMP_UP_TIME).get())
+      if (rampUpTimeSeconds < 0) {
+        throw new IllegalArgumentException(
+          s"Invalid value '$rampUpTimeSeconds'. The option 'rampUpTime' must not be negative")
+      }
+    }
+
+    if (options.get(NUM_PARTITIONS).isPresent) {
+      val numPartitions = options.get(NUM_PARTITIONS).get().toInt
+      if (numPartitions <= 0) {
+        throw new IllegalArgumentException(
+          s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
+      }
+    }
+
+    if (schema.isPresent) {
+      throw new AnalysisException("The rate source does not support a user-specified schema.")
+    }
+
+    new RateStreamMicroBatchReader(options, checkpointLocation)
+  }
+
+  override def createContinuousReader(
+     schema: Optional[StructType],
+     checkpointLocation: String,
+     options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options)
+
+  override def shortName(): String = "rate"
+}
+
+object RateStreamProvider {
+  val SCHEMA =
+    StructType(StructField("timestamp", TimestampType) :: StructField("value", LongType) :: Nil)
+
+  val VERSION = 1
+
+  val NUM_PARTITIONS = "numPartitions"
+  val ROWS_PER_SECOND = "rowsPerSecond"
+  val RAMP_UP_TIME = "rampUpTime"
+
+  /** Calculate the end value we will emit at the time `seconds`. */
+  def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
+    // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
+    // Then speedDeltaPerSecond = 2
+    //
+    // seconds   = 0 1 2  3  4  5  6
+    // speed     = 0 2 4  6  8 10 10 (speedDeltaPerSecond * seconds)
+    // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
+    val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
+    if (seconds <= rampUpTimeSeconds) {
+      // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
+      // avoid overflow
+      if (seconds % 2 == 1) {
+        (seconds + 1) / 2 * speedDeltaPerSecond * seconds
+      } else {
+        seconds / 2 * speedDeltaPerSecond * (seconds + 1)
+      }
+    } else {
+      // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
+      val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
+      rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
deleted file mode 100644
index 4e2459b..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.sources
-
-import java.util.Optional
-
-import scala.collection.JavaConverters._
-import scala.collection.mutable
-
-import org.json4s.DefaultFormats
-import org.json4s.jackson.Serialization
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
-import org.apache.spark.sql.types.{LongType, StructField, StructType, TimestampType}
-import org.apache.spark.util.{ManualClock, SystemClock}
-
-/**
- * This is a temporary register as we build out v2 migration. Microbatch read support should
- * be implemented in the same register as v1.
- */
-class RateSourceProviderV2 extends DataSourceV2 with MicroBatchReadSupport with DataSourceRegister {
-  override def createMicroBatchReader(
-      schema: Optional[StructType],
-      checkpointLocation: String,
-      options: DataSourceOptions): MicroBatchReader = {
-    new RateStreamMicroBatchReader(options)
-  }
-
-  override def shortName(): String = "ratev2"
-}
-
-class RateStreamMicroBatchReader(options: DataSourceOptions)
-  extends MicroBatchReader {
-  implicit val defaultFormats: DefaultFormats = DefaultFormats
-
-  val clock = {
-    // The option to use a manual clock is provided only for unit testing purposes.
-    if (options.get("useManualClock").orElse("false").toBoolean) new ManualClock
-    else new SystemClock
-  }
-
-  private val numPartitions =
-    options.get(RateStreamSourceV2.NUM_PARTITIONS).orElse("5").toInt
-  private val rowsPerSecond =
-    options.get(RateStreamSourceV2.ROWS_PER_SECOND).orElse("6").toLong
-
-  // The interval (in milliseconds) between rows in each partition.
-  // e.g. if there are 4 global rows per second, and 2 partitions, each partition
-  // should output rows every (1000 * 2 / 4) = 500 ms.
-  private val msPerPartitionBetweenRows = (1000 * numPartitions) / rowsPerSecond
-
-  override def readSchema(): StructType = {
-    StructType(
-      StructField("timestamp", TimestampType, false) ::
-      StructField("value", LongType, false) :: Nil)
-  }
-
-  val creationTimeMs = clock.getTimeMillis()
-
-  private var start: RateStreamOffset = _
-  private var end: RateStreamOffset = _
-
-  override def setOffsetRange(
-      start: Optional[Offset],
-      end: Optional[Offset]): Unit = {
-    this.start = start.orElse(
-      RateStreamSourceV2.createInitialOffset(numPartitions, creationTimeMs))
-      .asInstanceOf[RateStreamOffset]
-
-    this.end = end.orElse {
-      val currentTime = clock.getTimeMillis()
-      RateStreamOffset(
-        this.start.partitionToValueAndRunTimeMs.map {
-          case startOffset @ (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
-            // Calculate the number of rows we should advance in this partition (based on the
-            // current time), and output a corresponding offset.
-            val readInterval = currentTime - currentReadTime
-            val numNewRows = readInterval / msPerPartitionBetweenRows
-            if (numNewRows <= 0) {
-              startOffset
-            } else {
-              (part, ValueRunTimeMsPair(
-                  currentVal + (numNewRows * numPartitions),
-                  currentReadTime + (numNewRows * msPerPartitionBetweenRows)))
-            }
-        }
-      )
-    }.asInstanceOf[RateStreamOffset]
-  }
-
-  override def getStartOffset(): Offset = {
-    if (start == null) throw new IllegalStateException("start offset not set")
-    start
-  }
-  override def getEndOffset(): Offset = {
-    if (end == null) throw new IllegalStateException("end offset not set")
-    end
-  }
-
-  override def deserializeOffset(json: String): Offset = {
-    RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
-  }
-
-  override def createDataReaderFactories(): java.util.List[DataReaderFactory[Row]] = {
-    val startMap = start.partitionToValueAndRunTimeMs
-    val endMap = end.partitionToValueAndRunTimeMs
-    endMap.keys.toSeq.map { part =>
-      val ValueRunTimeMsPair(endVal, _) = endMap(part)
-      val ValueRunTimeMsPair(startVal, startTimeMs) = startMap(part)
-
-      val packedRows = mutable.ListBuffer[(Long, Long)]()
-      var outVal = startVal + numPartitions
-      var outTimeMs = startTimeMs
-      while (outVal <= endVal) {
-        packedRows.append((outTimeMs, outVal))
-        outVal += numPartitions
-        outTimeMs += msPerPartitionBetweenRows
-      }
-
-      RateStreamBatchTask(packedRows).asInstanceOf[DataReaderFactory[Row]]
-    }.toList.asJava
-  }
-
-  override def commit(end: Offset): Unit = {}
-  override def stop(): Unit = {}
-}
-
-case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactory[Row] {
-  override def createDataReader(): DataReader[Row] = new RateStreamBatchReader(vals)
-}
-
-class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
-  private var currentIndex = -1
-
-  override def next(): Boolean = {
-    // Return true as long as the new index is in the seq.
-    currentIndex += 1
-    currentIndex < vals.size
-  }
-
-  override def get(): Row = {
-    Row(
-      DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromMillis(vals(currentIndex)._1)),
-      vals(currentIndex)._2)
-  }
-
-  override def close(): Unit = {}
-}
-
-object RateStreamSourceV2 {
-  val NUM_PARTITIONS = "numPartitions"
-  val ROWS_PER_SECOND = "rowsPerSecond"
-
-  private[sql] def createInitialOffset(numPartitions: Int, creationTimeMs: Long) = {
-    RateStreamOffset(
-      Range(0, numPartitions).map { i =>
-        // Note that the starting offset is exclusive, so we have to decrement the starting value
-        // by the increment that will later be applied. The first row output in each
-        // partition will have a value equal to the partition index.
-        (i,
-          ValueRunTimeMsPair(
-            (i - numPartitions).toLong,
-            creationTimeMs))
-      }.toMap)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
deleted file mode 100644
index 03d0f63..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceSuite.scala
+++ /dev/null
@@ -1,194 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.util.concurrent.TimeUnit
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
-import org.apache.spark.util.ManualClock
-
-class RateSourceSuite extends StreamTest {
-
-  import testImplicits._
-
-  case class AdvanceRateManualClock(seconds: Long) extends AddData {
-    override def addData(query: Option[StreamExecution]): (Source, Offset) = {
-      assert(query.nonEmpty)
-      val rateSource = query.get.logicalPlan.collect {
-        case StreamingExecutionRelation(source, _) if source.isInstanceOf[RateStreamSource] =>
-          source.asInstanceOf[RateStreamSource]
-      }.head
-      rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
-      (rateSource, rateSource.getOffset.get)
-    }
-  }
-
-  test("basic") {
-    val input = spark.readStream
-      .format("rate")
-      .option("rowsPerSecond", "10")
-      .option("useManualClock", "true")
-      .load()
-    testStream(input)(
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
-      StopStream,
-      StartStream(),
-      // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
-      AdvanceRateManualClock(seconds = 2),
-      CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
-    )
-  }
-
-  test("uniform distribution of event timestamps") {
-    val input = spark.readStream
-      .format("rate")
-      .option("rowsPerSecond", "1500")
-      .option("useManualClock", "true")
-      .load()
-      .as[(java.sql.Timestamp, Long)]
-      .map(v => (v._1.getTime, v._2))
-    val expectedAnswer = (0 until 1500).map { v =>
-      (math.round(v * (1000.0 / 1500)), v)
-    }
-    testStream(input)(
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch(expectedAnswer: _*)
-    )
-  }
-
-  test("valueAtSecond") {
-    import RateStreamSource._
-
-    assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
-    assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
-
-    assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0)
-    assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1)
-    assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3)
-    assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8)
-
-    assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0)
-    assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2)
-    assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6)
-    assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12)
-    assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20)
-    assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30)
-  }
-
-  test("rampUpTime") {
-    val input = spark.readStream
-      .format("rate")
-      .option("rowsPerSecond", "10")
-      .option("rampUpTime", "4s")
-      .option("useManualClock", "true")
-      .load()
-      .as[(java.sql.Timestamp, Long)]
-      .map(v => (v._1.getTime, v._2))
-    testStream(input)(
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch({
-        Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
-      }: _*), // speed = 6
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
-      AdvanceRateManualClock(seconds = 1),
-      // Now we should reach full speed
-      CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10
-    )
-  }
-
-  test("numPartitions") {
-    val input = spark.readStream
-      .format("rate")
-      .option("rowsPerSecond", "10")
-      .option("numPartitions", "6")
-      .option("useManualClock", "true")
-      .load()
-      .select(spark_partition_id())
-      .distinct()
-    testStream(input)(
-      AdvanceRateManualClock(1),
-      CheckLastBatch((0 until 6): _*)
-    )
-  }
-
-  testQuietly("overflow") {
-    val input = spark.readStream
-      .format("rate")
-      .option("rowsPerSecond", Long.MaxValue.toString)
-      .option("useManualClock", "true")
-      .load()
-      .select(spark_partition_id())
-      .distinct()
-    testStream(input)(
-      AdvanceRateManualClock(2),
-      ExpectFailure[ArithmeticException](t => {
-        Seq("overflow", "rowsPerSecond").foreach { msg =>
-          assert(t.getMessage.contains(msg))
-        }
-      })
-    )
-  }
-
-  testQuietly("illegal option values") {
-    def testIllegalOptionValue(
-        option: String,
-        value: String,
-        expectedMessages: Seq[String]): Unit = {
-      val e = intercept[StreamingQueryException] {
-        spark.readStream
-          .format("rate")
-          .option(option, value)
-          .load()
-          .writeStream
-          .format("console")
-          .start()
-          .awaitTermination()
-      }
-      assert(e.getCause.isInstanceOf[IllegalArgumentException])
-      for (msg <- expectedMessages) {
-        assert(e.getCause.getMessage.contains(msg))
-      }
-    }
-
-    testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive"))
-    testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive"))
-  }
-
-  test("user-specified schema given") {
-    val exception = intercept[AnalysisException] {
-      spark.readStream
-        .format("rate")
-        .schema(spark.range(1).schema)
-        .load()
-    }
-    assert(exception.getMessage.contains(
-      "rate source does not support a user-specified schema"))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
deleted file mode 100644
index 983ba16..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.util.Optional
-import java.util.concurrent.TimeUnit
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.streaming.continuous._
-import org.apache.spark.sql.execution.streaming.sources.{RateStreamBatchTask, RateStreamMicroBatchReader, RateStreamSourceV2}
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, MicroBatchReadSupport}
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.streaming.StreamTest
-import org.apache.spark.util.ManualClock
-
-class RateSourceV2Suite extends StreamTest {
-  import testImplicits._
-
-  case class AdvanceRateManualClock(seconds: Long) extends AddData {
-    override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
-      assert(query.nonEmpty)
-      val rateSource = query.get.logicalPlan.collect {
-        case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
-      }.head
-      rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
-      rateSource.setOffsetRange(Optional.empty(), Optional.empty())
-      (rateSource, rateSource.getEndOffset())
-    }
-  }
-
-  test("microbatch in registry") {
-    DataSource.lookupDataSource("ratev2", spark.sqlContext.conf).newInstance() match {
-      case ds: MicroBatchReadSupport =>
-        val reader = ds.createMicroBatchReader(Optional.empty(), "", DataSourceOptions.empty())
-        assert(reader.isInstanceOf[RateStreamMicroBatchReader])
-      case _ =>
-        throw new IllegalStateException("Could not find v2 read support for rate")
-    }
-  }
-
-  test("basic microbatch execution") {
-    val input = spark.readStream
-      .format("rateV2")
-      .option("numPartitions", "1")
-      .option("rowsPerSecond", "10")
-      .option("useManualClock", "true")
-      .load()
-    testStream(input, useV2Sink = true)(
-      AdvanceRateManualClock(seconds = 1),
-      CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
-      StopStream,
-      StartStream(),
-      // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
-      AdvanceRateManualClock(seconds = 2),
-      CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
-    )
-  }
-
-  test("microbatch - numPartitions propagated") {
-    val reader = new RateStreamMicroBatchReader(
-      new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
-    reader.setOffsetRange(Optional.empty(), Optional.empty())
-    val tasks = reader.createDataReaderFactories()
-    assert(tasks.size == 11)
-  }
-
-  test("microbatch - set offset") {
-    val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty())
-    val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
-    val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000))))
-    reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
-    assert(reader.getStartOffset() == startOffset)
-    assert(reader.getEndOffset() == endOffset)
-  }
-
-  test("microbatch - infer offsets") {
-    val reader = new RateStreamMicroBatchReader(
-      new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "100").asJava))
-    reader.clock.waitTillTime(reader.clock.getTimeMillis() + 100)
-    reader.setOffsetRange(Optional.empty(), Optional.empty())
-    reader.getStartOffset() match {
-      case r: RateStreamOffset =>
-        assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == reader.creationTimeMs)
-      case _ => throw new IllegalStateException("unexpected offset type")
-    }
-    reader.getEndOffset() match {
-      case r: RateStreamOffset =>
-        // End offset may be a bit beyond 100 ms/9 rows after creation if the wait lasted
-        // longer than 100ms. It should never be early.
-        assert(r.partitionToValueAndRunTimeMs(0).value >= 9)
-        assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= reader.creationTimeMs + 100)
-
-      case _ => throw new IllegalStateException("unexpected offset type")
-    }
-  }
-
-  test("microbatch - predetermined batch size") {
-    val reader = new RateStreamMicroBatchReader(
-      new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava))
-    val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
-    val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000))))
-    reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
-    val tasks = reader.createDataReaderFactories()
-    assert(tasks.size == 1)
-    assert(tasks.get(0).asInstanceOf[RateStreamBatchTask].vals.size == 20)
-  }
-
-  test("microbatch - data read") {
-    val reader = new RateStreamMicroBatchReader(
-      new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava))
-    val startOffset = RateStreamSourceV2.createInitialOffset(11, reader.creationTimeMs)
-    val endOffset = RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map {
-      case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
-        (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000))
-    }.toMap)
-
-    reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
-    val tasks = reader.createDataReaderFactories()
-    assert(tasks.size == 11)
-
-    val readData = tasks.asScala
-      .map(_.createDataReader())
-      .flatMap { reader =>
-        val buf = scala.collection.mutable.ListBuffer[Row]()
-        while (reader.next()) buf.append(reader.get())
-        buf
-      }
-
-    assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
-  }
-
-  test("continuous in registry") {
-    DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
-      case ds: ContinuousReadSupport =>
-        val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
-        assert(reader.isInstanceOf[RateStreamContinuousReader])
-      case _ =>
-        throw new IllegalStateException("Could not find v2 read support for rate")
-    }
-  }
-
-  test("continuous data") {
-    val reader = new RateStreamContinuousReader(
-      new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
-    reader.setStartOffset(Optional.empty())
-    val tasks = reader.createDataReaderFactories()
-    assert(tasks.size == 2)
-
-    val data = scala.collection.mutable.ListBuffer[Row]()
-    tasks.asScala.foreach {
-      case t: RateStreamContinuousDataReaderFactory =>
-        val startTimeMs = reader.getStartOffset()
-          .asInstanceOf[RateStreamOffset]
-          .partitionToValueAndRunTimeMs(t.partitionIndex)
-          .runTimeMs
-        val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
-        for (rowIndex <- 0 to 9) {
-          r.next()
-          data.append(r.get())
-          assert(r.getOffset() ==
-            RateStreamPartitionOffset(
-              t.partitionIndex,
-              t.partitionIndex + rowIndex * 2,
-              startTimeMs + (rowIndex + 1) * 100))
-        }
-        assert(System.currentTimeMillis() >= startTimeMs + 1000)
-
-      case _ => throw new IllegalStateException("Unexpected task type")
-    }
-
-    assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/5b5a36ed/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
new file mode 100644
index 0000000..ff14ec3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -0,0 +1,334 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.nio.file.Files
+import java.util.Optional
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.Offset
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.util.ManualClock
+
+class RateSourceSuite extends StreamTest {
+
+  import testImplicits._
+
+  case class AdvanceRateManualClock(seconds: Long) extends AddData {
+    override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
+      assert(query.nonEmpty)
+      val rateSource = query.get.logicalPlan.collect {
+        case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source
+      }.head
+
+      rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds))
+      val offset = LongOffset(TimeUnit.MILLISECONDS.toSeconds(
+        rateSource.clock.getTimeMillis() - rateSource.creationTimeMs))
+      (rateSource, offset)
+    }
+  }
+
+  test("microbatch in registry") {
+    DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
+      case ds: MicroBatchReadSupport =>
+        val reader = ds.createMicroBatchReader(Optional.empty(), "dummy", DataSourceOptions.empty())
+        assert(reader.isInstanceOf[RateStreamMicroBatchReader])
+      case _ =>
+        throw new IllegalStateException("Could not find read support for rate")
+    }
+  }
+
+  test("compatible with old path in registry") {
+    DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider",
+      spark.sqlContext.conf).newInstance() match {
+      case ds: MicroBatchReadSupport =>
+        assert(ds.isInstanceOf[RateStreamProvider])
+      case _ =>
+        throw new IllegalStateException("Could not find read support for rate")
+    }
+  }
+
+  test("microbatch - basic") {
+    val input = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", "10")
+      .option("useManualClock", "true")
+      .load()
+    testStream(input)(
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch((0 until 10).map(v => new java.sql.Timestamp(v * 100L) -> v): _*),
+      StopStream,
+      StartStream(),
+      // Advance 2 seconds because creating a new RateSource will also create a new ManualClock
+      AdvanceRateManualClock(seconds = 2),
+      CheckLastBatch((10 until 20).map(v => new java.sql.Timestamp(v * 100L) -> v): _*)
+    )
+  }
+
+  test("microbatch - uniform distribution of event timestamps") {
+    val input = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", "1500")
+      .option("useManualClock", "true")
+      .load()
+      .as[(java.sql.Timestamp, Long)]
+      .map(v => (v._1.getTime, v._2))
+    val expectedAnswer = (0 until 1500).map { v =>
+      (math.round(v * (1000.0 / 1500)), v)
+    }
+    testStream(input)(
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch(expectedAnswer: _*)
+    )
+  }
+
+  test("microbatch - set offset") {
+    val temp = Files.createTempDirectory("dummy").toString
+    val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp)
+    val startOffset = LongOffset(0L)
+    val endOffset = LongOffset(1L)
+    reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+    assert(reader.getStartOffset() == startOffset)
+    assert(reader.getEndOffset() == endOffset)
+  }
+
+  test("microbatch - infer offsets") {
+    val tempFolder = Files.createTempDirectory("dummy").toString
+    val reader = new RateStreamMicroBatchReader(
+      new DataSourceOptions(
+        Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava),
+      tempFolder)
+    reader.clock.asInstanceOf[ManualClock].advance(100000)
+    reader.setOffsetRange(Optional.empty(), Optional.empty())
+    reader.getStartOffset() match {
+      case r: LongOffset => assert(r.offset === 0L)
+      case _ => throw new IllegalStateException("unexpected offset type")
+    }
+    reader.getEndOffset() match {
+      case r: LongOffset => assert(r.offset >= 100)
+      case _ => throw new IllegalStateException("unexpected offset type")
+    }
+  }
+
+  test("microbatch - predetermined batch size") {
+    val temp = Files.createTempDirectory("dummy").toString
+    val reader = new RateStreamMicroBatchReader(
+      new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp)
+    val startOffset = LongOffset(0L)
+    val endOffset = LongOffset(1L)
+    reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+    val tasks = reader.createDataReaderFactories()
+    assert(tasks.size == 1)
+    val dataReader = tasks.get(0).createDataReader()
+    val data = ArrayBuffer[Row]()
+    while (dataReader.next()) {
+      data.append(dataReader.get())
+    }
+    assert(data.size === 20)
+  }
+
+  test("microbatch - data read") {
+    val temp = Files.createTempDirectory("dummy").toString
+    val reader = new RateStreamMicroBatchReader(
+      new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp)
+    val startOffset = LongOffset(0L)
+    val endOffset = LongOffset(1L)
+    reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
+    val tasks = reader.createDataReaderFactories()
+    assert(tasks.size == 11)
+
+    val readData = tasks.asScala
+      .map(_.createDataReader())
+      .flatMap { reader =>
+        val buf = scala.collection.mutable.ListBuffer[Row]()
+        while (reader.next()) buf.append(reader.get())
+        buf
+      }
+
+    assert(readData.map(_.getLong(1)).sorted == Range(0, 33))
+  }
+
+  test("valueAtSecond") {
+    import RateStreamProvider._
+
+    assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 0)
+    assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 0) === 5)
+
+    assert(valueAtSecond(seconds = 0, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 0)
+    assert(valueAtSecond(seconds = 1, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 1)
+    assert(valueAtSecond(seconds = 2, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 3)
+    assert(valueAtSecond(seconds = 3, rowsPerSecond = 5, rampUpTimeSeconds = 2) === 8)
+
+    assert(valueAtSecond(seconds = 0, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 0)
+    assert(valueAtSecond(seconds = 1, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 2)
+    assert(valueAtSecond(seconds = 2, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 6)
+    assert(valueAtSecond(seconds = 3, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 12)
+    assert(valueAtSecond(seconds = 4, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 20)
+    assert(valueAtSecond(seconds = 5, rowsPerSecond = 10, rampUpTimeSeconds = 4) === 30)
+  }
+
+  test("rampUpTime") {
+    val input = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", "10")
+      .option("rampUpTime", "4s")
+      .option("useManualClock", "true")
+      .load()
+      .as[(java.sql.Timestamp, Long)]
+      .map(v => (v._1.getTime, v._2))
+    testStream(input)(
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch((0 until 2).map(v => v * 500 -> v): _*), // speed = 2
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch((2 until 6).map(v => 1000 + (v - 2) * 250 -> v): _*), // speed = 4
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch({
+        Seq(2000 -> 6, 2167 -> 7, 2333 -> 8, 2500 -> 9, 2667 -> 10, 2833 -> 11)
+      }: _*), // speed = 6
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch((12 until 20).map(v => 3000 + (v - 12) * 125 -> v): _*), // speed = 8
+      AdvanceRateManualClock(seconds = 1),
+      // Now we should reach full speed
+      CheckLastBatch((20 until 30).map(v => 4000 + (v - 20) * 100 -> v): _*), // speed = 10
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch((30 until 40).map(v => 5000 + (v - 30) * 100 -> v): _*), // speed = 10
+      AdvanceRateManualClock(seconds = 1),
+      CheckLastBatch((40 until 50).map(v => 6000 + (v - 40) * 100 -> v): _*) // speed = 10
+    )
+  }
+
+  test("numPartitions") {
+    val input = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", "10")
+      .option("numPartitions", "6")
+      .option("useManualClock", "true")
+      .load()
+      .select(spark_partition_id())
+      .distinct()
+    testStream(input)(
+      AdvanceRateManualClock(1),
+      CheckLastBatch((0 until 6): _*)
+    )
+  }
+
+  testQuietly("overflow") {
+    val input = spark.readStream
+      .format("rate")
+      .option("rowsPerSecond", Long.MaxValue.toString)
+      .option("useManualClock", "true")
+      .load()
+      .select(spark_partition_id())
+      .distinct()
+    testStream(input)(
+      AdvanceRateManualClock(2),
+      ExpectFailure[ArithmeticException](t => {
+        Seq("overflow", "rowsPerSecond").foreach { msg =>
+          assert(t.getMessage.contains(msg))
+        }
+      })
+    )
+  }
+
+  testQuietly("illegal option values") {
+    def testIllegalOptionValue(
+        option: String,
+        value: String,
+        expectedMessages: Seq[String]): Unit = {
+      val e = intercept[IllegalArgumentException] {
+        spark.readStream
+          .format("rate")
+          .option(option, value)
+          .load()
+          .writeStream
+          .format("console")
+          .start()
+          .awaitTermination()
+      }
+      for (msg <- expectedMessages) {
+        assert(e.getMessage.contains(msg))
+      }
+    }
+
+    testIllegalOptionValue("rowsPerSecond", "-1", Seq("-1", "rowsPerSecond", "positive"))
+    testIllegalOptionValue("numPartitions", "-1", Seq("-1", "numPartitions", "positive"))
+  }
+
+  test("user-specified schema given") {
+    val exception = intercept[AnalysisException] {
+      spark.readStream
+        .format("rate")
+        .schema(spark.range(1).schema)
+        .load()
+    }
+    assert(exception.getMessage.contains(
+      "rate source does not support a user-specified schema"))
+  }
+
+  test("continuous in registry") {
+    DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match {
+      case ds: ContinuousReadSupport =>
+        val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty())
+        assert(reader.isInstanceOf[RateStreamContinuousReader])
+      case _ =>
+        throw new IllegalStateException("Could not find read support for continuous rate")
+    }
+  }
+
+  test("continuous data") {
+    val reader = new RateStreamContinuousReader(
+      new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava))
+    reader.setStartOffset(Optional.empty())
+    val tasks = reader.createDataReaderFactories()
+    assert(tasks.size == 2)
+
+    val data = scala.collection.mutable.ListBuffer[Row]()
+    tasks.asScala.foreach {
+      case t: RateStreamContinuousDataReaderFactory =>
+        val startTimeMs = reader.getStartOffset()
+          .asInstanceOf[RateStreamOffset]
+          .partitionToValueAndRunTimeMs(t.partitionIndex)
+          .runTimeMs
+        val r = t.createDataReader().asInstanceOf[RateStreamContinuousDataReader]
+        for (rowIndex <- 0 to 9) {
+          r.next()
+          data.append(r.get())
+          assert(r.getOffset() ==
+            RateStreamPartitionOffset(
+              t.partitionIndex,
+              t.partitionIndex + rowIndex * 2,
+              startTimeMs + (rowIndex + 1) * 100))
+        }
+        assert(System.currentTimeMillis() >= startTimeMs + 1000)
+
+      case _ => throw new IllegalStateException("Unexpected task type")
+    }
+
+    assert(data.map(_.getLong(1)).toSeq.sorted == Range(0, 20))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org