You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/04/23 20:19:22 UTC

spark git commit: [SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task

Repository: spark
Updated Branches:
  refs/heads/master 448d248f8 -> 770add81c


[SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task

## What changes were proposed in this pull request?

A structured streaming query with a streaming aggregation can throw the following error in rare cases. 

```
java.lang.IllegalStateException: Cannot commit after already committed or aborted
	at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643)
	at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135)
	at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359)
	at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102)
	at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251)
	at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114)
	at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336)
```

This can happen when the following conditions are accidentally hit. 
 - Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.). 
 - Query running in `update}` mode
 - After the shuffle, a partition has exactly 128 records. 

This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`.

## How was this patch tested?

Added unit test that I have confirm will fail without the fix.

Author: Tathagata Das <ta...@gmail.com>

Closes #21124 from tdas/SPARK-23004.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/770add81
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/770add81
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/770add81

Branch: refs/heads/master
Commit: 770add81c3474e754867d7105031a5eaf27159bd
Parents: 448d248
Author: Tathagata Das <ta...@gmail.com>
Authored: Mon Apr 23 13:20:32 2018 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Mon Apr 23 13:20:32 2018 -0700

----------------------------------------------------------------------
 .../execution/streaming/statefulOperators.scala | 40 ++++++++++----------
 .../streaming/StreamingAggregationSuite.scala   | 25 ++++++++++++
 2 files changed, 44 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/770add81/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index b9b07a2..c9354ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -340,37 +340,35 @@ case class StateStoreSaveExec(
           // Update and output modified rows from the StateStore.
           case Some(Update) =>
 
-            val updatesStartTimeNs = System.nanoTime
-
-            new Iterator[InternalRow] {
-
+            new NextIterator[InternalRow] {
               // Filter late date using watermark if specified
               private[this] val baseIterator = watermarkPredicateForData match {
                 case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
                 case None => iter
               }
+              private val updatesStartTimeNs = System.nanoTime
 
-              override def hasNext: Boolean = {
-                if (!baseIterator.hasNext) {
-                  allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
-
-                  // Remove old aggregates if watermark specified
-                  allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
-                  commitTimeMs += timeTakenMs { store.commit() }
-                  setStoreMetrics(store)
-                  false
+              override protected def getNext(): InternalRow = {
+                if (baseIterator.hasNext) {
+                  val row = baseIterator.next().asInstanceOf[UnsafeRow]
+                  val key = getKey(row)
+                  store.put(key, row)
+                  numOutputRows += 1
+                  numUpdatedStateRows += 1
+                  row
                 } else {
-                  true
+                  finished = true
+                  null
                 }
               }
 
-              override def next(): InternalRow = {
-                val row = baseIterator.next().asInstanceOf[UnsafeRow]
-                val key = getKey(row)
-                store.put(key, row)
-                numOutputRows += 1
-                numUpdatedStateRows += 1
-                row
+              override protected def close(): Unit = {
+                allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+
+                // Remove old aggregates if watermark specified
+                allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+                commitTimeMs += timeTakenMs { store.commit() }
+                setStoreMetrics(store)
               }
             }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/770add81/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 1cae8cb..382da13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
     )
   }
 
+  test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") {
+    // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
+    // by ensuring the following.
+    // - A streaming query with a streaming aggregation.
+    // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate.
+    // - Post shuffle partition has exactly 128 records (i.e. the threshold at which
+    //   ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a
+    //   micro-batch with 128 records that shuffle to a single partition.
+    // This test throws the exact error reported in SPARK-23004 without the corresponding fix.
+    withSQLConf("spark.sql.shuffle.partitions" -> "1") {
+      val input = MemoryStream[Int]
+      val df = input.toDF().toDF("value")
+        .selectExpr("value as group", "value")
+        .groupBy("group")
+        .agg(collect_list("value"))
+      testStream(df, outputMode = OutputMode.Update)(
+        AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
+        AssertOnQuery { q =>
+          q.processAllAvailable()
+          true
+        }
+      )
+    }
+  }
+
   /** Add blocks of data to the `BlockRDDBackedSource`. */
   case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
     override def addData(query: Option[StreamExecution]): (Source, Offset) = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org