You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2017/02/08 19:34:04 UTC

spark git commit: [SPARK-19413][SS] MapGroupsWithState for arbitrary stateful operations for branch-2.1

Repository: spark
Updated Branches:
  refs/heads/branch-2.1 71b6eacf7 -> 502c927b8


[SPARK-19413][SS] MapGroupsWithState for arbitrary stateful operations for branch-2.1

This is a follow up PR for merging #16758 to spark 2.1 branch

## What changes were proposed in this pull request?

`mapGroupsWithState` is a new API for arbitrary stateful operations in Structured Streaming, similar to `DStream.mapWithState`

*Requirements*
- Users should be able to specify a function that can do the following
- Access the input row corresponding to a key
- Access the previous state corresponding to a key
- Optionally, update or remove the state
- Output any number of new rows (or none at all)

*Proposed API*
```
// ------------ New methods on KeyValueGroupedDataset ------------
class KeyValueGroupedDataset[K, V] {
	// Scala friendly
	def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => U)
        def flatMapGroupsWithState[S: Encode, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => Iterator[U])
	// Java friendly
       def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U])
       def flatMapGroupsWithState[S, U](func: FlatMapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U])
}

// ------------------- New Java-friendly function classes -------------------
public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
  R call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception;
}
public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
  Iterator<R> call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception;
}

// ---------------------- Wrapper class for state data ----------------------
trait KeyedState[S] {
	def exists(): Boolean
  	def get(): S 			// throws Exception is state does not exist
	def getOption(): Option[S]
	def update(newState: S): Unit
	def remove(): Unit		// exists() will be false after this
}
```

Key Semantics of the State class
- The state can be null.
- If the state.remove() is called, then state.exists() will return false, and getOption will returm None.
- After that state.update(newState) is called, then state.exists() will return true, and getOption will return Some(...).
- None of the operations are thread-safe. This is to avoid memory barriers.

*Usage*
```
val stateFunc = (word: String, words: Iterator[String, runningCount: KeyedState[Long]) => {
    val newCount = words.size + runningCount.getOption.getOrElse(0L)
    runningCount.update(newCount)
   (word, newCount)
}

dataset					                        // type is Dataset[String]
  .groupByKey[String](w => w)        	                // generates KeyValueGroupedDataset[String, String]
  .mapGroupsWithState[Long, (String, Long)](stateFunc)	// returns Dataset[(String, Long)]
```

## How was this patch tested?
New unit tests.

Author: Tathagata Das <ta...@gmail.com>

Closes #16850 from tdas/mapWithState-branch-2.1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/502c927b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/502c927b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/502c927b

Branch: refs/heads/branch-2.1
Commit: 502c927b8c8a99ef2adf4e6e1d7a6d9232d45ef5
Parents: 71b6eac
Author: Tathagata Das <ta...@gmail.com>
Authored: Wed Feb 8 11:33:59 2017 -0800
Committer: Shixiong Zhu <sh...@databricks.com>
Committed: Wed Feb 8 11:33:59 2017 -0800

----------------------------------------------------------------------
 .../analysis/UnsupportedOperationChecker.scala  |  11 +-
 .../sql/catalyst/plans/logical/object.scala     |  49 +++
 .../analysis/UnsupportedOperationsSuite.scala   |  24 +-
 .../FlatMapGroupsWithStateFunction.java         |  38 +++
 .../function/MapGroupsWithStateFunction.java    |  38 +++
 .../spark/sql/KeyValueGroupedDataset.scala      | 113 +++++++
 .../scala/org/apache/spark/sql/KeyedState.scala | 142 ++++++++
 .../spark/sql/execution/SparkStrategies.scala   |  21 +-
 .../apache/spark/sql/execution/objects.scala    |  22 ++
 .../streaming/IncrementalExecution.scala        |  19 +-
 .../execution/streaming/KeyedStateImpl.scala    |  80 +++++
 .../execution/streaming/ProgressReporter.scala  |   2 +-
 .../execution/streaming/StatefulAggregate.scala | 237 -------------
 .../state/HDFSBackedStateStoreProvider.scala    |  19 ++
 .../execution/streaming/state/StateStore.scala  |   5 +
 .../sql/execution/streaming/state/package.scala |  11 +-
 .../execution/streaming/statefulOperators.scala | 323 ++++++++++++++++++
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  32 ++
 .../sql/streaming/MapGroupsWithStateSuite.scala | 335 +++++++++++++++++++
 19 files changed, 1272 insertions(+), 249 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index f4d016c..d8aad42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -46,8 +46,13 @@ object UnsupportedOperationChecker {
         "Queries without streaming sources cannot be executed with writeStream.start()")(plan)
     }
 
+    /** Collect all the streaming aggregates in a sub plan */
+    def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = {
+      subplan.collect { case a: Aggregate if a.isStreaming => a }
+    }
+
     // Disallow multiple streaming aggregations
-    val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
+    val aggregates = collectStreamingAggregates(plan)
 
     if (aggregates.size > 1) {
       throwError(
@@ -114,6 +119,10 @@ object UnsupportedOperationChecker {
         case _: InsertIntoTable =>
           throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets")
 
+        case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
+          throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
+            "streaming DataFrame/Dataset")
+
         case Join(left, right, joinType, _) =>
 
           joinType match {

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0ab4c90..0be4823 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -313,6 +313,55 @@ case class MapGroups(
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer
 
+/** Internal class representing State */
+trait LogicalKeyedState[S]
+
+/** Factory for constructing new `MapGroupsWithState` nodes. */
+object MapGroupsWithState {
+  def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
+      func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      child: LogicalPlan): LogicalPlan = {
+    val mapped = new MapGroupsWithState(
+      func,
+      UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+      UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
+      groupingAttributes,
+      dataAttributes,
+      CatalystSerde.generateObjAttr[U],
+      encoderFor[S].resolveAndBind().deserializer,
+      encoderFor[S].namedExpressions,
+      child)
+    CatalystSerde.serialize[U](mapped)
+  }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`,
+ * while using state data.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
+ * @param groupingAttributes used to group the data
+ * @param dataAttributes used to read the data
+ * @param outputObjAttr used to define the output object
+ * @param stateDeserializer used to deserialize state before calling `func`
+ * @param stateSerializer used to serialize updated state after calling `func`
+ */
+case class MapGroupsWithState(
+    func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
+    groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
+    outputObjAttr: Attribute,
+    stateDeserializer: Expression,
+    stateSerializer: Seq[NamedExpression],
+    child: LogicalPlan) extends UnaryNode with ObjectProducer
+
 /** Factory for constructing new `FlatMapGroupsInR` nodes. */
 object FlatMapGroupsInR {
   def apply(

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index dcdb1ae..3b756e8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.Count
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, LongType}
 
 /** A dummy command for testing unsupported operations. */
 case class DummyCommand() extends Command
@@ -111,6 +111,24 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     outputMode = Complete,
     expectedMsgs = Seq("distinct aggregation"))
 
+  // MapGroupsWithState: Not supported after a streaming aggregation
+  val att = new AttributeReference(name = "a", dataType = LongType)()
+  assertSupportedInBatchPlan(
+    "mapGroupsWithState - mapGroupsWithState on batch relation",
+    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
+
+  assertSupportedInStreamingPlan(
+    "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
+    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
+    outputMode = Append)
+
+  assertNotSupportedInStreamingPlan(
+    "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
+    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
+      Aggregate(Nil, aggExprs("c"), streamRelation)),
+    outputMode = Complete,
+    expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+
   // Inner joins: Stream-stream not supported
   testBinaryOperationInStreamingPlan(
     "inner join",

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
new file mode 100644
index 0000000..2570c8d
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}.
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+  Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
new file mode 100644
index 0000000..614d392
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)}
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+  R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 395d709..94e689a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -219,6 +219,119 @@ class KeyValueGroupedDataset[K, V] private[sql](
   }
 
   /**
+   * ::Experimental::
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def mapGroupsWithState[S: Encoder, U: Encoder](
+      func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
+    flatMapGroupsWithState[S, U](
+      (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+  }
+
+  /**
+   * ::Experimental::
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] = {
+    flatMapGroupsWithState[S, U](
+      (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+    )(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * ::Experimental::
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def flatMapGroupsWithState[S: Encoder, U: Encoder](
+      func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+    Dataset[U](
+      sparkSession,
+      MapGroupsWithState[K, V, S, U](
+        func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+        groupingAttributes,
+        dataAttributes,
+        logicalPlan))
+  }
+
+  /**
+   * ::Experimental::
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] = {
+    flatMapGroupsWithState[S, U](
+      (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+    )(stateEncoder, outputEncoder)
+  }
+
+  /**
    * (Scala-specific)
    * Reduces the elements of each group of data using the specified binary function.
    * The given function must be commutative and associative or the result may be non-deterministic.

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
new file mode 100644
index 0000000..6864b6f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.IllegalArgumentException
+
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+
+/**
+ * :: Experimental ::
+ *
+ * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
+ * `flatMapGroupsWithState` operations on
+ * [[KeyValueGroupedDataset]].
+ *
+ * Detail description on `[map/flatMap]GroupsWithState` operation
+ * ------------------------------------------------------------
+ * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
+ * will invoke the user-given function on each group (defined by the grouping function in
+ * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger.
+ * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]],
+ * the function will be invoked once for each group that has data in the batch.
+ *
+ * The function is invoked with following parameters.
+ *  - The key of the group.
+ *  - An iterator containing all the values for this key.
+ *  - A user-defined state object set by previous invocations of the given function.
+ * In case of a batch Dataset, there is only one invocation and state object will be empty as
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
+ * is equivalent to `[map/flatMap]Groups`.
+ *
+ * Important points to note about the function.
+ *  - In a trigger, the function will be called only the groups present in the batch. So do not
+ *    assume that the function will be called in every trigger for every group that has state.
+ *  - There is no guaranteed ordering of values in the iterator in the function, neither with
+ *    batch, nor with streaming Datasets.
+ *  - All the data will be shuffled before applying the function.
+ *
+ * Important points to note about using KeyedState.
+ *  - The value of the state cannot be null. So updating state with null will throw
+ *    `IllegalArgumentException`.
+ *  - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
+ *  - If `remove()` is called, then `exists()` will return `false`,
+ *    `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
+ *  - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ *    `get()` and `getOption()`will return the updated value.
+ *
+ * Scala example of using KeyedState in `mapGroupsWithState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
+ *   // Check if state exists
+ *   if (state.exists) {
+ *     val existingState = state.get  // Get the existing state
+ *     val shouldRemove = ...         // Decide whether to remove the state
+ *     if (shouldRemove) {
+ *       state.remove()     // Remove the state
+ *     } else {
+ *       val newState = ...
+ *       state.update(newState)    // Set the new state
+ *     }
+ *   } else {
+ *     val initialState = ...
+ *     state.update(initialState)  // Set the initial state
+ *   }
+ *   ... // return something
+ * }
+ *
+ * }}}
+ *
+ * Java example of using `KeyedState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
+ *    new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
+ *
+ *      @Override
+ *      public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
+ *        if (state.exists()) {
+ *          int existingState = state.get(); // Get the existing state
+ *          boolean shouldRemove = ...; // Decide whether to remove the state
+ *          if (shouldRemove) {
+ *            state.remove(); // Remove the state
+ *          } else {
+ *            int newState = ...;
+ *            state.update(newState); // Set the new state
+ *          }
+ *        } else {
+ *          int initialState = ...; // Set the initial state
+ *          state.update(initialState);
+ *        }
+ *        ... // return something
+ *      }
+ *    };
+ * }}}
+ *
+ * @tparam S User-defined type of the state to be stored for each key. Must be encodable into
+ *           Spark SQL types (see [[Encoder]] for more details).
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+trait KeyedState[S] extends LogicalKeyedState[S] {
+
+  /** Whether state exists or not. */
+  def exists: Boolean
+
+  /** Get the state value if it exists, or throw NoSuchElementException. */
+  @throws[NoSuchElementException]("when state does not exist")
+  def get: S
+
+  /** Get the state value as a scala Option. */
+  def getOption: Option[S]
+
+  /**
+   * Update the value of the state. Note that `null` is not a valid value, and it throws
+   * IllegalArgumentException.
+   */
+  @throws[IllegalArgumentException]("when updating with null")
+  def update(newState: S): Unit
+
+  /** Remove this keyed state. */
+  def remove(): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ba82ec1..adea358 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -324,6 +324,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
+  /**
+   * Strategy to convert MapGroupsWithState logical operator to physical operator
+   * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
+   */
+  object MapGroupsWithStateStrategy extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case MapGroupsWithState(
+          f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
+        val execPlan = MapGroupsWithStateExec(
+          f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
+          planLater(child))
+        execPlan :: Nil
+      case _ =>
+        Nil
+    }
+  }
+
   // Can we automate these 'pass through' operations?
   object BasicOperators extends Strategy {
     def numPartitions: Int = self.numPartitions
@@ -365,6 +382,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
       case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
         execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
+      case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+        execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
       case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
         execution.CoGroupExec(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index fde3b2a..199ba5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+import org.apache.spark.sql.execution.streaming.KeyedStateImpl
 import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
 
 
@@ -144,6 +146,11 @@ object ObjectOperator {
     (i: InternalRow) => proj(i).get(0, deserializer.dataType)
   }
 
+  def deserializeRowToObject(deserializer: Expression): InternalRow => Any = {
+    val proj = GenerateSafeProjection.generate(deserializer :: Nil)
+    (i: InternalRow) => proj(i).get(0, deserializer.dataType)
+  }
+
   def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
     val proj = GenerateUnsafeProjection.generate(serializer)
     val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
@@ -344,6 +351,21 @@ case class MapGroupsExec(
   }
 }
 
+object MapGroupsExec {
+  def apply(
+      func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
+      keyDeserializer: Expression,
+      valueDeserializer: Expression,
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      outputObjAttr: Attribute,
+      child: SparkPlan): MapGroupsExec = {
+    val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
+    new MapGroupsExec(f, keyDeserializer, valueDeserializer,
+      groupingAttributes, dataAttributes, outputObjAttr, child)
+  }
+}
+
 /**
  * Groups the input rows together and calls the R function with each group and an iterator
  * containing all elements in the group.

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 6ab6fa6..5c4cbfa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.concurrent.atomic.AtomicInteger
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
 import org.apache.spark.sql.SparkSession
@@ -39,8 +41,9 @@ class IncrementalExecution(
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
 
   // TODO: make this always part of planning.
-  val stateStrategy =
+  val streamingExtraStrategies =
     sparkSession.sessionState.planner.StatefulAggregationStrategy +:
+    sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
     sparkSession.sessionState.planner.StreamingRelationStrategy +:
     sparkSession.sessionState.experimentalMethods.extraStrategies
 
@@ -49,7 +52,7 @@ class IncrementalExecution(
     new SparkPlanner(
       sparkSession.sparkContext,
       sparkSession.sessionState.conf,
-      stateStrategy)
+      streamingExtraStrategies)
 
   /**
    * See [SPARK-18339]
@@ -68,7 +71,7 @@ class IncrementalExecution(
    * Records the current id for a given stateful operator in the query plan as the `state`
    * preparation walks the query plan.
    */
-  private var operatorId = 0
+  private val operatorId = new AtomicInteger(0)
 
   /** Locates save/restore pairs surrounding aggregation. */
   val state = new Rule[SparkPlan] {
@@ -77,8 +80,8 @@ class IncrementalExecution(
       case StateStoreSaveExec(keys, None, None, None,
              UnaryExecNode(agg,
                StateStoreRestoreExec(keys2, None, child))) =>
-        val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId)
-        operatorId += 1
+        val stateId =
+          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
 
         StateStoreSaveExec(
           keys,
@@ -90,6 +93,12 @@ class IncrementalExecution(
               keys,
               Some(stateId),
               child) :: Nil))
+      case MapGroupsWithStateExec(
+             f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
+        val stateId =
+          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
+        MapGroupsWithStateExec(
+          f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
new file mode 100644
index 0000000..eee7ec4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.KeyedState
+
+/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */
+private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] {
+  private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
+  private var defined: Boolean = optionalValue.isDefined
+  private var updated: Boolean = false
+  // whether value has been updated (but not removed)
+  private var removed: Boolean = false // whether value has been removed
+
+  // ========= Public API =========
+  override def exists: Boolean = defined
+
+  override def get: S = {
+    if (defined) {
+      value
+    } else {
+      throw new NoSuchElementException("State is either not defined or has already been removed")
+    }
+  }
+
+  override def getOption: Option[S] = {
+    if (defined) {
+      Some(value)
+    } else {
+      None
+    }
+  }
+
+  override def update(newValue: S): Unit = {
+    if (newValue == null) {
+      throw new IllegalArgumentException("'null' is not a valid state value")
+    }
+    value = newValue
+    defined = true
+    updated = true
+    removed = false
+  }
+
+  override def remove(): Unit = {
+    defined = false
+    updated = false
+    removed = true
+  }
+
+  override def toString: String = {
+    s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})"
+  }
+
+  // ========= Internal API =========
+
+  /** Whether the state has been marked for removing */
+  def isRemoved: Boolean = {
+    removed
+  }
+
+  /** Whether the state has been been updated */
+  def isUpdated: Boolean = {
+    updated
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 1f74fff..693933f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -186,7 +186,7 @@ trait ProgressReporter extends Logging {
     // lastExecution could belong to one of the previous triggers if `!hasNewData`.
     // Walking the plan again should be inexpensive.
     val stateNodes = lastExecution.executedPlan.collect {
-      case p if p.isInstanceOf[StateStoreSaveExec] => p
+      case p if p.isInstanceOf[StateStoreWriter] => p
     }
     stateNodes.map { node =>
       val numRowsUpdated = if (hasNewData) {

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
deleted file mode 100644
index d4ccced..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
+++ /dev/null
@@ -1,237 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.TaskContext
-
-
-/** Used to identify the state store for a given operator. */
-case class OperatorStateId(
-    checkpointLocation: String,
-    operatorId: Long,
-    batchId: Long)
-
-/**
- * An operator that saves or restores state from the [[StateStore]].  The [[OperatorStateId]] should
- * be filled in by `prepareForExecution` in [[IncrementalExecution]].
- */
-trait StatefulOperator extends SparkPlan {
-  def stateId: Option[OperatorStateId]
-
-  protected def getStateId: OperatorStateId = attachTree(this) {
-    stateId.getOrElse {
-      throw new IllegalStateException("State location not present for execution")
-    }
-  }
-}
-
-/**
- * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
- * to the stream (in addition to the input tuple) if present.
- */
-case class StateStoreRestoreExec(
-    keyExpressions: Seq[Attribute],
-    stateId: Option[OperatorStateId],
-    child: SparkPlan)
-  extends execution.UnaryExecNode with StatefulOperator {
-
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      operatorId = getStateId.operatorId,
-      storeVersion = getStateId.batchId,
-      keyExpressions.toStructType,
-      child.output.toStructType,
-      sqlContext.sessionState,
-      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
-        val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
-        iter.flatMap { row =>
-          val key = getKey(row)
-          val savedState = store.get(key)
-          numOutputRows += 1
-          row +: savedState.toSeq
-        }
-    }
-  }
-
-  override def output: Seq[Attribute] = child.output
-
-  override def outputPartitioning: Partitioning = child.outputPartitioning
-}
-
-/**
- * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
- */
-case class StateStoreSaveExec(
-    keyExpressions: Seq[Attribute],
-    stateId: Option[OperatorStateId] = None,
-    outputMode: Option[OutputMode] = None,
-    eventTimeWatermark: Option[Long] = None,
-    child: SparkPlan)
-  extends execution.UnaryExecNode with StatefulOperator {
-
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
-    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
-
-  /** Generate a predicate that matches data older than the watermark */
-  private lazy val watermarkPredicate: Option[Predicate] = {
-    val optionalWatermarkAttribute =
-      keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
-
-    optionalWatermarkAttribute.map { watermarkAttribute =>
-      // If we are evicting based on a window, use the end of the window.  Otherwise just
-      // use the attribute itself.
-      val evictionExpression =
-        if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
-          LessThanOrEqual(
-            GetStructField(watermarkAttribute, 1),
-            Literal(eventTimeWatermark.get * 1000))
-        } else {
-          LessThanOrEqual(
-            watermarkAttribute,
-            Literal(eventTimeWatermark.get * 1000))
-        }
-
-      logInfo(s"Filtering state store on: $evictionExpression")
-      newPredicate(evictionExpression, keyExpressions)
-    }
-  }
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    metrics // force lazy init at driver
-    assert(outputMode.nonEmpty,
-      "Incorrect planning in IncrementalExecution, outputMode has not been set")
-
-    child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      operatorId = getStateId.operatorId,
-      storeVersion = getStateId.batchId,
-      keyExpressions.toStructType,
-      child.output.toStructType,
-      sqlContext.sessionState,
-      Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
-        val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
-        val numOutputRows = longMetric("numOutputRows")
-        val numTotalStateRows = longMetric("numTotalStateRows")
-        val numUpdatedStateRows = longMetric("numUpdatedStateRows")
-
-        // Abort the state store in case of error
-        TaskContext.get().addTaskCompletionListener(_ => {
-          if (!store.hasCommitted) {
-            store.abort()
-          }
-        })
-
-        outputMode match {
-          // Update and output all rows in the StateStore.
-          case Some(Complete) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
-            }
-            store.commit()
-            numTotalStateRows += store.numKeys()
-            store.iterator().map { case (k, v) =>
-              numOutputRows += 1
-              v.asInstanceOf[InternalRow]
-            }
-
-          // Update and output only rows being evicted from the StateStore
-          case Some(Append) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
-            }
-
-            // Assumption: Append mode can be done only when watermark has been specified
-            store.remove(watermarkPredicate.get.eval)
-            store.commit()
-
-            numTotalStateRows += store.numKeys()
-            store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed =>
-              numOutputRows += 1
-              removed.value.asInstanceOf[InternalRow]
-            }
-
-          // Update and output modified rows from the StateStore.
-          case Some(Update) =>
-
-            new Iterator[InternalRow] {
-
-              // Filter late date using watermark if specified
-              private[this] val baseIterator = watermarkPredicate match {
-                case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
-                case None => iter
-              }
-
-              override def hasNext: Boolean = {
-                if (!baseIterator.hasNext) {
-                  // Remove old aggregates if watermark specified
-                  if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval)
-                  store.commit()
-                  numTotalStateRows += store.numKeys()
-                  false
-                } else {
-                  true
-                }
-              }
-
-              override def next(): InternalRow = {
-                val row = baseIterator.next().asInstanceOf[UnsafeRow]
-                val key = getKey(row)
-                store.put(key.copy(), row.copy())
-                numOutputRows += 1
-                numUpdatedStateRows += 1
-                row
-              }
-            }
-
-          case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode")
-        }
-    }
-  }
-
-  override def output: Seq[Attribute] = child.output
-
-  override def outputPartitioning: Partitioning = child.outputPartitioning
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 1279b71..61eb601 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -147,6 +147,25 @@ private[state] class HDFSBackedStateStoreProvider(
       }
     }
 
+    /** Remove a single key. */
+    override def remove(key: UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot remove after already committed or aborted")
+      if (mapToUpdate.containsKey(key)) {
+        val value = mapToUpdate.remove(key)
+        Option(allUpdates.get(key)) match {
+          case Some(ValueUpdated(_, _)) | None =>
+            // Value existed in previous version and maybe was updated, mark removed
+            allUpdates.put(key, ValueRemoved(key, value))
+          case Some(ValueAdded(_, _)) =>
+            // Value did not exist in previous version and was added, should not appear in updates
+            allUpdates.remove(key)
+          case Some(ValueRemoved(_, _)) =>
+          // Remove already in update map, no need to change
+        }
+        writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
+      }
+    }
+
     /** Commit all the updates that have been made to the store, and return the new version. */
     override def commit(): Long = {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index e61d95a..dcb24b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -59,6 +59,11 @@ trait StateStore {
   def remove(condition: UnsafeRow => Boolean): Unit
 
   /**
+   * Remove a single key.
+   */
+  def remove(key: UnsafeRow): Unit
+
+  /**
    * Commit all the updates that have been made to the store, and return the new version.
    */
   def commit(): Long

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 1b56c08..589042a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
 
 import scala.reflect.ClassTag
 
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.internal.SessionState
@@ -59,10 +60,18 @@ package object state {
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
+
       val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
+      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
+        // Abort the state store in case of error
+        TaskContext.get().addTaskCompletionListener(_ => {
+          if (!store.hasCommitted) store.abort()
+        })
+        cleanedF(store, iter)
+      }
       new StateStoreRDD(
         dataRDD,
-        cleanedF,
+        wrappedF,
         checkpointLocation,
         operatorId,
         storeVersion,

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
new file mode 100644
index 0000000..1292452
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.CompletionIterator
+
+
+/** Used to identify the state store for a given operator. */
+case class OperatorStateId(
+    checkpointLocation: String,
+    operatorId: Long,
+    batchId: Long)
+
+/**
+ * An operator that reads or writes state from the [[StateStore]].  The [[OperatorStateId]] should
+ * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ */
+trait StatefulOperator extends SparkPlan {
+  def stateId: Option[OperatorStateId]
+
+  protected def getStateId: OperatorStateId = attachTree(this) {
+    stateId.getOrElse {
+      throw new IllegalStateException("State location not present for execution")
+    }
+  }
+}
+
+/** An operator that reads from a StateStore. */
+trait StateStoreReader extends StatefulOperator {
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+}
+
+/** An operator that writes to a StateStore. */
+trait StateStoreWriter extends StatefulOperator {
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
+    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+}
+
+/**
+ * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
+ * to the stream (in addition to the input tuple) if present.
+ */
+case class StateStoreRestoreExec(
+    keyExpressions: Seq[Attribute],
+    stateId: Option[OperatorStateId],
+    child: SparkPlan)
+  extends execution.UnaryExecNode with StateStoreReader {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val numOutputRows = longMetric("numOutputRows")
+
+    child.execute().mapPartitionsWithStateStore(
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      keyExpressions.toStructType,
+      child.output.toStructType,
+      sqlContext.sessionState,
+      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+        val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+        iter.flatMap { row =>
+          val key = getKey(row)
+          val savedState = store.get(key)
+          numOutputRows += 1
+          row +: savedState.toSeq
+        }
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
+ */
+case class StateStoreSaveExec(
+    keyExpressions: Seq[Attribute],
+    stateId: Option[OperatorStateId] = None,
+    outputMode: Option[OutputMode] = None,
+    eventTimeWatermark: Option[Long] = None,
+    child: SparkPlan)
+  extends execution.UnaryExecNode with StateStoreWriter {
+
+  /** Generate a predicate that matches data older than the watermark */
+  private lazy val watermarkPredicate: Option[Predicate] = {
+    val optionalWatermarkAttribute =
+      keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey))
+
+    optionalWatermarkAttribute.map { watermarkAttribute =>
+      // If we are evicting based on a window, use the end of the window.  Otherwise just
+      // use the attribute itself.
+      val evictionExpression =
+        if (watermarkAttribute.dataType.isInstanceOf[StructType]) {
+          LessThanOrEqual(
+            GetStructField(watermarkAttribute, 1),
+            Literal(eventTimeWatermark.get * 1000))
+        } else {
+          LessThanOrEqual(
+            watermarkAttribute,
+            Literal(eventTimeWatermark.get * 1000))
+        }
+
+      logInfo(s"Filtering state store on: $evictionExpression")
+      newPredicate(evictionExpression, keyExpressions)
+    }
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    metrics // force lazy init at driver
+    assert(outputMode.nonEmpty,
+      "Incorrect planning in IncrementalExecution, outputMode has not been set")
+
+    child.execute().mapPartitionsWithStateStore(
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      keyExpressions.toStructType,
+      child.output.toStructType,
+      sqlContext.sessionState,
+      Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+        val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+        val numOutputRows = longMetric("numOutputRows")
+        val numTotalStateRows = longMetric("numTotalStateRows")
+        val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+
+        outputMode match {
+          // Update and output all rows in the StateStore.
+          case Some(Complete) =>
+            while (iter.hasNext) {
+              val row = iter.next().asInstanceOf[UnsafeRow]
+              val key = getKey(row)
+              store.put(key.copy(), row.copy())
+              numUpdatedStateRows += 1
+            }
+            store.commit()
+            numTotalStateRows += store.numKeys()
+            store.iterator().map { case (k, v) =>
+              numOutputRows += 1
+              v.asInstanceOf[InternalRow]
+            }
+
+          // Update and output only rows being evicted from the StateStore
+          case Some(Append) =>
+            while (iter.hasNext) {
+              val row = iter.next().asInstanceOf[UnsafeRow]
+              val key = getKey(row)
+              store.put(key.copy(), row.copy())
+              numUpdatedStateRows += 1
+            }
+
+            // Assumption: Append mode can be done only when watermark has been specified
+            store.remove(watermarkPredicate.get.eval _)
+            store.commit()
+
+            numTotalStateRows += store.numKeys()
+            store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed =>
+              numOutputRows += 1
+              removed.value.asInstanceOf[InternalRow]
+            }
+
+          // Update and output modified rows from the StateStore.
+          case Some(Update) =>
+
+            new Iterator[InternalRow] {
+
+              // Filter late date using watermark if specified
+              private[this] val baseIterator = watermarkPredicate match {
+                case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
+                case None => iter
+              }
+
+              override def hasNext: Boolean = {
+                if (!baseIterator.hasNext) {
+                  // Remove old aggregates if watermark specified
+                  if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _)
+                  store.commit()
+                  numTotalStateRows += store.numKeys()
+                  false
+                } else {
+                  true
+                }
+              }
+
+              override def next(): InternalRow = {
+                val row = baseIterator.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key.copy(), row.copy())
+                numOutputRows += 1
+                numUpdatedStateRows += 1
+                row
+              }
+            }
+
+          case _ => throw new UnsupportedOperationException(s"Invalid output mode: $outputMode")
+        }
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+}
+
+
+/** Physical operator for executing streaming mapGroupsWithState. */
+case class MapGroupsWithStateExec(
+    func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
+    groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
+    outputObjAttr: Attribute,
+    stateId: Option[OperatorStateId],
+    stateDeserializer: Expression,
+    stateSerializer: Seq[NamedExpression],
+    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  /** Distribute by grouping attributes */
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(groupingAttributes) :: Nil
+
+  /** Ordering needed for using GroupingIterator */
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsWithStateStore[InternalRow](
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      groupingAttributes.toStructType,
+      child.output.toStructType,
+      sqlContext.sessionState,
+      Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+        val numTotalStateRows = longMetric("numTotalStateRows")
+        val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+        val numOutputRows = longMetric("numOutputRows")
+
+        // Generate a iterator that returns the rows grouped by the grouping function
+        val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
+
+        // Converters to and from object and rows
+        val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+        val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+        val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+        val getStateObj =
+          ObjectOperator.deserializeRowToObject(stateDeserializer)
+        val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer)
+
+        // For every group, get the key, values and corresponding state and call the function,
+        // and return an iterator of rows
+        val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) =>
+
+          val key = keyRow.asInstanceOf[UnsafeRow]
+          val keyObj = getKeyObj(keyRow)                         // convert key to objects
+          val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
+          val stateObjOption = store.get(key).map(getStateObj)   // get existing state if any
+          val wrappedState = new KeyedStateImpl(stateObjOption)
+          val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj =>
+            numOutputRows += 1
+            getOutputRow(obj) // convert back to rows
+          }
+
+          // Return an iterator of rows generated this key,
+          // such that fully consumed, the updated state value will be saved
+          CompletionIterator[InternalRow, Iterator[InternalRow]](
+            mappedIterator, {
+              // When the iterator is consumed, then write changes to state
+              if (wrappedState.isRemoved) {
+                store.remove(key)
+                numUpdatedStateRows += 1
+              } else if (wrappedState.isUpdated) {
+                store.put(key, outputStateObj(wrappedState.get))
+                numUpdatedStateRows += 1
+              }
+            })
+        }
+
+        // Return an iterator of all the rows generated by all the keys, such that when fully
+        // consumer, all the state updates will be committed by the state store
+        CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, {
+          store.commit()
+          numTotalStateRows += store.numKeys()
+        })
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 8304b72..5ef4e88 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -225,6 +225,38 @@ public class JavaDatasetSuite implements Serializable {
 
     Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
 
+    Dataset<String> mapped2 = grouped.mapGroupsWithState(
+      new MapGroupsWithStateFunction<Integer, String, Long, String>() {
+        @Override
+        public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (values.hasNext()) {
+            sb.append(values.next());
+          }
+          return sb.toString();
+        }
+        },
+        Encoders.LONG(),
+        Encoders.STRING());
+
+    Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList()));
+
+    Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
+      new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() {
+        @Override
+        public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (values.hasNext()) {
+            sb.append(values.next());
+          }
+          return Collections.singletonList(sb.toString()).iterator();
+        }
+      },
+      Encoders.LONG(),
+      Encoders.STRING());
+
+    Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
+
     Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
       @Override
       public String call(String v1, String v2) throws Exception {

http://git-wip-us.apache.org/repos/asf/spark/blob/502c927b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
new file mode 100644
index 0000000..0524898
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.KeyedState
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.StateStore
+
+/** Class to check custom state types */
+case class RunningCount(count: Long)
+
+class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
+
+  import testImplicits._
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    StateStore.stop()
+  }
+
+  test("KeyedState - get, exists, update, remove") {
+    var state: KeyedStateImpl[String] = null
+
+    def testState(
+        expectedData: Option[String],
+        shouldBeUpdated: Boolean = false,
+        shouldBeRemoved: Boolean = false): Unit = {
+      if (expectedData.isDefined) {
+        assert(state.exists)
+        assert(state.get === expectedData.get)
+      } else {
+        assert(!state.exists)
+        intercept[NoSuchElementException] {
+          state.get
+        }
+      }
+      assert(state.getOption === expectedData)
+      assert(state.isUpdated === shouldBeUpdated)
+      assert(state.isRemoved === shouldBeRemoved)
+    }
+
+    // Updating empty state
+    state = new KeyedStateImpl[String](None)
+    testState(None)
+    state.update("")
+    testState(Some(""), shouldBeUpdated = true)
+
+    // Updating exiting state
+    state = new KeyedStateImpl[String](Some("2"))
+    testState(Some("2"))
+    state.update("3")
+    testState(Some("3"), shouldBeUpdated = true)
+
+    // Removing state
+    state.remove()
+    testState(None, shouldBeRemoved = true, shouldBeUpdated = false)
+    state.remove()      // should be still callable
+    state.update("4")
+    testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true)
+
+    // Updating by null throw exception
+    intercept[IllegalArgumentException] {
+      state.update(null)
+    }
+  }
+
+  test("KeyedState - primitive type") {
+    var intState = new KeyedStateImpl[Int](None)
+    intercept[NoSuchElementException] {
+      intState.get
+    }
+    assert(intState.getOption === None)
+
+    intState = new KeyedStateImpl[Int](Some(10))
+    assert(intState.get == 10)
+    intState.update(0)
+    assert(intState.get == 0)
+    intState.remove()
+    intercept[NoSuchElementException] {
+      intState.get
+    }
+  }
+
+  test("flatMapGroupsWithState - streaming") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count if state is defined, otherwise does not return anything
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      if (count == 3) {
+        state.remove()
+        Iterator.empty
+      } else {
+        state.update(RunningCount(count))
+        Iterator((key, count.toString))
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+    testStream(result, Append)(
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", "1")),
+      assertNumStateRows(total = 1, updated = 1),
+      AddData(inputData, "a", "b"),
+      CheckLastBatch(("a", "2"), ("b", "1")),
+      assertNumStateRows(total = 2, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+      CheckLastBatch(("b", "2")),
+      assertNumStateRows(total = 1, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+      CheckLastBatch(("a", "1"), ("c", "1")),
+      assertNumStateRows(total = 3, updated = 2)
+    )
+  }
+
+  test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count if state is defined, otherwise does not return anything
+    // Additionally, it updates state lazily as the returned iterator get consumed
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      values.flatMap { _ =>
+        val count = state.getOption.map(_.count).getOrElse(0L) + 1
+        if (count == 3) {
+          state.remove()
+          None
+        } else {
+          state.update(RunningCount(count))
+          Some((key, count.toString))
+        }
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+    testStream(result, Append)(
+      AddData(inputData, "a", "a", "b"),
+      CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+      CheckLastBatch(("b", "2")),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+      CheckLastBatch(("a", "1"), ("c", "1"))
+    )
+  }
+
+  test("flatMapGroupsWithState - batch") {
+    // Function that returns running count only if its even, otherwise does not return
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+      Iterator((key, values.size))
+    }
+    checkAnswer(
+      Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+      Seq(("a", 2), ("b", 1)).toDF)
+  }
+
+  test("mapGroupsWithState - streaming") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      if (count == 3) {
+        state.remove()
+        (key, "-1")
+      } else {
+        state.update(RunningCount(count))
+        (key, count.toString)
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+    testStream(result, Append)(
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", "1")),
+      assertNumStateRows(total = 1, updated = 1),
+      AddData(inputData, "a", "b"),
+      CheckLastBatch(("a", "2"), ("b", "1")),
+      assertNumStateRows(total = 2, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
+      CheckLastBatch(("a", "-1"), ("b", "2")),
+      assertNumStateRows(total = 1, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
+      CheckLastBatch(("a", "1"), ("c", "1")),
+      assertNumStateRows(total = 3, updated = 2)
+    )
+  }
+
+  test("mapGroupsWithState - streaming + aggregation") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      if (count == 3) {
+        state.remove()
+        (key, "-1")
+      } else {
+        state.update(RunningCount(count))
+        (key, count.toString)
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+        .groupByKey(_._1)
+        .count()
+
+    testStream(result, Complete)(
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", 1)),
+      AddData(inputData, "a", "b"),
+      // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
+      CheckLastBatch(("a", 2), ("b", 1)),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"),
+      // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
+      // so increment a and b by 1
+      CheckLastBatch(("a", 3), ("b", 2)),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"),
+      // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
+      // so increment a and c by 1
+      CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
+    )
+  }
+
+  test("mapGroupsWithState - batch") {
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+      (key, values.size)
+    }
+
+    checkAnswer(
+      spark.createDataset(Seq("a", "a", "b"))
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc)
+        .toDF,
+      spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
+  }
+
+  testQuietly("StateStore.abort on task failure handling") {
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      state.update(RunningCount(count))
+      (key, count)
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+    def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
+      MapGroupsWithStateSuite.failInTask = value
+      true
+    }
+
+    testStream(result, Append)(
+      setFailInTask(false),
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", 1L)),
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", 2L)),
+      setFailInTask(true),
+      AddData(inputData, "a"),
+      ExpectFailure[SparkException](),   // task should fail but should not increment count
+      setFailInTask(false),
+      StartStream(),
+      CheckLastBatch(("a", 3L))     // task should not fail, and should show correct count
+    )
+  }
+
+  private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
+    val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
+    assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
+    assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
+    true
+  }
+}
+
+object MapGroupsWithStateSuite {
+  var failInTask = true
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org