You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/04/23 20:19:22 UTC
spark git commit: [SPARK-23004][SS] Ensure StateStore.commit is
called only once in a streaming aggregation task
Repository: spark
Updated Branches:
refs/heads/master 448d248f8 -> 770add81c
[SPARK-23004][SS] Ensure StateStore.commit is called only once in a streaming aggregation task
## What changes were proposed in this pull request?
A structured streaming query with a streaming aggregation can throw the following error in rare cases.
```
java.lang.IllegalStateException: Cannot commit after already committed or aborted
at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider.org$apache$spark$sql$execution$streaming$state$HDFSBackedStateStoreProvider$$verify(HDFSBackedStateStoreProvider.scala:643)
at org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider$HDFSBackedStateStore.commit(HDFSBackedStateStoreProvider.scala:135)
at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2$$anonfun$hasNext$2.apply$mcV$sp(statefulOperators.scala:359)
at org.apache.spark.sql.execution.streaming.StateStoreWriter$class.timeTakenMs(statefulOperators.scala:102)
at org.apache.spark.sql.execution.streaming.StateStoreSaveExec.timeTakenMs(statefulOperators.scala:251)
at org.apache.spark.sql.execution.streaming.StateStoreSaveExec$$anonfun$doExecute$3$$anon$2.hasNext(statefulOperators.scala:359)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.processInputs(ObjectAggregationIterator.scala:188)
at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.<init>(ObjectAggregationIterator.scala:78)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:114)
at org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec$$anonfun$doExecute$1$$anonfun$2.apply(ObjectHashAggregateExec.scala:105)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsWithIndexInternal$1$$anonfun$apply$24.apply(RDD.scala:830)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:42)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:336)
```
This can happen when the following conditions are accidentally hit.
- Streaming aggregation with aggregation function that is a subset of [`TypedImperativeAggregation`](https://github.com/apache/spark/blob/76b8b840ddc951ee6203f9cccd2c2b9671c1b5e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L473) (for example, `collect_set`, `collect_list`, `percentile`, etc.).
- Query running in `update}` mode
- After the shuffle, a partition has exactly 128 records.
This causes StateStore.commit to be called twice. See the [JIRA](https://issues.apache.org/jira/browse/SPARK-23004) for a more detailed explanation. The solution is to use `NextIterator` or `CompletionIterator`, each of which has a flag to prevent the "onCompletion" task from being called more than once. In this PR, I chose to implement using `NextIterator`.
## How was this patch tested?
Added unit test that I have confirm will fail without the fix.
Author: Tathagata Das <ta...@gmail.com>
Closes #21124 from tdas/SPARK-23004.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/770add81
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/770add81
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/770add81
Branch: refs/heads/master
Commit: 770add81c3474e754867d7105031a5eaf27159bd
Parents: 448d248
Author: Tathagata Das <ta...@gmail.com>
Authored: Mon Apr 23 13:20:32 2018 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Mon Apr 23 13:20:32 2018 -0700
----------------------------------------------------------------------
.../execution/streaming/statefulOperators.scala | 40 ++++++++++----------
.../streaming/StreamingAggregationSuite.scala | 25 ++++++++++++
2 files changed, 44 insertions(+), 21 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/770add81/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index b9b07a2..c9354ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -340,37 +340,35 @@ case class StateStoreSaveExec(
// Update and output modified rows from the StateStore.
case Some(Update) =>
- val updatesStartTimeNs = System.nanoTime
-
- new Iterator[InternalRow] {
-
+ new NextIterator[InternalRow] {
// Filter late date using watermark if specified
private[this] val baseIterator = watermarkPredicateForData match {
case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
case None => iter
}
+ private val updatesStartTimeNs = System.nanoTime
- override def hasNext: Boolean = {
- if (!baseIterator.hasNext) {
- allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
-
- // Remove old aggregates if watermark specified
- allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
- commitTimeMs += timeTakenMs { store.commit() }
- setStoreMetrics(store)
- false
+ override protected def getNext(): InternalRow = {
+ if (baseIterator.hasNext) {
+ val row = baseIterator.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key, row)
+ numOutputRows += 1
+ numUpdatedStateRows += 1
+ row
} else {
- true
+ finished = true
+ null
}
}
- override def next(): InternalRow = {
- val row = baseIterator.next().asInstanceOf[UnsafeRow]
- val key = getKey(row)
- store.put(key, row)
- numOutputRows += 1
- numUpdatedStateRows += 1
- row
+ override protected def close(): Unit = {
+ allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+
+ // Remove old aggregates if watermark specified
+ allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+ commitTimeMs += timeTakenMs { store.commit() }
+ setStoreMetrics(store)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/770add81/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 1cae8cb..382da13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -536,6 +536,31 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}
+ test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not throw errors") {
+ // See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
+ // by ensuring the following.
+ // - A streaming query with a streaming aggregation.
+ // - Aggregation function 'collect_list' that is a subclass of TypedImperativeAggregate.
+ // - Post shuffle partition has exactly 128 records (i.e. the threshold at which
+ // ObjectHashAggregateExec falls back to sort-based aggregation). This is done by having a
+ // micro-batch with 128 records that shuffle to a single partition.
+ // This test throws the exact error reported in SPARK-23004 without the corresponding fix.
+ withSQLConf("spark.sql.shuffle.partitions" -> "1") {
+ val input = MemoryStream[Int]
+ val df = input.toDF().toDF("value")
+ .selectExpr("value as group", "value")
+ .groupBy("group")
+ .agg(collect_list("value"))
+ testStream(df, outputMode = OutputMode.Update)(
+ AddData(input, (1 to spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
+ AssertOnQuery { q =>
+ q.processAllAvailable()
+ true
+ }
+ )
+ }
+ }
+
/** Add blocks of data to the `BlockRDDBackedSource`. */
case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData {
override def addData(query: Option[StreamExecution]): (Source, Offset) = {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org