You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/21 06:35:48 UTC

[spark] branch master updated: [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a16aed9a17 [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
8a16aed9a17 is described below

commit 8a16aed9a17269b4c8111779229507e3c28ba945
Author: bogao007 <bo...@databricks.com>
AuthorDate: Wed Jun 21 15:35:34 2023 +0900

    [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
    
    ### Why are the changes needed?
    
    To support streaming state APIs in Spark Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes
    
    ### How was this patch tested?
    
    Added unit test
    
    Closes #41558 from bogao007/sc-state-api.
    
    Authored-by: bogao007 <bo...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../apache/spark/sql/KeyValueGroupedDataset.scala  | 398 +++++++++++++++++++++
 .../sql/KeyValueGroupedDatasetE2ETestSuite.scala   | 107 ++++++
 .../CheckConnectJvmClientCompatibility.scala       |   6 -
 .../FlatMapGroupsWithStateStreamingSuite.scala     | 224 ++++++++++++
 .../function/FlatMapGroupsWithStateFunction.java   |  39 ++
 .../java/function/MapGroupsWithStateFunction.java  |  38 ++
 .../main/protobuf/spark/connect/relations.proto    |  16 +
 .../apache/spark/sql/connect/common/UdfUtils.scala |  26 ++
 .../apache/spark/sql/streaming/GroupState.scala    | 336 +++++++++++++++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  92 ++++-
 python/pyspark/sql/connect/proto/relations_pb2.py  |  24 +-
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  84 ++++-
 12 files changed, 1359 insertions(+), 31 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 7b2fa3b52be..20c130b83cb 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
 import org.apache.spark.sql.connect.common.UdfUtils
 import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
 import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
 
 /**
  * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
@@ -460,6 +461,356 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable
     cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
       UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
   }
+
+  protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
+      outputMode: Option[OutputMode],
+      timeoutConf: GroupStateTimeout,
+      initialState: Option[KeyValueGroupedDataset[K, S]],
+      isMapGroupWithState: Boolean)(
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+    throw new UnsupportedOperationException
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See
+   * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def mapGroupsWithState[S: Encoder, U: Encoder](
+      func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
+    mapGroupsWithState(GroupStateTimeout.NoTimeout)(func)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See
+   * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(
+      func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
+    flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)(
+      UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See
+   * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param timeoutConf
+   *   Timeout Conf, see GroupStateTimeout for more details
+   * @param initialState
+   *   The user provided state that will be initialized when the first batch of data is processed
+   *   in the streaming query. The user defined function will be called on the state data even if
+   *   there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to
+   *   a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def mapGroupsWithState[S: Encoder, U: Encoder](
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S])(
+      func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
+    flatMapGroupsWithStateHelper(
+      None,
+      timeoutConf,
+      Some(initialState),
+      isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param stateEncoder
+   *   Encoder for the state type.
+   * @param outputEncoder
+   *   Encoder for the output type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] = {
+    mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
+      stateEncoder,
+      outputEncoder)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param stateEncoder
+   *   Encoder for the state type.
+   * @param outputEncoder
+   *   Encoder for the output type.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): Dataset[U] = {
+    mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
+      stateEncoder,
+      outputEncoder)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param stateEncoder
+   *   Encoder for the state type.
+   * @param outputEncoder
+   *   Encoder for the output type.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   * @param initialState
+   *   The user provided state that will be initialized when the first batch of data is processed
+   *   in the streaming query. The user defined function will be called on the state data even if
+   *   there are no other values in the group.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+    mapGroupsWithState[S, U](timeoutConf, initialState)(
+      UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param outputMode
+   *   The output mode of the function.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def flatMapGroupsWithState[S: Encoder, U: Encoder](
+      outputMode: OutputMode,
+      timeoutConf: GroupStateTimeout)(
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+    flatMapGroupsWithStateHelper(
+      Some(outputMode),
+      timeoutConf,
+      None,
+      isMapGroupWithState = false)(func)
+  }
+
+  /**
+   * (Scala-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param outputMode
+   *   The output mode of the function.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   * @param initialState
+   *   The user provided state that will be initialized when the first batch of data is processed
+   *   in the streaming query. The user defined function will be called on the state data even if
+   *   there are no other values in the group. To covert a Dataset `ds` of type of type
+   *   `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
+   *   {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what
+   *   types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def flatMapGroupsWithState[S: Encoder, U: Encoder](
+      outputMode: OutputMode,
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S])(
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+    flatMapGroupsWithStateHelper(
+      Some(outputMode),
+      timeoutConf,
+      Some(initialState),
+      isMapGroupWithState = false)(func)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param outputMode
+   *   The output mode of the function.
+   * @param stateEncoder
+   *   Encoder for the state type.
+   * @param outputEncoder
+   *   Encoder for the output type.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      outputMode: OutputMode,
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout): Dataset[U] = {
+    val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
+    flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * (Java-specific) Applies the given function to each group of data, while maintaining a
+   * user-defined per-group state. The result Dataset will represent the objects returned by the
+   * function. For a static batch Dataset, the function will be invoked once per group. For a
+   * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+   * and updates to each group's state will be saved across invocations. See `GroupState` for more
+   * details.
+   *
+   * @tparam S
+   *   The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U
+   *   The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func
+   *   Function to be called on every group.
+   * @param outputMode
+   *   The output mode of the function.
+   * @param stateEncoder
+   *   Encoder for the state type.
+   * @param outputEncoder
+   *   Encoder for the output type.
+   * @param timeoutConf
+   *   Timeout configuration for groups that do not receive data for a while.
+   * @param initialState
+   *   The user provided state that will be initialized when the first batch of data is processed
+   *   in the streaming query. The user defined function will be called on the state data even if
+   *   there are no other values in the group. To covert a Dataset `ds` of type of type
+   *   `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
+   *   {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 3.5.0
+   */
+  def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      outputMode: OutputMode,
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U],
+      timeoutConf: GroupStateTimeout,
+      initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+    val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
+    flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)(
+      stateEncoder,
+      outputEncoder)
+  }
 }
 
 /**
@@ -572,6 +923,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
     agg(aggregator)
   }
 
+  override protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
+      outputMode: Option[OutputMode],
+      timeoutConf: GroupStateTimeout,
+      initialState: Option[KeyValueGroupedDataset[K, S]],
+      isMapGroupWithState: Boolean)(
+      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+    if (outputMode.isDefined && outputMode.get != OutputMode.Append &&
+      outputMode.get != OutputMode.Update) {
+      throw new IllegalArgumentException("The output mode of function should be append or update")
+    }
+
+    if (initialState.isDefined) {
+      assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
+    }
+
+    val initialStateImpl = if (initialState.isDefined) {
+      initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
+    } else {
+      null
+    }
+
+    val outputEncoder = encoderFor[U]
+    val nf = if (valueMapFunc == UdfUtils.identical()) {
+      func
+    } else {
+      UdfUtils.mapValuesAdaptor(func, valueMapFunc)
+    }
+
+    sparkSession.newDataset[U](outputEncoder) { builder =>
+      val groupMapBuilder = builder.getGroupMapBuilder
+      groupMapBuilder
+        .setInput(plan.getRoot)
+        .addAllGroupingExpressions(groupingExprs)
+        .setFunc(getUdf(nf, outputEncoder)(ivEncoder))
+        .setIsMapGroupsWithState(isMapGroupWithState)
+        .setOutputMode(if (outputMode.isEmpty) OutputMode.Update.toString
+        else outputMode.get.toString)
+        .setTimeoutConf(timeoutConf.toString)
+
+      if (initialStateImpl != null) {
+        groupMapBuilder
+          .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs)
+          .setInitialInput(initialStateImpl.plan.getRoot)
+      }
+    }
+  }
+
   private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])(
       inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = {
     val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index e7a77eed70d..404239f7e14 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -16,14 +16,21 @@
  */
 package org.apache.spark.sql
 
+import java.sql.Timestamp
 import java.util.Arrays
 
 import io.grpc.StatusRuntimeException
 
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
 import org.apache.spark.sql.connect.client.util.QueryTest
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
 import org.apache.spark.sql.types._
 
+case class ClickEvent(id: String, timestamp: Timestamp)
+
+case class ClickState(id: String, count: Int)
+
 /**
  * All tests in this class requires client UDF artifacts synced with the server.
  */
@@ -447,4 +454,104 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
 
     checkDataset(keys, "1", "2", "10", "20")
   }
+
+  test("flatMapGroupsWithState") {
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+        Iterator(ClickState(key, values.size))
+      }
+
+    val session: SparkSession = spark
+    import session.implicits._
+    val values = Seq(
+      ClickEvent("a", new Timestamp(12345)),
+      ClickEvent("a", new Timestamp(12346)),
+      ClickEvent("a", new Timestamp(12347)),
+      ClickEvent("b", new Timestamp(12348)),
+      ClickEvent("b", new Timestamp(12349)),
+      ClickEvent("c", new Timestamp(12350)))
+      .toDS()
+      .groupByKey(_.id)
+      .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc)
+
+    checkDataset(values, ClickState("a", 3), ClickState("b", 2), ClickState("c", 1))
+  }
+
+  test("flatMapGroupsWithState - with initial state") {
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        val currState = state.getOption.getOrElse(ClickState(key, 0))
+        Iterator(ClickState(key, currState.count + values.size))
+      }
+
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val initialStateDS = Seq(ClickState("a", 2), ClickState("b", 1)).toDS()
+    val initialState = initialStateDS.groupByKey(_.id).mapValues(x => x)
+
+    val values = Seq(
+      ClickEvent("a", new Timestamp(12345)),
+      ClickEvent("a", new Timestamp(12346)),
+      ClickEvent("a", new Timestamp(12347)),
+      ClickEvent("b", new Timestamp(12348)),
+      ClickEvent("b", new Timestamp(12349)),
+      ClickEvent("c", new Timestamp(12350)))
+      .toDS()
+      .groupByKey(_.id)
+      .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+
+    checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1))
+  }
+
+  test("mapGroupsWithState") {
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+        ClickState(key, values.size)
+      }
+
+    val session: SparkSession = spark
+    import session.implicits._
+    val values = Seq(
+      ClickEvent("a", new Timestamp(12345)),
+      ClickEvent("a", new Timestamp(12346)),
+      ClickEvent("a", new Timestamp(12347)),
+      ClickEvent("b", new Timestamp(12348)),
+      ClickEvent("b", new Timestamp(12349)),
+      ClickEvent("c", new Timestamp(12350)))
+      .toDS()
+      .groupByKey(_.id)
+      .mapGroupsWithState(GroupStateTimeout.NoTimeout)(stateFunc)
+
+    checkDataset(values, ClickState("a", 3), ClickState("b", 2), ClickState("c", 1))
+  }
+
+  test("mapGroupsWithState - with initial state") {
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        val currState = state.getOption.getOrElse(ClickState(key, 0))
+        ClickState(key, currState.count + values.size)
+      }
+
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val initialStateDS = Seq(ClickState("a", 2), ClickState("b", 1)).toDS()
+    val initialState = initialStateDS.groupByKey(_.id).mapValues(x => x)
+
+    val values = Seq(
+      ClickEvent("a", new Timestamp(12345)),
+      ClickEvent("a", new Timestamp(12346)),
+      ClickEvent("a", new Timestamp(12347)),
+      ClickEvent("b", new Timestamp(12348)),
+      ClickEvent("b", new Timestamp(12349)),
+      ClickEvent("c", new Timestamp(12350)))
+      .toDS()
+      .groupByKey(_.id)
+      .mapGroupsWithState(GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+
+    checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1))
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7a9a889706d..6b648fd152b 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -201,12 +201,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
 
       // KeyValueGroupedDataset
-      ProblemFilters.exclude[Problem](
-        "org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState"
-      ), // streaming
-      ProblemFilters.exclude[Problem](
-        "org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState"
-      ), // streaming
       ProblemFilters.exclude[Problem](
         "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.this"),
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
new file mode 100644
index 00000000000..cdb6b9a2e9c
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.sql.Timestamp
+
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.sql.{SparkSession, SQLHelper}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
+import org.apache.spark.sql.connect.client.util.QueryTest
+import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+
+case class ClickEvent(id: String, timestamp: Timestamp)
+
+case class ClickState(id: String, count: Int)
+
+class FlatMapGroupsWithStateStreamingSuite extends QueryTest with SQLHelper {
+
+  val flatMapGroupsWithStateSchema: StructType = StructType(
+    Array(StructField("id", StringType), StructField("timestamp", TimestampType)))
+
+  val flatMapGroupsWithStateData: Seq[ClickEvent] = Seq(
+    ClickEvent("a", new Timestamp(12345)),
+    ClickEvent("a", new Timestamp(12346)),
+    ClickEvent("a", new Timestamp(12347)),
+    ClickEvent("b", new Timestamp(12348)),
+    ClickEvent("b", new Timestamp(12349)),
+    ClickEvent("c", new Timestamp(12350)))
+
+  val flatMapGroupsWithStateInitialStateData: Seq[ClickState] =
+    Seq(ClickState("a", 2), ClickState("b", 1))
+
+  test("flatMapGroupsWithState - streaming") {
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+        Iterator(ClickState(key, values.size))
+      }
+    spark.sql("DROP TABLE IF EXISTS my_sink")
+
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      flatMapGroupsWithStateData.toDS().write.parquet(path)
+      val q = spark.readStream
+        .schema(flatMapGroupsWithStateSchema)
+        .parquet(path)
+        .as[ClickEvent]
+        .groupByKey(_.id)
+        .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc)
+        .writeStream
+        .format("memory")
+        .queryName("my_sink")
+        .start()
+
+      try {
+        q.processAllAvailable()
+        eventually(timeout(30.seconds)) {
+          checkDataset(
+            spark.table("my_sink").toDF().as[ClickState],
+            ClickState("c", 1),
+            ClickState("b", 2),
+            ClickState("a", 3))
+        }
+      } finally {
+        q.stop()
+        spark.sql("DROP TABLE IF EXISTS my_sink")
+      }
+    }
+  }
+
+  test("flatMapGroupsWithState - streaming - with initial state") {
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        val currState = state.getOption.getOrElse(ClickState(key, 0))
+        Iterator(ClickState(key, currState.count + values.size))
+      }
+    val initialState = flatMapGroupsWithStateInitialStateData
+      .toDS()
+      .groupByKey(_.id)
+      .mapValues(x => x)
+    spark.sql("DROP TABLE IF EXISTS my_sink")
+
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      flatMapGroupsWithStateData.toDS().write.parquet(path)
+      val q = spark.readStream
+        .schema(flatMapGroupsWithStateSchema)
+        .parquet(path)
+        .as[ClickEvent]
+        .groupByKey(_.id)
+        .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+        .writeStream
+        .format("memory")
+        .queryName("my_sink")
+        .start()
+
+      try {
+        q.processAllAvailable()
+        eventually(timeout(30.seconds)) {
+          checkDataset(
+            spark.table("my_sink").toDF().as[ClickState],
+            ClickState("c", 1),
+            ClickState("b", 3),
+            ClickState("a", 5))
+        }
+      } finally {
+        q.stop()
+        spark.sql("DROP TABLE IF EXISTS my_sink")
+      }
+    }
+  }
+
+  test("mapGroupsWithState - streaming") {
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+        ClickState(key, values.size)
+      }
+    spark.sql("DROP TABLE IF EXISTS my_sink")
+
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      flatMapGroupsWithStateData.toDS().write.parquet(path)
+      val q = spark.readStream
+        .schema(flatMapGroupsWithStateSchema)
+        .parquet(path)
+        .as[ClickEvent]
+        .groupByKey(_.id)
+        .mapGroupsWithState(GroupStateTimeout.NoTimeout)(stateFunc)
+        .writeStream
+        .format("memory")
+        .queryName("my_sink")
+        .outputMode("update")
+        .start()
+
+      try {
+        q.processAllAvailable()
+        eventually(timeout(30.seconds)) {
+          checkDataset(
+            spark.table("my_sink").toDF().as[ClickState],
+            ClickState("c", 1),
+            ClickState("b", 2),
+            ClickState("a", 3))
+        }
+      } finally {
+        q.stop()
+        spark.sql("DROP TABLE IF EXISTS my_sink")
+      }
+    }
+  }
+
+  test("mapGroupsWithState - streaming - with initial state") {
+    val session: SparkSession = spark
+    import session.implicits._
+
+    val stateFunc =
+      (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+        val currState = state.getOption.getOrElse(ClickState(key, 0))
+        ClickState(key, currState.count + values.size)
+      }
+    val initialState = flatMapGroupsWithStateInitialStateData
+      .toDS()
+      .groupByKey(_.id)
+      .mapValues(x => x)
+    spark.sql("DROP TABLE IF EXISTS my_sink")
+
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      flatMapGroupsWithStateData.toDS().write.parquet(path)
+      val q = spark.readStream
+        .schema(flatMapGroupsWithStateSchema)
+        .parquet(path)
+        .as[ClickEvent]
+        .groupByKey(_.id)
+        .mapGroupsWithState(GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+        .writeStream
+        .format("memory")
+        .queryName("my_sink")
+        .outputMode("update")
+        .start()
+
+      try {
+        q.processAllAvailable()
+        eventually(timeout(30.seconds)) {
+          checkDataset(
+            spark.table("my_sink").toDF().as[ClickState],
+            ClickState("c", 1),
+            ClickState("b", 3),
+            ClickState("a", 5))
+        }
+      } finally {
+        q.stop()
+        spark.sql("DROP TABLE IF EXISTS my_sink")
+      }
+    }
+  }
+}
diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
new file mode 100644
index 00000000000..c917c8d28be
--- /dev/null
+++ b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.streaming.GroupState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState(
+ * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode,
+ * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}
+ * @since 3.5.0
+ */
+@Experimental
+@Evolving
+public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+    Iterator<R> call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
+}
diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
new file mode 100644
index 00000000000..ae179ad7d27
--- /dev/null
+++ b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.streaming.GroupState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@code org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState(
+ * MapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}
+ * @since 3.5.0
+ */
+@Experimental
+@Evolving
+public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+    R call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
+}
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 6347bd7bc56..ea432bb48fc 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -857,6 +857,22 @@ message GroupMap {
 
   // (Optional) Expressions for sorting. Only used by Scala Sorted Group Map API.
   repeated Expression sorting_expressions = 4;
+
+  // Below fields are only used by (Flat)MapGroupsWithState
+  // (Optional) Input relation for initial State.
+  Relation initial_input = 5;
+
+  // (Optional) Expressions for grouping keys of the initial state input relation.
+  repeated Expression initial_grouping_expressions = 6;
+
+  // (Optional) True if MapGroupsWithState, false if FlatMapGroupsWithState.
+  optional bool is_map_groups_with_state = 7;
+
+  // (Optional) The output mode of the function.
+  optional string output_mode = 8;
+
+  // (Optional) Timeout configuration for groups that do not receive data for a while.
+  optional string timeout_conf = 9;
 }
 
 message CoGroupMap {
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index 06a6c74f268..883637ff86c 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -20,6 +20,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.streaming.GroupState
 
 /**
  * Util functions to help convert input functions between typed filter, map, flatMap,
@@ -95,6 +96,31 @@ private[sql] object UdfUtils extends Serializable {
       }
   }
 
+  def mapValuesAdaptor[K, V, S, U, IV](
+      f: (K, Iterator[V], GroupState[S]) => Iterator[U],
+      valueMapFunc: IV => V): (K, Iterator[IV], GroupState[S]) => Iterator[U] = {
+    (k: K, itr: Iterator[IV], s: GroupState[S]) =>
+      {
+        f(k, itr.map(v => valueMapFunc(v)), s)
+      }
+  }
+
+  def mapGroupsWithStateFuncToFlatMapAdaptor[K, V, S, U](
+      f: (K, Iterator[V], GroupState[S]) => U): (K, Iterator[V], GroupState[S]) => Iterator[U] = {
+    (k: K, itr: Iterator[V], s: GroupState[S]) => Iterator(f(k, itr, s))
+  }
+
+  def mapGroupsWithStateFuncToScalaFunc[K, V, S, U](
+      f: MapGroupsWithStateFunction[K, V, S, U]): (K, Iterator[V], GroupState[S]) => U = {
+    (key, data, groupState) => f.call(key, data.asJava, groupState)
+  }
+
+  def flatMapGroupsWithStateFuncToScalaFunc[K, V, S, U](
+      f: FlatMapGroupsWithStateFunction[K, V, S, U])
+      : (K, Iterator[V], GroupState[S]) => Iterator[U] = { (key, data, groupState) =>
+    f.call(key, data.asJava, groupState).asScala
+  }
+
   def mapReduceFuncToScalaFunc[T](func: ReduceFunction[T]): (T, T) => T = func.call
 
   def identical[T](): T => T = t => t
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
new file mode 100644
index 00000000000..bd418a89534
--- /dev/null
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
+
+/**
+ * :: Experimental ::
+ *
+ * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and
+ * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`.
+ *
+ * Detail description on `[map/flatMap]GroupsWithState` operation
+ * -------------------------------------------------------------- Both, `mapGroupsWithState` and
+ * `flatMapGroupsWithState` in `KeyValueGroupedDataset` will invoke the user-given function on
+ * each group (defined by the grouping function in `Dataset.groupByKey()`) while maintaining a
+ * user-defined per-group state between invocations. For a static batch Dataset, the function will
+ * be invoked once per group. For a streaming Dataset, the function will be invoked for each group
+ * repeatedly in every trigger. That is, in every batch of the `StreamingQuery`, the function will
+ * be invoked once for each group that has data in the trigger. Furthermore, if timeout is set,
+ * then the function will be invoked on timed-out groups (more detail below).
+ *
+ * The function is invoked with the following parameters.
+ *   - The key of the group.
+ *   - An iterator containing all the values for this group.
+ *   - A user-defined state object set by previous invocations of the given function.
+ *
+ * In case of a batch Dataset, there is only one invocation and the state object will be empty as
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` is
+ * equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have no
+ * effect.
+ *
+ * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the
+ * former allows the function to return one and only one record, whereas the latter allows the
+ * function to return any number of records (including no records). Furthermore, the
+ * `flatMapGroupsWithState` is associated with an operation output mode, which can be either
+ * `Append` or `Update`. Semantically, this defines whether the output records of one trigger is
+ * effectively replacing the previously output records (from previous triggers) or is appending to
+ * the list of previously output records. Essentially, this defines how the Result Table (refer to
+ * the semantics in the programming guide) is updated, and allows us to reason about the semantics
+ * of later operations.
+ *
+ * Important points to note about the function (both mapGroupsWithState and
+ * flatMapGroupsWithState).
+ *   - In a trigger, the function will be called only the groups present in the batch. So do not
+ *     assume that the function will be called in every trigger for every group that has state.
+ *   - There is no guaranteed ordering of values in the iterator in the function, neither with
+ *     batch, nor with streaming Datasets.
+ *   - All the data will be shuffled before applying the function.
+ *   - If timeout is set, then the function will also be called with no values. See more details
+ *     on `GroupStateTimeout` below.
+ *
+ * Important points to note about using `GroupState`.
+ *   - The value of the state cannot be null. So updating state with null will throw
+ *     `IllegalArgumentException`.
+ *   - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers.
+ *   - If `remove()` is called, then `exists()` will return `false`, `get()` will throw
+ *     `NoSuchElementException` and `getOption()` will return `None`
+ *   - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ *     `get()` and `getOption()`will return the updated value.
+ *
+ * Important points to note about using `GroupStateTimeout`.
+ *   - The timeout type is a global param across all the groups (set as `timeout` param in
+ *     `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable
+ *     per group by calling `setTimeout...()` in `GroupState`.
+ *   - Timeouts can be either based on processing time (i.e.
+ *     `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e.
+ *     `GroupStateTimeout.EventTimeTimeout`).
+ *   - With `ProcessingTimeTimeout`, the timeout duration can be set by calling
+ *     `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the
+ *     set duration. Guarantees provided by this timeout with a duration of D ms are as follows:
+ *     - Timeout will never occur before the clock time has advanced by D ms
+ *     - Timeout will occur eventually when there is a trigger in the query (i.e. after D ms). So
+ *       there is no strict upper bound on when the timeout would occur. For example, the trigger
+ *       interval of the query will affect when the timeout actually occurs. If there is no data
+ *       in the stream (for any group) for a while, then there will not be any trigger and timeout
+ *       function call will not occur until there is data.
+ *     - Since the processing time timeout is based on the clock time, it is affected by the
+ *       variations in the system clock (i.e. time zone changes, clock skew, etc.).
+ *   - With `EventTimeTimeout`, the user also has to specify the event time watermark in the query
+ *     using `Dataset.withWatermark()`. With this setting, data that is older than the watermark
+ *     is filtered out. The timeout can be set for a group by setting a timeout timestamp
+ *     using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark
+ *     advances beyond the set timestamp. You can control the timeout delay by two parameters -
+ *     (i) watermark delay and an additional duration beyond the timestamp in the event (which is
+ *     guaranteed to be newer than watermark due to the filtering). Guarantees provided by this
+ *     timeout are as follows:
+ *     - Timeout will never occur before the watermark has exceeded the set timeout.
+ *     - Similar to processing time timeouts, there is no strict upper bound on the delay when the
+ *       timeout actually occurs. The watermark can advance only when there is data in the stream
+ *       and the event time of the data has actually advanced.
+ *   - When the timeout occurs for a group, the function is called for that group with no values,
+ *     and `GroupState.hasTimedOut()` set to true.
+ *   - The timeout is reset every time the function is called on a group, that is, when the group
+ *     has new data, or the group has timed out. So the user has to set the timeout duration every
+ *     time the function is called, otherwise, there will not be any timeout set.
+ *
+ * `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument.
+ * This state will be applied when the first batch of the streaming query is processed. If there
+ * are no matching rows in the data for the keys present in the initial state, the state is still
+ * applied and the function will be invoked with the values being an empty iterator.
+ *
+ * Scala example of using GroupState in `mapGroupsWithState`:
+ * {{{
+ * // A mapping function that maintains an integer state for string keys and returns a string.
+ * // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
+ * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
+ *
+ *   if (state.hasTimedOut) {                // If called when timing out, remove the state
+ *     state.remove()
+ *
+ *   } else if (state.exists) {              // If state exists, use it for processing
+ *     val existingState = state.get         // Get the existing state
+ *     val shouldRemove = ...                // Decide whether to remove the state
+ *     if (shouldRemove) {
+ *       state.remove()                      // Remove the state
+ *
+ *     } else {
+ *       val newState = ...
+ *       state.update(newState)              // Set the new state
+ *       state.setTimeoutDuration("1 hour")  // Set the timeout
+ *     }
+ *
+ *   } else {
+ *     val initialState = ...
+ *     state.update(initialState)            // Set the initial state
+ *     state.setTimeoutDuration("1 hour")    // Set the timeout
+ *   }
+ *   ...
+ *   // return something
+ * }
+ *
+ * dataset
+ *   .groupByKey(...)
+ *   .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction)
+ * }}}
+ *
+ * Java example of using `GroupState`:
+ * {{{
+ * // A mapping function that maintains an integer state for string keys and returns a string.
+ * // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
+ * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
+ *    new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
+ *
+ *      @Override
+ *      public String call(String key, Iterator<Integer> value, GroupState<Integer> state) {
+ *        if (state.hasTimedOut()) {            // If called when timing out, remove the state
+ *          state.remove();
+ *
+ *        } else if (state.exists()) {            // If state exists, use it for processing
+ *          int existingState = state.get();      // Get the existing state
+ *          boolean shouldRemove = ...;           // Decide whether to remove the state
+ *          if (shouldRemove) {
+ *            state.remove();                     // Remove the state
+ *
+ *          } else {
+ *            int newState = ...;
+ *            state.update(newState);             // Set the new state
+ *            state.setTimeoutDuration("1 hour"); // Set the timeout
+ *          }
+ *
+ *        } else {
+ *          int initialState = ...;               // Set the initial state
+ *          state.update(initialState);
+ *          state.setTimeoutDuration("1 hour");   // Set the timeout
+ *        }
+ *        ...
+ *         // return something
+ *      }
+ *    };
+ *
+ * dataset
+ *     .groupByKey(...)
+ *     .mapGroupsWithState(
+ *         mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout);
+ * }}}
+ *
+ * @tparam S
+ *   User-defined type of the state to be stored for each group. Must be encodable into Spark SQL
+ *   types (see `Encoder` for more details).
+ * @since 3.5.0
+ */
+@Experimental
+@Evolving
+trait GroupState[S] extends LogicalGroupState[S] {
+
+  /** Whether state exists or not. */
+  def exists: Boolean
+
+  /** Get the state value if it exists, or throw NoSuchElementException. */
+  @throws[NoSuchElementException]("when state does not exist")
+  def get: S
+
+  /** Get the state value as a scala Option. */
+  def getOption: Option[S]
+
+  /** Update the value of the state. */
+  def update(newState: S): Unit
+
+  /** Remove this state. */
+  def remove(): Unit
+
+  /**
+   * Whether the function has been called because the key has timed out.
+   * @note
+   *   This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`.
+   */
+  def hasTimedOut: Boolean
+
+  /**
+   * Set the timeout duration in ms for this key.
+   *
+   * @note
+   *   [[GroupStateTimeout Processing time timeout]] must be enabled in
+   *   `[map/flatMap]GroupsWithState` for calling this method.
+   * @note
+   *   This method has no effect when used in a batch query.
+   */
+  @throws[IllegalArgumentException]("if 'durationMs' is not positive")
+  @throws[UnsupportedOperationException](
+    "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState")
+  def setTimeoutDuration(durationMs: Long): Unit
+
+  /**
+   * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc.
+   *
+   * @note
+   *   [[GroupStateTimeout Processing time timeout]] must be enabled in
+   *   `[map/flatMap]GroupsWithState` for calling this method.
+   * @note
+   *   This method has no effect when used in a batch query.
+   */
+  @throws[IllegalArgumentException]("if 'duration' is not a valid duration")
+  @throws[UnsupportedOperationException](
+    "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState")
+  def setTimeoutDuration(duration: String): Unit
+
+  /**
+   * Set the timeout timestamp for this key as milliseconds in epoch time. This timestamp cannot
+   * be older than the current watermark.
+   *
+   * @note
+   *   [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+   *   for calling this method.
+   * @note
+   *   This method has no effect when used in a batch query.
+   */
+  @throws[IllegalArgumentException](
+    "if 'timestampMs' is not positive or less than the current watermark in a streaming query")
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+  def setTimeoutTimestamp(timestampMs: Long): Unit
+
+  /**
+   * Set the timeout timestamp for this key as milliseconds in epoch time and an additional
+   * duration as a string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the
+   * additional duration) cannot be older than the current watermark.
+   *
+   * @note
+   *   [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+   *   for calling this method.
+   * @note
+   *   This method has no side effect when used in a batch query.
+   */
+  @throws[IllegalArgumentException](
+    "if 'additionalDuration' is invalid or the final timeout timestamp is less than " +
+      "the current watermark in a streaming query")
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+  def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit
+
+  /**
+   * Set the timeout timestamp for this key as a java.sql.Date. This timestamp cannot be older
+   * than the current watermark.
+   *
+   * @note
+   *   [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+   *   for calling this method.
+   * @note
+   *   This method has no side effect when used in a batch query.
+   */
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+  def setTimeoutTimestamp(timestamp: java.sql.Date): Unit
+
+  /**
+   * Set the timeout timestamp for this key as a java.sql.Date and an additional duration as a
+   * string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the additional
+   * duration) cannot be older than the current watermark.
+   *
+   * @note
+   *   [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+   *   for calling this method.
+   * @note
+   *   This method has no side effect when used in a batch query.
+   */
+  @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+  @throws[UnsupportedOperationException](
+    "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+  def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit
+
+  /**
+   * Get the current event time watermark as milliseconds in epoch time.
+   *
+   * @note
+   *   In a streaming query, this can be called only when watermark is set before calling
+   *   `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1.
+   */
+  @throws[UnsupportedOperationException](
+    "if watermark has not been set before in [map|flatMap]GroupsWithState")
+  def getCurrentWatermarkMs(): Long
+
+  /**
+   * Get the current processing time as milliseconds in epoch time.
+   * @note
+   *   In a streaming query, this will return a constant value throughout the duration of a
+   *   trigger, even if the trigger is re-executed.
+   */
+  def getCurrentProcessingTimeMs(): Long
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dc819fb4020..6ee252d1a58 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
 import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
@@ -64,11 +65,12 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation}
 import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPythonFunction}
 import org.apache.spark.sql.execution.stat.StatFunctions
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
 import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
 import org.apache.spark.sql.expressions.ReduceAggregator
 import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils}
 import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst}
-import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryProgress, Trigger}
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryProgress, Trigger}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.storage.CacheId
@@ -570,16 +572,82 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
       rel.getGroupingExpressionsList,
       rel.getSortingExpressionsList)
 
-    val mapped = new MapGroups(
-      udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
-      udf.inputDeserializer(ds.groupingAttributes),
-      ds.valueDeserializer,
-      ds.groupingAttributes,
-      ds.dataAttributes,
-      ds.sortOrder,
-      udf.outputObjAttr,
-      ds.analyzed)
-    SerializeFromObject(udf.outputNamedExpression, mapped)
+    if (rel.hasIsMapGroupsWithState) {
+      val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput
+      val initialDs = if (hasInitialState) {
+        UntypedKeyValueGroupedDataset(
+          rel.getInitialInput,
+          rel.getInitialGroupingExpressionsList,
+          rel.getSortingExpressionsList)
+      } else {
+        UntypedKeyValueGroupedDataset(
+          rel.getInput,
+          rel.getGroupingExpressionsList,
+          rel.getSortingExpressionsList)
+      }
+      val timeoutConf = if (!rel.hasTimeoutConf) {
+        GroupStateTimeout.NoTimeout
+      } else {
+        groupStateTimeoutFromString(rel.getTimeoutConf)
+      }
+      val outputMode = if (!rel.hasOutputMode) {
+        OutputMode.Update
+      } else {
+        InternalOutputModes(rel.getOutputMode)
+      }
+
+      val flatMapGroupsWithState = if (hasInitialState) {
+        new FlatMapGroupsWithState(
+          udf.function
+            .asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
+          udf.inputDeserializer(ds.groupingAttributes),
+          ds.valueDeserializer,
+          ds.groupingAttributes,
+          ds.dataAttributes,
+          udf.outputObjAttr,
+          initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]],
+          outputMode,
+          rel.getIsMapGroupsWithState,
+          timeoutConf,
+          hasInitialState,
+          initialDs.groupingAttributes,
+          initialDs.dataAttributes,
+          initialDs.valueDeserializer,
+          initialDs.analyzed,
+          ds.analyzed)
+      } else {
+        new FlatMapGroupsWithState(
+          udf.function
+            .asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
+          udf.inputDeserializer(ds.groupingAttributes),
+          ds.valueDeserializer,
+          ds.groupingAttributes,
+          ds.dataAttributes,
+          udf.outputObjAttr,
+          initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]],
+          outputMode,
+          rel.getIsMapGroupsWithState,
+          timeoutConf,
+          hasInitialState,
+          ds.groupingAttributes,
+          ds.dataAttributes,
+          udf.inputDeserializer(ds.groupingAttributes),
+          LocalRelation(initialDs.vEncoder.schema.toAttributes), // empty data set
+          ds.analyzed)
+      }
+      SerializeFromObject(udf.outputNamedExpression, flatMapGroupsWithState)
+    } else {
+      val mapped = new MapGroups(
+        udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+        udf.inputDeserializer(ds.groupingAttributes),
+        ds.valueDeserializer,
+        ds.groupingAttributes,
+        ds.dataAttributes,
+        ds.sortOrder,
+        udf.outputObjAttr,
+        ds.analyzed)
+      SerializeFromObject(udf.outputNamedExpression, mapped)
+    }
   }
 
   private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 7b1c55408be..20e0a13c5e4 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -886,17 +886,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _MAPPARTITIONS._serialized_start = 10801
     _MAPPARTITIONS._serialized_end = 10982
     _GROUPMAP._serialized_start = 10985
-    _GROUPMAP._serialized_end = 11264
-    _COGROUPMAP._serialized_start = 11267
-    _COGROUPMAP._serialized_end = 11793
-    _APPLYINPANDASWITHSTATE._serialized_start = 11796
-    _APPLYINPANDASWITHSTATE._serialized_end = 12153
-    _COLLECTMETRICS._serialized_start = 12156
-    _COLLECTMETRICS._serialized_end = 12292
-    _PARSE._serialized_start = 12295
-    _PARSE._serialized_end = 12683
+    _GROUPMAP._serialized_end = 11620
+    _COGROUPMAP._serialized_start = 11623
+    _COGROUPMAP._serialized_end = 12149
+    _APPLYINPANDASWITHSTATE._serialized_start = 12152
+    _APPLYINPANDASWITHSTATE._serialized_end = 12509
+    _COLLECTMETRICS._serialized_start = 12512
+    _COLLECTMETRICS._serialized_end = 12648
+    _PARSE._serialized_start = 12651
+    _PARSE._serialized_end = 13039
     _PARSE_OPTIONSENTRY._serialized_start = 3687
     _PARSE_OPTIONSENTRY._serialized_end = 3745
-    _PARSE_PARSEFORMAT._serialized_start = 12584
-    _PARSE_PARSEFORMAT._serialized_end = 12672
+    _PARSE_PARSEFORMAT._serialized_start = 12940
+    _PARSE_PARSEFORMAT._serialized_end = 13028
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 69a4d6b9ccc..bd6460519a4 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -2986,6 +2986,11 @@ class GroupMap(google.protobuf.message.Message):
     GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
     FUNC_FIELD_NUMBER: builtins.int
     SORTING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    INITIAL_INPUT_FIELD_NUMBER: builtins.int
+    INITIAL_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    IS_MAP_GROUPS_WITH_STATE_FIELD_NUMBER: builtins.int
+    OUTPUT_MODE_FIELD_NUMBER: builtins.int
+    TIMEOUT_CONF_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> global___Relation:
         """(Required) Input relation for Group Map API: apply, applyInPandas."""
@@ -3006,6 +3011,24 @@ class GroupMap(google.protobuf.message.Message):
         pyspark.sql.connect.proto.expressions_pb2.Expression
     ]:
         """(Optional) Expressions for sorting. Only used by Scala Sorted Group Map API."""
+    @property
+    def initial_input(self) -> global___Relation:
+        """Below fields are only used by (Flat)MapGroupsWithState
+        (Optional) Input relation for initial State.
+        """
+    @property
+    def initial_grouping_expressions(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """(Optional) Expressions for grouping keys of the initial state input relation."""
+    is_map_groups_with_state: builtins.bool
+    """(Optional) True if MapGroupsWithState, false if FlatMapGroupsWithState."""
+    output_mode: builtins.str
+    """(Optional) The output mode of the function."""
+    timeout_conf: builtins.str
+    """(Optional) Timeout configuration for groups that do not receive data for a while."""
     def __init__(
         self,
         *,
@@ -3020,23 +3043,82 @@ class GroupMap(google.protobuf.message.Message):
             pyspark.sql.connect.proto.expressions_pb2.Expression
         ]
         | None = ...,
+        initial_input: global___Relation | None = ...,
+        initial_grouping_expressions: collections.abc.Iterable[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]
+        | None = ...,
+        is_map_groups_with_state: builtins.bool | None = ...,
+        output_mode: builtins.str | None = ...,
+        timeout_conf: builtins.str | None = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_is_map_groups_with_state",
+            b"_is_map_groups_with_state",
+            "_output_mode",
+            b"_output_mode",
+            "_timeout_conf",
+            b"_timeout_conf",
+            "func",
+            b"func",
+            "initial_input",
+            b"initial_input",
+            "input",
+            b"input",
+            "is_map_groups_with_state",
+            b"is_map_groups_with_state",
+            "output_mode",
+            b"output_mode",
+            "timeout_conf",
+            b"timeout_conf",
+        ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_is_map_groups_with_state",
+            b"_is_map_groups_with_state",
+            "_output_mode",
+            b"_output_mode",
+            "_timeout_conf",
+            b"_timeout_conf",
             "func",
             b"func",
             "grouping_expressions",
             b"grouping_expressions",
+            "initial_grouping_expressions",
+            b"initial_grouping_expressions",
+            "initial_input",
+            b"initial_input",
             "input",
             b"input",
+            "is_map_groups_with_state",
+            b"is_map_groups_with_state",
+            "output_mode",
+            b"output_mode",
             "sorting_expressions",
             b"sorting_expressions",
+            "timeout_conf",
+            b"timeout_conf",
         ],
     ) -> None: ...
+    @typing.overload
+    def WhichOneof(
+        self,
+        oneof_group: typing_extensions.Literal[
+            "_is_map_groups_with_state", b"_is_map_groups_with_state"
+        ],
+    ) -> typing_extensions.Literal["is_map_groups_with_state"] | None: ...
+    @typing.overload
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_output_mode", b"_output_mode"]
+    ) -> typing_extensions.Literal["output_mode"] | None: ...
+    @typing.overload
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_timeout_conf", b"_timeout_conf"]
+    ) -> typing_extensions.Literal["timeout_conf"] | None: ...
 
 global___GroupMap = GroupMap
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org