You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2021/07/21 05:49:23 UTC
[spark] branch master updated: [SPARK-36132][SS][SQL] Support
initial state for batch mode of flatMapGroupsWithState
This is an automated email from the ASF dual-hosted git repository.
tdas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new efcce23 [SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState
efcce23 is described below
commit efcce23b913ce0de961ac261050e3d6dbf261f6e
Author: Rahul Mahadev <ra...@databricks.com>
AuthorDate: Wed Jul 21 01:48:58 2021 -0400
[SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState
### What changes were proposed in this pull request?
Adding support for accepting an initial state with flatMapGroupsWithState in batch mode.
### Why are the changes needed?
SPARK-35897 added support for accepting an initial state for streaming queries using flatMapGroupsWithState. the code flow is separate for batch and streaming and required a different PR.
### Does this PR introduce _any_ user-facing change?
Yes as discussed above flatMapGroupsWithState in batch mode can accept an initialState, previously this would throw an UnsupportedOperationException
### How was this patch tested?
Added relevant unit tests in FlatMapGroupsWithStateSuite and modified the tests `JavaDatasetSuite`
Closes #33336 from rahulsmahadev/flatMapGroupsWithStateBatch.
Authored-by: Rahul Mahadev <ra...@databricks.com>
Signed-off-by: Tathagata Das <ta...@gmail.com>
---
.../analysis/UnsupportedOperationChecker.scala | 6 --
.../spark/sql/execution/SparkStrategies.scala | 11 +++-
.../streaming/FlatMapGroupsWithStateExec.scala | 71 +++++++++++++++++++++-
.../org/apache/spark/sql/JavaDatasetSuite.java | 18 +-----
.../streaming/FlatMapGroupsWithStateSuite.scala | 52 ++++++++++++++++
5 files changed, 130 insertions(+), 28 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 13c7f75..321725d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -37,12 +37,6 @@ object UnsupportedOperationChecker extends Logging {
case p if p.isStreaming =>
throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
- case f: FlatMapGroupsWithState =>
- if (f.hasInitialState) {
- throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
- " operation on a batch DataFrame/Dataset")(f)
- }
-
case _ =>
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6d10fa8..7624b15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -690,9 +690,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
case logical.FlatMapGroupsWithState(
- f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) =>
- execution.MapGroupsExec(
- f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
+ f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
+ isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
+ initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
+ FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
+ f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
+ initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
+ hasInitialState, planLater(initialState), planLater(child)
+ ) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 03694d4..a00a622 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -309,9 +309,7 @@ case class FlatMapGroupsWithStateExec(
var foundInitialStateForKey = false
initialStateRowIter.foreach { initialStateRow =>
if (foundInitialStateForKey) {
- throw new IllegalArgumentException("The initial state provided contained " +
- "multiple rows(state) with the same key. Make sure to de-duplicate the " +
- "initial state before passing it.")
+ FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException()
}
foundInitialStateForKey = true
val initStateObj = getStateObj.get(initialStateRow)
@@ -403,3 +401,70 @@ case class FlatMapGroupsWithStateExec(
copy(child = newLeft, initialState = newRight)
}
+object FlatMapGroupsWithStateExec {
+
+ def foundDuplicateInitialKeyException(): Exception = {
+ throw new IllegalArgumentException("The initial state provided contained " +
+ "multiple rows(state) with the same key. Make sure to de-duplicate the " +
+ "initial state before passing it.")
+ }
+
+ /**
+ * Plan logical flatmapGroupsWIthState for batch queries
+ * If the initial state is provided, we create an instance of the CoGroupExec, if the initial
+ * state is not provided we create an instance of the MapGroupsExec
+ */
+ // scalastyle:off argcount
+ def generateSparkPlanForBatchQueries(
+ userFunc: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ initialStateDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ initialStateGroupAttrs: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ initialStateDataAttrs: Seq[Attribute],
+ outputObjAttr: Attribute,
+ timeoutConf: GroupStateTimeout,
+ hasInitialState: Boolean,
+ initialState: SparkPlan,
+ child: SparkPlan): SparkPlan = {
+ if (hasInitialState) {
+ val watermarkPresent = child.output.exists {
+ case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
+ case _ => false
+ }
+ val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
+ // Check if there is only one state for every key.
+ var foundInitialStateForKey = false
+ val optionalStates = states.map { stateValue =>
+ if (foundInitialStateForKey) {
+ foundDuplicateInitialKeyException()
+ }
+ foundInitialStateForKey = true
+ stateValue
+ }.toArray
+
+ // Create group state object
+ val groupState = GroupStateImpl.createForStreaming(
+ optionalStates.headOption,
+ System.currentTimeMillis,
+ GroupStateImpl.NO_TIMESTAMP,
+ timeoutConf,
+ hasTimedOut = false,
+ watermarkPresent)
+
+ // Call user function with the state and values for this key
+ userFunc(keyRow, values, groupState)
+ }
+ CoGroupExec(
+ func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
+ initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, outputObjAttr,
+ child, initialState)
+ } else {
+ MapGroupsExec(
+ userFunc, keyDeserializer, valueDeserializer, groupingAttributes,
+ dataAttributes, outputObjAttr, timeoutConf, child)
+ }
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0500c52..28439f2 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -196,14 +196,7 @@ public class JavaDatasetSuite implements Serializable {
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
- Assert.assertThrows(
- "Initial state is not supported in [flatMap|map]GroupsWithState " +
- "operation on a batch DataFrame/Dataset",
- AnalysisException.class,
- () -> {
- flatMapped2.collectAsList();
- }
- );
+ Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(flatMapped2.collectAsList()));
Dataset<String> mapped2 = grouped.mapGroupsWithState(
(MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
StringBuilder sb = new StringBuilder(key.toString());
@@ -216,14 +209,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.STRING(),
GroupStateTimeout.NoTimeout(),
kvInitStateMappedDS);
- Assert.assertThrows(
- "Initial state is not supported in [flatMap|map]GroupsWithState " +
- "operation on a batch DataFrame/Dataset",
- AnalysisException.class,
- () -> {
- mapped2.collectAsList();
- }
- );
+ Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(mapped2.collectAsList()));
}
@Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 152dd16..d34b2b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -1284,6 +1284,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
assertCannotGetWatermark { state.getCurrentWatermarkMs() }
assert(!state.hasTimedOut)
+ if (key.contains("EventTime")) {
+ state.setTimeoutTimestamp(0, "1 hour")
+ }
+ if (key.contains("ProcessingTime")) {
+ state.setTimeoutDuration("1 hour")
+ }
val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
// We need to check if not explicitly calling update will still save the init state or not
if (!key.contains("NoUpdate")) {
@@ -1413,6 +1419,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
)
}
+ Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout =>
+ test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") {
+ // We will test them on different shuffle partition configuration to make sure the
+ // grouping by key will still work. On higher number of shuffle partitions its possible
+ // that all keys end up on different partitions.
+ val initialState = Seq(
+ (s"keyInStateAndData-1-$timeout", new RunningCount(1)),
+ ("keyInStateAndData-2", new RunningCount(2)),
+ ("keyNoUpdate", new RunningCount(2)), // state.update will not be called
+ ("keyOnlyInState-1", new RunningCount(1))
+ ).toDS().groupByKey(x => x._1).mapValues(_._2)
+
+ val inputData = Seq(
+ ("keyOnlyInData"), ("keyInStateAndData-2")
+ )
+ val result = inputData.toDS().groupByKey(x => x)
+ .flatMapGroupsWithState(
+ Update, timeout, initialState)(flatMapGroupsWithStateFunc)
+
+ val expected = Seq(
+ ("keyOnlyInState-1", Seq[String](), "1"),
+ ("keyNoUpdate", Seq[String](), "2"), // update will not be called
+ ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1
+ (s"keyInStateAndData-1-$timeout", Seq[String](), "1"),
+ ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
+ ).toDF()
+ checkAnswer(result.toDF(), expected)
+ }
+ }
+
+ testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") {
+ val initialState = Seq(
+ ("a", new RunningCount(1)),
+ ("a", new RunningCount(2))
+ ).toDS().groupByKey(x => x._1).mapValues(_._2)
+
+ val e = intercept[SparkException] {
+ Seq("a", "b").toDS().groupByKey(x => x)
+ .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc)
+ .show()
+ }
+ assert(e.getMessage.contains(
+ "The initial state provided contained multiple rows(state) with the same key." +
+ " Make sure to de-duplicate the initial state before passing it."))
+ }
+
testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
val initialStateData = MemoryStream[(String, RunningCount)]
initialStateData.addData(("a", new RunningCount(1)))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org