You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2016/03/23 20:48:12 UTC

spark git commit: [SPARK-13809][SQL] State store for streaming aggregations

Repository: spark
Updated Branches:
  refs/heads/master 0a64294fc -> 8c826880f


[SPARK-13809][SQL] State store for streaming aggregations

## What changes were proposed in this pull request?

In this PR, I am implementing a new abstraction for management of streaming state data - State Store. It is a key-value store for persisting running aggregates for aggregate operations in streaming dataframes. The motivation and design is discussed here.

https://docs.google.com/document/d/1-ncawFx8JS5Zyfq1HAEGBx56RDet9wfVp_hDM8ZL254/edit#

## How was this patch tested?
- [x] Unit tests
- [x] Cluster tests

**Coverage from unit tests**

<img width="952" alt="screen shot 2016-03-21 at 3 09 40 pm" src="https://cloud.githubusercontent.com/assets/663212/13935872/fdc8ba86-ef76-11e5-93e8-9fa310472c7b.png">

## TODO
- [x] Fix updates() iterator to avoid duplicate updates for same key
- [x] Use Coordinator in ContinuousQueryManager
- [x] Plugging in hadoop conf and other confs
- [x] Unit tests
  - [x] StateStore object lifecycle and methods
  - [x] StateStoreCoordinator communication and logic
  - [x] StateStoreRDD fault-tolerance
  - [x] StateStoreRDD preferred location using StateStoreCoordinator
- [ ] Cluster tests
  - [ ] Whether preferred locations are set correctly
  - [ ] Whether recovery works correctly with distributed storage
  - [x] Basic performance tests
- [x] Docs

Author: Tathagata Das <ta...@gmail.com>

Closes #11645 from tdas/state-store.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8c826880
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8c826880
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8c826880

Branch: refs/heads/master
Commit: 8c826880f5eaa3221c4e9e7d3fece54e821a0b98
Parents: 0a64294
Author: Tathagata Das <ta...@gmail.com>
Authored: Wed Mar 23 12:48:05 2016 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Wed Mar 23 12:48:05 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/ContinuousQueryManager.scala      |   3 +
 .../state/HDFSBackedStateStoreProvider.scala    | 584 +++++++++++++++++++
 .../execution/streaming/state/StateStore.scala  | 247 ++++++++
 .../streaming/state/StateStoreConf.scala        |  37 ++
 .../streaming/state/StateStoreCoordinator.scala | 146 +++++
 .../streaming/state/StateStoreRDD.scala         |  70 +++
 .../sql/execution/streaming/state/package.scala |  75 +++
 .../org/apache/spark/sql/internal/SQLConf.scala |  13 +
 .../state/StateStoreCoordinatorSuite.scala      | 123 ++++
 .../streaming/state/StateStoreRDDSuite.scala    | 192 ++++++
 .../streaming/state/StateStoreSuite.scala       | 562 ++++++++++++++++++
 11 files changed, 2052 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index fa8219b..465feeb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
+import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
 import org.apache.spark.sql.util.ContinuousQueryListener
 
 /**
@@ -33,6 +34,8 @@ import org.apache.spark.sql.util.ContinuousQueryListener
 @Experimental
 class ContinuousQueryManager(sqlContext: SQLContext) {
 
+  private[sql] val stateStoreCoordinator =
+    StateStoreCoordinatorRef.forDriver(sqlContext.sparkContext.env)
   private val listenerBus = new ContinuousQueryListenerBus(sqlContext.sparkContext.listenerBus)
   private val activeQueries = new mutable.HashMap[String, ContinuousQuery]
   private val activeQueriesLock = new Object

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
new file mode 100644
index 0000000..ee015ba
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.{DataInputStream, DataOutputStream, IOException}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.Random
+import scala.util.control.NonFatal
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileStatus, Path}
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.io.LZ4CompressionCodec
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+
+/**
+ * An implementation of [[StateStoreProvider]] and [[StateStore]] in which all the data is backed
+ * by files in a HDFS-compatible file system. All updates to the store has to be done in sets
+ * transactionally, and each set of updates increments the store's version. These versions can
+ * be used to re-execute the updates (by retries in RDD operations) on the correct version of
+ * the store, and regenerate the store version.
+ *
+ * Usage:
+ * To update the data in the state store, the following order of operations are needed.
+ *
+ * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store
+ * - store.update(...)
+ * - store.remove(...)
+ * - store.commit()    // commits all the updates to made with version number
+ * - store.iterator()  // key-value data after last commit as an iterator
+ * - store.updates()   // updates made in the last as an iterator
+ *
+ * Fault-tolerance model:
+ * - Every set of updates is written to a delta file before committing.
+ * - The state store is responsible for managing, collapsing and cleaning up of delta files.
+ * - Multiple attempts to commit the same version of updates may overwrite each other.
+ *   Consistency guarantees depend on whether multiple attempts have the same updates and
+ *   the overwrite semantics of underlying file system.
+ * - Background maintenance of files ensures that last versions of the store is always recoverable
+ * to ensure re-executed RDD operations re-apply updates on the correct past version of the
+ * store.
+ */
+private[state] class HDFSBackedStateStoreProvider(
+    val id: StateStoreId,
+    keySchema: StructType,
+    valueSchema: StructType,
+    storeConf: StateStoreConf,
+    hadoopConf: Configuration
+  ) extends StateStoreProvider with Logging {
+
+  type MapType = java.util.HashMap[UnsafeRow, UnsafeRow]
+
+  /** Implementation of [[StateStore]] API which is backed by a HDFS-compatible file system */
+  class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType)
+    extends StateStore {
+
+    /** Trait and classes representing the internal state of the store */
+    trait STATE
+    case object UPDATING extends STATE
+    case object COMMITTED extends STATE
+    case object CANCELLED extends STATE
+
+    private val newVersion = version + 1
+    private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
+    private val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true))
+
+    private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]()
+
+    @volatile private var state: STATE = UPDATING
+    @volatile private var finalDeltaFile: Path = null
+
+    override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
+
+    /**
+     * Update the value of a key using the value generated by the update function.
+     * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
+     *       versions of the store data.
+     */
+    override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot update after already committed or cancelled")
+      val oldValueOption = Option(mapToUpdate.get(key))
+      val value = updateFunc(oldValueOption)
+      mapToUpdate.put(key, value)
+
+      Option(allUpdates.get(key)) match {
+        case Some(ValueAdded(_, _)) =>
+          // Value did not exist in previous version and was added already, keep it marked as added
+          allUpdates.put(key, ValueAdded(key, value))
+        case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) =>
+          // Value existed in prev version and updated/removed, mark it as updated
+          allUpdates.put(key, ValueUpdated(key, value))
+        case None =>
+          // There was no prior update, so mark this as added or updated according to its presence
+          // in previous version.
+          val update =
+            if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value)
+          allUpdates.put(key, update)
+      }
+      writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
+    }
+
+    /** Remove keys that match the following condition */
+    override def remove(condition: UnsafeRow => Boolean): Unit = {
+      verify(state == UPDATING, "Cannot remove after already committed or cancelled")
+      val keyIter = mapToUpdate.keySet().iterator()
+      while (keyIter.hasNext) {
+        val key = keyIter.next
+        if (condition(key)) {
+          keyIter.remove()
+
+          Option(allUpdates.get(key)) match {
+            case Some(ValueUpdated(_, _)) | None =>
+              // Value existed in previous version and maybe was updated, mark removed
+              allUpdates.put(key, KeyRemoved(key))
+            case Some(ValueAdded(_, _)) =>
+              // Value did not exist in previous version and was added, should not appear in updates
+              allUpdates.remove(key)
+            case Some(KeyRemoved(_)) =>
+              // Remove already in update map, no need to change
+          }
+          writeToDeltaFile(tempDeltaFileStream, KeyRemoved(key))
+        }
+      }
+    }
+
+    /** Commit all the updates that have been made to the store, and return the new version. */
+    override def commit(): Long = {
+      verify(state == UPDATING, "Cannot commit again after already committed or cancelled")
+
+      try {
+        finalizeDeltaFile(tempDeltaFileStream)
+        finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile)
+        state = COMMITTED
+        logInfo(s"Committed version $newVersion for $this")
+        newVersion
+      } catch {
+        case NonFatal(e) =>
+          throw new IllegalStateException(
+            s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e)
+      }
+    }
+
+    /** Cancel all the updates made on this store. This store will not be usable any more. */
+    override def cancel(): Unit = {
+      state = CANCELLED
+      if (tempDeltaFileStream != null) {
+        tempDeltaFileStream.close()
+      }
+      if (tempDeltaFile != null && fs.exists(tempDeltaFile)) {
+        fs.delete(tempDeltaFile, true)
+      }
+      logInfo("Canceled ")
+    }
+
+    /**
+     * Get an iterator of all the store data. This can be called only after committing the
+     * updates.
+     */
+    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
+      verify(state == COMMITTED, "Cannot get iterator of store data before comitting")
+      HDFSBackedStateStoreProvider.this.iterator(newVersion)
+    }
+
+    /**
+     * Get an iterator of all the updates made to the store in the current version.
+     * This can be called only after committing the updates.
+     */
+    override def updates(): Iterator[StoreUpdate] = {
+      verify(state == COMMITTED, "Cannot get iterator of updates before committing")
+      allUpdates.values().asScala.toIterator
+    }
+
+    /**
+     * Whether all updates have been committed
+     */
+    override def hasCommitted: Boolean = {
+      state == COMMITTED
+    }
+  }
+
+  /** Get the state store for making updates to create a new `version` of the store. */
+  override def getStore(version: Long): StateStore = synchronized {
+    require(version >= 0, "Version cannot be less than 0")
+    val newMap = new MapType()
+    if (version > 0) {
+      newMap.putAll(loadMap(version))
+    }
+    val store = new HDFSBackedStateStore(version, newMap)
+    logInfo(s"Retrieved version $version of $this for update")
+    store
+  }
+
+  /** Do maintenance backing data files, including creating snapshots and cleaning up old files */
+  override def doMaintenance(): Unit = {
+    try {
+      doSnapshot()
+      cleanup()
+    } catch {
+      case NonFatal(e) =>
+        logWarning(s"Error performing snapshot and cleaning up $this")
+    }
+  }
+
+  override def toString(): String = {
+    s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]"
+  }
+
+  /* Internal classes and methods */
+
+  private val loadedMaps = new mutable.HashMap[Long, MapType]
+  private val baseDir =
+    new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
+  private val fs = baseDir.getFileSystem(hadoopConf)
+  private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+
+  initialize()
+
+  private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
+
+  /** Commit a set of updates to the store with the given new version */
+  private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = {
+    synchronized {
+      val finalDeltaFile = deltaFile(newVersion)
+      fs.rename(tempDeltaFile, finalDeltaFile)
+      loadedMaps.put(newVersion, map)
+      finalDeltaFile
+    }
+  }
+
+  /**
+   * Get iterator of all the data of the latest version of the store.
+   * Note that this will look up the files to determined the latest known version.
+   */
+  private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
+    val versionsInFiles = fetchFiles().map(_.version).toSet
+    val versionsLoaded = loadedMaps.keySet
+    val allKnownVersions = versionsInFiles ++ versionsLoaded
+    if (allKnownVersions.nonEmpty) {
+      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x =>
+        (x.getKey, x.getValue)
+      }
+    } else Iterator.empty
+  }
+
+  /** Get iterator of a specific version of the store */
+  private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
+    loadMap(version).entrySet().iterator().asScala.map { x =>
+      (x.getKey, x.getValue)
+    }
+  }
+
+  /** Initialize the store provider */
+  private def initialize(): Unit = {
+    if (!fs.exists(baseDir)) {
+      fs.mkdirs(baseDir)
+    } else {
+      if (!fs.isDirectory(baseDir)) {
+        throw new IllegalStateException(
+          s"Cannot use ${id.checkpointLocation} for storing state data for $this as" +
+            s"$baseDir already exists and is not a directory")
+      }
+    }
+  }
+
+  /** Load the required version of the map data from the backing files */
+  private def loadMap(version: Long): MapType = {
+    if (version <= 0) return new MapType
+    synchronized { loadedMaps.get(version) }.getOrElse {
+      val mapFromFile = readSnapshotFile(version).getOrElse {
+        val prevMap = loadMap(version - 1)
+        val newMap = new MapType(prevMap)
+        newMap.putAll(prevMap)
+        updateFromDeltaFile(version, newMap)
+        newMap
+      }
+      loadedMaps.put(version, mapFromFile)
+      mapFromFile
+    }
+  }
+
+  private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = {
+
+    def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = {
+      val keyBytes = key.getBytes()
+      val valueBytes = value.getBytes()
+      output.writeInt(keyBytes.size)
+      output.write(keyBytes)
+      output.writeInt(valueBytes.size)
+      output.write(valueBytes)
+    }
+
+    def writeRemove(key: UnsafeRow): Unit = {
+      val keyBytes = key.getBytes()
+      output.writeInt(keyBytes.size)
+      output.write(keyBytes)
+      output.writeInt(-1)
+    }
+
+    update match {
+      case ValueAdded(key, value) =>
+        writeUpdate(key, value)
+      case ValueUpdated(key, value) =>
+        writeUpdate(key, value)
+      case KeyRemoved(key) =>
+        writeRemove(key)
+    }
+  }
+
+  private def finalizeDeltaFile(output: DataOutputStream): Unit = {
+    output.writeInt(-1)  // Write this magic number to signify end of file
+    output.close()
+  }
+
+  private def updateFromDeltaFile(version: Long, map: MapType): Unit = {
+    val fileToRead = deltaFile(version)
+    if (!fs.exists(fileToRead)) {
+      throw new IllegalStateException(
+        s"Error reading delta file $fileToRead of $this: $fileToRead does not exist")
+    }
+    var input: DataInputStream = null
+    try {
+      input = decompressStream(fs.open(fileToRead))
+      var eof = false
+
+      while(!eof) {
+        val keySize = input.readInt()
+        if (keySize == -1) {
+          eof = true
+        } else if (keySize < 0) {
+          throw new IOException(
+            s"Error reading delta file $fileToRead of $this: key size cannot be $keySize")
+        } else {
+          val keyRowBuffer = new Array[Byte](keySize)
+          ByteStreams.readFully(input, keyRowBuffer, 0, keySize)
+
+          val keyRow = new UnsafeRow(keySchema.fields.length)
+          keyRow.pointTo(keyRowBuffer, keySize)
+
+          val valueSize = input.readInt()
+          if (valueSize < 0) {
+            map.remove(keyRow)
+          } else {
+            val valueRowBuffer = new Array[Byte](valueSize)
+            ByteStreams.readFully(input, valueRowBuffer, 0, valueSize)
+            val valueRow = new UnsafeRow(valueSchema.fields.length)
+            valueRow.pointTo(valueRowBuffer, valueSize)
+            map.put(keyRow, valueRow)
+          }
+        }
+      }
+    } finally {
+      if (input != null) input.close()
+    }
+    logInfo(s"Read delta file for version $version of $this from $fileToRead")
+  }
+
+  private def writeSnapshotFile(version: Long, map: MapType): Unit = {
+    val fileToWrite = snapshotFile(version)
+    var output: DataOutputStream = null
+    Utils.tryWithSafeFinally {
+      output = compressStream(fs.create(fileToWrite, false))
+      val iter = map.entrySet().iterator()
+      while(iter.hasNext) {
+        val entry = iter.next()
+        val keyBytes = entry.getKey.getBytes()
+        val valueBytes = entry.getValue.getBytes()
+        output.writeInt(keyBytes.size)
+        output.write(keyBytes)
+        output.writeInt(valueBytes.size)
+        output.write(valueBytes)
+      }
+      output.writeInt(-1)
+    } {
+      if (output != null) output.close()
+    }
+    logInfo(s"Written snapshot file for version $version of $this at $fileToWrite")
+  }
+
+  private def readSnapshotFile(version: Long): Option[MapType] = {
+    val fileToRead = snapshotFile(version)
+    if (!fs.exists(fileToRead)) return None
+
+    val map = new MapType()
+    var input: DataInputStream = null
+
+    try {
+      input = decompressStream(fs.open(fileToRead))
+      var eof = false
+
+      while (!eof) {
+        val keySize = input.readInt()
+        if (keySize == -1) {
+          eof = true
+        } else if (keySize < 0) {
+          throw new IOException(
+            s"Error reading snapshot file $fileToRead of $this: key size cannot be $keySize")
+        } else {
+          val keyRowBuffer = new Array[Byte](keySize)
+          ByteStreams.readFully(input, keyRowBuffer, 0, keySize)
+
+          val keyRow = new UnsafeRow(keySchema.fields.length)
+          keyRow.pointTo(keyRowBuffer, keySize)
+
+          val valueSize = input.readInt()
+          if (valueSize < 0) {
+            throw new IOException(
+              s"Error reading snapshot file $fileToRead of $this: value size cannot be $valueSize")
+          } else {
+            val valueRowBuffer = new Array[Byte](valueSize)
+            ByteStreams.readFully(input, valueRowBuffer, 0, valueSize)
+            val valueRow = new UnsafeRow(valueSchema.fields.length)
+            valueRow.pointTo(valueRowBuffer, valueSize)
+            map.put(keyRow, valueRow)
+          }
+        }
+      }
+      logInfo(s"Read snapshot file for version $version of $this from $fileToRead")
+      Some(map)
+    } finally {
+      if (input != null) input.close()
+    }
+  }
+
+
+  /** Perform a snapshot of the store to allow delta files to be consolidated */
+  private def doSnapshot(): Unit = {
+    try {
+      val files = fetchFiles()
+      if (files.nonEmpty) {
+        val lastVersion = files.last.version
+        val deltaFilesForLastVersion =
+          filesForVersion(files, lastVersion).filter(_.isSnapshot == false)
+        synchronized { loadedMaps.get(lastVersion) } match {
+          case Some(map) =>
+            if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) {
+              writeSnapshotFile(lastVersion, map)
+            }
+          case None =>
+            // The last map is not loaded, probably some other instance is incharge
+        }
+
+      }
+    } catch {
+      case NonFatal(e) =>
+        logWarning(s"Error doing snapshots for $this", e)
+    }
+  }
+
+  /**
+    * Clean up old snapshots and delta files that are not needed any more. It ensures that last
+    * few versions of the store can be recovered from the files, so re-executed RDD operations
+    * can re-apply updates on the past versions of the store.
+    */
+  private[state] def cleanup(): Unit = {
+    try {
+      val files = fetchFiles()
+      if (files.nonEmpty) {
+        val earliestVersionToRetain = files.last.version - storeConf.minVersionsToRetain
+        if (earliestVersionToRetain > 0) {
+          val earliestFileToRetain = filesForVersion(files, earliestVersionToRetain).head
+          synchronized {
+            val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq
+            mapsToRemove.foreach(loadedMaps.remove)
+          }
+          files.filter(_.version < earliestFileToRetain.version).foreach { f =>
+            fs.delete(f.path, true)
+          }
+          logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this")
+        }
+      }
+    } catch {
+      case NonFatal(e) =>
+        logWarning(s"Error cleaning up files for $this", e)
+    }
+  }
+
+  /** Files needed to recover the given version of the store */
+  private def filesForVersion(allFiles: Seq[StoreFile], version: Long): Seq[StoreFile] = {
+    require(version >= 0)
+    require(allFiles.exists(_.version == version))
+
+    val latestSnapshotFileBeforeVersion = allFiles
+      .filter(_.isSnapshot == true)
+      .takeWhile(_.version <= version)
+      .lastOption
+    val deltaBatchFiles = latestSnapshotFileBeforeVersion match {
+      case Some(snapshotFile) =>
+        val deltaBatchIds = (snapshotFile.version + 1) to version
+
+        val deltaFiles = allFiles.filter { file =>
+          file.version > snapshotFile.version && file.version <= version
+        }
+        verify(
+          deltaFiles.size == version - snapshotFile.version,
+          s"Unexpected list of delta files for version $version for $this: $deltaFiles"
+        )
+        deltaFiles
+
+      case None =>
+        allFiles.takeWhile(_.version <= version)
+    }
+    latestSnapshotFileBeforeVersion.toSeq ++ deltaBatchFiles
+  }
+
+  /** Fetch all the files that back the store */
+  private def fetchFiles(): Seq[StoreFile] = {
+    val files: Seq[FileStatus] = try {
+      fs.listStatus(baseDir)
+    } catch {
+      case _: java.io.FileNotFoundException =>
+        Seq.empty
+    }
+    val versionToFiles = new mutable.HashMap[Long, StoreFile]
+    files.foreach { status =>
+      val path = status.getPath
+      val nameParts = path.getName.split("\\.")
+      if (nameParts.size == 2) {
+        val version = nameParts(0).toLong
+        nameParts(1).toLowerCase match {
+          case "delta" =>
+            // ignore the file otherwise, snapshot file already exists for that batch id
+            if (!versionToFiles.contains(version)) {
+              versionToFiles.put(version, StoreFile(version, path, isSnapshot = false))
+            }
+          case "snapshot" =>
+            versionToFiles.put(version, StoreFile(version, path, isSnapshot = true))
+          case _ =>
+            logWarning(s"Could not identify file $path for $this")
+        }
+      }
+    }
+    val storeFiles = versionToFiles.values.toSeq.sortBy(_.version)
+    logDebug(s"Current set of files for $this: $storeFiles")
+    storeFiles
+  }
+
+  private def compressStream(outputStream: DataOutputStream): DataOutputStream = {
+    val compressed = new LZ4CompressionCodec(sparkConf).compressedOutputStream(outputStream)
+    new DataOutputStream(compressed)
+  }
+
+  private def decompressStream(inputStream: DataInputStream): DataInputStream = {
+    val compressed = new LZ4CompressionCodec(sparkConf).compressedInputStream(inputStream)
+    new DataInputStream(compressed)
+  }
+
+  private def deltaFile(version: Long): Path = {
+    new Path(baseDir, s"$version.delta")
+  }
+
+  private def snapshotFile(version: Long): Path = {
+    new Path(baseDir, s"$version.snapshot")
+  }
+
+  private def verify(condition: => Boolean, msg: String): Unit = {
+    if (!condition) {
+      throw new IllegalStateException(msg)
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
new file mode 100644
index 0000000..ca5c864
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.Timer
+import java.util.concurrent.{ScheduledFuture, TimeUnit}
+
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
+
+
+/** Unique identifier for a [[StateStore]] */
+case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int)
+
+
+/**
+ * Base trait for a versioned key-value store used for streaming aggregations
+ */
+trait StateStore {
+
+  /** Unique identifier of the store */
+  def id: StateStoreId
+
+  /** Version of the data in this store before committing updates. */
+  def version: Long
+
+  /**
+   * Update the value of a key using the value generated by the update function.
+   * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
+   *       versions of the store data.
+   */
+  def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
+
+  /**
+   * Remove keys that match the following condition.
+   */
+  def remove(condition: UnsafeRow => Boolean): Unit
+
+  /**
+   * Commit all the updates that have been made to the store, and return the new version.
+   */
+  def commit(): Long
+
+  /** Cancel all the updates that have been made to the store. */
+  def cancel(): Unit
+
+  /**
+   * Iterator of store data after a set of updates have been committed.
+   * This can be called only after commitUpdates() has been called in the current thread.
+   */
+  def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
+
+  /**
+   * Iterator of the updates that have been committed.
+   * This can be called only after commitUpdates() has been called in the current thread.
+   */
+  def updates(): Iterator[StoreUpdate]
+
+  /**
+   * Whether all updates have been committed
+   */
+  def hasCommitted: Boolean
+}
+
+
+/** Trait representing a provider of a specific version of a [[StateStore]]. */
+trait StateStoreProvider {
+
+  /** Get the store with the existing version. */
+  def getStore(version: Long): StateStore
+
+  /** Optional method for providers to allow for background maintenance */
+  def doMaintenance(): Unit = { }
+}
+
+
+/** Trait representing updates made to a [[StateStore]]. */
+sealed trait StoreUpdate
+
+case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+
+case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+
+case class KeyRemoved(key: UnsafeRow) extends StoreUpdate
+
+
+/**
+ * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
+ * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
+ * it also runs a periodic background tasks to do maintenance on the loaded stores. For each
+ * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
+ * the store is the active instance. Accordingly, it either keeps it loaded and performs
+ * maintenance, or unloads the store.
+ */
+private[state] object StateStore extends Logging {
+
+  val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval"
+  val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
+
+  private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
+  private val maintenanceTaskExecutor =
+    ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
+
+  @volatile private var maintenanceTask: ScheduledFuture[_] = null
+  @volatile private var _coordRef: StateStoreCoordinatorRef = null
+
+  /** Get or create a store associated with the id. */
+  def get(
+      storeId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      version: Long,
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): StateStore = {
+    require(version >= 0)
+    val storeProvider = loadedProviders.synchronized {
+      startMaintenanceIfNeeded()
+      val provider = loadedProviders.getOrElseUpdate(
+        storeId,
+        new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf))
+      reportActiveStoreInstance(storeId)
+      provider
+    }
+    storeProvider.getStore(version)
+  }
+
+  /** Unload a state store provider */
+  def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
+    loadedProviders.remove(storeId)
+  }
+
+  /** Whether a state store provider is loaded or not */
+  def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
+    loadedProviders.contains(storeId)
+  }
+
+  /** Unload and stop all state store providers */
+  def stop(): Unit = loadedProviders.synchronized {
+    loadedProviders.clear()
+    _coordRef = null
+    if (maintenanceTask != null) {
+      maintenanceTask.cancel(false)
+      maintenanceTask = null
+    }
+    logInfo("StateStore stopped")
+  }
+
+  /** Start the periodic maintenance task if not already started and if Spark active */
+  private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized {
+    val env = SparkEnv.get
+    if (maintenanceTask == null && env != null) {
+      val periodMs = env.conf.getTimeAsMs(
+        MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s")
+      val runnable = new Runnable {
+        override def run(): Unit = { doMaintenance() }
+      }
+      maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate(
+        runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
+      logInfo("State Store maintenance task started")
+    }
+  }
+
+  /**
+   * Execute background maintenance task in all the loaded store providers if they are still
+   * the active instances according to the coordinator.
+   */
+  private def doMaintenance(): Unit = {
+    logDebug("Doing maintenance")
+    loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
+      try {
+        if (verifyIfStoreInstanceActive(id)) {
+          provider.doMaintenance()
+        } else {
+          unload(id)
+          logInfo(s"Unloaded $provider")
+        }
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Error managing $provider")
+      }
+    }
+  }
+
+  private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
+    try {
+      val host = SparkEnv.get.blockManager.blockManagerId.host
+      val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
+      coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
+      logDebug(s"Reported that the loaded instance $storeId is active")
+    } catch {
+      case NonFatal(e) =>
+        logWarning(s"Error reporting active instance of $storeId")
+    }
+  }
+
+  private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
+    try {
+      val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
+      val verified =
+        coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
+      logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" )
+      verified
+    } catch {
+      case NonFatal(e) =>
+        logWarning(s"Error verifying active instance of $storeId")
+        false
+    }
+  }
+
+  private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized {
+    val env = SparkEnv.get
+    if (env != null) {
+      if (_coordRef == null) {
+        _coordRef = StateStoreCoordinatorRef.forExecutor(env)
+      }
+      logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
+      Some(_coordRef)
+    } else {
+      _coordRef = null
+      None
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
new file mode 100644
index 0000000..cca22a0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.internal.SQLConf
+
+/** A class that contains configuration parameters for [[StateStore]]s. */
+private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+
+  def this() = this(new SQLConf)
+
+  import SQLConf._
+
+  val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
+
+  val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
+}
+
+private[state] object StateStoreConf {
+  val empty = new StateStoreConf()
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
new file mode 100644
index 0000000..5aa0636
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.collection.mutable
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.util.RpcUtils
+
+/** Trait representing all messages to [[StateStoreCoordinator]] */
+private sealed trait StateStoreCoordinatorMessage extends Serializable
+
+/** Classes representing messages */
+private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String)
+  extends StateStoreCoordinatorMessage
+
+private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String)
+  extends StateStoreCoordinatorMessage
+
+private case class GetLocation(storeId: StateStoreId)
+  extends StateStoreCoordinatorMessage
+
+private case class DeactivateInstances(storeRootLocation: String)
+  extends StateStoreCoordinatorMessage
+
+private object StopCoordinator
+  extends StateStoreCoordinatorMessage
+
+/** Helper object used to create reference to [[StateStoreCoordinator]]. */
+private[sql] object StateStoreCoordinatorRef extends Logging {
+
+  private val endpointName = "StateStoreCoordinator"
+
+  /**
+   * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as
+   * executors.
+   */
+  def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
+    try {
+      val coordinator = new StateStoreCoordinator(env.rpcEnv)
+      val coordinatorRef = env.rpcEnv.setupEndpoint(endpointName, coordinator)
+      logInfo("Registered StateStoreCoordinator endpoint")
+      new StateStoreCoordinatorRef(coordinatorRef)
+    } catch {
+      case e: IllegalArgumentException =>
+        val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv)
+        logDebug("Retrieved existing StateStoreCoordinator endpoint")
+        new StateStoreCoordinatorRef(rpcEndpointRef)
+    }
+  }
+
+  def forExecutor(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
+    val rpcEndpointRef = RpcUtils.makeDriverRef(endpointName, env.conf, env.rpcEnv)
+    logDebug("Retrieved existing StateStoreCoordinator endpoint")
+    new StateStoreCoordinatorRef(rpcEndpointRef)
+  }
+}
+
+/**
+ * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of
+ * [[StateStore]]s across all the executors, and get their locations for job scheduling.
+ */
+private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
+
+  private[state] def reportActiveInstance(
+      storeId: StateStoreId,
+      host: String,
+      executorId: String): Unit = {
+    rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId))
+  }
+
+  /** Verify whether the given executor has the active instance of a state store */
+  private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = {
+    rpcEndpointRef.askWithRetry[Boolean](VerifyIfInstanceActive(storeId, executorId))
+  }
+
+  /** Get the location of the state store */
+  private[state] def getLocation(storeId: StateStoreId): Option[String] = {
+    rpcEndpointRef.askWithRetry[Option[String]](GetLocation(storeId))
+  }
+
+  /** Deactivate instances related to a set of operator */
+  private[state] def deactivateInstances(storeRootLocation: String): Unit = {
+    rpcEndpointRef.askWithRetry[Boolean](DeactivateInstances(storeRootLocation))
+  }
+
+  private[state] def stop(): Unit = {
+    rpcEndpointRef.askWithRetry[Boolean](StopCoordinator)
+  }
+}
+
+
+/**
+ * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
+ * and get their locations for job scheduling.
+ */
+private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint {
+  private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]
+
+  override def receive: PartialFunction[Any, Unit] = {
+    case ReportActiveInstance(id, host, executorId) =>
+      instances.put(id, ExecutorCacheTaskLocation(host, executorId))
+  }
+
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+    case VerifyIfInstanceActive(id, execId) =>
+      val response = instances.get(id) match {
+        case Some(location) => location.executorId == execId
+        case None => false
+      }
+      context.reply(response)
+
+    case GetLocation(id) =>
+      context.reply(instances.get(id).map(_.toString))
+
+    case DeactivateInstances(loc) =>
+      val storeIdsToRemove =
+        instances.keys.filter(_.checkpointLocation == loc).toSeq
+      instances --= storeIdsToRemove
+      context.reply(true)
+
+    case StopCoordinator =>
+      stop() // Stop before replying to ensure that endpoint name has been deregistered
+      context.reply(true)
+  }
+}
+
+

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
new file mode 100644
index 0000000..3318660
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+
+/**
+ * An RDD that allows computations to be executed against [[StateStore]]s. It
+ * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as
+ * preferred locations.
+ */
+class StateStoreRDD[T: ClassTag, U: ClassTag](
+    dataRDD: RDD[T],
+    storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+    checkpointLocation: String,
+    operatorId: Long,
+    storeVersion: Long,
+    keySchema: StructType,
+    valueSchema: StructType,
+    storeConf: StateStoreConf,
+    @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
+  extends RDD[U](dataRDD) {
+
+  // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
+  private val confBroadcast = dataRDD.context.broadcast(
+    new SerializableConfiguration(dataRDD.context.hadoopConfiguration))
+
+  override protected def getPartitions: Array[Partition] = dataRDD.partitions
+
+  override def getPreferredLocations(partition: Partition): Seq[String] = {
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
+  }
+
+  override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
+    var store: StateStore = null
+
+    Utils.tryWithSafeFinally {
+      val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+      store = StateStore.get(
+        storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+      val inputIter = dataRDD.iterator(partition, ctxt)
+      val outputIter = storeUpdateFunction(store, inputIter)
+      assert(store.hasCommitted)
+      outputIter
+    } {
+      if (store != null) store.cancel()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
new file mode 100644
index 0000000..b249e37
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.types.StructType
+
+package object state {
+
+  implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
+
+    /** Map each partition of a RDD along with data in a [[StateStore]]. */
+    def mapPartitionWithStateStore[U: ClassTag](
+        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+        checkpointLocation: String,
+        operatorId: Long,
+        storeVersion: Long,
+        keySchema: StructType,
+        valueSchema: StructType
+      )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
+
+      mapPartitionWithStateStore(
+        storeUpdateFunction,
+        checkpointLocation,
+        operatorId,
+        storeVersion,
+        keySchema,
+        valueSchema,
+        new StateStoreConf(sqlContext.conf),
+        Some(sqlContext.streams.stateStoreCoordinator))
+    }
+
+    /** Map each partition of a RDD along with data in a [[StateStore]]. */
+    private[state] def mapPartitionWithStateStore[U: ClassTag](
+        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+        checkpointLocation: String,
+        operatorId: Long,
+        storeVersion: Long,
+        keySchema: StructType,
+        valueSchema: StructType,
+        storeConf: StateStoreConf,
+        storeCoordinator: Option[StateStoreCoordinatorRef]
+      ): StateStoreRDD[T, U] = {
+      val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
+      new StateStoreRDD(
+        dataRDD,
+        cleanedF,
+        checkpointLocation,
+        operatorId,
+        storeVersion,
+        keySchema,
+        valueSchema,
+        storeConf,
+        storeCoordinator)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index fd1d77f..863a876 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -524,6 +524,19 @@ object SQLConf {
     doc = "When true, the planner will try to find out duplicated exchanges and re-use them.",
     isPublic = false)
 
+  val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf(
+    "spark.sql.streaming.stateStore.minDeltasForSnapshot",
+    defaultValue = Some(10),
+    doc = "Minimum number of state store delta files that needs to be generated before they " +
+      "consolidated into snapshots.",
+    isPublic = false)
+
+  val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf(
+    "spark.sql.streaming.stateStore.minBatchesToRetain",
+    defaultValue = Some(2),
+    doc = "Minimum number of versions of a state store's data to retain after cleaning.",
+    isPublic = false)
+
   val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation",
     defaultValue = None,
     doc = "The default location for storing checkpoint data for continuously executing queries.",

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
new file mode 100644
index 0000000..c99c2f5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+
+class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
+
+  import StateStoreCoordinatorSuite._
+
+  test("report, verify, getLocation") {
+    withCoordinatorRef(sc) { coordinatorRef =>
+      val id = StateStoreId("x", 0, 0)
+
+      assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
+      assert(coordinatorRef.getLocation(id) === None)
+
+      coordinatorRef.reportActiveInstance(id, "hostX", "exec1")
+      eventually(timeout(5 seconds)) {
+        assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === true)
+        assert(
+          coordinatorRef.getLocation(id) ===
+            Some(ExecutorCacheTaskLocation("hostX", "exec1").toString))
+      }
+
+      coordinatorRef.reportActiveInstance(id, "hostX", "exec2")
+
+      eventually(timeout(5 seconds)) {
+        assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
+        assert(coordinatorRef.verifyIfInstanceActive(id, "exec2") === true)
+
+        assert(
+          coordinatorRef.getLocation(id) ===
+            Some(ExecutorCacheTaskLocation("hostX", "exec2").toString))
+      }
+    }
+  }
+
+  test("make inactive") {
+    withCoordinatorRef(sc) { coordinatorRef =>
+      val id1 = StateStoreId("x", 0, 0)
+      val id2 = StateStoreId("y", 1, 0)
+      val id3 = StateStoreId("x", 0, 1)
+      val host = "hostX"
+      val exec = "exec1"
+
+      coordinatorRef.reportActiveInstance(id1, host, exec)
+      coordinatorRef.reportActiveInstance(id2, host, exec)
+      coordinatorRef.reportActiveInstance(id3, host, exec)
+
+      eventually(timeout(5 seconds)) {
+        assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === true)
+        assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
+        assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true)
+
+      }
+
+      coordinatorRef.deactivateInstances("x")
+
+      assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false)
+      assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
+      assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === false)
+
+      assert(coordinatorRef.getLocation(id1) === None)
+      assert(
+        coordinatorRef.getLocation(id2) ===
+          Some(ExecutorCacheTaskLocation(host, exec).toString))
+      assert(coordinatorRef.getLocation(id3) === None)
+
+      coordinatorRef.deactivateInstances("y")
+      assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false)
+      assert(coordinatorRef.getLocation(id2) === None)
+    }
+  }
+
+  test("multiple references have same underlying coordinator") {
+    withCoordinatorRef(sc) { coordRef1 =>
+      val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env)
+
+      val id = StateStoreId("x", 0, 0)
+
+      coordRef1.reportActiveInstance(id, "hostX", "exec1")
+
+      eventually(timeout(5 seconds)) {
+        assert(coordRef2.verifyIfInstanceActive(id, "exec1") === true)
+        assert(
+          coordRef2.getLocation(id) ===
+            Some(ExecutorCacheTaskLocation("hostX", "exec1").toString))
+      }
+    }
+  }
+}
+
+object StateStoreCoordinatorSuite {
+  def withCoordinatorRef(sc: SparkContext)(body: StateStoreCoordinatorRef => Unit): Unit = {
+    var coordinatorRef: StateStoreCoordinatorRef = null
+    try {
+      coordinatorRef = StateStoreCoordinatorRef.forDriver(sc.env)
+      body(coordinatorRef)
+    } finally {
+      if (coordinatorRef != null) coordinatorRef.stop()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
new file mode 100644
index 0000000..24cec30
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+import java.nio.file.Files
+
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.LocalSparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
+
+  private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
+  private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
+  private val keySchema = StructType(Seq(StructField("key", StringType, true)))
+  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+  import StateStoreSuite._
+
+  after {
+    StateStore.stop()
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    Utils.deleteRecursively(new File(tempDir))
+  }
+
+  test("versioning and immutability") {
+    quietly {
+      withSpark(new SparkContext(sparkConf)) { sc =>
+        implicit val sqlContet = new SQLContext(sc)
+        val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+        val increment = (store: StateStore, iter: Iterator[String]) => {
+          iter.foreach { s =>
+            store.update(
+              stringToRow(s), oldRow => {
+                val oldValue = oldRow.map(rowToInt).getOrElse(0)
+                intToRow(oldValue + 1)
+              })
+          }
+          store.commit()
+          store.iterator().map(rowsToStringInt)
+        }
+        val opId = 0
+        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
+          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+        // Generate next version of stores
+        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
+          increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+        assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+        // Make sure the previous RDD still has the same data.
+        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+      }
+    }
+  }
+
+  test("recovering from files") {
+    quietly {
+      val opId = 0
+      val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+      def makeStoreRDD(
+          sc: SparkContext,
+          seq: Seq[String],
+          storeVersion: Int): RDD[(String, Int)] = {
+        implicit val sqlContext = new SQLContext(sc)
+        makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
+          increment, path, opId, storeVersion, keySchema, valueSchema)
+      }
+
+      // Generate RDDs and state store data
+      withSpark(new SparkContext(sparkConf)) { sc =>
+        for (i <- 1 to 20) {
+          require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+        }
+      }
+
+      // With a new context, try using the earlier state store data
+      withSpark(new SparkContext(sparkConf)) { sc =>
+        assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+      }
+    }
+  }
+
+  test("preferred locations using StateStoreCoordinator") {
+    quietly {
+      val opId = 0
+      val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+      withSpark(new SparkContext(sparkConf)) { sc =>
+        implicit val sqlContext = new SQLContext(sc)
+        val coordinatorRef = sqlContext.streams.stateStoreCoordinator
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
+        assert(
+          coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+            Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
+
+        val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
+          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        require(rdd.partitions.length === 2)
+
+        assert(
+          rdd.preferredLocations(rdd.partitions(0)) ===
+            Seq(ExecutorCacheTaskLocation("host1", "exec1").toString))
+
+        assert(
+          rdd.preferredLocations(rdd.partitions(1)) ===
+            Seq(ExecutorCacheTaskLocation("host2", "exec2").toString))
+
+        rdd.collect()
+      }
+    }
+  }
+
+  test("distributed test") {
+    quietly {
+      withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
+        implicit val sqlContet = new SQLContext(sc)
+        val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+        val increment = (store: StateStore, iter: Iterator[String]) => {
+          iter.foreach { s =>
+            store.update(
+              stringToRow(s), oldRow => {
+                val oldValue = oldRow.map(rowToInt).getOrElse(0)
+                intToRow(oldValue + 1)
+              })
+          }
+          store.commit()
+          store.iterator().map(rowsToStringInt)
+        }
+        val opId = 0
+        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
+          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+        // Generate next version of stores
+        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
+          increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+        assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+        // Make sure the previous RDD still has the same data.
+        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+      }
+    }
+  }
+
+  private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = {
+    sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
+  }
+
+  private val increment = (store: StateStore, iter: Iterator[String]) => {
+    iter.foreach { s =>
+      store.update(
+        stringToRow(s), oldRow => {
+          val oldValue = oldRow.map(rowToInt).getOrElse(0)
+          intToRow(oldValue + 1)
+        })
+    }
+    store.commit()
+    store.iterator().map(rowsToStringInt)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8c826880/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
new file mode 100644
index 0000000..22b2f4f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite}
+import org.apache.spark.LocalSparkContext._
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
+
+class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester {
+  type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
+
+  import StateStoreCoordinatorSuite._
+  import StateStoreSuite._
+
+  private val tempDir = Utils.createTempDir().toString
+  private val keySchema = StructType(Seq(StructField("key", StringType, true)))
+  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+  after {
+    StateStore.stop()
+  }
+
+  test("update, remove, commit, and all data iterator") {
+    val provider = newStoreProvider()
+
+    // Verify state before starting a new set of updates
+    assert(provider.latestIterator().isEmpty)
+
+    val store = provider.getStore(0)
+    assert(!store.hasCommitted)
+    intercept[IllegalStateException] {
+      store.iterator()
+    }
+    intercept[IllegalStateException] {
+      store.updates()
+    }
+
+    // Verify state after updating
+    update(store, "a", 1)
+    intercept[IllegalStateException] {
+      store.iterator()
+    }
+    intercept[IllegalStateException] {
+      store.updates()
+    }
+    assert(provider.latestIterator().isEmpty)
+
+    // Make updates, commit and then verify state
+    update(store, "b", 2)
+    update(store, "aa", 3)
+    remove(store, _.startsWith("a"))
+    assert(store.commit() === 1)
+
+    assert(store.hasCommitted)
+    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+    assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2))
+    assert(fileExists(provider, version = 1, isSnapshot = false))
+
+    assert(getDataFromFiles(provider) === Set("b" -> 2))
+
+    // Trying to get newer versions should fail
+    intercept[Exception] {
+      provider.getStore(2)
+    }
+    intercept[Exception] {
+      getDataFromFiles(provider, 2)
+    }
+
+    // New updates to the reloaded store with new version, and does not change old version
+    val reloadedProvider = new HDFSBackedStateStoreProvider(
+      store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
+    val reloadedStore = reloadedProvider.getStore(1)
+    update(reloadedStore, "c", 4)
+    assert(reloadedStore.commit() === 2)
+    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+    assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
+    assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2))
+    assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
+  }
+
+  test("updates iterator with all combos of updates and removes") {
+    val provider = newStoreProvider()
+    var currentVersion: Int = 0
+    def withStore(body: StateStore => Unit): Unit = {
+      val store = provider.getStore(currentVersion)
+      body(store)
+      currentVersion += 1
+    }
+
+    // New data should be seen in updates as value added, even if they had multiple updates
+    withStore { store =>
+      update(store, "a", 1)
+      update(store, "aa", 1)
+      update(store, "aa", 2)
+      store.commit()
+      assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
+      assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
+    }
+
+    // Multiple updates to same key should be collapsed in the updates as a single value update
+    // Keys that have not been updated should not appear in the updates
+    withStore { store =>
+      update(store, "a", 4)
+      update(store, "a", 6)
+      store.commit()
+      assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
+      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
+    }
+
+    // Keys added, updated and finally removed before commit should not appear in updates
+    withStore { store =>
+      update(store, "b", 4)     // Added, finally removed
+      update(store, "bb", 5)    // Added, updated, finally removed
+      update(store, "bb", 6)
+      remove(store, _.startsWith("b"))
+      store.commit()
+      assert(updatesToSet(store.updates()) === Set.empty)
+      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
+    }
+
+    // Removed data should be seen in updates as a key removed
+    // Removed, but re-added data should be seen in updates as a value update
+    withStore { store =>
+      remove(store, _.startsWith("a"))
+      update(store, "a", 10)
+      store.commit()
+      assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
+      assert(rowsToSet(store.iterator()) === Set("a" -> 10))
+    }
+  }
+
+  test("cancel") {
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    update(store, "a", 1)
+    store.commit()
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    // cancelUpdates should not change the data in the files
+    val store1 = provider.getStore(1)
+    update(store1, "b", 1)
+    store1.cancel()
+    assert(getDataFromFiles(provider) === Set("a" -> 1))
+  }
+
+  test("getStore with unexpected versions") {
+    val provider = newStoreProvider()
+
+    intercept[IllegalArgumentException] {
+      provider.getStore(-1)
+    }
+
+    // Prepare some data in the stoer
+    val store = provider.getStore(0)
+    update(store, "a", 1)
+    assert(store.commit() === 1)
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    intercept[IllegalStateException] {
+      provider.getStore(2)
+    }
+
+    // Update store version with some data
+    val store1 = provider.getStore(1)
+    update(store1, "b", 1)
+    assert(store1.commit() === 2)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
+    assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
+
+    // Overwrite the version with other data
+    val store2 = provider.getStore(1)
+    update(store2, "c", 1)
+    assert(store2.commit() === 2)
+    assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
+    assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
+  }
+
+  test("snapshotting") {
+    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+
+    var currentVersion = 0
+    def updateVersionTo(targetVersion: Int): Unit = {
+      for (i <- currentVersion + 1 to targetVersion) {
+        val store = provider.getStore(currentVersion)
+        update(store, "a", i)
+        store.commit()
+        currentVersion += 1
+      }
+      require(currentVersion === targetVersion)
+    }
+
+    updateVersionTo(2)
+    require(getDataFromFiles(provider) === Set("a" -> 2))
+    provider.doMaintenance()               // should not generate snapshot files
+    assert(getDataFromFiles(provider) === Set("a" -> 2))
+
+    for (i <- 1 to currentVersion) {
+      assert(fileExists(provider, i, isSnapshot = false))  // all delta files present
+      assert(!fileExists(provider, i, isSnapshot = true))  // no snapshot files present
+    }
+
+    // After version 6, snapshotting should generate one snapshot file
+    updateVersionTo(6)
+    require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly")
+    provider.doMaintenance()       // should generate snapshot files
+
+    val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
+    assert(snapshotVersion.nonEmpty, "snapshot file not generated")
+    deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
+    assert(
+      getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+      "snapshotting messed up the data of the snapshotted version")
+    assert(
+      getDataFromFiles(provider) === Set("a" -> 6),
+      "snapshotting messed up the data of the final version")
+
+    // After version 20, snapshotting should generate newer snapshot files
+    updateVersionTo(20)
+    require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly")
+    provider.doMaintenance()       // do snapshot
+
+    val latestSnapshotVersion = (0 to 20).filter(version =>
+      fileExists(provider, version, isSnapshot = true)).lastOption
+    assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
+    assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
+
+    deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
+    assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data")
+  }
+
+  test("cleaning") {
+    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+
+    for (i <- 1 to 20) {
+      val store = provider.getStore(i - 1)
+      update(store, "a", i)
+      store.commit()
+      provider.doMaintenance() // do cleanup
+    }
+    require(
+      rowsToSet(provider.latestIterator()) === Set("a" -> 20),
+      "store not updated correctly")
+
+    assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted
+
+    // last couple of versions should be retrievable
+    assert(getDataFromFiles(provider, 20) === Set("a" -> 20))
+    assert(getDataFromFiles(provider, 19) === Set("a" -> 19))
+  }
+
+
+  test("corrupted file handling") {
+    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    for (i <- 1 to 6) {
+      val store = provider.getStore(i - 1)
+      update(store, "a", i)
+      store.commit()
+      provider.doMaintenance() // do cleanup
+    }
+    val snapshotVersion = (0 to 10).find( version =>
+      fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found"))
+
+    // Corrupt snapshot file and verify that it throws error
+    assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion))
+    corruptFile(provider, snapshotVersion, isSnapshot = true)
+    intercept[Exception] {
+      getDataFromFiles(provider, snapshotVersion)
+    }
+
+    // Corrupt delta file and verify that it throws error
+    assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
+    corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
+    intercept[Exception] {
+      getDataFromFiles(provider, snapshotVersion - 1)
+    }
+
+    // Delete delta file and verify that it throws error
+    deleteFilesEarlierThanVersion(provider, snapshotVersion)
+    intercept[Exception] {
+      getDataFromFiles(provider, snapshotVersion - 1)
+    }
+  }
+
+  test("StateStore.get") {
+    quietly {
+      val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+      val storeId = StateStoreId(dir, 0, 0)
+      val storeConf = StateStoreConf.empty
+      val hadoopConf = new Configuration()
+
+
+      // Verify that trying to get incorrect versions throw errors
+      intercept[IllegalArgumentException] {
+        StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf)
+      }
+      assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
+
+      intercept[IllegalStateException] {
+        StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      }
+
+      // Increase version of the store
+      val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+      assert(store0.version === 0)
+      update(store0, "a", 1)
+      store0.commit()
+
+      assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
+      assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0)
+
+      // Verify that you can remove the store and still reload and use it
+      StateStore.unload(storeId)
+      assert(!StateStore.isLoaded(storeId))
+
+      val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      assert(StateStore.isLoaded(storeId))
+      update(store1, "a", 2)
+      assert(store1.commit() === 2)
+      assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
+    }
+  }
+
+  test("maintenance") {
+    val conf = new SparkConf()
+      .setMaster("local")
+      .setAppName("test")
+      .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms")
+      .set("spark.rpc.numRetries", "1")
+    val opId = 0
+    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val storeId = StateStoreId(dir, opId, 0)
+    val storeConf = StateStoreConf.empty
+    val hadoopConf = new Configuration()
+    val provider = new HDFSBackedStateStoreProvider(
+      storeId, keySchema, valueSchema, storeConf, hadoopConf)
+
+    quietly {
+      withSpark(new SparkContext(conf)) { sc =>
+        withCoordinatorRef(sc) { coordinatorRef =>
+          for (i <- 1 to 20) {
+            val store = StateStore.get(
+              storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
+            update(store, "a", i)
+            store.commit()
+          }
+          eventually(timeout(10 seconds)) {
+            assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported")
+          }
+
+          // Background maintenance should clean up and generate snapshots
+          eventually(timeout(10 seconds)) {
+            // Earliest delta file should get cleaned up
+            assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
+
+            // Some snapshots should have been generated
+            val snapshotVersions = (0 to 20).filter { version =>
+              fileExists(provider, version, isSnapshot = true)
+            }
+            assert(snapshotVersions.nonEmpty, "no snapshot file found")
+          }
+
+          // If driver decides to deactivate all instances of the store, then this instance
+          // should be unloaded
+          coordinatorRef.deactivateInstances(dir)
+          eventually(timeout(10 seconds)) {
+            assert(!StateStore.isLoaded(storeId))
+          }
+
+          // Reload the store and verify
+          StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf)
+          assert(StateStore.isLoaded(storeId))
+
+          // If some other executor loads the store, then this instance should be unloaded
+          coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec")
+          eventually(timeout(10 seconds)) {
+            assert(!StateStore.isLoaded(storeId))
+          }
+
+          // Reload the store and verify
+          StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf)
+          assert(StateStore.isLoaded(storeId))
+        }
+      }
+
+      // Verify if instance is unloaded if SparkContext is stopped
+      require(SparkEnv.get === null)
+      eventually(timeout(10 seconds)) {
+        assert(!StateStore.isLoaded(storeId))
+      }
+    }
+  }
+
+  def getDataFromFiles(
+      provider: HDFSBackedStateStoreProvider,
+    version: Int = -1): Set[(String, Int)] = {
+    val reloadedProvider = new HDFSBackedStateStoreProvider(
+      provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
+    if (version < 0) {
+      reloadedProvider.latestIterator().map(rowsToStringInt).toSet
+    } else {
+      reloadedProvider.iterator(version).map(rowsToStringInt).toSet
+    }
+  }
+
+  def assertMap(
+      testMapOption: Option[MapType],
+      expectedMap: Map[String, Int]): Unit = {
+    assert(testMapOption.nonEmpty, "no map present")
+    val convertedMap = testMapOption.get.map(rowsToStringInt)
+    assert(convertedMap === expectedMap)
+  }
+
+  def fileExists(
+      provider: HDFSBackedStateStoreProvider,
+      version: Long,
+      isSnapshot: Boolean): Boolean = {
+    val method = PrivateMethod[Path]('baseDir)
+    val basePath = provider invokePrivate method()
+    val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+    val filePath = new File(basePath.toString, fileName)
+    filePath.exists
+  }
+
+  def deleteFilesEarlierThanVersion(provider: HDFSBackedStateStoreProvider, version: Long): Unit = {
+    val method = PrivateMethod[Path]('baseDir)
+    val basePath = provider invokePrivate method()
+    for (version <- 0 until version.toInt) {
+      for (isSnapshot <- Seq(false, true)) {
+        val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+        val filePath = new File(basePath.toString, fileName)
+        if (filePath.exists) filePath.delete()
+      }
+    }
+  }
+
+  def corruptFile(
+    provider: HDFSBackedStateStoreProvider,
+    version: Long,
+    isSnapshot: Boolean): Unit = {
+    val method = PrivateMethod[Path]('baseDir)
+    val basePath = provider invokePrivate method()
+    val fileName = if (isSnapshot) s"$version.snapshot" else s"$version.delta"
+    val filePath = new File(basePath.toString, fileName)
+    filePath.delete()
+    filePath.createNewFile()
+  }
+
+  def storeLoaded(storeId: StateStoreId): Boolean = {
+    val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores)
+    val loadedStores = StateStore invokePrivate method()
+    loadedStores.contains(storeId)
+  }
+
+  def unloadStore(storeId: StateStoreId): Boolean = {
+    val method = PrivateMethod('remove)
+    StateStore invokePrivate method(storeId)
+  }
+
+  def newStoreProvider(
+      opId: Long = Random.nextLong,
+      partition: Int = 0,
+      minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get
+    ): HDFSBackedStateStoreProvider = {
+    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val sqlConf = new SQLConf()
+    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
+    new HDFSBackedStateStoreProvider(
+      StateStoreId(dir, opId, partition),
+      keySchema,
+      valueSchema,
+      new StateStoreConf(sqlConf),
+      new Configuration())
+  }
+
+  def remove(store: StateStore, condition: String => Boolean): Unit = {
+    store.remove(row => condition(rowToString(row)))
+  }
+
+  private def update(store: StateStore, key: String, value: Int): Unit = {
+    store.update(stringToRow(key), _ => intToRow(value))
+  }
+}
+
+private[state] object StateStoreSuite {
+
+  /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */
+  trait TestUpdate
+  case class Added(key: String, value: Int) extends TestUpdate
+  case class Updated(key: String, value: Int) extends TestUpdate
+  case class Removed(key: String) extends TestUpdate
+
+  val strProj = UnsafeProjection.create(Array[DataType](StringType))
+  val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
+
+  def stringToRow(s: String): UnsafeRow = {
+    strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy()
+  }
+
+  def intToRow(i: Int): UnsafeRow = {
+    intProj.apply(new GenericInternalRow(Array[Any](i))).copy()
+  }
+
+  def rowToString(row: UnsafeRow): String = {
+    row.getUTF8String(0).toString
+  }
+
+  def rowToInt(row: UnsafeRow): Int = {
+    row.getInt(0)
+  }
+
+  def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = {
+    (rowToInt(row._1), rowToInt(row._2))
+  }
+
+
+  def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = {
+    (rowToString(row._1), rowToInt(row._2))
+  }
+
+  def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = {
+    iterator.map(rowsToStringInt).toSet
+  }
+
+  def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = {
+    iterator.map { _ match {
+      case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value))
+      case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value))
+      case KeyRemoved(key) => Removed(rowToString(key))
+    }}.toSet
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org