You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/21 06:35:48 UTC
[spark] branch master updated: [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8a16aed9a17 [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
8a16aed9a17 is described below
commit 8a16aed9a17269b4c8111779229507e3c28ba945
Author: bogao007 <bo...@databricks.com>
AuthorDate: Wed Jun 21 15:35:34 2023 +0900
[SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
### What changes were proposed in this pull request?
Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect
### Why are the changes needed?
To support streaming state APIs in Spark Connect
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
Added unit test
Closes #41558 from bogao007/sc-state-api.
Authored-by: bogao007 <bo...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../apache/spark/sql/KeyValueGroupedDataset.scala | 398 +++++++++++++++++++++
.../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 107 ++++++
.../CheckConnectJvmClientCompatibility.scala | 6 -
.../FlatMapGroupsWithStateStreamingSuite.scala | 224 ++++++++++++
.../function/FlatMapGroupsWithStateFunction.java | 39 ++
.../java/function/MapGroupsWithStateFunction.java | 38 ++
.../main/protobuf/spark/connect/relations.proto | 16 +
.../apache/spark/sql/connect/common/UdfUtils.scala | 26 ++
.../apache/spark/sql/streaming/GroupState.scala | 336 +++++++++++++++++
.../sql/connect/planner/SparkConnectPlanner.scala | 92 ++++-
python/pyspark/sql/connect/proto/relations_pb2.py | 24 +-
python/pyspark/sql/connect/proto/relations_pb2.pyi | 84 ++++-
12 files changed, 1359 insertions(+), 31 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 7b2fa3b52be..20c130b83cb 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
import org.apache.spark.sql.connect.common.UdfUtils
import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode}
/**
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
@@ -460,6 +461,356 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable
cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)(
UdfUtils.coGroupFunctionToScalaFunc(f))(encoder)
}
+
+ protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
+ outputMode: Option[OutputMode],
+ timeoutConf: GroupStateTimeout,
+ initialState: Option[KeyValueGroupedDataset[K, S]],
+ isMapGroupWithState: Boolean)(
+ func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+ throw new UnsupportedOperationException
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See
+ * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def mapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
+ mapGroupsWithState(GroupStateTimeout.NoTimeout)(func)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See
+ * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(
+ func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
+ flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)(
+ UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See
+ * [[org.apache.spark.sql.streaming.GroupState]] for more details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param timeoutConf
+ * Timeout Conf, see GroupStateTimeout for more details
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to
+ * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def mapGroupsWithState[S: Encoder, U: Encoder](
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S])(
+ func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
+ flatMapGroupsWithStateHelper(
+ None,
+ timeoutConf,
+ Some(initialState),
+ isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func))
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
+ stateEncoder,
+ outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] = {
+ mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(
+ stateEncoder,
+ outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+ mapGroupsWithState[S, U](timeoutConf, initialState)(
+ UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ outputMode: OutputMode,
+ timeoutConf: GroupStateTimeout)(
+ func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+ flatMapGroupsWithStateHelper(
+ Some(outputMode),
+ timeoutConf,
+ None,
+ isMapGroupWithState = false)(func)
+ }
+
+ /**
+ * (Scala-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group. To covert a Dataset `ds` of type of type
+ * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
+ * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what
+ * types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ outputMode: OutputMode,
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S])(
+ func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+ flatMapGroupsWithStateHelper(
+ Some(outputMode),
+ timeoutConf,
+ Some(initialState),
+ isMapGroupWithState = false)(func)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout): Dataset[U] = {
+ val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
+ flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * (Java-specific) Applies the given function to each group of data, while maintaining a
+ * user-defined per-group state. The result Dataset will represent the objects returned by the
+ * function. For a static batch Dataset, the function will be invoked once per group. For a
+ * streaming Dataset, the function will be invoked for each group repeatedly in every trigger,
+ * and updates to each group's state will be saved across invocations. See `GroupState` for more
+ * details.
+ *
+ * @tparam S
+ * The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U
+ * The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func
+ * Function to be called on every group.
+ * @param outputMode
+ * The output mode of the function.
+ * @param stateEncoder
+ * Encoder for the state type.
+ * @param outputEncoder
+ * Encoder for the output type.
+ * @param timeoutConf
+ * Timeout configuration for groups that do not receive data for a while.
+ * @param initialState
+ * The user provided state that will be initialized when the first batch of data is processed
+ * in the streaming query. The user defined function will be called on the state data even if
+ * there are no other values in the group. To covert a Dataset `ds` of type of type
+ * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use
+ * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}}
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 3.5.0
+ */
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ outputMode: OutputMode,
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U],
+ timeoutConf: GroupStateTimeout,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+ val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func)
+ flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)(
+ stateEncoder,
+ outputEncoder)
+ }
}
/**
@@ -572,6 +923,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
agg(aggregator)
}
+ override protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder](
+ outputMode: Option[OutputMode],
+ timeoutConf: GroupStateTimeout,
+ initialState: Option[KeyValueGroupedDataset[K, S]],
+ isMapGroupWithState: Boolean)(
+ func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = {
+ if (outputMode.isDefined && outputMode.get != OutputMode.Append &&
+ outputMode.get != OutputMode.Update) {
+ throw new IllegalArgumentException("The output mode of function should be append or update")
+ }
+
+ if (initialState.isDefined) {
+ assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]])
+ }
+
+ val initialStateImpl = if (initialState.isDefined) {
+ initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]
+ } else {
+ null
+ }
+
+ val outputEncoder = encoderFor[U]
+ val nf = if (valueMapFunc == UdfUtils.identical()) {
+ func
+ } else {
+ UdfUtils.mapValuesAdaptor(func, valueMapFunc)
+ }
+
+ sparkSession.newDataset[U](outputEncoder) { builder =>
+ val groupMapBuilder = builder.getGroupMapBuilder
+ groupMapBuilder
+ .setInput(plan.getRoot)
+ .addAllGroupingExpressions(groupingExprs)
+ .setFunc(getUdf(nf, outputEncoder)(ivEncoder))
+ .setIsMapGroupsWithState(isMapGroupWithState)
+ .setOutputMode(if (outputMode.isEmpty) OutputMode.Update.toString
+ else outputMode.get.toString)
+ .setTimeoutConf(timeoutConf.toString)
+
+ if (initialStateImpl != null) {
+ groupMapBuilder
+ .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs)
+ .setInitialInput(initialStateImpl.plan.getRoot)
+ }
+ }
+ }
+
private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])(
inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = {
val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index e7a77eed70d..404239f7e14 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -16,14 +16,21 @@
*/
package org.apache.spark.sql
+import java.sql.Timestamp
import java.util.Arrays
import io.grpc.StatusRuntimeException
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
import org.apache.spark.sql.connect.client.util.QueryTest
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.sql.types._
+case class ClickEvent(id: String, timestamp: Timestamp)
+
+case class ClickState(id: String, count: Int)
+
/**
* All tests in this class requires client UDF artifacts synced with the server.
*/
@@ -447,4 +454,104 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
checkDataset(keys, "1", "2", "10", "20")
}
+
+ test("flatMapGroupsWithState") {
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ Iterator(ClickState(key, values.size))
+ }
+
+ val session: SparkSession = spark
+ import session.implicits._
+ val values = Seq(
+ ClickEvent("a", new Timestamp(12345)),
+ ClickEvent("a", new Timestamp(12346)),
+ ClickEvent("a", new Timestamp(12347)),
+ ClickEvent("b", new Timestamp(12348)),
+ ClickEvent("b", new Timestamp(12349)),
+ ClickEvent("c", new Timestamp(12350)))
+ .toDS()
+ .groupByKey(_.id)
+ .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc)
+
+ checkDataset(values, ClickState("a", 3), ClickState("b", 2), ClickState("c", 1))
+ }
+
+ test("flatMapGroupsWithState - with initial state") {
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ val currState = state.getOption.getOrElse(ClickState(key, 0))
+ Iterator(ClickState(key, currState.count + values.size))
+ }
+
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val initialStateDS = Seq(ClickState("a", 2), ClickState("b", 1)).toDS()
+ val initialState = initialStateDS.groupByKey(_.id).mapValues(x => x)
+
+ val values = Seq(
+ ClickEvent("a", new Timestamp(12345)),
+ ClickEvent("a", new Timestamp(12346)),
+ ClickEvent("a", new Timestamp(12347)),
+ ClickEvent("b", new Timestamp(12348)),
+ ClickEvent("b", new Timestamp(12349)),
+ ClickEvent("c", new Timestamp(12350)))
+ .toDS()
+ .groupByKey(_.id)
+ .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+
+ checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1))
+ }
+
+ test("mapGroupsWithState") {
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ ClickState(key, values.size)
+ }
+
+ val session: SparkSession = spark
+ import session.implicits._
+ val values = Seq(
+ ClickEvent("a", new Timestamp(12345)),
+ ClickEvent("a", new Timestamp(12346)),
+ ClickEvent("a", new Timestamp(12347)),
+ ClickEvent("b", new Timestamp(12348)),
+ ClickEvent("b", new Timestamp(12349)),
+ ClickEvent("c", new Timestamp(12350)))
+ .toDS()
+ .groupByKey(_.id)
+ .mapGroupsWithState(GroupStateTimeout.NoTimeout)(stateFunc)
+
+ checkDataset(values, ClickState("a", 3), ClickState("b", 2), ClickState("c", 1))
+ }
+
+ test("mapGroupsWithState - with initial state") {
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ val currState = state.getOption.getOrElse(ClickState(key, 0))
+ ClickState(key, currState.count + values.size)
+ }
+
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val initialStateDS = Seq(ClickState("a", 2), ClickState("b", 1)).toDS()
+ val initialState = initialStateDS.groupByKey(_.id).mapValues(x => x)
+
+ val values = Seq(
+ ClickEvent("a", new Timestamp(12345)),
+ ClickEvent("a", new Timestamp(12346)),
+ ClickEvent("a", new Timestamp(12347)),
+ ClickEvent("b", new Timestamp(12348)),
+ ClickEvent("b", new Timestamp(12349)),
+ ClickEvent("c", new Timestamp(12350)))
+ .toDS()
+ .groupByKey(_.id)
+ .mapGroupsWithState(GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+
+ checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1))
+ }
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7a9a889706d..6b648fd152b 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -201,12 +201,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
// KeyValueGroupedDataset
- ProblemFilters.exclude[Problem](
- "org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState"
- ), // streaming
- ProblemFilters.exclude[Problem](
- "org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState"
- ), // streaming
ProblemFilters.exclude[Problem](
"org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.this"),
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
new file mode 100644
index 00000000000..cdb6b9a2e9c
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.sql.Timestamp
+
+import org.scalatest.concurrent.Eventually.eventually
+import org.scalatest.concurrent.Futures.timeout
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.sql.{SparkSession, SQLHelper}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
+import org.apache.spark.sql.connect.client.util.QueryTest
+import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+
+case class ClickEvent(id: String, timestamp: Timestamp)
+
+case class ClickState(id: String, count: Int)
+
+class FlatMapGroupsWithStateStreamingSuite extends QueryTest with SQLHelper {
+
+ val flatMapGroupsWithStateSchema: StructType = StructType(
+ Array(StructField("id", StringType), StructField("timestamp", TimestampType)))
+
+ val flatMapGroupsWithStateData: Seq[ClickEvent] = Seq(
+ ClickEvent("a", new Timestamp(12345)),
+ ClickEvent("a", new Timestamp(12346)),
+ ClickEvent("a", new Timestamp(12347)),
+ ClickEvent("b", new Timestamp(12348)),
+ ClickEvent("b", new Timestamp(12349)),
+ ClickEvent("c", new Timestamp(12350)))
+
+ val flatMapGroupsWithStateInitialStateData: Seq[ClickState] =
+ Seq(ClickState("a", 2), ClickState("b", 1))
+
+ test("flatMapGroupsWithState - streaming") {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ Iterator(ClickState(key, values.size))
+ }
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ flatMapGroupsWithStateData.toDS().write.parquet(path)
+ val q = spark.readStream
+ .schema(flatMapGroupsWithStateSchema)
+ .parquet(path)
+ .as[ClickEvent]
+ .groupByKey(_.id)
+ .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc)
+ .writeStream
+ .format("memory")
+ .queryName("my_sink")
+ .start()
+
+ try {
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ checkDataset(
+ spark.table("my_sink").toDF().as[ClickState],
+ ClickState("c", 1),
+ ClickState("b", 2),
+ ClickState("a", 3))
+ }
+ } finally {
+ q.stop()
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+ }
+ }
+ }
+
+ test("flatMapGroupsWithState - streaming - with initial state") {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ val currState = state.getOption.getOrElse(ClickState(key, 0))
+ Iterator(ClickState(key, currState.count + values.size))
+ }
+ val initialState = flatMapGroupsWithStateInitialStateData
+ .toDS()
+ .groupByKey(_.id)
+ .mapValues(x => x)
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ flatMapGroupsWithStateData.toDS().write.parquet(path)
+ val q = spark.readStream
+ .schema(flatMapGroupsWithStateSchema)
+ .parquet(path)
+ .as[ClickEvent]
+ .groupByKey(_.id)
+ .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+ .writeStream
+ .format("memory")
+ .queryName("my_sink")
+ .start()
+
+ try {
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ checkDataset(
+ spark.table("my_sink").toDF().as[ClickState],
+ ClickState("c", 1),
+ ClickState("b", 3),
+ ClickState("a", 5))
+ }
+ } finally {
+ q.stop()
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+ }
+ }
+ }
+
+ test("mapGroupsWithState - streaming") {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ ClickState(key, values.size)
+ }
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ flatMapGroupsWithStateData.toDS().write.parquet(path)
+ val q = spark.readStream
+ .schema(flatMapGroupsWithStateSchema)
+ .parquet(path)
+ .as[ClickEvent]
+ .groupByKey(_.id)
+ .mapGroupsWithState(GroupStateTimeout.NoTimeout)(stateFunc)
+ .writeStream
+ .format("memory")
+ .queryName("my_sink")
+ .outputMode("update")
+ .start()
+
+ try {
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ checkDataset(
+ spark.table("my_sink").toDF().as[ClickState],
+ ClickState("c", 1),
+ ClickState("b", 2),
+ ClickState("a", 3))
+ }
+ } finally {
+ q.stop()
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+ }
+ }
+ }
+
+ test("mapGroupsWithState - streaming - with initial state") {
+ val session: SparkSession = spark
+ import session.implicits._
+
+ val stateFunc =
+ (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => {
+ val currState = state.getOption.getOrElse(ClickState(key, 0))
+ ClickState(key, currState.count + values.size)
+ }
+ val initialState = flatMapGroupsWithStateInitialStateData
+ .toDS()
+ .groupByKey(_.id)
+ .mapValues(x => x)
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ flatMapGroupsWithStateData.toDS().write.parquet(path)
+ val q = spark.readStream
+ .schema(flatMapGroupsWithStateSchema)
+ .parquet(path)
+ .as[ClickEvent]
+ .groupByKey(_.id)
+ .mapGroupsWithState(GroupStateTimeout.NoTimeout, initialState)(stateFunc)
+ .writeStream
+ .format("memory")
+ .queryName("my_sink")
+ .outputMode("update")
+ .start()
+
+ try {
+ q.processAllAvailable()
+ eventually(timeout(30.seconds)) {
+ checkDataset(
+ spark.table("my_sink").toDF().as[ClickState],
+ ClickState("c", 1),
+ ClickState("b", 3),
+ ClickState("a", 5))
+ }
+ } finally {
+ q.stop()
+ spark.sql("DROP TABLE IF EXISTS my_sink")
+ }
+ }
+ }
+}
diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
new file mode 100644
index 00000000000..c917c8d28be
--- /dev/null
+++ b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.streaming.GroupState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState(
+ * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode,
+ * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}
+ * @since 3.5.0
+ */
+@Experimental
+@Evolving
+public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+ Iterator<R> call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
+}
diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
new file mode 100644
index 00000000000..ae179ad7d27
--- /dev/null
+++ b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.streaming.GroupState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@code org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState(
+ * MapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)}
+ * @since 3.5.0
+ */
+@Experimental
+@Evolving
+public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+ R call(K key, Iterator<V> values, GroupState<S> state) throws Exception;
+}
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 6347bd7bc56..ea432bb48fc 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -857,6 +857,22 @@ message GroupMap {
// (Optional) Expressions for sorting. Only used by Scala Sorted Group Map API.
repeated Expression sorting_expressions = 4;
+
+ // Below fields are only used by (Flat)MapGroupsWithState
+ // (Optional) Input relation for initial State.
+ Relation initial_input = 5;
+
+ // (Optional) Expressions for grouping keys of the initial state input relation.
+ repeated Expression initial_grouping_expressions = 6;
+
+ // (Optional) True if MapGroupsWithState, false if FlatMapGroupsWithState.
+ optional bool is_map_groups_with_state = 7;
+
+ // (Optional) The output mode of the function.
+ optional string output_mode = 8;
+
+ // (Optional) Timeout configuration for groups that do not receive data for a while.
+ optional string timeout_conf = 9;
}
message CoGroupMap {
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
index 06a6c74f268..883637ff86c 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala
@@ -20,6 +20,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.api.java.function._
import org.apache.spark.sql.Row
+import org.apache.spark.sql.streaming.GroupState
/**
* Util functions to help convert input functions between typed filter, map, flatMap,
@@ -95,6 +96,31 @@ private[sql] object UdfUtils extends Serializable {
}
}
+ def mapValuesAdaptor[K, V, S, U, IV](
+ f: (K, Iterator[V], GroupState[S]) => Iterator[U],
+ valueMapFunc: IV => V): (K, Iterator[IV], GroupState[S]) => Iterator[U] = {
+ (k: K, itr: Iterator[IV], s: GroupState[S]) =>
+ {
+ f(k, itr.map(v => valueMapFunc(v)), s)
+ }
+ }
+
+ def mapGroupsWithStateFuncToFlatMapAdaptor[K, V, S, U](
+ f: (K, Iterator[V], GroupState[S]) => U): (K, Iterator[V], GroupState[S]) => Iterator[U] = {
+ (k: K, itr: Iterator[V], s: GroupState[S]) => Iterator(f(k, itr, s))
+ }
+
+ def mapGroupsWithStateFuncToScalaFunc[K, V, S, U](
+ f: MapGroupsWithStateFunction[K, V, S, U]): (K, Iterator[V], GroupState[S]) => U = {
+ (key, data, groupState) => f.call(key, data.asJava, groupState)
+ }
+
+ def flatMapGroupsWithStateFuncToScalaFunc[K, V, S, U](
+ f: FlatMapGroupsWithStateFunction[K, V, S, U])
+ : (K, Iterator[V], GroupState[S]) => Iterator[U] = { (key, data, groupState) =>
+ f.call(key, data.asJava, groupState).asScala
+ }
+
def mapReduceFuncToScalaFunc[T](func: ReduceFunction[T]): (T, T) => T = func.call
def identical[T](): T => T = t => t
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
new file mode 100644
index 00000000000..bd418a89534
--- /dev/null
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState
+
+/**
+ * :: Experimental ::
+ *
+ * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and
+ * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`.
+ *
+ * Detail description on `[map/flatMap]GroupsWithState` operation
+ * -------------------------------------------------------------- Both, `mapGroupsWithState` and
+ * `flatMapGroupsWithState` in `KeyValueGroupedDataset` will invoke the user-given function on
+ * each group (defined by the grouping function in `Dataset.groupByKey()`) while maintaining a
+ * user-defined per-group state between invocations. For a static batch Dataset, the function will
+ * be invoked once per group. For a streaming Dataset, the function will be invoked for each group
+ * repeatedly in every trigger. That is, in every batch of the `StreamingQuery`, the function will
+ * be invoked once for each group that has data in the trigger. Furthermore, if timeout is set,
+ * then the function will be invoked on timed-out groups (more detail below).
+ *
+ * The function is invoked with the following parameters.
+ * - The key of the group.
+ * - An iterator containing all the values for this group.
+ * - A user-defined state object set by previous invocations of the given function.
+ *
+ * In case of a batch Dataset, there is only one invocation and the state object will be empty as
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` is
+ * equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have no
+ * effect.
+ *
+ * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the
+ * former allows the function to return one and only one record, whereas the latter allows the
+ * function to return any number of records (including no records). Furthermore, the
+ * `flatMapGroupsWithState` is associated with an operation output mode, which can be either
+ * `Append` or `Update`. Semantically, this defines whether the output records of one trigger is
+ * effectively replacing the previously output records (from previous triggers) or is appending to
+ * the list of previously output records. Essentially, this defines how the Result Table (refer to
+ * the semantics in the programming guide) is updated, and allows us to reason about the semantics
+ * of later operations.
+ *
+ * Important points to note about the function (both mapGroupsWithState and
+ * flatMapGroupsWithState).
+ * - In a trigger, the function will be called only the groups present in the batch. So do not
+ * assume that the function will be called in every trigger for every group that has state.
+ * - There is no guaranteed ordering of values in the iterator in the function, neither with
+ * batch, nor with streaming Datasets.
+ * - All the data will be shuffled before applying the function.
+ * - If timeout is set, then the function will also be called with no values. See more details
+ * on `GroupStateTimeout` below.
+ *
+ * Important points to note about using `GroupState`.
+ * - The value of the state cannot be null. So updating state with null will throw
+ * `IllegalArgumentException`.
+ * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers.
+ * - If `remove()` is called, then `exists()` will return `false`, `get()` will throw
+ * `NoSuchElementException` and `getOption()` will return `None`
+ * - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ * `get()` and `getOption()`will return the updated value.
+ *
+ * Important points to note about using `GroupStateTimeout`.
+ * - The timeout type is a global param across all the groups (set as `timeout` param in
+ * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable
+ * per group by calling `setTimeout...()` in `GroupState`.
+ * - Timeouts can be either based on processing time (i.e.
+ * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e.
+ * `GroupStateTimeout.EventTimeTimeout`).
+ * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling
+ * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the
+ * set duration. Guarantees provided by this timeout with a duration of D ms are as follows:
+ * - Timeout will never occur before the clock time has advanced by D ms
+ * - Timeout will occur eventually when there is a trigger in the query (i.e. after D ms). So
+ * there is no strict upper bound on when the timeout would occur. For example, the trigger
+ * interval of the query will affect when the timeout actually occurs. If there is no data
+ * in the stream (for any group) for a while, then there will not be any trigger and timeout
+ * function call will not occur until there is data.
+ * - Since the processing time timeout is based on the clock time, it is affected by the
+ * variations in the system clock (i.e. time zone changes, clock skew, etc.).
+ * - With `EventTimeTimeout`, the user also has to specify the event time watermark in the query
+ * using `Dataset.withWatermark()`. With this setting, data that is older than the watermark
+ * is filtered out. The timeout can be set for a group by setting a timeout timestamp
+ * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark
+ * advances beyond the set timestamp. You can control the timeout delay by two parameters -
+ * (i) watermark delay and an additional duration beyond the timestamp in the event (which is
+ * guaranteed to be newer than watermark due to the filtering). Guarantees provided by this
+ * timeout are as follows:
+ * - Timeout will never occur before the watermark has exceeded the set timeout.
+ * - Similar to processing time timeouts, there is no strict upper bound on the delay when the
+ * timeout actually occurs. The watermark can advance only when there is data in the stream
+ * and the event time of the data has actually advanced.
+ * - When the timeout occurs for a group, the function is called for that group with no values,
+ * and `GroupState.hasTimedOut()` set to true.
+ * - The timeout is reset every time the function is called on a group, that is, when the group
+ * has new data, or the group has timed out. So the user has to set the timeout duration every
+ * time the function is called, otherwise, there will not be any timeout set.
+ *
+ * `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument.
+ * This state will be applied when the first batch of the streaming query is processed. If there
+ * are no matching rows in the data for the keys present in the initial state, the state is still
+ * applied and the function will be invoked with the values being an empty iterator.
+ *
+ * Scala example of using GroupState in `mapGroupsWithState`:
+ * {{{
+ * // A mapping function that maintains an integer state for string keys and returns a string.
+ * // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
+ * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = {
+ *
+ * if (state.hasTimedOut) { // If called when timing out, remove the state
+ * state.remove()
+ *
+ * } else if (state.exists) { // If state exists, use it for processing
+ * val existingState = state.get // Get the existing state
+ * val shouldRemove = ... // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove() // Remove the state
+ *
+ * } else {
+ * val newState = ...
+ * state.update(newState) // Set the new state
+ * state.setTimeoutDuration("1 hour") // Set the timeout
+ * }
+ *
+ * } else {
+ * val initialState = ...
+ * state.update(initialState) // Set the initial state
+ * state.setTimeoutDuration("1 hour") // Set the timeout
+ * }
+ * ...
+ * // return something
+ * }
+ *
+ * dataset
+ * .groupByKey(...)
+ * .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction)
+ * }}}
+ *
+ * Java example of using `GroupState`:
+ * {{{
+ * // A mapping function that maintains an integer state for string keys and returns a string.
+ * // Additionally, it sets a timeout to remove the state if it has not received data for an hour.
+ * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
+ * new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
+ *
+ * @Override
+ * public String call(String key, Iterator<Integer> value, GroupState<Integer> state) {
+ * if (state.hasTimedOut()) { // If called when timing out, remove the state
+ * state.remove();
+ *
+ * } else if (state.exists()) { // If state exists, use it for processing
+ * int existingState = state.get(); // Get the existing state
+ * boolean shouldRemove = ...; // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove(); // Remove the state
+ *
+ * } else {
+ * int newState = ...;
+ * state.update(newState); // Set the new state
+ * state.setTimeoutDuration("1 hour"); // Set the timeout
+ * }
+ *
+ * } else {
+ * int initialState = ...; // Set the initial state
+ * state.update(initialState);
+ * state.setTimeoutDuration("1 hour"); // Set the timeout
+ * }
+ * ...
+ * // return something
+ * }
+ * };
+ *
+ * dataset
+ * .groupByKey(...)
+ * .mapGroupsWithState(
+ * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout);
+ * }}}
+ *
+ * @tparam S
+ * User-defined type of the state to be stored for each group. Must be encodable into Spark SQL
+ * types (see `Encoder` for more details).
+ * @since 3.5.0
+ */
+@Experimental
+@Evolving
+trait GroupState[S] extends LogicalGroupState[S] {
+
+ /** Whether state exists or not. */
+ def exists: Boolean
+
+ /** Get the state value if it exists, or throw NoSuchElementException. */
+ @throws[NoSuchElementException]("when state does not exist")
+ def get: S
+
+ /** Get the state value as a scala Option. */
+ def getOption: Option[S]
+
+ /** Update the value of the state. */
+ def update(newState: S): Unit
+
+ /** Remove this state. */
+ def remove(): Unit
+
+ /**
+ * Whether the function has been called because the key has timed out.
+ * @note
+ * This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`.
+ */
+ def hasTimedOut: Boolean
+
+ /**
+ * Set the timeout duration in ms for this key.
+ *
+ * @note
+ * [[GroupStateTimeout Processing time timeout]] must be enabled in
+ * `[map/flatMap]GroupsWithState` for calling this method.
+ * @note
+ * This method has no effect when used in a batch query.
+ */
+ @throws[IllegalArgumentException]("if 'durationMs' is not positive")
+ @throws[UnsupportedOperationException](
+ "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState")
+ def setTimeoutDuration(durationMs: Long): Unit
+
+ /**
+ * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc.
+ *
+ * @note
+ * [[GroupStateTimeout Processing time timeout]] must be enabled in
+ * `[map/flatMap]GroupsWithState` for calling this method.
+ * @note
+ * This method has no effect when used in a batch query.
+ */
+ @throws[IllegalArgumentException]("if 'duration' is not a valid duration")
+ @throws[UnsupportedOperationException](
+ "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState")
+ def setTimeoutDuration(duration: String): Unit
+
+ /**
+ * Set the timeout timestamp for this key as milliseconds in epoch time. This timestamp cannot
+ * be older than the current watermark.
+ *
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no effect when used in a batch query.
+ */
+ @throws[IllegalArgumentException](
+ "if 'timestampMs' is not positive or less than the current watermark in a streaming query")
+ @throws[UnsupportedOperationException](
+ "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+ def setTimeoutTimestamp(timestampMs: Long): Unit
+
+ /**
+ * Set the timeout timestamp for this key as milliseconds in epoch time and an additional
+ * duration as a string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the
+ * additional duration) cannot be older than the current watermark.
+ *
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no side effect when used in a batch query.
+ */
+ @throws[IllegalArgumentException](
+ "if 'additionalDuration' is invalid or the final timeout timestamp is less than " +
+ "the current watermark in a streaming query")
+ @throws[UnsupportedOperationException](
+ "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+ def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit
+
+ /**
+ * Set the timeout timestamp for this key as a java.sql.Date. This timestamp cannot be older
+ * than the current watermark.
+ *
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no side effect when used in a batch query.
+ */
+ @throws[UnsupportedOperationException](
+ "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+ def setTimeoutTimestamp(timestamp: java.sql.Date): Unit
+
+ /**
+ * Set the timeout timestamp for this key as a java.sql.Date and an additional duration as a
+ * string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the additional
+ * duration) cannot be older than the current watermark.
+ *
+ * @note
+ * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState`
+ * for calling this method.
+ * @note
+ * This method has no side effect when used in a batch query.
+ */
+ @throws[IllegalArgumentException]("if 'additionalDuration' is invalid")
+ @throws[UnsupportedOperationException](
+ "if event time timeout has not been enabled in [map|flatMap]GroupsWithState")
+ def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit
+
+ /**
+ * Get the current event time watermark as milliseconds in epoch time.
+ *
+ * @note
+ * In a streaming query, this can be called only when watermark is set before calling
+ * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1.
+ */
+ @throws[UnsupportedOperationException](
+ "if watermark has not been set before in [map|flatMap]GroupsWithState")
+ def getCurrentWatermarkMs(): Long
+
+ /**
+ * Get the current processing time as milliseconds in epoch time.
+ * @note
+ * In a streaming query, this will return a constant value throughout the duration of a
+ * trigger, even if the trigger is re-executed.
+ */
+ def getCurrentProcessingTimeMs(): Long
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dc819fb4020..6ee252d1a58 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket}
@@ -64,11 +65,12 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation}
import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPythonFunction}
import org.apache.spark.sql.execution.stat.StatFunctions
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils}
import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst}
-import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryProgress, Trigger}
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.CacheId
@@ -570,16 +572,82 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
rel.getGroupingExpressionsList,
rel.getSortingExpressionsList)
- val mapped = new MapGroups(
- udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
- udf.inputDeserializer(ds.groupingAttributes),
- ds.valueDeserializer,
- ds.groupingAttributes,
- ds.dataAttributes,
- ds.sortOrder,
- udf.outputObjAttr,
- ds.analyzed)
- SerializeFromObject(udf.outputNamedExpression, mapped)
+ if (rel.hasIsMapGroupsWithState) {
+ val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput
+ val initialDs = if (hasInitialState) {
+ UntypedKeyValueGroupedDataset(
+ rel.getInitialInput,
+ rel.getInitialGroupingExpressionsList,
+ rel.getSortingExpressionsList)
+ } else {
+ UntypedKeyValueGroupedDataset(
+ rel.getInput,
+ rel.getGroupingExpressionsList,
+ rel.getSortingExpressionsList)
+ }
+ val timeoutConf = if (!rel.hasTimeoutConf) {
+ GroupStateTimeout.NoTimeout
+ } else {
+ groupStateTimeoutFromString(rel.getTimeoutConf)
+ }
+ val outputMode = if (!rel.hasOutputMode) {
+ OutputMode.Update
+ } else {
+ InternalOutputModes(rel.getOutputMode)
+ }
+
+ val flatMapGroupsWithState = if (hasInitialState) {
+ new FlatMapGroupsWithState(
+ udf.function
+ .asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
+ udf.inputDeserializer(ds.groupingAttributes),
+ ds.valueDeserializer,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ udf.outputObjAttr,
+ initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]],
+ outputMode,
+ rel.getIsMapGroupsWithState,
+ timeoutConf,
+ hasInitialState,
+ initialDs.groupingAttributes,
+ initialDs.dataAttributes,
+ initialDs.valueDeserializer,
+ initialDs.analyzed,
+ ds.analyzed)
+ } else {
+ new FlatMapGroupsWithState(
+ udf.function
+ .asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
+ udf.inputDeserializer(ds.groupingAttributes),
+ ds.valueDeserializer,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ udf.outputObjAttr,
+ initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]],
+ outputMode,
+ rel.getIsMapGroupsWithState,
+ timeoutConf,
+ hasInitialState,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ udf.inputDeserializer(ds.groupingAttributes),
+ LocalRelation(initialDs.vEncoder.schema.toAttributes), // empty data set
+ ds.analyzed)
+ }
+ SerializeFromObject(udf.outputNamedExpression, flatMapGroupsWithState)
+ } else {
+ val mapped = new MapGroups(
+ udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+ udf.inputDeserializer(ds.groupingAttributes),
+ ds.valueDeserializer,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ ds.sortOrder,
+ udf.outputObjAttr,
+ ds.analyzed)
+ SerializeFromObject(udf.outputNamedExpression, mapped)
+ }
}
private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 7b1c55408be..20e0a13c5e4 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -886,17 +886,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_MAPPARTITIONS._serialized_start = 10801
_MAPPARTITIONS._serialized_end = 10982
_GROUPMAP._serialized_start = 10985
- _GROUPMAP._serialized_end = 11264
- _COGROUPMAP._serialized_start = 11267
- _COGROUPMAP._serialized_end = 11793
- _APPLYINPANDASWITHSTATE._serialized_start = 11796
- _APPLYINPANDASWITHSTATE._serialized_end = 12153
- _COLLECTMETRICS._serialized_start = 12156
- _COLLECTMETRICS._serialized_end = 12292
- _PARSE._serialized_start = 12295
- _PARSE._serialized_end = 12683
+ _GROUPMAP._serialized_end = 11620
+ _COGROUPMAP._serialized_start = 11623
+ _COGROUPMAP._serialized_end = 12149
+ _APPLYINPANDASWITHSTATE._serialized_start = 12152
+ _APPLYINPANDASWITHSTATE._serialized_end = 12509
+ _COLLECTMETRICS._serialized_start = 12512
+ _COLLECTMETRICS._serialized_end = 12648
+ _PARSE._serialized_start = 12651
+ _PARSE._serialized_end = 13039
_PARSE_OPTIONSENTRY._serialized_start = 3687
_PARSE_OPTIONSENTRY._serialized_end = 3745
- _PARSE_PARSEFORMAT._serialized_start = 12584
- _PARSE_PARSEFORMAT._serialized_end = 12672
+ _PARSE_PARSEFORMAT._serialized_start = 12940
+ _PARSE_PARSEFORMAT._serialized_end = 13028
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 69a4d6b9ccc..bd6460519a4 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -2986,6 +2986,11 @@ class GroupMap(google.protobuf.message.Message):
GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
FUNC_FIELD_NUMBER: builtins.int
SORTING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+ INITIAL_INPUT_FIELD_NUMBER: builtins.int
+ INITIAL_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+ IS_MAP_GROUPS_WITH_STATE_FIELD_NUMBER: builtins.int
+ OUTPUT_MODE_FIELD_NUMBER: builtins.int
+ TIMEOUT_CONF_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for Group Map API: apply, applyInPandas."""
@@ -3006,6 +3011,24 @@ class GroupMap(google.protobuf.message.Message):
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Optional) Expressions for sorting. Only used by Scala Sorted Group Map API."""
+ @property
+ def initial_input(self) -> global___Relation:
+ """Below fields are only used by (Flat)MapGroupsWithState
+ (Optional) Input relation for initial State.
+ """
+ @property
+ def initial_grouping_expressions(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Optional) Expressions for grouping keys of the initial state input relation."""
+ is_map_groups_with_state: builtins.bool
+ """(Optional) True if MapGroupsWithState, false if FlatMapGroupsWithState."""
+ output_mode: builtins.str
+ """(Optional) The output mode of the function."""
+ timeout_conf: builtins.str
+ """(Optional) Timeout configuration for groups that do not receive data for a while."""
def __init__(
self,
*,
@@ -3020,23 +3043,82 @@ class GroupMap(google.protobuf.message.Message):
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
+ initial_input: global___Relation | None = ...,
+ initial_grouping_expressions: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ is_map_groups_with_state: builtins.bool | None = ...,
+ output_mode: builtins.str | None = ...,
+ timeout_conf: builtins.str | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_is_map_groups_with_state",
+ b"_is_map_groups_with_state",
+ "_output_mode",
+ b"_output_mode",
+ "_timeout_conf",
+ b"_timeout_conf",
+ "func",
+ b"func",
+ "initial_input",
+ b"initial_input",
+ "input",
+ b"input",
+ "is_map_groups_with_state",
+ b"is_map_groups_with_state",
+ "output_mode",
+ b"output_mode",
+ "timeout_conf",
+ b"timeout_conf",
+ ],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_is_map_groups_with_state",
+ b"_is_map_groups_with_state",
+ "_output_mode",
+ b"_output_mode",
+ "_timeout_conf",
+ b"_timeout_conf",
"func",
b"func",
"grouping_expressions",
b"grouping_expressions",
+ "initial_grouping_expressions",
+ b"initial_grouping_expressions",
+ "initial_input",
+ b"initial_input",
"input",
b"input",
+ "is_map_groups_with_state",
+ b"is_map_groups_with_state",
+ "output_mode",
+ b"output_mode",
"sorting_expressions",
b"sorting_expressions",
+ "timeout_conf",
+ b"timeout_conf",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal[
+ "_is_map_groups_with_state", b"_is_map_groups_with_state"
+ ],
+ ) -> typing_extensions.Literal["is_map_groups_with_state"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_output_mode", b"_output_mode"]
+ ) -> typing_extensions.Literal["output_mode"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_timeout_conf", b"_timeout_conf"]
+ ) -> typing_extensions.Literal["timeout_conf"] | None: ...
global___GroupMap = GroupMap
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org