You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2017/02/08 19:34:04 UTC
spark git commit: [SPARK-19413][SS] MapGroupsWithState for arbitrary
stateful operations for branch-2.1
Repository: spark
Updated Branches:
refs/heads/branch-2.1 71b6eacf7 -> 502c927b8
[SPARK-19413][SS] MapGroupsWithState for arbitrary stateful operations for branch-2.1
This is a follow up PR for merging #16758 to spark 2.1 branch
## What changes were proposed in this pull request?
`mapGroupsWithState` is a new API for arbitrary stateful operations in Structured Streaming, similar to `DStream.mapWithState`
*Requirements*
- Users should be able to specify a function that can do the following
- Access the input row corresponding to a key
- Access the previous state corresponding to a key
- Optionally, update or remove the state
- Output any number of new rows (or none at all)
*Proposed API*
```
// ------------ New methods on KeyValueGroupedDataset ------------
class KeyValueGroupedDataset[K, V] {
// Scala friendly
def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => U)
def flatMapGroupsWithState[S: Encode, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => Iterator[U])
// Java friendly
def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U])
def flatMapGroupsWithState[S, U](func: FlatMapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U])
}
// ------------------- New Java-friendly function classes -------------------
public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
R call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception;
}
public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
Iterator<R> call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception;
}
// ---------------------- Wrapper class for state data ----------------------
trait KeyedState[S] {
def exists(): Boolean
def get(): S // throws Exception is state does not exist
def getOption(): Option[S]
def update(newState: S): Unit
def remove(): Unit // exists() will be false after this
}
```
Key Semantics of the State class
- The state can be null.
- If the state.remove() is called, then state.exists() will return false, and getOption will returm None.
- After that state.update(newState) is called, then state.exists() will return true, and getOption will return Some(...).
- None of the operations are thread-safe. This is to avoid memory barriers.
*Usage*
```
val stateFunc = (word: String, words: Iterator[String, runningCount: KeyedState[Long]) => {
val newCount = words.size + runningCount.getOption.getOrElse(0L)
runningCount.update(newCount)
(word, newCount)
}
dataset // type is Dataset[String]
.groupByKey[String](w => w) // generates KeyValueGroupedDataset[String, String]
.mapGroupsWithState[Long, (String, Long)](stateFunc) // returns Dataset[(String, Long)]
```
## How was this patch tested?
New unit tests.
Author: Tathagata Das <ta...@gmail.com>
Closes #16850 from tdas/mapWithState-branch-2.1.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/502c927b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/502c927b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/502c927b
Branch: refs/heads/branch-2.1
Commit: 502c927b8c8a99ef2adf4e6e1d7a6d9232d45ef5
Parents: 71b6eac
Author: Tathagata Das <ta...@gmail.com>
Authored: Wed Feb 8 11:33:59 2017 -0800
Committer: Shixiong Zhu <sh...@databricks.com>
Committed: Wed Feb 8 11:33:59 2017 -0800
----------------------------------------------------------------------
.../analysis/UnsupportedOperationChecker.scala | 11 +-
.../sql/catalyst/plans/logical/object.scala | 49 +++
.../analysis/UnsupportedOperationsSuite.scala | 24 +-
.../FlatMapGroupsWithStateFunction.java | 38 +++
.../function/MapGroupsWithStateFunction.java | 38 +++
.../spark/sql/KeyValueGroupedDataset.scala | 113 +++++++
.../scala/org/apache/spark/sql/KeyedState.scala | 142 ++++++++
.../spark/sql/execution/SparkStrategies.scala | 21 +-
.../apache/spark/sql/execution/objects.scala | 22 ++
.../streaming/IncrementalExecution.scala | 19 +-
.../execution/streaming/KeyedStateImpl.scala | 80 +++++
.../execution/streaming/ProgressReporter.scala | 2 +-
.../execution/streaming/StatefulAggregate.scala | 237 -------------
.../state/HDFSBackedStateStoreProvider.scala | 19 ++
.../execution/streaming/state/StateStore.scala | 5 +
.../sql/execution/streaming/state/package.scala | 11 +-
.../execution/streaming/statefulOperators.scala | 323 ++++++++++++++++++
.../org/apache/spark/sql/JavaDatasetSuite.java | 32 ++
.../sql/streaming/MapGroupsWithStateSuite.scala | 335 +++++++++++++++++++
19 files changed, 1272 insertions(+), 249 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index f4d016c..d8aad42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -46,8 +46,13 @@ object UnsupportedOperationChecker {
"Queries without streaming sources cannot be executed with writeStream.start()")(plan)
}
+ /** Collect all the streaming aggregates in a sub plan */
+ def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = {
+ subplan.collect { case a: Aggregate if a.isStreaming => a }
+ }
+
// Disallow multiple streaming aggregations
- val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
+ val aggregates = collectStreamingAggregates(plan)
if (aggregates.size > 1) {
throwError(
@@ -114,6 +119,10 @@ object UnsupportedOperationChecker {
case _: InsertIntoTable =>
throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets")
+ case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
+ throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
+ "streaming DataFrame/Dataset")
+
case Join(left, right, joinType, _) =>
joinType match {
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0ab4c90..0be4823 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -313,6 +313,55 @@ case class MapGroups(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer
+/** Internal class representing State */
+trait LogicalKeyedState[S]
+
+/** Factory for constructing new `MapGroupsWithState` nodes. */
+object MapGroupsWithState {
+ def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ child: LogicalPlan): LogicalPlan = {
+ val mapped = new MapGroupsWithState(
+ func,
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
+ groupingAttributes,
+ dataAttributes,
+ CatalystSerde.generateObjAttr[U],
+ encoderFor[S].resolveAndBind().deserializer,
+ encoderFor[S].namedExpressions,
+ child)
+ CatalystSerde.serialize[U](mapped)
+ }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`,
+ * while using state data.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
+ * @param groupingAttributes used to group the data
+ * @param dataAttributes used to read the data
+ * @param outputObjAttr used to define the output object
+ * @param stateDeserializer used to deserialize state before calling `func`
+ * @param stateSerializer used to serialize updated state after calling `func`
+ */
+case class MapGroupsWithState(
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ stateDeserializer: Expression,
+ stateSerializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectProducer
+
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
object FlatMapGroupsInR {
def apply(
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index dcdb1ae..3b756e8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, LongType}
/** A dummy command for testing unsupported operations. */
case class DummyCommand() extends Command
@@ -111,6 +111,24 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
outputMode = Complete,
expectedMsgs = Seq("distinct aggregation"))
+ // MapGroupsWithState: Not supported after a streaming aggregation
+ val att = new AttributeReference(name = "a", dataType = LongType)()
+ assertSupportedInBatchPlan(
+ "mapGroupsWithState - mapGroupsWithState on batch relation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
+
+ assertSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
+ outputMode = Append)
+
+ assertNotSupportedInStreamingPlan(
+ "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
+ MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
+ Aggregate(Nil, aggExprs("c"), streamRelation)),
+ outputMode = Complete,
+ expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+
// Inner joins: Stream-stream not supported
testBinaryOperationInStreamingPlan(
"inner join",
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
new file mode 100644
index 0000000..2570c8d
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}.
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+ Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
new file mode 100644
index 0000000..614d392
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)}
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+ R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 395d709..94e689a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -219,6 +219,119 @@ class KeyValueGroupedDataset[K, V] private[sql](
}
/**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def mapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
+ flatMapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+ }
+
+ /**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def mapGroupsWithState[S, U](
+ func: MapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+ )(stateEncoder, outputEncoder)
+ }
+
+ /**
+ * ::Experimental::
+ * (Scala-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S: Encoder, U: Encoder](
+ func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+ Dataset[U](
+ sparkSession,
+ MapGroupsWithState[K, V, S, U](
+ func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+ groupingAttributes,
+ dataAttributes,
+ logicalPlan))
+ }
+
+ /**
+ * ::Experimental::
+ * (Java-specific)
+ * Applies the given function to each group of data, while maintaining a user-defined per-group
+ * state. The result Dataset will represent the objects returned by the function.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+ * updates to each group's state will be saved across invocations.
+ * See [[KeyedState]] for more details.
+ *
+ * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+ * @param func Function to be called on every group.
+ * @param stateEncoder Encoder for the state type.
+ * @param outputEncoder Encoder for the output type.
+ *
+ * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+ * @since 2.1.1
+ */
+ @Experimental
+ @InterfaceStability.Evolving
+ def flatMapGroupsWithState[S, U](
+ func: FlatMapGroupsWithStateFunction[K, V, S, U],
+ stateEncoder: Encoder[S],
+ outputEncoder: Encoder[U]): Dataset[U] = {
+ flatMapGroupsWithState[S, U](
+ (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+ )(stateEncoder, outputEncoder)
+ }
+
+ /**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
* The given function must be commutative and associative or the result may be non-deterministic.
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
new file mode 100644
index 0000000..6864b6f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.IllegalArgumentException
+
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+
+/**
+ * :: Experimental ::
+ *
+ * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
+ * `flatMapGroupsWithState` operations on
+ * [[KeyValueGroupedDataset]].
+ *
+ * Detail description on `[map/flatMap]GroupsWithState` operation
+ * ------------------------------------------------------------
+ * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
+ * will invoke the user-given function on each group (defined by the grouping function in
+ * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger.
+ * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]],
+ * the function will be invoked once for each group that has data in the batch.
+ *
+ * The function is invoked with following parameters.
+ * - The key of the group.
+ * - An iterator containing all the values for this key.
+ * - A user-defined state object set by previous invocations of the given function.
+ * In case of a batch Dataset, there is only one invocation and state object will be empty as
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
+ * is equivalent to `[map/flatMap]Groups`.
+ *
+ * Important points to note about the function.
+ * - In a trigger, the function will be called only the groups present in the batch. So do not
+ * assume that the function will be called in every trigger for every group that has state.
+ * - There is no guaranteed ordering of values in the iterator in the function, neither with
+ * batch, nor with streaming Datasets.
+ * - All the data will be shuffled before applying the function.
+ *
+ * Important points to note about using KeyedState.
+ * - The value of the state cannot be null. So updating state with null will throw
+ * `IllegalArgumentException`.
+ * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
+ * - If `remove()` is called, then `exists()` will return `false`,
+ * `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
+ * - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ * `get()` and `getOption()`will return the updated value.
+ *
+ * Scala example of using KeyedState in `mapGroupsWithState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
+ * // Check if state exists
+ * if (state.exists) {
+ * val existingState = state.get // Get the existing state
+ * val shouldRemove = ... // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove() // Remove the state
+ * } else {
+ * val newState = ...
+ * state.update(newState) // Set the new state
+ * }
+ * } else {
+ * val initialState = ...
+ * state.update(initialState) // Set the initial state
+ * }
+ * ... // return something
+ * }
+ *
+ * }}}
+ *
+ * Java example of using `KeyedState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
+ * new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
+ *
+ * @Override
+ * public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
+ * if (state.exists()) {
+ * int existingState = state.get(); // Get the existing state
+ * boolean shouldRemove = ...; // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove(); // Remove the state
+ * } else {
+ * int newState = ...;
+ * state.update(newState); // Set the new state
+ * }
+ * } else {
+ * int initialState = ...; // Set the initial state
+ * state.update(initialState);
+ * }
+ * ... // return something
+ * }
+ * };
+ * }}}
+ *
+ * @tparam S User-defined type of the state to be stored for each key. Must be encodable into
+ * Spark SQL types (see [[Encoder]] for more details).
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+trait KeyedState[S] extends LogicalKeyedState[S] {
+
+ /** Whether state exists or not. */
+ def exists: Boolean
+
+ /** Get the state value if it exists, or throw NoSuchElementException. */
+ @throws[NoSuchElementException]("when state does not exist")
+ def get: S
+
+ /** Get the state value as a scala Option. */
+ def getOption: Option[S]
+
+ /**
+ * Update the value of the state. Note that `null` is not a valid value, and it throws
+ * IllegalArgumentException.
+ */
+ @throws[IllegalArgumentException]("when updating with null")
+ def update(newState: S): Unit
+
+ /** Remove this keyed state. */
+ def remove(): Unit
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ba82ec1..adea358 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -324,6 +324,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ /**
+ * Strategy to convert MapGroupsWithState logical operator to physical operator
+ * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
+ */
+ object MapGroupsWithStateStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case MapGroupsWithState(
+ f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
+ val execPlan = MapGroupsWithStateExec(
+ f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
+ planLater(child))
+ execPlan :: Nil
+ case _ =>
+ Nil
+ }
+ }
+
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
def numPartitions: Int = self.numPartitions
@@ -365,6 +382,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
+ case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+ execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index fde3b2a..199ba5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+import org.apache.spark.sql.execution.streaming.KeyedStateImpl
import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
@@ -144,6 +146,11 @@ object ObjectOperator {
(i: InternalRow) => proj(i).get(0, deserializer.dataType)
}
+ def deserializeRowToObject(deserializer: Expression): InternalRow => Any = {
+ val proj = GenerateSafeProjection.generate(deserializer :: Nil)
+ (i: InternalRow) => proj(i).get(0, deserializer.dataType)
+ }
+
def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
val proj = GenerateUnsafeProjection.generate(serializer)
val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
@@ -344,6 +351,21 @@ case class MapGroupsExec(
}
}
+object MapGroupsExec {
+ def apply(
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ child: SparkPlan): MapGroupsExec = {
+ val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
+ new MapGroupsExec(f, keyDeserializer, valueDeserializer,
+ groupingAttributes, dataAttributes, outputObjAttr, child)
+ }
+}
+
/**
* Groups the input rows together and calls the R function with each group and an iterator
* containing all elements in the group.
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 6ab6fa6..5c4cbfa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.streaming
+import java.util.concurrent.atomic.AtomicInteger
+
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
import org.apache.spark.sql.SparkSession
@@ -39,8 +41,9 @@ class IncrementalExecution(
extends QueryExecution(sparkSession, logicalPlan) with Logging {
// TODO: make this always part of planning.
- val stateStrategy =
+ val streamingExtraStrategies =
sparkSession.sessionState.planner.StatefulAggregationStrategy +:
+ sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
sparkSession.sessionState.planner.StreamingRelationStrategy +:
sparkSession.sessionState.experimentalMethods.extraStrategies
@@ -49,7 +52,7 @@ class IncrementalExecution(
new SparkPlanner(
sparkSession.sparkContext,
sparkSession.sessionState.conf,
- stateStrategy)
+ streamingExtraStrategies)
/**
* See [SPARK-18339]
@@ -68,7 +71,7 @@ class IncrementalExecution(
* Records the current id for a given stateful operator in the query plan as the `state`
* preparation walks the query plan.
*/
- private var operatorId = 0
+ private val operatorId = new AtomicInteger(0)
/** Locates save/restore pairs surrounding aggregation. */
val state = new Rule[SparkPlan] {
@@ -77,8 +80,8 @@ class IncrementalExecution(
case StateStoreSaveExec(keys, None, None, None,
UnaryExecNode(agg,
StateStoreRestoreExec(keys2, None, child))) =>
- val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId)
- operatorId += 1
+ val stateId =
+ OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
StateStoreSaveExec(
keys,
@@ -90,6 +93,12 @@ class IncrementalExecution(
keys,
Some(stateId),
child) :: Nil))
+ case MapGroupsWithStateExec(
+ f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
+ val stateId =
+ OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
+ MapGroupsWithStateExec(
+ f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
new file mode 100644
index 0000000..eee7ec4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.KeyedState
+
+/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */
+private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] {
+ private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
+ private var defined: Boolean = optionalValue.isDefined
+ private var updated: Boolean = false
+ // whether value has been updated (but not removed)
+ private var removed: Boolean = false // whether value has been removed
+
+ // ========= Public API =========
+ override def exists: Boolean = defined
+
+ override def get: S = {
+ if (defined) {
+ value
+ } else {
+ throw new NoSuchElementException("State is either not defined or has already been removed")
+ }
+ }
+
+ override def getOption: Option[S] = {
+ if (defined) {
+ Some(value)
+ } else {
+ None
+ }
+ }
+
+ override def update(newValue: S): Unit = {
+ if (newValue == null) {
+ throw new IllegalArgumentException("'null' is not a valid state value")
+ }
+ value = newValue
+ defined = true
+ updated = true
+ removed = false
+ }
+
+ override def remove(): Unit = {
+ defined = false
+ updated = false
+ removed = true
+ }
+
+ override def toString: String = {
+ s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})"
+ }
+
+ // ========= Internal API =========
+
+ /** Whether the state has been marked for removing */
+ def isRemoved: Boolean = {
+ removed
+ }
+
+ /** Whether the state has been been updated */
+ def isUpdated: Boolean = {
+ updated
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 1f74fff..693933f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -186,7 +186,7 @@ trait ProgressReporter extends Logging {
// lastExecution could belong to one of the previous triggers if `!hasNewData`.
// Walking the plan again should be inexpensive.
val stateNodes = lastExecution.executedPlan.collect {
- case p if p.isInstanceOf[StateStoreSaveExec] => p
+ case p if p.isInstanceOf[StateStoreWriter] => p
}
stateNodes.map { node =>
val numRowsUpdated = if (hasNewData) {
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
deleted file mode 100644
index d4ccced..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
+++ /dev/null
@@ -1,237 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.TaskContext
-
-
-/** Used to identify the state store for a given operator. */
-case class OperatorStateId(
- checkpointLocation: String,
- operatorId: Long,
- batchId: Long)
-
-/**
- * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should
- * be filled in by `prepareForExecution` in [[IncrementalExecution]].
- */
-trait StatefulOperator extends SparkPlan {
- def stateId: Option[OperatorStateId]
-
- protected def getStateId: OperatorStateId = attachTree(this) {
- stateId.getOrElse {
- throw new IllegalStateException("State location not present for execution")
- }
- }
-}
-
-/**
- * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
- * to the stream (in addition to the input tuple) if present.
- */
-case class StateStoreRestoreExec(
- keyExpressions: Seq[Attribute],
- stateId: Option[OperatorStateId],
- child: SparkPlan)
- extends execution.UnaryExecNode with StatefulOperator {
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
-
- override protected def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- child.execute().mapPartitionsWithStateStore(
- getStateId.checkpointLocation,
- operatorId = getStateId.operatorId,
- storeVersion = getStateId.batchId,
- keyExpressions.toStructType,
- child.output.toStructType,
- sqlContext.sessionState,
- Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
- val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
- iter.flatMap { row =>
- val key = getKey(row)
- val savedState = store.get(key)
- numOutputRows += 1
- row +: savedState.toSeq
- }
- }
- }
-
- override def output: Seq[Attribute] = child.output
-
- override def outputPartitioning: Partitioning = child.outputPartitioning
-}
-
-/**
- * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
- */
-case class StateStoreSaveExec(
- keyExpressions: Seq[Attribute],
- stateId: Option[OperatorStateId] = None,
- outputMode: Option[OutputMode] = None,
- eventTimeWatermark: Option[Long] = None,
- child: SparkPlan)
- extends execution.UnaryExecNode with StatefulOperator {
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
- "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
- "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
-
- /** Generate a predicate that matches data older than the watermark */
- private lazy val watermarkPredicate: Option[Predicate] = {
- val optionalWatermarkAttribute =
- keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
-
- optionalWatermarkAttribute.map { watermarkAttribute =>
- // If we are evicting based on a window, use the end of the window. Otherwise just
- // use the attribute itself.
- val evictionExpression =
- if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
- LessThanOrEqual(
- GetStructField(watermarkAttribute, 1),
- Literal(eventTimeWatermark.get * 1000))
- } else {
- LessThanOrEqual(
- watermarkAttribute,
- Literal(eventTimeWatermark.get * 1000))
- }
-
- logInfo(s"Filtering state store on: $evictionExpression")
- newPredicate(evictionExpression, keyExpressions)
- }
- }
-
- override protected def doExecute(): RDD[InternalRow] = {
- metrics // force lazy init at driver
- assert(outputMode.nonEmpty,
- "Incorrect planning in IncrementalExecution, outputMode has not been set")
-
- child.execute().mapPartitionsWithStateStore(
- getStateId.checkpointLocation,
- operatorId = getStateId.operatorId,
- storeVersion = getStateId.batchId,
- keyExpressions.toStructType,
- child.output.toStructType,
- sqlContext.sessionState,
- Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
- val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
- val numOutputRows = longMetric("numOutputRows")
- val numTotalStateRows = longMetric("numTotalStateRows")
- val numUpdatedStateRows = longMetric("numUpdatedStateRows")
-
- // Abort the state store in case of error
- TaskContext.get().addTaskCompletionListener(_ => {
- if (!store.hasCommitted) {
- store.abort()
- }
- })
-
- outputMode match {
- // Update and output all rows in the StateStore.
- case Some(Complete) =>
- while (iter.hasNext) {
- val row = iter.next().asInstanceOf[UnsafeRow]
- val key = getKey(row)
- store.put(key.copy(), row.copy())
- numUpdatedStateRows += 1
- }
- store.commit()
- numTotalStateRows += store.numKeys()
- store.iterator().map { case (k, v) =>
- numOutputRows += 1
- v.asInstanceOf[InternalRow]
- }
-
- // Update and output only rows being evicted from the StateStore
- case Some(Append) =>
- while (iter.hasNext) {
- val row = iter.next().asInstanceOf[UnsafeRow]
- val key = getKey(row)
- store.put(key.copy(), row.copy())
- numUpdatedStateRows += 1
- }
-
- // Assumption: Append mode can be done only when watermark has been specified
- store.remove(watermarkPredicate.get.eval)
- store.commit()
-
- numTotalStateRows += store.numKeys()
- store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed =>
- numOutputRows += 1
- removed.value.asInstanceOf[InternalRow]
- }
-
- // Update and output modified rows from the StateStore.
- case Some(Update) =>
-
- new Iterator[InternalRow] {
-
- // Filter late date using watermark if specified
- private[this] val baseIterator = watermarkPredicate match {
- case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
- case None => iter
- }
-
- override def hasNext: Boolean = {
- if (!baseIterator.hasNext) {
- // Remove old aggregates if watermark specified
- if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval)
- store.commit()
- numTotalStateRows += store.numKeys()
- false
- } else {
- true
- }
- }
-
- override def next(): InternalRow = {
- val row = baseIterator.next().asInstanceOf[UnsafeRow]
- val key = getKey(row)
- store.put(key.copy(), row.copy())
- numOutputRows += 1
- numUpdatedStateRows += 1
- row
- }
- }
-
- case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode")
- }
- }
- }
-
- override def output: Seq[Attribute] = child.output
-
- override def outputPartitioning: Partitioning = child.outputPartitioning
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 1279b71..61eb601 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -147,6 +147,25 @@ private[state] class HDFSBackedStateStoreProvider(
}
}
+ /** Remove a single key. */
+ override def remove(key: UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot remove after already committed or aborted")
+ if (mapToUpdate.containsKey(key)) {
+ val value = mapToUpdate.remove(key)
+ Option(allUpdates.get(key)) match {
+ case Some(ValueUpdated(_, _)) | None =>
+ // Value existed in previous version and maybe was updated, mark removed
+ allUpdates.put(key, ValueRemoved(key, value))
+ case Some(ValueAdded(_, _)) =>
+ // Value did not exist in previous version and was added, should not appear in updates
+ allUpdates.remove(key)
+ case Some(ValueRemoved(_, _)) =>
+ // Remove already in update map, no need to change
+ }
+ writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
+ }
+ }
+
/** Commit all the updates that have been made to the store, and return the new version. */
override def commit(): Long = {
verify(state == UPDATING, "Cannot commit after already committed or aborted")
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index e61d95a..dcb24b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -59,6 +59,11 @@ trait StateStore {
def remove(condition: UnsafeRow => Boolean): Unit
/**
+ * Remove a single key.
+ */
+ def remove(key: UnsafeRow): Unit
+
+ /**
* Commit all the updates that have been made to the store, and return the new version.
*/
def commit(): Long
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 1b56c08..589042a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
import scala.reflect.ClassTag
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
@@ -59,10 +60,18 @@ package object state {
sessionState: SessionState,
storeCoordinator: Option[StateStoreCoordinatorRef])(
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
+
val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
+ val wrappedF = (store: StateStore, iter: Iterator[T]) => {
+ // Abort the state store in case of error
+ TaskContext.get().addTaskCompletionListener(_ => {
+ if (!store.hasCommitted) store.abort()
+ })
+ cleanedF(store, iter)
+ }
new StateStoreRDD(
dataRDD,
- cleanedF,
+ wrappedF,
checkpointLocation,
operatorId,
storeVersion,
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
new file mode 100644
index 0000000..1292452
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.CompletionIterator
+
+
+/** Used to identify the state store for a given operator. */
+case class OperatorStateId(
+ checkpointLocation: String,
+ operatorId: Long,
+ batchId: Long)
+
+/**
+ * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should
+ * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ */
+trait StatefulOperator extends SparkPlan {
+ def stateId: Option[OperatorStateId]
+
+ protected def getStateId: OperatorStateId = attachTree(this) {
+ stateId.getOrElse {
+ throw new IllegalStateException("State location not present for execution")
+ }
+ }
+}
+
+/** An operator that reads from a StateStore. */
+trait StateStoreReader extends StatefulOperator {
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+}
+
+/** An operator that writes to a StateStore. */
+trait StateStoreWriter extends StatefulOperator {
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
+ "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+}
+
+/**
+ * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
+ * to the stream (in addition to the input tuple) if present.
+ */
+case class StateStoreRestoreExec(
+ keyExpressions: Seq[Attribute],
+ stateId: Option[OperatorStateId],
+ child: SparkPlan)
+ extends execution.UnaryExecNode with StateStoreReader {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ sqlContext.sessionState,
+ Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+ val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+ iter.flatMap { row =>
+ val key = getKey(row)
+ val savedState = store.get(key)
+ numOutputRows += 1
+ row +: savedState.toSeq
+ }
+ }
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
+ */
+case class StateStoreSaveExec(
+ keyExpressions: Seq[Attribute],
+ stateId: Option[OperatorStateId] = None,
+ outputMode: Option[OutputMode] = None,
+ eventTimeWatermark: Option[Long] = None,
+ child: SparkPlan)
+ extends execution.UnaryExecNode with StateStoreWriter {
+
+ /** Generate a predicate that matches data older than the watermark */
+ private lazy val watermarkPredicate: Option[Predicate] = {
+ val optionalWatermarkAttribute =
+ keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
+
+ optionalWatermarkAttribute.map { watermarkAttribute =>
+ // If we are evicting based on a window, use the end of the window. Otherwise just
+ // use the attribute itself.
+ val evictionExpression =
+ if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
+ LessThanOrEqual(
+ GetStructField(watermarkAttribute, 1),
+ Literal(eventTimeWatermark.get * 1000))
+ } else {
+ LessThanOrEqual(
+ watermarkAttribute,
+ Literal(eventTimeWatermark.get * 1000))
+ }
+
+ logInfo(s"Filtering state store on: $evictionExpression")
+ newPredicate(evictionExpression, keyExpressions)
+ }
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ metrics // force lazy init at driver
+ assert(outputMode.nonEmpty,
+ "Incorrect planning in IncrementalExecution, outputMode has not been set")
+
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ sqlContext.sessionState,
+ Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+ val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+ val numOutputRows = longMetric("numOutputRows")
+ val numTotalStateRows = longMetric("numTotalStateRows")
+ val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+
+ outputMode match {
+ // Update and output all rows in the StateStore.
+ case Some(Complete) =>
+ while (iter.hasNext) {
+ val row = iter.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key.copy(), row.copy())
+ numUpdatedStateRows += 1
+ }
+ store.commit()
+ numTotalStateRows += store.numKeys()
+ store.iterator().map { case (k, v) =>
+ numOutputRows += 1
+ v.asInstanceOf[InternalRow]
+ }
+
+ // Update and output only rows being evicted from the StateStore
+ case Some(Append) =>
+ while (iter.hasNext) {
+ val row = iter.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key.copy(), row.copy())
+ numUpdatedStateRows += 1
+ }
+
+ // Assumption: Append mode can be done only when watermark has been specified
+ store.remove(watermarkPredicate.get.eval _)
+ store.commit()
+
+ numTotalStateRows += store.numKeys()
+ store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed =>
+ numOutputRows += 1
+ removed.value.asInstanceOf[InternalRow]
+ }
+
+ // Update and output modified rows from the StateStore.
+ case Some(Update) =>
+
+ new Iterator[InternalRow] {
+
+ // Filter late date using watermark if specified
+ private[this] val baseIterator = watermarkPredicate match {
+ case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
+ case None => iter
+ }
+
+ override def hasNext: Boolean = {
+ if (!baseIterator.hasNext) {
+ // Remove old aggregates if watermark specified
+ if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _)
+ store.commit()
+ numTotalStateRows += store.numKeys()
+ false
+ } else {
+ true
+ }
+ }
+
+ override def next(): InternalRow = {
+ val row = baseIterator.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key.copy(), row.copy())
+ numOutputRows += 1
+ numUpdatedStateRows += 1
+ row
+ }
+ }
+
+ case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode")
+ }
+ }
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+}
+
+
+/** Physical operator for executing streaming mapGroupsWithState. */
+case class MapGroupsWithStateExec(
+ func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
+ stateId: Option[OperatorStateId],
+ stateDeserializer: Expression,
+ stateSerializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
+
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+
+ /** Distribute by grouping attributes */
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(groupingAttributes) :: Nil
+
+ /** Ordering needed for using GroupingIterator */
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsWithStateStore[InternalRow](
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ groupingAttributes.toStructType,
+ child.output.toStructType,
+ sqlContext.sessionState,
+ Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+ val numTotalStateRows = longMetric("numTotalStateRows")
+ val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+ val numOutputRows = longMetric("numOutputRows")
+
+ // Generate a iterator that returns the rows grouped by the grouping function
+ val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
+
+ // Converters to and from object and rows
+ val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+ val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+ val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+ val getStateObj =
+ ObjectOperator.deserializeRowToObject(stateDeserializer)
+ val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer)
+
+ // For every group, get the key, values and corresponding state and call the function,
+ // and return an iterator of rows
+ val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) =>
+
+ val key = keyRow.asInstanceOf[UnsafeRow]
+ val keyObj = getKeyObj(keyRow) // convert key to objects
+ val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
+ val stateObjOption = store.get(key).map(getStateObj) // get existing state if any
+ val wrappedState = new KeyedStateImpl(stateObjOption)
+ val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj =>
+ numOutputRows += 1
+ getOutputRow(obj) // convert back to rows
+ }
+
+ // Return an iterator of rows generated this key,
+ // such that fully consumed, the updated state value will be saved
+ CompletionIterator[InternalRow, Iterator[InternalRow]](
+ mappedIterator, {
+ // When the iterator is consumed, then write changes to state
+ if (wrappedState.isRemoved) {
+ store.remove(key)
+ numUpdatedStateRows += 1
+ } else if (wrappedState.isUpdated) {
+ store.put(key, outputStateObj(wrappedState.get))
+ numUpdatedStateRows += 1
+ }
+ })
+ }
+
+ // Return an iterator of all the rows generated by all the keys, such that when fully
+ // consumer, all the state updates will be committed by the state store
+ CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, {
+ store.commit()
+ numTotalStateRows += store.numKeys()
+ })
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 8304b72..5ef4e88 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -225,6 +225,38 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
+ Dataset<String> mapped2 = grouped.mapGroupsWithState(
+ new MapGroupsWithStateFunction<Integer, String, Long, String>() {
+ @Override
+ public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
+ }
+ return sb.toString();
+ }
+ },
+ Encoders.LONG(),
+ Encoders.STRING());
+
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList()));
+
+ Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
+ new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() {
+ @Override
+ public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) {
+ StringBuilder sb = new StringBuilder(key.toString());
+ while (values.hasNext()) {
+ sb.append(values.next());
+ }
+ return Collections.singletonList(sb.toString()).iterator();
+ }
+ },
+ Encoders.LONG(),
+ Encoders.STRING());
+
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
+
Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
@Override
public String call(String v1, String v2) throws Exception {
http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
new file mode 100644
index 0000000..0524898
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.KeyedState
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.StateStore
+
+/** Class to check custom state types */
+case class RunningCount(count: Long)
+
+class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
+
+ import testImplicits._
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ StateStore.stop()
+ }
+
+ test("KeyedState - get, exists, update, remove") {
+ var state: KeyedStateImpl[String] = null
+
+ def testState(
+ expectedData: Option[String],
+ shouldBeUpdated: Boolean = false,
+ shouldBeRemoved: Boolean = false): Unit = {
+ if (expectedData.isDefined) {
+ assert(state.exists)
+ assert(state.get === expectedData.get)
+ } else {
+ assert(!state.exists)
+ intercept[NoSuchElementException] {
+ state.get
+ }
+ }
+ assert(state.getOption === expectedData)
+ assert(state.isUpdated === shouldBeUpdated)
+ assert(state.isRemoved === shouldBeRemoved)
+ }
+
+ // Updating empty state
+ state = new KeyedStateImpl[String](None)
+ testState(None)
+ state.update("")
+ testState(Some(""), shouldBeUpdated = true)
+
+ // Updating exiting state
+ state = new KeyedStateImpl[String](Some("2"))
+ testState(Some("2"))
+ state.update("3")
+ testState(Some("3"), shouldBeUpdated = true)
+
+ // Removing state
+ state.remove()
+ testState(None, shouldBeRemoved = true, shouldBeUpdated = false)
+ state.remove() // should be still callable
+ state.update("4")
+ testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true)
+
+ // Updating by null throw exception
+ intercept[IllegalArgumentException] {
+ state.update(null)
+ }
+ }
+
+ test("KeyedState - primitive type") {
+ var intState = new KeyedStateImpl[Int](None)
+ intercept[NoSuchElementException] {
+ intState.get
+ }
+ assert(intState.getOption === None)
+
+ intState = new KeyedStateImpl[Int](Some(10))
+ assert(intState.get == 10)
+ intState.update(0)
+ assert(intState.get == 0)
+ intState.remove()
+ intercept[NoSuchElementException] {
+ intState.get
+ }
+ }
+
+ test("flatMapGroupsWithState - streaming") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count if state is defined, otherwise does not return anything
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ Iterator.empty
+ } else {
+ state.update(RunningCount(count))
+ Iterator((key, count.toString))
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+ testStream(result, Append)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", "1")),
+ assertNumStateRows(total = 1, updated = 1),
+ AddData(inputData, "a", "b"),
+ CheckLastBatch(("a", "2"), ("b", "1")),
+ assertNumStateRows(total = 2, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+ CheckLastBatch(("b", "2")),
+ assertNumStateRows(total = 1, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+ CheckLastBatch(("a", "1"), ("c", "1")),
+ assertNumStateRows(total = 3, updated = 2)
+ )
+ }
+
+ test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count if state is defined, otherwise does not return anything
+ // Additionally, it updates state lazily as the returned iterator get consumed
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ values.flatMap { _ =>
+ val count = state.getOption.map(_.count).getOrElse(0L) + 1
+ if (count == 3) {
+ state.remove()
+ None
+ } else {
+ state.update(RunningCount(count))
+ Some((key, count.toString))
+ }
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+ testStream(result, Append)(
+ AddData(inputData, "a", "a", "b"),
+ CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+ CheckLastBatch(("b", "2")),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+ CheckLastBatch(("a", "1"), ("c", "1"))
+ )
+ }
+
+ test("flatMapGroupsWithState - batch") {
+ // Function that returns running count only if its even, otherwise does not return
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ Iterator((key, values.size))
+ }
+ checkAnswer(
+ Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+ Seq(("a", 2), ("b", 1)).toDF)
+ }
+
+ test("mapGroupsWithState - streaming") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ (key, "-1")
+ } else {
+ state.update(RunningCount(count))
+ (key, count.toString)
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+ testStream(result, Append)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", "1")),
+ assertNumStateRows(total = 1, updated = 1),
+ AddData(inputData, "a", "b"),
+ CheckLastBatch(("a", "2"), ("b", "1")),
+ assertNumStateRows(total = 2, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
+ CheckLastBatch(("a", "-1"), ("b", "2")),
+ assertNumStateRows(total = 1, updated = 2),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
+ CheckLastBatch(("a", "1"), ("c", "1")),
+ assertNumStateRows(total = 3, updated = 2)
+ )
+ }
+
+ test("mapGroupsWithState - streaming + aggregation") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ (key, "-1")
+ } else {
+ state.update(RunningCount(count))
+ (key, count.toString)
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+ .groupByKey(_._1)
+ .count()
+
+ testStream(result, Complete)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 1)),
+ AddData(inputData, "a", "b"),
+ // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
+ CheckLastBatch(("a", 2), ("b", 1)),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"),
+ // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
+ // so increment a and b by 1
+ CheckLastBatch(("a", 3), ("b", 2)),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"),
+ // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
+ // so increment a and c by 1
+ CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
+ )
+ }
+
+ test("mapGroupsWithState - batch") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+ (key, values.size)
+ }
+
+ checkAnswer(
+ spark.createDataset(Seq("a", "a", "b"))
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc)
+ .toDF,
+ spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
+ }
+
+ testQuietly("StateStore.abort on task failure handling") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ state.update(RunningCount(count))
+ (key, count)
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+ def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
+ MapGroupsWithStateSuite.failInTask = value
+ true
+ }
+
+ testStream(result, Append)(
+ setFailInTask(false),
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 1L)),
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 2L)),
+ setFailInTask(true),
+ AddData(inputData, "a"),
+ ExpectFailure[SparkException](), // task should fail but should not increment count
+ setFailInTask(false),
+ StartStream(),
+ CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count
+ )
+ }
+
+ private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
+ val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
+ assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
+ assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
+ true
+ }
+}
+
+object MapGroupsWithStateSuite {
+ var failInTask = true
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org