You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2016/11/29 07:07:20 UTC

spark git commit: [SPARK-18339][SPARK-18513][SQL] Don't push down current_timestamp for filters in StructuredStreaming and persist batch and watermark timestamps to offset log.

Repository: spark
Updated Branches:
  refs/heads/master e2318ede0 -> 3c0beea47


[SPARK-18339][SPARK-18513][SQL] Don't push down current_timestamp for filters in StructuredStreaming and persist batch and watermark timestamps to offset log.

## What changes were proposed in this pull request?

For the following workflow:
1. I have a column called time which is at minute level precision in a Streaming DataFrame
2. I want to perform groupBy time, count
3. Then I want my MemorySink to only have the last 30 minutes of counts and I perform this by
.where('time >= current_timestamp().cast("long") - 30 * 60)
what happens is that the `filter` gets pushed down before the aggregation, and the filter happens on the source data for the aggregation instead of the result of the aggregation (where I actually want to filter).
I guess the main issue here is that `current_timestamp` is non-deterministic in the streaming context and shouldn't be pushed down the filter.
Does this require us to store the `current_timestamp` for each trigger of the streaming job, that is something to discuss.

Furthermore, we want to persist current batch timestamp and watermark timestamp to the offset log so that these values are consistent across multiple executions of the same batch.

brkyvz zsxwing tdas

## How was this patch tested?

A test was added to StreamingAggregationSuite ensuring the above use case is handled. The test injects a stream of time values (in seconds) to a query that runs in complete mode and only outputs the (count) aggregation results for the past 10 seconds.

Author: Tyson Condie <tc...@gmail.com>

Closes #15949 from tcondie/SPARK-18339.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3c0beea4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3c0beea4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3c0beea4

Branch: refs/heads/master
Commit: 3c0beea4752d39ee630a107316f40aff4a1b4ae7
Parents: e2318ed
Author: Tyson Condie <tc...@gmail.com>
Authored: Mon Nov 28 23:07:17 2016 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Mon Nov 28 23:07:17 2016 -0800

----------------------------------------------------------------------
 .../expressions/datetimeExpressions.scala       |  33 +++++-
 .../streaming/IncrementalExecution.scala        |  19 +++-
 .../execution/streaming/StreamExecution.scala   |  67 ++++++++++---
 .../execution/streaming/StreamProgress.scala    |   4 +-
 .../spark/sql/execution/streaming/memory.scala  |   4 +
 .../StreamExecutionMetadataSuite.scala          |  35 +++++++
 .../streaming/StreamingAggregationSuite.scala   | 100 +++++++++++++++++++
 .../sql/streaming/StreamingQuerySuite.scala     |   4 +-
 .../spark/sql/streaming/WatermarkSuite.scala    |  40 +++++---
 9 files changed, 273 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 1db1d19..ef1ac36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -17,14 +17,14 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.sql.Timestamp
 import java.text.SimpleDateFormat
 import java.util.{Calendar, Locale, TimeZone}
 
 import scala.util.Try
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback,
-  ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -72,6 +72,35 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
 }
 
 /**
+ * Expression representing the current batch time, which is used by StreamExecution to
+ * 1. prevent optimizer from pushing this expression below a stateful operator
+ * 2. allow IncrementalExecution to substitute this expression with a Literal(timestamp)
+ *
+ * There is no code generation since this expression should be replaced with a literal.
+ */
+case class CurrentBatchTimestamp(timestampMs: Long, dataType: DataType)
+  extends LeafExpression with Nondeterministic with CodegenFallback {
+
+  override def nullable: Boolean = false
+
+  override def prettyName: String = "current_batch_timestamp"
+
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
+
+  /**
+   * Need to return literal value in order to support compile time expression evaluation
+   * e.g., select(current_date())
+   */
+  override protected def evalInternal(input: InternalRow): Any = toLiteral.value
+
+  def toLiteral: Literal = dataType match {
+    case _: TimestampType =>
+      Literal(DateTimeUtils.fromJavaTimestamp(new Timestamp(timestampMs)), TimestampType)
+    case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs), DateType)
+  }
+}
+
+/**
  * Adds a number of days to startdate.
  */
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index e9d072f..6ab6fa6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import org.apache.spark.sql.{InternalOutputModes, SparkSession}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode}
@@ -34,7 +36,7 @@ class IncrementalExecution(
     val checkpointLocation: String,
     val currentBatchId: Long,
     val currentEventTimeWatermark: Long)
-  extends QueryExecution(sparkSession, logicalPlan) {
+  extends QueryExecution(sparkSession, logicalPlan) with Logging {
 
   // TODO: make this always part of planning.
   val stateStrategy =
@@ -50,6 +52,19 @@ class IncrementalExecution(
       stateStrategy)
 
   /**
+   * See [SPARK-18339]
+   * Walk the optimized logical plan and replace CurrentBatchTimestamp
+   * with the desired literal
+   */
+  override lazy val optimizedPlan: LogicalPlan = {
+    sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions {
+      case ts @ CurrentBatchTimestamp(timestamp, _) =>
+        logInfo(s"Current batch timestamp = $timestamp")
+        ts.toLiteral
+    }
+  }
+
+  /**
    * Records the current id for a given stateful operator in the query plan as the `state`
    * preparation walks the query plan.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 3ca6fea..21664d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -25,11 +25,13 @@ import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.Path
+import org.json4s.NoTypeHints
+import org.json4s.jackson.Serialization
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
@@ -92,8 +94,8 @@ class StreamExecution(
   /** The current batchId or -1 if execution has not yet been initialized. */
   private var currentBatchId: Long = -1
 
-  /** The current eventTime watermark, used to bound the lateness of data that will processed. */
-  private var currentEventTimeWatermark: Long = 0
+  /** Stream execution metadata */
+  private var streamExecutionMetadata = StreamExecutionMetadata()
 
   /** All stream sources present in the query plan. */
   private val sources =
@@ -251,7 +253,7 @@ class StreamExecution(
           this,
           s"Query $name terminated with exception: ${e.getMessage}",
           e,
-          Some(committedOffsets.toOffsetSeq(sources)))
+          Some(committedOffsets.toOffsetSeq(sources, streamExecutionMetadata.json)))
         logError(s"Query $name terminated with error", e)
         // Rethrow the fatal errors to allow the user using `Thread.UncaughtExceptionHandler` to
         // handle them
@@ -288,7 +290,9 @@ class StreamExecution(
         logInfo(s"Resuming streaming query, starting with batch $batchId")
         currentBatchId = batchId
         availableOffsets = nextOffsets.toStreamProgress(sources)
-        logDebug(s"Found possibly uncommitted offsets $availableOffsets")
+        streamExecutionMetadata = StreamExecutionMetadata(nextOffsets.metadata.getOrElse("{}"))
+        logDebug(s"Found possibly unprocessed offsets $availableOffsets " +
+          s"at batch timestamp ${streamExecutionMetadata.batchTimestampMs}")
 
         offsetLog.get(batchId - 1).foreach {
           case lastOffsets =>
@@ -344,10 +348,14 @@ class StreamExecution(
       }
     }
     if (hasNewData) {
+      // Current batch timestamp in milliseconds
+      streamExecutionMetadata.batchTimestampMs = triggerClock.getTimeMillis()
       reportTimeTaken(OFFSET_WAL_WRITE_LATENCY) {
-        assert(offsetLog.add(currentBatchId, availableOffsets.toOffsetSeq(sources)),
+        assert(offsetLog.add(currentBatchId,
+          availableOffsets.toOffsetSeq(sources, streamExecutionMetadata.json)),
           s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
-        logInfo(s"Committed offsets for batch $currentBatchId.")
+        logInfo(s"Committed offsets for batch $currentBatchId. " +
+          s"Metadata ${streamExecutionMetadata.toString}")
 
         // NOTE: The following code is correct because runBatches() processes exactly one
         // batch at a time. If we add pipeline parallelism (multiple batches in flight at
@@ -422,6 +430,12 @@ class StreamExecution(
     val replacementMap = AttributeMap(replacements)
     val triggerLogicalPlan = withNewSources transformAllExpressions {
       case a: Attribute if replacementMap.contains(a) => replacementMap(a)
+      case ct: CurrentTimestamp =>
+        CurrentBatchTimestamp(streamExecutionMetadata.batchTimestampMs,
+          ct.dataType)
+      case cd: CurrentDate =>
+        CurrentBatchTimestamp(streamExecutionMetadata.batchTimestampMs,
+          cd.dataType)
     }
 
     val executedPlan = reportTimeTaken(OPTIMIZER_LATENCY) {
@@ -431,7 +445,7 @@ class StreamExecution(
         outputMode,
         checkpointFile("state"),
         currentBatchId,
-        currentEventTimeWatermark)
+        streamExecutionMetadata.batchWatermarkMs)
       lastExecution.executedPlan // Force the lazy generation of execution plan
     }
 
@@ -447,11 +461,12 @@ class StreamExecution(
         logTrace(s"Maximum observed eventTime: ${e.maxEventTime.value}")
         (e.maxEventTime.value / 1000) - e.delay.milliseconds()
     }.headOption.foreach { newWatermark =>
-      if (newWatermark > currentEventTimeWatermark) {
+      if (newWatermark > streamExecutionMetadata.batchWatermarkMs) {
         logInfo(s"Updating eventTime watermark to: $newWatermark ms")
-        currentEventTimeWatermark = newWatermark
+        streamExecutionMetadata.batchWatermarkMs = newWatermark
       } else {
-        logTrace(s"Event time didn't move: $newWatermark < $currentEventTimeWatermark")
+        logTrace(s"Event time didn't move: $newWatermark < " +
+          s"$streamExecutionMetadata.currentEventTimeWatermark")
       }
 
       if (newWatermark != 0) {
@@ -713,7 +728,7 @@ class StreamExecution(
     }.toArray
     val sinkStatus = SinkStatus(
       sink.toString,
-      committedOffsets.toOffsetSeq(sources).toString)
+      committedOffsets.toOffsetSeq(sources, streamExecutionMetadata.json).toString)
 
     currentStatus =
       StreamingQueryStatus(
@@ -741,6 +756,34 @@ object StreamExecution {
 }
 
 /**
+ * Contains metadata associated with a stream execution. This information is
+ * persisted to the offset log via the OffsetSeq metadata field. Current
+ * information contained in this object includes:
+ *
+ * @param batchWatermarkMs: The current eventTime watermark, used to
+ * bound the lateness of data that will processed. Time unit: milliseconds
+ * @param batchTimestampMs: The current batch processing timestamp.
+ * Time unit: milliseconds
+ */
+case class StreamExecutionMetadata(
+    var batchWatermarkMs: Long = 0,
+    var batchTimestampMs: Long = 0) {
+  private implicit val formats = StreamExecutionMetadata.formats
+
+  /**
+   * JSON string representation of this object.
+   */
+  def json: String = Serialization.write(this)
+}
+
+object StreamExecutionMetadata {
+  private implicit val formats = Serialization.formats(NoTypeHints)
+
+  def apply(json: String): StreamExecutionMetadata =
+    Serialization.read[StreamExecutionMetadata](json)
+}
+
+/**
  * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread
  * and will use `classOf[StreamExecutionThread]` to check.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
index 05a6547..21b8750 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
@@ -26,8 +26,8 @@ class StreamProgress(
     val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset])
   extends scala.collection.immutable.Map[Source, Offset] {
 
-  def toOffsetSeq(source: Seq[Source]): OffsetSeq = {
-    OffsetSeq(source.map(get))
+  def toOffsetSeq(source: Seq[Source], metadata: String): OffsetSeq = {
+    OffsetSeq(source.map(get), Some(metadata))
   }
 
   override def toString: String =

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 582b548..adf6963 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -206,6 +206,10 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
     }
   }
 
+  def clear(): Unit = {
+    batches.clear()
+  }
+
   override def toString(): String = "MemorySink"
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala
new file mode 100644
index 0000000..c7139c5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamExecutionMetadataSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.execution.streaming.StreamExecutionMetadata
+
+class StreamExecutionMetadataSuite extends StreamTest {
+
+  test("stream execution metadata") {
+    assert(StreamExecutionMetadata(0, 0) ===
+      StreamExecutionMetadata("""{}"""))
+    assert(StreamExecutionMetadata(1, 0) ===
+      StreamExecutionMetadata("""{"batchWatermarkMs":1}"""))
+    assert(StreamExecutionMetadata(0, 2) ===
+      StreamExecutionMetadata("""{"batchTimestampMs":2}"""))
+    assert(StreamExecutionMetadata(1, 2) ===
+      StreamExecutionMetadata(
+        """{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index e59b549..fbe560e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.sql.streaming
 
+import java.util.TimeZone
+
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.InternalOutputModes._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.state.StateStore
@@ -235,4 +238,101 @@ class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll {
       CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
     )
   }
+
+  test("prune results by current_time, complete mode") {
+    import testImplicits._
+    val clock = new StreamManualClock
+    val inputData = MemoryStream[Long]
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*"))
+        .where('value >= current_timestamp().cast("long") - 10L)
+
+    testStream(aggregated, Complete)(
+      StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
+
+      // advance clock to 10 seconds, all keys retained
+      AddData(inputData, 0L, 5L, 5L, 10L),
+      AdvanceManualClock(10 * 1000),
+      CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
+
+      // advance clock to 20 seconds, should retain keys >= 10
+      AddData(inputData, 15L, 15L, 20L),
+      AdvanceManualClock(10 * 1000),
+      CheckLastBatch((10L, 1), (15L, 2), (20L, 1)),
+
+      // advance clock to 30 seconds, should retain keys >= 20
+      AddData(inputData, 0L, 85L),
+      AdvanceManualClock(10 * 1000),
+      CheckLastBatch((20L, 1), (85L, 1)),
+
+      // bounce stream and ensure correct batch timestamp is used
+      // i.e., we don't take it from the clock, which is at 90 seconds.
+      StopStream,
+      AssertOnQuery { q => // clear the sink
+        q.sink.asInstanceOf[MemorySink].clear()
+        // advance by a minute i.e., 90 seconds total
+        clock.advance(60 * 1000L)
+        true
+      },
+      StartStream(ProcessingTime("10 seconds"), triggerClock = clock),
+      CheckLastBatch((20L, 1), (85L, 1)),
+      AssertOnQuery { q =>
+        clock.getTimeMillis() == 90000L
+      },
+
+      // advance clock to 100 seconds, should retain keys >= 90
+      AddData(inputData, 85L, 90L, 100L, 105L),
+      AdvanceManualClock(10 * 1000),
+      CheckLastBatch((90L, 1), (100L, 1), (105L, 1))
+    )
+  }
+
+  test("prune results by current_date, complete mode") {
+    import testImplicits._
+    val clock = new StreamManualClock
+    val tz = TimeZone.getDefault.getID
+    val inputData = MemoryStream[Long]
+    val aggregated =
+      inputData.toDF()
+        .select(to_utc_timestamp(from_unixtime('value * DateTimeUtils.SECONDS_PER_DAY), tz))
+        .toDF("value")
+        .groupBy($"value")
+        .agg(count("*"))
+        .where($"value".cast("date") >= date_sub(current_date(), 10))
+        .select(($"value".cast("long") / DateTimeUtils.SECONDS_PER_DAY).cast("long"), $"count(1)")
+    testStream(aggregated, Complete)(
+      StartStream(ProcessingTime("10 day"), triggerClock = clock),
+      // advance clock to 10 days, should retain all keys
+      AddData(inputData, 0L, 5L, 5L, 10L),
+      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
+      CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
+      // advance clock to 20 days, should retain keys >= 10
+      AddData(inputData, 15L, 15L, 20L),
+      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
+      CheckLastBatch((10L, 1), (15L, 2), (20L, 1)),
+      // advance clock to 30 days, should retain keys >= 20
+      AddData(inputData, 85L),
+      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
+      CheckLastBatch((20L, 1), (85L, 1)),
+
+      // bounce stream and ensure correct batch timestamp is used
+      // i.e., we don't take it from the clock, which is at 90 days.
+      StopStream,
+      AssertOnQuery { q => // clear the sink
+        q.sink.asInstanceOf[MemorySink].clear()
+        // advance by 60 days i.e., 90 days total
+        clock.advance(DateTimeUtils.MILLIS_PER_DAY * 60)
+        true
+      },
+      StartStream(ProcessingTime("10 day"), triggerClock = clock),
+      CheckLastBatch((20L, 1), (85L, 1)),
+
+      // advance clock to 100 days, should retain keys >= 90
+      AddData(inputData, 85L, 90L, 100L, 105L),
+      AdvanceManualClock(DateTimeUtils.MILLIS_PER_DAY * 10),
+      CheckLastBatch((90L, 1), (100L, 1), (105L, 1))
+    )
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index e2e66d6..8ecb33c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -103,8 +103,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
       TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000),
       TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10),
       AssertOnQuery(
-        q =>
-          q.exception.get.startOffset.get === q.committedOffsets.toOffsetSeq(Seq(inputData)),
+        q => q.exception.get.startOffset.get.offsets ===
+          q.committedOffsets.toOffsetSeq(Seq(inputData), "{}").offsets,
         "incorrect start offset on exception")
     )
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3c0beea4/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala
index 3617ec0..3e9488c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/WatermarkSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.functions.{count, window}
 
@@ -96,27 +96,41 @@ class WatermarkSuite extends StreamTest with BeforeAndAfter with Logging {
     )
   }
 
-  ignore("recovery") {
+  test("recovery") {
     val inputData = MemoryStream[Int]
+    val df = inputData.toDF()
+      .withColumn("eventTime", $"value".cast("timestamp"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
 
-    val windowedAggregation = inputData.toDF()
-        .withColumn("eventTime", $"value".cast("timestamp"))
-        .withWatermark("eventTime", "10 seconds")
-        .groupBy(window($"eventTime", "5 seconds") as 'window)
-        .agg(count("*") as 'count)
-        .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
-
-    testStream(windowedAggregation)(
+    testStream(df)(
       AddData(inputData, 10, 11, 12, 13, 14, 15),
-      CheckAnswer(),
+      CheckLastBatch(),
       AddData(inputData, 25), // Advance watermark to 15 seconds
       StopStream,
       StartStream(),
-      CheckAnswer(),
+      CheckLastBatch(),
       AddData(inputData, 25), // Evict items less than previous watermark.
+      CheckLastBatch((10, 5)),
       StopStream,
+      AssertOnQuery { q => // clear the sink
+        q.sink.asInstanceOf[MemorySink].clear()
+        true
+      },
       StartStream(),
-      CheckAnswer((10, 5))
+      CheckLastBatch((10, 5)), // Recompute last batch and re-evict timestamp 10
+      AddData(inputData, 30), // Advance watermark to 20 seconds
+      CheckLastBatch(),
+      StopStream,
+      StartStream(), // Watermark should still be 15 seconds
+      AddData(inputData, 17),
+      CheckLastBatch(), // We still do not see next batch
+      AddData(inputData, 30), // Advance watermark to 20 seconds
+      CheckLastBatch(),
+      AddData(inputData, 30), // Evict items less than previous watermark.
+      CheckLastBatch((15, 2)) // Ensure we see next window
     )
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org