You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2021/07/21 05:49:23 UTC

[spark] branch master updated: [SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState

This is an automated email from the ASF dual-hosted git repository.

tdas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new efcce23  [SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState
efcce23 is described below

commit efcce23b913ce0de961ac261050e3d6dbf261f6e
Author: Rahul Mahadev <ra...@databricks.com>
AuthorDate: Wed Jul 21 01:48:58 2021 -0400

    [SPARK-36132][SS][SQL] Support initial state for batch mode of flatMapGroupsWithState
    
    ### What changes were proposed in this pull request?
    Adding support for accepting an initial state with flatMapGroupsWithState in batch mode.
    
    ### Why are the changes needed?
    SPARK-35897  added support for accepting an initial state for streaming queries using flatMapGroupsWithState. the code flow is separate for batch and streaming and required a different PR.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes as discussed above flatMapGroupsWithState in batch mode can accept an initialState, previously this would throw an UnsupportedOperationException
    
    ### How was this patch tested?
    
    Added relevant unit tests in FlatMapGroupsWithStateSuite and modified the  tests `JavaDatasetSuite`
    
    Closes #33336 from rahulsmahadev/flatMapGroupsWithStateBatch.
    
    Authored-by: Rahul Mahadev <ra...@databricks.com>
    Signed-off-by: Tathagata Das <ta...@gmail.com>
---
 .../analysis/UnsupportedOperationChecker.scala     |  6 --
 .../spark/sql/execution/SparkStrategies.scala      | 11 +++-
 .../streaming/FlatMapGroupsWithStateExec.scala     | 71 +++++++++++++++++++++-
 .../org/apache/spark/sql/JavaDatasetSuite.java     | 18 +-----
 .../streaming/FlatMapGroupsWithStateSuite.scala    | 52 ++++++++++++++++
 5 files changed, 130 insertions(+), 28 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 13c7f75..321725d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -37,12 +37,6 @@ object UnsupportedOperationChecker extends Logging {
       case p if p.isStreaming =>
         throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
 
-      case f: FlatMapGroupsWithState =>
-        if (f.hasInitialState) {
-          throwError("Initial state is not supported in [flatMap|map]GroupsWithState" +
-            " operation on a batch DataFrame/Dataset")(f)
-        }
-
       case _ =>
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6d10fa8..7624b15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -690,9 +690,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
         execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
       case logical.FlatMapGroupsWithState(
-          f, key, value, grouping, data, output, _, _, _, timeout, _, _, _, _, _, child) =>
-        execution.MapGroupsExec(
-          f, key, value, grouping, data, output, timeout, planLater(child)) :: Nil
+          f, keyDeserializer, valueDeserializer, grouping, data, output, stateEncoder, outputMode,
+          isFlatMapGroupsWithState, timeout, hasInitialState, initialStateGroupAttrs,
+          initialStateDataAttrs, initialStateDeserializer, initialState, child) =>
+        FlatMapGroupsWithStateExec.generateSparkPlanForBatchQueries(
+          f, keyDeserializer, valueDeserializer, initialStateDeserializer, grouping,
+          initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
+          hasInitialState, planLater(initialState), planLater(child)
+        ) :: Nil
       case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
         execution.CoGroupExec(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 03694d4..a00a622 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -309,9 +309,7 @@ case class FlatMapGroupsWithStateExec(
           var foundInitialStateForKey = false
           initialStateRowIter.foreach { initialStateRow =>
             if (foundInitialStateForKey) {
-              throw new IllegalArgumentException("The initial state provided contained " +
-                "multiple rows(state) with the same key. Make sure to de-duplicate the " +
-                "initial state before passing it.")
+              FlatMapGroupsWithStateExec.foundDuplicateInitialKeyException()
             }
             foundInitialStateForKey = true
             val initStateObj = getStateObj.get(initialStateRow)
@@ -403,3 +401,70 @@ case class FlatMapGroupsWithStateExec(
     copy(child = newLeft, initialState = newRight)
 }
 
+object FlatMapGroupsWithStateExec {
+
+  def foundDuplicateInitialKeyException(): Exception = {
+    throw new IllegalArgumentException("The initial state provided contained " +
+      "multiple rows(state) with the same key. Make sure to de-duplicate the " +
+      "initial state before passing it.")
+  }
+
+  /**
+   * Plan logical flatmapGroupsWIthState for batch queries
+   * If the initial state is provided, we create an instance of the CoGroupExec, if the initial
+   * state is not provided we create an instance of the MapGroupsExec
+   */
+  // scalastyle:off argcount
+  def generateSparkPlanForBatchQueries(
+      userFunc: (Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any],
+      keyDeserializer: Expression,
+      valueDeserializer: Expression,
+      initialStateDeserializer: Expression,
+      groupingAttributes: Seq[Attribute],
+      initialStateGroupAttrs: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      initialStateDataAttrs: Seq[Attribute],
+      outputObjAttr: Attribute,
+      timeoutConf: GroupStateTimeout,
+      hasInitialState: Boolean,
+      initialState: SparkPlan,
+      child: SparkPlan): SparkPlan = {
+    if (hasInitialState) {
+      val watermarkPresent = child.output.exists {
+        case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
+        case _ => false
+      }
+      val func = (keyRow: Any, values: Iterator[Any], states: Iterator[Any]) => {
+        // Check if there is only one state for every key.
+        var foundInitialStateForKey = false
+        val optionalStates = states.map { stateValue =>
+          if (foundInitialStateForKey) {
+            foundDuplicateInitialKeyException()
+          }
+          foundInitialStateForKey = true
+          stateValue
+        }.toArray
+
+        // Create group state object
+        val groupState = GroupStateImpl.createForStreaming(
+          optionalStates.headOption,
+          System.currentTimeMillis,
+          GroupStateImpl.NO_TIMESTAMP,
+          timeoutConf,
+          hasTimedOut = false,
+          watermarkPresent)
+
+        // Call user function with the state and values for this key
+        userFunc(keyRow, values, groupState)
+      }
+      CoGroupExec(
+        func, keyDeserializer, valueDeserializer, initialStateDeserializer, groupingAttributes,
+        initialStateGroupAttrs, dataAttributes, initialStateDataAttrs, outputObjAttr,
+        child, initialState)
+    } else {
+      MapGroupsExec(
+        userFunc, keyDeserializer, valueDeserializer, groupingAttributes,
+        dataAttributes, outputObjAttr, timeoutConf, child)
+    }
+  }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0500c52..28439f2 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -196,14 +196,7 @@ public class JavaDatasetSuite implements Serializable {
       GroupStateTimeout.NoTimeout(),
       kvInitStateMappedDS);
 
-    Assert.assertThrows(
-      "Initial state is not supported in [flatMap|map]GroupsWithState " +
-              "operation on a batch DataFrame/Dataset",
-      AnalysisException.class,
-      () -> {
-        flatMapped2.collectAsList();
-      }
-    );
+    Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(flatMapped2.collectAsList()));
     Dataset<String> mapped2 = grouped.mapGroupsWithState(
       (MapGroupsWithStateFunction<Integer, String, Long, String>) (key, values, s) -> {
         StringBuilder sb = new StringBuilder(key.toString());
@@ -216,14 +209,7 @@ public class JavaDatasetSuite implements Serializable {
       Encoders.STRING(),
       GroupStateTimeout.NoTimeout(),
       kvInitStateMappedDS);
-    Assert.assertThrows(
-      "Initial state is not supported in [flatMap|map]GroupsWithState " +
-              "operation on a batch DataFrame/Dataset",
-      AnalysisException.class,
-      () -> {
-        mapped2.collectAsList();
-      }
-    );
+    Assert.assertEquals(asSet("1a", "2", "3foobar"), toSet(mapped2.collectAsList()));
   }
 
   @Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 152dd16..d34b2b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -1284,6 +1284,12 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
       assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
       assertCannotGetWatermark { state.getCurrentWatermarkMs() }
       assert(!state.hasTimedOut)
+      if (key.contains("EventTime")) {
+        state.setTimeoutTimestamp(0, "1 hour")
+      }
+      if (key.contains("ProcessingTime")) {
+        state.setTimeoutDuration("1  hour")
+      }
       val count = state.getOption.map(_.count).getOrElse(0L) + valList.size
       // We need to check if not explicitly calling update will still save the init state or not
       if (!key.contains("NoUpdate")) {
@@ -1413,6 +1419,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
     )
   }
 
+  Seq(NoTimeout(), EventTimeTimeout(), ProcessingTimeTimeout()).foreach { timeout =>
+    test(s"flatMapGroupsWithState - initial state - batch mode - timeout ${timeout}") {
+      // We will test them on different shuffle partition configuration to make sure the
+      // grouping by key will still work. On higher number of shuffle partitions its possible
+      // that all keys end up on different partitions.
+      val initialState = Seq(
+        (s"keyInStateAndData-1-$timeout", new RunningCount(1)),
+        ("keyInStateAndData-2", new RunningCount(2)),
+        ("keyNoUpdate", new RunningCount(2)), // state.update will not be called
+        ("keyOnlyInState-1", new RunningCount(1))
+      ).toDS().groupByKey(x => x._1).mapValues(_._2)
+
+      val inputData = Seq(
+        ("keyOnlyInData"), ("keyInStateAndData-2")
+      )
+      val result = inputData.toDS().groupByKey(x => x)
+        .flatMapGroupsWithState(
+          Update, timeout, initialState)(flatMapGroupsWithStateFunc)
+
+      val expected = Seq(
+        ("keyOnlyInState-1", Seq[String](), "1"),
+        ("keyNoUpdate", Seq[String](), "2"), // update will not be called
+        ("keyInStateAndData-2", Seq[String]("keyInStateAndData-2"), "3"), // inc by 1
+        (s"keyInStateAndData-1-$timeout", Seq[String](), "1"),
+        ("keyOnlyInData", Seq[String]("keyOnlyInData"), "1") // inc by 1
+      ).toDF()
+      checkAnswer(result.toDF(), expected)
+    }
+  }
+
+  testQuietly("flatMapGroupsWithState - initial state - batch mode - duplicate state") {
+    val initialState = Seq(
+      ("a", new RunningCount(1)),
+      ("a", new RunningCount(2))
+    ).toDS().groupByKey(x => x._1).mapValues(_._2)
+
+    val e = intercept[SparkException] {
+      Seq("a", "b").toDS().groupByKey(x => x)
+        .flatMapGroupsWithState(Update, NoTimeout(), initialState)(flatMapGroupsWithStateFunc)
+        .show()
+    }
+    assert(e.getMessage.contains(
+      "The initial state provided contained multiple rows(state) with the same key." +
+        " Make sure to de-duplicate the initial state before passing it."))
+  }
+
   testQuietly("flatMapGroupsWithState - initial state - streaming initial state") {
     val initialStateData = MemoryStream[(String, RunningCount)]
     initialStateData.addData(("a", new RunningCount(1)))

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org