You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/04/02 00:15:23 UTC

[1/2] spark git commit: [SPARK-14255][SQL] Streaming Aggregation

Repository: spark
Updated Branches:
  refs/heads/master 0b7d4966c -> 0fc4aaa71


http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 85db051..6be94eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{CompletionIterator, Utils}
 
 class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
 
@@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   }
 
   test("versioning and immutability") {
-    quietly {
-      withSpark(new SparkContext(sparkConf)) { sc =>
-        implicit val sqlContet = new SQLContext(sc)
-        val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
-        val increment = (store: StateStore, iter: Iterator[String]) => {
-          iter.foreach { s =>
-            store.update(
-              stringToRow(s), oldRow => {
-                val oldValue = oldRow.map(rowToInt).getOrElse(0)
-                intToRow(oldValue + 1)
-              })
-          }
-          store.commit()
-          store.iterator().map(rowsToStringInt)
-        }
-        val opId = 0
-        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
-        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      val sqlContext = new SQLContext(sc)
+      val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+      val opId = 0
+      val rdd1 =
+        makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+            sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+            increment)
+      assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+      // Generate next version of stores
+      val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+      assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+      // Make sure the previous RDD still has the same data.
+      assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+    }
+  }
 
-        // Generate next version of stores
-        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 1, keySchema, valueSchema)
-        assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+  test("recovering from files") {
+    val opId = 0
+    val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+    def makeStoreRDD(
+        sc: SparkContext,
+        seq: Seq[String],
+        storeVersion: Int): RDD[(String, Int)] = {
+      implicit val sqlContext = new SQLContext(sc)
+      makeRDD(sc, Seq("a")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
+    }
 
-        // Make sure the previous RDD still has the same data.
-        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+    // Generate RDDs and state store data
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      for (i <- 1 to 20) {
+        require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
       }
     }
+
+    // With a new context, try using the earlier state store data
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+    }
   }
 
-  test("recovering from files") {
-    quietly {
-      val opId = 0
+  test("usage with iterators - only gets and only puts") {
+    withSpark(new SparkContext(sparkConf)) { sc =>
+      implicit val sqlContext = new SQLContext(sc)
       val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+      val opId = 0
 
-      def makeStoreRDD(
-          sc: SparkContext,
-          seq: Seq[String],
-          storeVersion: Int): RDD[(String, Int)] = {
-        implicit val sqlContext = new SQLContext(sc)
-        makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion, keySchema, valueSchema)
+      // Returns an iterator of the incremented value made into the store
+      def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
+        val resIterator = iter.map { s =>
+          val key = stringToRow(s)
+          val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+          val newValue = oldValue + 1
+          store.put(key, intToRow(newValue))
+          (s, newValue)
+        }
+        CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, {
+          store.commit()
+        })
       }
 
-      // Generate RDDs and state store data
-      withSpark(new SparkContext(sparkConf)) { sc =>
-        for (i <- 1 to 20) {
-          require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+      def iteratorOfGets(
+          store: StateStore,
+          iter: Iterator[String]): Iterator[(String, Option[Int])] = {
+        iter.map { s =>
+          val key = stringToRow(s)
+          val value = store.get(key).map(rowToInt)
+          (s, value)
         }
       }
 
-      // With a new context, try using the earlier state store data
-      withSpark(new SparkContext(sparkConf)) { sc =>
-        assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
-      }
+      val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+      assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
+
+      val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
+      assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
+
+      val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+        sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
+      assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
     }
   }
 
@@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
           coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
-        val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+          sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   test("distributed test") {
     quietly {
       withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
-        implicit val sqlContet = new SQLContext(sc)
+        implicit val sqlContext = new SQLContext(sc)
         val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
-        val increment = (store: StateStore, iter: Iterator[String]) => {
-          iter.foreach { s =>
-            store.update(
-              stringToRow(s), oldRow => {
-                val oldValue = oldRow.map(rowToInt).getOrElse(0)
-                intToRow(oldValue + 1)
-              })
-          }
-          store.commit()
-          store.iterator().map(rowsToStringInt)
-        }
         val opId = 0
-        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+        val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+          sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
-        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
-          increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+        val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+          sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
-      store.update(
-        stringToRow(s), oldRow => {
-          val oldValue = oldRow.map(rowToInt).getOrElse(0)
-          intToRow(oldValue + 1)
-        })
+      val key = stringToRow(s)
+      val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+      store.put(key, intToRow(oldValue + 1))
     }
     store.commit()
     store.iterator().map(rowsToStringInt)

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 22b2f4f..0e5936d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     StateStore.stop()
   }
 
-  test("update, remove, commit, and all data iterator") {
+  test("get, put, remove, commit, and all data iterator") {
     val provider = newStoreProvider()
 
     // Verify state before starting a new set of updates
@@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     }
 
     // Verify state after updating
-    update(store, "a", 1)
+    put(store, "a", 1)
     intercept[IllegalStateException] {
       store.iterator()
     }
@@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     assert(provider.latestIterator().isEmpty)
 
     // Make updates, commit and then verify state
-    update(store, "b", 2)
-    update(store, "aa", 3)
+    put(store, "b", 2)
+    put(store, "aa", 3)
     remove(store, _.startsWith("a"))
     assert(store.commit() === 1)
 
@@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     val reloadedProvider = new HDFSBackedStateStoreProvider(
       store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
     val reloadedStore = reloadedProvider.getStore(1)
-    update(reloadedStore, "c", 4)
+    put(reloadedStore, "c", 4)
     assert(reloadedStore.commit() === 2)
     assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
     assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
@@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
   test("updates iterator with all combos of updates and removes") {
     val provider = newStoreProvider()
     var currentVersion: Int = 0
+
     def withStore(body: StateStore => Unit): Unit = {
       val store = provider.getStore(currentVersion)
       body(store)
@@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // New data should be seen in updates as value added, even if they had multiple updates
     withStore { store =>
-      update(store, "a", 1)
-      update(store, "aa", 1)
-      update(store, "aa", 2)
+      put(store, "a", 1)
+      put(store, "aa", 1)
+      put(store, "aa", 2)
       store.commit()
       assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
       assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
@@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     // Multiple updates to same key should be collapsed in the updates as a single value update
     // Keys that have not been updated should not appear in the updates
     withStore { store =>
-      update(store, "a", 4)
-      update(store, "a", 6)
+      put(store, "a", 4)
+      put(store, "a", 6)
       store.commit()
       assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
       assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
@@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Keys added, updated and finally removed before commit should not appear in updates
     withStore { store =>
-      update(store, "b", 4)     // Added, finally removed
-      update(store, "bb", 5)    // Added, updated, finally removed
-      update(store, "bb", 6)
+      put(store, "b", 4)     // Added, finally removed
+      put(store, "bb", 5)    // Added, updated, finally removed
+      put(store, "bb", 6)
       remove(store, _.startsWith("b"))
       store.commit()
       assert(updatesToSet(store.updates()) === Set.empty)
@@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     // Removed, but re-added data should be seen in updates as a value update
     withStore { store =>
       remove(store, _.startsWith("a"))
-      update(store, "a", 10)
+      put(store, "a", 10)
       store.commit()
       assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
       assert(rowsToSet(store.iterator()) === Set("a" -> 10))
@@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
   test("cancel") {
     val provider = newStoreProvider()
     val store = provider.getStore(0)
-    update(store, "a", 1)
+    put(store, "a", 1)
     store.commit()
     assert(rowsToSet(store.iterator()) === Set("a" -> 1))
 
     // cancelUpdates should not change the data in the files
     val store1 = provider.getStore(1)
-    update(store1, "b", 1)
-    store1.cancel()
+    put(store1, "b", 1)
+    store1.abort()
     assert(getDataFromFiles(provider) === Set("a" -> 1))
   }
 
@@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Prepare some data in the stoer
     val store = provider.getStore(0)
-    update(store, "a", 1)
+    put(store, "a", 1)
     assert(store.commit() === 1)
     assert(rowsToSet(store.iterator()) === Set("a" -> 1))
 
@@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Update store version with some data
     val store1 = provider.getStore(1)
-    update(store1, "b", 1)
+    put(store1, "b", 1)
     assert(store1.commit() === 2)
     assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
     assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
 
     // Overwrite the version with other data
     val store2 = provider.getStore(1)
-    update(store2, "c", 1)
+    put(store2, "c", 1)
     assert(store2.commit() === 2)
     assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
     assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
@@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     def updateVersionTo(targetVersion: Int): Unit = {
       for (i <- currentVersion + 1 to targetVersion) {
         val store = provider.getStore(currentVersion)
-        update(store, "a", i)
+        put(store, "a", i)
         store.commit()
         currentVersion += 1
       }
@@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     for (i <- 1 to 20) {
       val store = provider.getStore(i - 1)
-      update(store, "a", i)
+      put(store, "a", i)
       store.commit()
       provider.doMaintenance() // do cleanup
     }
@@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     val provider = newStoreProvider(minDeltasForSnapshot = 5)
     for (i <- 1 to 6) {
       val store = provider.getStore(i - 1)
-      update(store, "a", i)
+      put(store, "a", i)
       store.commit()
       provider.doMaintenance() // do cleanup
     }
@@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
       // Increase version of the store
       val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
       assert(store0.version === 0)
-      update(store0, "a", 1)
+      put(store0, "a", 1)
       store0.commit()
 
       assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
@@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
       val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
-      update(store1, "a", 2)
+      put(store1, "a", 2)
       assert(store1.commit() === 2)
       assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
     }
@@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
           for (i <- 1 to 20) {
             val store = StateStore.get(
               storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
-            update(store, "a", i)
+            put(store, "a", i)
             store.commit()
           }
           eventually(timeout(10 seconds)) {
@@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     store.remove(row => condition(rowToString(row)))
   }
 
-  private def update(store: StateStore, key: String, value: Int): Unit = {
-    store.update(stringToRow(key), _ => intToRow(value))
+  private def put(store: StateStore, key: String, value: Int): Unit = {
+    store.put(stringToRow(key), intToRow(value))
+  }
+
+  private def get(store: StateStore, key: String): Option[Int] = {
+    store.get(stringToRow(key)).map(rowToInt)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
new file mode 100644
index 0000000..b63ce89
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+object FailureSinglton {
+  var firstTime = true
+}
+
+class StreamingAggregationSuite extends StreamTest with SharedSQLContext {
+
+  import testImplicits._
+
+  test("simple count") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*"))
+        .as[(Int, Long)]
+
+    testStream(aggregated)(
+      AddData(inputData, 3),
+      CheckLastBatch((3, 1)),
+      AddData(inputData, 3, 2),
+      CheckLastBatch((3, 2), (2, 1)),
+      StopStream,
+      StartStream,
+      AddData(inputData, 3, 2, 1),
+      CheckLastBatch((3, 3), (2, 2), (1, 1)),
+      // By default we run in new tuple mode.
+      AddData(inputData, 4, 4, 4, 4),
+      CheckLastBatch((4, 4))
+    )
+  }
+
+  test("multiple keys") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value", $"value" + 1)
+        .agg(count("*"))
+        .as[(Int, Int, Long)]
+
+    testStream(aggregated)(
+      AddData(inputData, 1, 2),
+      CheckLastBatch((1, 2, 1), (2, 3, 1)),
+      AddData(inputData, 1, 2),
+      CheckLastBatch((1, 2, 2), (2, 3, 2))
+    )
+  }
+
+  test("multiple aggregations") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*") as 'count)
+        .groupBy($"value" % 2)
+        .agg(sum($"count"))
+        .as[(Int, Long)]
+
+    testStream(aggregated)(
+      AddData(inputData, 1, 2, 3, 4),
+      CheckLastBatch((0, 2), (1, 2)),
+      AddData(inputData, 1, 3, 5),
+      CheckLastBatch((1, 5))
+    )
+  }
+
+  testQuietly("midbatch failure") {
+    val inputData = MemoryStream[Int]
+    FailureSinglton.firstTime = true
+    val aggregated =
+      inputData.toDS()
+          .map { i =>
+            if (i == 4 && FailureSinglton.firstTime) {
+              FailureSinglton.firstTime = false
+              sys.error("injected failure")
+            }
+
+            i
+          }
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+    testStream(aggregated)(
+      StartStream,
+      AddData(inputData, 1, 2, 3, 4),
+      ExpectFailure[SparkException](),
+      StartStream,
+      CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1))
+    )
+  }
+
+  test("typed aggregators") {
+    def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
+      new SumOf(f).toColumn
+
+    val inputData = MemoryStream[(String, Int)]
+    val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2))
+
+    testStream(aggregated)(
+      AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)),
+      CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 2bdb428..ff40c36 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -77,8 +77,9 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
   /**
    * Planner that takes into account Hive-specific strategies.
    */
-  override lazy val planner: SparkPlanner = {
-    new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies {
+  override def planner: SparkPlanner = {
+    new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
+      with HiveStrategies {
       override val hiveContext = ctx
 
       override def strategies: Seq[Strategy] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-14255][SQL] Streaming Aggregation

Posted by ma...@apache.org.
[SPARK-14255][SQL] Streaming Aggregation

This PR adds the ability to perform aggregations inside of a `ContinuousQuery`.  In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`.  Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations.  The resulting physical plan performs the aggregation using the following progression:
   - Partial Aggregation
   - Shuffle
   - Partial Merge (now there is at most 1 tuple per group)
   - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
   - Partial Merge (now there is at most 1 tuple per group)
   - StateStoreSave (saves the tuple for the next batch)
   - Complete (output the current result of the aggregation)

The following refactoring was also performed to allow us to plug into existing code:
 - The get/put implementation is taken from #12013
 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation`
 - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container.  This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`.  Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup.
 - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case.
 - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes.

Author: Michael Armbrust <mi...@databricks.com>

Closes #12048 from marmbrus/statefulAgg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0fc4aaa7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0fc4aaa7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0fc4aaa7

Branch: refs/heads/master
Commit: 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c
Parents: 0b7d496
Author: Michael Armbrust <mi...@databricks.com>
Authored: Fri Apr 1 15:15:16 2016 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Apr 1 15:15:16 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   9 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |   2 +-
 .../spark/sql/catalyst/errors/package.scala     |   7 +-
 .../expressions/aggregate/interfaces.scala      |  37 ++++-
 .../catalyst/expressions/namedExpressions.scala |   2 +-
 .../sql/catalyst/optimizer/Optimizer.scala      |  14 +-
 .../spark/sql/catalyst/planning/patterns.scala  |  73 +++++++++
 .../spark/sql/catalyst/plans/PlanTest.scala     |   3 +
 .../spark/sql/execution/QueryExecution.scala    |  24 ++-
 .../apache/spark/sql/execution/SparkPlan.scala  |   7 +
 .../spark/sql/execution/SparkPlanner.scala      |   4 +-
 .../spark/sql/execution/SparkStrategies.scala   |  92 ++++-------
 .../org/apache/spark/sql/execution/Window.scala |   2 +-
 .../aggregate/TungstenAggregationIterator.scala |   4 +-
 .../spark/sql/execution/aggregate/utils.scala   | 121 ++++++++++++---
 .../streaming/IncrementalExecution.scala        |  72 +++++++++
 .../execution/streaming/StatefulAggregate.scala | 119 +++++++++++++++
 .../execution/streaming/StreamExecution.scala   |  12 +-
 .../spark/sql/execution/streaming/memory.scala  |   4 +-
 .../state/HDFSBackedStateStoreProvider.scala    |  36 +++--
 .../execution/streaming/state/StateStore.scala  |  19 ++-
 .../streaming/state/StateStoreConf.scala        |   4 +-
 .../streaming/state/StateStoreRDD.scala         |  17 +--
 .../sql/execution/streaming/state/package.scala |  21 ++-
 .../apache/spark/sql/execution/subquery.scala   |  11 +-
 .../spark/sql/expressions/Aggregator.scala      |   4 +-
 .../spark/sql/internal/SessionState.scala       |  16 +-
 .../scala/org/apache/spark/sql/StreamTest.scala |  36 +++--
 .../spark/sql/execution/SparkPlanTest.scala     |  10 +-
 .../streaming/state/StateStoreRDDSuite.scala    | 152 +++++++++++--------
 .../streaming/state/StateStoreSuite.scala       |  61 ++++----
 .../streaming/StreamingAggregationSuite.scala   | 132 ++++++++++++++++
 .../spark/sql/hive/HiveSessionState.scala       |   5 +-
 33 files changed, 827 insertions(+), 305 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d82ee3a..05e2b9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -336,6 +336,11 @@ class Analyzer(
                 Last(ifExpr(expr), Literal(true))
               case a: AggregateFunction =>
                 a.withNewChildren(a.children.map(ifExpr))
+            }.transform {
+              // We are duplicating aggregates that are now computing a different value for each
+              // pivot value.
+              // TODO: Don't construct the physical container until after analysis.
+              case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
             }
             if (filteredAggregate.fastEquals(aggregate)) {
               throw new AnalysisException(
@@ -1153,11 +1158,11 @@ class Analyzer(
 
           // Extract Windowed AggregateExpression
           case we @ WindowExpression(
-              AggregateExpression(function, mode, isDistinct),
+              ae @ AggregateExpression(function, _, _, _),
               spec: WindowSpecDefinition) =>
             val newChildren = function.children.map(extractExpr)
             val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
-            val newAgg = AggregateExpression(newFunction, mode, isDistinct)
+            val newAgg = ae.copy(aggregateFunction = newFunction)
             seenWindowAggregates += newAgg
             WindowExpression(newAgg, spec)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1d1e892..4880502 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -76,7 +76,7 @@ trait CheckAnalysis {
           case g: GroupingID =>
             failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
 
-          case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
+          case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
             failAnalysis(s"Distinct window functions are not supported: $w")
 
           case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
index 0d44d1d..0420b4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
 package object errors {
 
   class TreeNodeException[TreeType <: TreeNode[_]](
-      tree: TreeType, msg: String, cause: Throwable)
+      @transient val tree: TreeType,
+      msg: String,
+      cause: Throwable)
     extends Exception(msg, cause) {
 
+    val treeString = tree.toString
+
     // Yes, this is the same as a default parameter, but... those don't seem to work with SBT
     // external project dependencies for some reason.
     def this(tree: TreeType, msg: String) = this(tree, msg, null)
 
     override def getMessage: String = {
-      val treeString = tree.toString
       s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ff3064a..d31ccf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.types._
@@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable {
   override def children: Seq[Expression] = Nil
 }
 
+object AggregateExpression {
+  def apply(
+      aggregateFunction: AggregateFunction,
+      mode: AggregateMode,
+      isDistinct: Boolean): AggregateExpression = {
+    AggregateExpression(
+      aggregateFunction,
+      mode,
+      isDistinct,
+      NamedExpression.newExprId)
+  }
+}
+
 /**
  * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
  * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
@@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable {
 private[sql] case class AggregateExpression(
     aggregateFunction: AggregateFunction,
     mode: AggregateMode,
-    isDistinct: Boolean)
+    isDistinct: Boolean,
+    resultId: ExprId)
   extends Expression
   with Unevaluable {
 
+  lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
+    AttributeReference(
+      aggregateFunction.toString,
+      aggregateFunction.dataType,
+      aggregateFunction.nullable)(exprId = resultId)
+  } else {
+    // This is a bit of a hack.  Really we should not be constructing this container and reasoning
+    // about datatypes / aggregation mode until after we have finished analysis and made it to
+    // planning.
+    UnresolvedAttribute(aggregateFunction.toString)
+  }
+
+  // We compute the same thing regardless of our final result.
+  override lazy val canonicalized: Expression =
+    AggregateExpression(
+      aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+      mode,
+      isDistinct,
+      ExprId(0))
+
   override def children: Seq[Expression] = aggregateFunction :: Nil
   override def dataType: DataType = aggregateFunction.dataType
   override def foldable: Boolean = false

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 262582c..2307122 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -329,7 +329,7 @@ case class PrettyAttribute(
   override def withName(newName: String): Attribute = throw new UnsupportedOperationException
   override def qualifier: Option[String] = throw new UnsupportedOperationException
   override def exprId: ExprId = throw new UnsupportedOperationException
-  override def nullable: Boolean = throw new UnsupportedOperationException
+  override def nullable: Boolean = true
 }
 
 object VirtualColumn {

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948e..326933e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
-      case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
+      case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
         Cast(Literal(0L), e.dataType)
       case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
       case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] {
         Literal.create(null, e.dataType)
       case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
       case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
-      case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
+      case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
         // This rule should be only triggered when isDistinct field is false.
-        AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+        ae.copy(aggregateFunction = Count(Literal(1)))
 
       // For Coalesce, remove null literals.
       case e @ Coalesce(children) =>
@@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
   private val MAX_DOUBLE_DIGITS = 15
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+    case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
       if prec + 10 <= MAX_LONG_DIGITS =>
-      MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
+      MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
 
-    case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+    case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
       if prec + 4 <= MAX_DOUBLE_DIGITS =>
-      val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
+      val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
       Cast(
         Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
         DecimalType(prec + 4, scale + 4))

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c92707..28d2c44 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.IntegerType
@@ -216,3 +217,75 @@ object IntegerIndex {
     case _ => None
   }
 }
+
+/**
+ * An extractor used when planning the physical execution of an aggregation. Compared with a logical
+ * aggregation, the following transformations are performed:
+ *  - Unnamed grouping expressions are named so that they can be referred to across phases of
+ *    aggregation
+ *  - Aggregations that appear multiple times are deduplicated.
+ *  - The compution of the aggregations themselves is separated from the final result. For example,
+ *    the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
+ *    computation that computes `count.resultAttribute + 1`.
+ */
+object PhysicalAggregation {
+  // groupingExpressions, aggregateExpressions, resultExpressions, child
+  type ReturnType =
+    (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
+
+  def unapply(a: Any): Option[ReturnType] = a match {
+    case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+      // A single aggregate expression might appear multiple times in resultExpressions.
+      // In order to avoid evaluating an individual aggregate function multiple times, we'll
+      // build a set of the distinct aggregate expressions and build a function which can
+      // be used to re-write expressions so that they reference the single copy of the
+      // aggregate function which actually gets computed.
+      val aggregateExpressions = resultExpressions.flatMap { expr =>
+        expr.collect {
+          case agg: AggregateExpression => agg
+        }
+      }.distinct
+
+      val namedGroupingExpressions = groupingExpressions.map {
+        case ne: NamedExpression => ne -> ne
+        // If the expression is not a NamedExpressions, we add an alias.
+        // So, when we generate the result of the operator, the Aggregate Operator
+        // can directly get the Seq of attributes representing the grouping expressions.
+        case other =>
+          val withAlias = Alias(other, other.toString)()
+          other -> withAlias
+      }
+      val groupExpressionMap = namedGroupingExpressions.toMap
+
+      // The original `resultExpressions` are a set of expressions which may reference
+      // aggregate expressions, grouping column values, and constants. When aggregate operator
+      // emits output rows, we will use `resultExpressions` to generate an output projection
+      // which takes the grouping columns and final aggregate result buffer as input.
+      // Thus, we must re-write the result expressions so that their attributes match up with
+      // the attributes of the final result projection's input row:
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transformDown {
+          case ae: AggregateExpression =>
+            // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+            // so replace each aggregate expression by its corresponding attribute in the set:
+            ae.resultAttribute
+          case expression =>
+            // Since we're using `namedGroupingAttributes` to extract the grouping key
+            // columns, we need to replace grouping key expressions with their corresponding
+            // attributes. We do not rely on the equality check at here since attributes may
+            // differ cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      Some((
+        namedGroupingExpressions.map(_._2),
+        aggregateExpressions,
+        rewrittenResultExpressions,
+        child))
+
+    case _ => None
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index aa5d433..7191936 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
 import org.apache.spark.sql.catalyst.util._
 
@@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
         AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
       case a: Alias =>
         Alias(a.child, a.name)(exprId = ExprId(0))
+      case ae: AggregateExpression =>
+        ae.copy(resultId = ExprId(0))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 912b84a..4843553 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{AnalysisException, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
 
 /**
  * The primary workflow for executing relational queries using Spark.  Designed to allow easy
@@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
  */
 class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
 
+  // TODO: Move the planner an optimizer into here from SessionState.
+  protected def planner = sqlContext.sessionState.planner
+
   def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
     case e: AnalysisException =>
       val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
@@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
 
   lazy val sparkPlan: SparkPlan = {
     SQLContext.setActive(sqlContext)
-    sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
+    planner.plan(ReturnAnswer(optimizedPlan)).next()
   }
 
   // executedPlan should not be used to initialize any SparkPlan. It should be
   // only used for execution.
-  lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
+  lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
 
   /** Internal version of the RDD. Avoids copies and has no schema */
   lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
 
+  /**
+   * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
+   * row format conversions as needed.
+   */
+  protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
+    preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
+  }
+
+  /** A sequence of rules that will be applied in order to the physical plan before execution. */
+  protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+    PlanSubqueries(sqlContext),
+    EnsureRequirements(sqlContext.conf),
+    CollapseCodegenStages(sqlContext.conf),
+    ReuseExchange(sqlContext.conf))
+
   protected def stringOrError[A](f: => A): String =
     try f.toString catch { case e: Throwable => e.toString }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 010ed7f..b1b3d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan {
   override def producedAttributes: AttributeSet = outputSet
 }
 
+object UnaryNode {
+  def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
+    case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
+    case _ => None
+  }
+}
+
 private[sql] trait UnaryNode extends SparkPlan {
   def child: SparkPlan
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 9da2c74..ac8072f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf
 class SparkPlanner(
     val sparkContext: SparkContext,
     val conf: SQLConf,
-    val experimentalMethods: ExperimentalMethods)
+    val extraStrategies: Seq[Strategy])
   extends SparkStrategies {
 
   def numPartitions: Int = conf.numShufflePartitions
 
   def strategies: Seq[Strategy] =
-    experimentalMethods.extraStrategies ++ (
+      extraStrategies ++ (
       FileSourceStrategy ::
       DataSourceStrategy ::
       DDLStrategy ::

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7a2e2b7..5bcc172 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.Strategy
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -204,28 +203,32 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
   }
 
   /**
+   * Used to plan aggregation queries that are computed incrementally as part of a
+   * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner
+   * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]]
+   */
+  object StatefulAggregationStrategy extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case PhysicalAggregation(
+        namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
+
+        aggregate.Utils.planStreamingAggregation(
+          namedGroupingExpressions,
+          aggregateExpressions,
+          rewrittenResultExpressions,
+          planLater(child))
+
+      case _ => Nil
+    }
+  }
+
+  /**
    * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
    */
   object Aggregation extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
-        // A single aggregate expression might appear multiple times in resultExpressions.
-        // In order to avoid evaluating an individual aggregate function multiple times, we'll
-        // build a set of the distinct aggregate expressions and build a function which can
-        // be used to re-write expressions so that they reference the single copy of the
-        // aggregate function which actually gets computed.
-        val aggregateExpressions = resultExpressions.flatMap { expr =>
-          expr.collect {
-            case agg: AggregateExpression => agg
-          }
-        }.distinct
-        // For those distinct aggregate expressions, we create a map from the
-        // aggregate function to the corresponding attribute of the function.
-        val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
-          val aggregateFunction = agg.aggregateFunction
-          val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
-          (aggregateFunction, agg.isDistinct) -> attribute
-        }.toMap
+      case PhysicalAggregation(
+          groupingExpressions, aggregateExpressions, resultExpressions, child) =>
 
         val (functionsWithDistinct, functionsWithoutDistinct) =
           aggregateExpressions.partition(_.isDistinct)
@@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           // This is a sanity check. We should not reach here when we have multiple distinct
           // column sets. Our MultipleDistinctRewriter should take care this case.
           sys.error("You hit a query analyzer bug. Please report your query to " +
-            "Spark user mailing list.")
-        }
-
-        val namedGroupingExpressions = groupingExpressions.map {
-          case ne: NamedExpression => ne -> ne
-          // If the expression is not a NamedExpressions, we add an alias.
-          // So, when we generate the result of the operator, the Aggregate Operator
-          // can directly get the Seq of attributes representing the grouping expressions.
-          case other =>
-            val withAlias = Alias(other, other.toString)()
-            other -> withAlias
-        }
-        val groupExpressionMap = namedGroupingExpressions.toMap
-
-        // The original `resultExpressions` are a set of expressions which may reference
-        // aggregate expressions, grouping column values, and constants. When aggregate operator
-        // emits output rows, we will use `resultExpressions` to generate an output projection
-        // which takes the grouping columns and final aggregate result buffer as input.
-        // Thus, we must re-write the result expressions so that their attributes match up with
-        // the attributes of the final result projection's input row:
-        val rewrittenResultExpressions = resultExpressions.map { expr =>
-          expr.transformDown {
-            case AggregateExpression(aggregateFunction, _, isDistinct) =>
-              // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
-              // so replace each aggregate expression by its corresponding attribute in the set:
-              aggregateFunctionToAttribute(aggregateFunction, isDistinct)
-            case expression =>
-              // Since we're using `namedGroupingAttributes` to extract the grouping key
-              // columns, we need to replace grouping key expressions with their corresponding
-              // attributes. We do not rely on the equality check at here since attributes may
-              // differ cosmetically. Instead, we use semanticEquals.
-              groupExpressionMap.collectFirst {
-                case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-              }.getOrElse(expression)
-          }.asInstanceOf[NamedExpression]
+              "Spark user mailing list.")
         }
 
         val aggregateOperator =
@@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
                 "aggregate functions which don't support partial aggregation.")
             } else {
               aggregate.Utils.planAggregateWithoutPartial(
-                namedGroupingExpressions.map(_._2),
+                groupingExpressions,
                 aggregateExpressions,
-                aggregateFunctionToAttribute,
-                rewrittenResultExpressions,
+                resultExpressions,
                 planLater(child))
             }
           } else if (functionsWithDistinct.isEmpty) {
             aggregate.Utils.planAggregateWithoutDistinct(
-              namedGroupingExpressions.map(_._2),
+              groupingExpressions,
               aggregateExpressions,
-              aggregateFunctionToAttribute,
-              rewrittenResultExpressions,
+              resultExpressions,
               planLater(child))
           } else {
             aggregate.Utils.planAggregateWithOneDistinct(
-              namedGroupingExpressions.map(_._2),
+              groupingExpressions,
               functionsWithDistinct,
               functionsWithoutDistinct,
-              aggregateFunctionToAttribute,
-              rewrittenResultExpressions,
+              resultExpressions,
               planLater(child))
           }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 270c09a..7acf020 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -177,7 +177,7 @@ case class Window(
         case e @ WindowExpression(function, spec) =>
           val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
           function match {
-            case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f)
+            case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
             case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
             case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
             case f => sys.error(s"Unsupported window function: $f")

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 213bca9..ce504e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -242,9 +242,9 @@ class TungstenAggregationIterator(
     // Basically the value of the KVIterator returned by externalSorter
     // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it.
     val newExpressions = aggregateExpressions.map {
-      case agg @ AggregateExpression(_, Partial, _) =>
+      case agg @ AggregateExpression(_, Partial, _, _) =>
         agg.copy(mode = PartialMerge)
-      case agg @ AggregateExpression(_, Complete, _) =>
+      case agg @ AggregateExpression(_, Complete, _, _) =>
         agg.copy(mode = Final)
       case other => other
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 1e113cc..4682949 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave}
 
 /**
  * Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -29,15 +30,11 @@ object Utils {
   def planAggregateWithoutPartial(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
-      aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
     val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
-    val completeAggregateAttributes = completeAggregateExpressions.map {
-      expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-    }
-
+    val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
     SortBasedAggregate(
       requiredChildDistributionExpressions = Some(groupingExpressions),
       groupingExpressions = groupingExpressions,
@@ -83,7 +80,6 @@ object Utils {
   def planAggregateWithoutDistinct(
       groupingExpressions: Seq[NamedExpression],
       aggregateExpressions: Seq[AggregateExpression],
-      aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
     // Check if we can use TungstenAggregate.
@@ -111,9 +107,7 @@ object Utils {
     val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
     // The attributes of the final aggregation buffer, which is presented as input to the result
     // projection:
-    val finalAggregateAttributes = finalAggregateExpressions.map {
-      expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-    }
+    val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
 
     val finalAggregate = createAggregate(
         requiredChildDistributionExpressions = Some(groupingAttributes),
@@ -131,7 +125,6 @@ object Utils {
       groupingExpressions: Seq[NamedExpression],
       functionsWithDistinct: Seq[AggregateExpression],
       functionsWithoutDistinct: Seq[AggregateExpression],
-      aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
@@ -151,9 +144,7 @@ object Utils {
     // 1. Create an Aggregate Operator for partial aggregations.
     val partialAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
-      val aggregateAttributes = aggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-      }
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
       // We will group by the original grouping expression, plus an additional expression for the
       // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
       // expressions will be [key, value].
@@ -169,9 +160,7 @@ object Utils {
     // 2. Create an Aggregate Operator for partial merge aggregations.
     val partialMergeAggregate: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
-      val aggregateAttributes = aggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-      }
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
       createAggregate(
         requiredChildDistributionExpressions =
           Some(groupingAttributes ++ distinctAttributes),
@@ -190,7 +179,7 @@ object Utils {
       // Children of an AggregateFunction with DISTINCT keyword has already
       // been evaluated. At here, we need to replace original children
       // to AttributeReferences.
-      case agg @ AggregateExpression(aggregateFunction, mode, true) =>
+      case agg @ AggregateExpression(aggregateFunction, mode, true, _) =>
         aggregateFunction.transformDown(distinctColumnAttributeLookup)
           .asInstanceOf[AggregateFunction]
     }
@@ -199,9 +188,7 @@ object Utils {
       val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
       // The attributes of the final aggregation buffer, which is presented as input to the result
       // projection:
-      val mergeAggregateAttributes = mergeAggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-      }
+      val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute)
       val (distinctAggregateExpressions, distinctAggregateAttributes) =
         rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
           // We rewrite the aggregate function to a non-distinct aggregation because
@@ -211,7 +198,7 @@ object Utils {
           val expr = AggregateExpression(func, Partial, isDistinct = true)
           // Use original AggregationFunction to lookup attributes, which is used to build
           // aggregateFunctionToAttribute
-          val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+          val attr = functionsWithDistinct(i).resultAttribute
           (expr, attr)
       }.unzip
 
@@ -232,9 +219,7 @@ object Utils {
       val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
       // The attributes of the final aggregation buffer, which is presented as input to the result
       // projection:
-      val finalAggregateAttributes = finalAggregateExpressions.map {
-        expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
-      }
+      val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
 
       val (distinctAggregateExpressions, distinctAggregateAttributes) =
         rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
@@ -245,7 +230,7 @@ object Utils {
           val expr = AggregateExpression(func, Final, isDistinct = true)
           // Use original AggregationFunction to lookup attributes, which is used to build
           // aggregateFunctionToAttribute
-          val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+          val attr = functionsWithDistinct(i).resultAttribute
           (expr, attr)
       }.unzip
 
@@ -261,4 +246,90 @@ object Utils {
 
     finalAndCompleteAggregate :: Nil
   }
+
+  /**
+   * Plans a streaming aggregation using the following progression:
+   *  - Partial Aggregation
+   *  - Shuffle
+   *  - Partial Merge (now there is at most 1 tuple per group)
+   *  - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
+   *  - PartialMerge (now there is at most 1 tuple per group)
+   *  - StateStoreSave (saves the tuple for the next batch)
+   *  - Complete (output the current result of the aggregation)
+   */
+  def planStreamingAggregation(
+      groupingExpressions: Seq[NamedExpression],
+      functionsWithoutDistinct: Seq[AggregateExpression],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
+
+    val partialAggregate: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      // We will group by the original grouping expression, plus an additional expression for the
+      // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+      // expressions will be [key, value].
+      createAggregate(
+        groupingExpressions = groupingExpressions,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        resultExpressions = groupingAttributes ++
+            aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = child)
+    }
+
+    val partialMerged1: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      createAggregate(
+        requiredChildDistributionExpressions =
+            Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        initialInputBufferOffset = groupingAttributes.length,
+        resultExpressions = groupingAttributes ++
+            aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = partialAggregate)
+    }
+
+    val restored = StateStoreRestore(groupingAttributes, None, partialMerged1)
+
+    val partialMerged2: SparkPlan = {
+      val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+      val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+      createAggregate(
+        requiredChildDistributionExpressions =
+            Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = aggregateExpressions,
+        aggregateAttributes = aggregateAttributes,
+        initialInputBufferOffset = groupingAttributes.length,
+        resultExpressions = groupingAttributes ++
+            aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+        child = restored)
+    }
+
+    val saved = StateStoreSave(groupingAttributes, None, partialMerged2)
+
+    val finalAndCompleteAggregate: SparkPlan = {
+      val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
+      // The attributes of the final aggregation buffer, which is presented as input to the result
+      // projection:
+      val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
+
+      createAggregate(
+        requiredChildDistributionExpressions = Some(groupingAttributes),
+        groupingExpressions = groupingAttributes,
+        aggregateExpressions = finalAggregateExpressions,
+        aggregateAttributes = finalAggregateAttributes,
+        initialInputBufferOffset = groupingAttributes.length,
+        resultExpressions = resultExpressions,
+        child = saved)
+    }
+
+    finalAndCompleteAggregate :: Nil
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
new file mode 100644
index 0000000..aaced49
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -0,0 +1,72 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode}
+
+/**
+ * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]]
+ * plan incrementally. Possibly preserving state in between each execution.
+ */
+class IncrementalExecution(
+    ctx: SQLContext,
+    logicalPlan: LogicalPlan,
+    checkpointLocation: String,
+    currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) {
+
+  // TODO: make this always part of planning.
+  val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil
+
+  // Modified planner with stateful operations.
+  override def planner: SparkPlanner =
+    new SparkPlanner(
+      sqlContext.sparkContext,
+      sqlContext.conf,
+      stateStrategy)
+
+  /**
+   * Records the current id for a given stateful operator in the query plan as the `state`
+   * preperation walks the query plan.
+   */
+  private var operatorId = 0
+
+  /** Locates save/restore pairs surrounding aggregation. */
+  val state = new Rule[SparkPlan] {
+    override def apply(plan: SparkPlan): SparkPlan = plan transform {
+      case StateStoreSave(keys, None,
+             UnaryNode(agg,
+               StateStoreRestore(keys2, None, child))) =>
+        val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1)
+        operatorId += 1
+
+        StateStoreSave(
+          keys,
+          Some(stateId),
+          agg.withNewChildren(
+            StateStoreRestore(
+              keys,
+              Some(stateId),
+              child) :: Nil))
+    }
+  }
+
+  override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
new file mode 100644
index 0000000..5957747
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.execution.SparkPlan
+
+/** Used to identify the state store for a given operator. */
+case class OperatorStateId(
+    checkpointLocation: String,
+    operatorId: Long,
+    batchId: Long)
+
+/**
+ * An operator that saves or restores state from the [[StateStore]].  The [[OperatorStateId]] should
+ * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ */
+trait StatefulOperator extends SparkPlan {
+  def stateId: Option[OperatorStateId]
+
+  protected def getStateId: OperatorStateId = attachTree(this) {
+    stateId.getOrElse {
+      throw new IllegalStateException("State location not present for execution")
+    }
+  }
+}
+
+/**
+ * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
+ * to the stream (in addition to the input tuple) if present.
+ */
+case class StateStoreRestore(
+    keyExpressions: Seq[Attribute],
+    stateId: Option[OperatorStateId],
+    child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsWithStateStore(
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      keyExpressions.toStructType,
+      child.output.toStructType,
+      new StateStoreConf(sqlContext.conf),
+      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+        val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+        iter.flatMap { row =>
+          val key = getKey(row)
+          val savedState = store.get(key)
+          row +: savedState.toSeq
+        }
+    }
+  }
+  override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
+ */
+case class StateStoreSave(
+    keyExpressions: Seq[Attribute],
+    stateId: Option[OperatorStateId],
+    child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsWithStateStore(
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      keyExpressions.toStructType,
+      child.output.toStructType,
+      new StateStoreConf(sqlContext.conf),
+      Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+        new Iterator[InternalRow] {
+          private[this] val baseIterator = iter
+          private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+
+          override def hasNext: Boolean = {
+            if (!baseIterator.hasNext) {
+              store.commit()
+              false
+            } else {
+              true
+            }
+          }
+
+          override def next(): InternalRow = {
+            val row = baseIterator.next().asInstanceOf[UnsafeRow]
+            val key = getKey(row)
+            store.put(key.copy(), row.copy())
+            row
+          }
+        }
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index c4e410d..511e30c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.util._
@@ -272,6 +273,8 @@ class StreamExecution(
   private def runBatch(): Unit = {
     val startTime = System.nanoTime()
 
+    // TODO: Move this to IncrementalExecution.
+
     // Request unprocessed data from all sources.
     val newData = availableOffsets.flatMap {
       case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) =>
@@ -305,13 +308,14 @@ class StreamExecution(
     }
 
     val optimizerStart = System.nanoTime()
-
-    lastExecution = new QueryExecution(sqlContext, newPlan)
-    val executedPlan = lastExecution.executedPlan
+    lastExecution =
+        new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId)
+    lastExecution.executedPlan
     val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000
     logDebug(s"Optimized batch in ${optimizerTime}ms")
 
-    val nextBatch = Dataset.ofRows(sqlContext, newPlan)
+    val nextBatch =
+      new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema))
     sink.addBatch(currentBatchId - 1, nextBatch)
 
     awaitBatchLock.synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 0f91e59..7d97f81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
  * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
  * tests and does not provide durability.
  */
-class MemorySink(schema: StructType) extends Sink with Logging {
+class MemorySink(val schema: StructType) extends Sink with Logging {
   /** An order list of batches that have been written to this [[Sink]]. */
   private val batches = new ArrayBuffer[Array[Row]]()
 
@@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with Logging {
     batches.flatten
   }
 
+  def lastBatch: Seq[Row] = batches.last
+
   def toDebugString: String = synchronized {
     batches.zipWithIndex.map { case (b, i) =>
       val dataStr = try b.mkString(" ") catch {

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index ee015ba..998eb82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider(
     trait STATE
     case object UPDATING extends STATE
     case object COMMITTED extends STATE
-    case object CANCELLED extends STATE
+    case object ABORTED extends STATE
 
     private val newVersion = version + 1
     private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
@@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider(
 
     override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
 
-    /**
-     * Update the value of a key using the value generated by the update function.
-     * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
-     *       versions of the store data.
-     */
-    override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = {
-      verify(state == UPDATING, "Cannot update after already committed or cancelled")
-      val oldValueOption = Option(mapToUpdate.get(key))
-      val value = updateFunc(oldValueOption)
+    override def get(key: UnsafeRow): Option[UnsafeRow] = {
+      Option(mapToUpdate.get(key))
+    }
+
+    override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot remove after already committed or cancelled")
+
+      val isNewKey = !mapToUpdate.containsKey(key)
       mapToUpdate.put(key, value)
 
       Option(allUpdates.get(key)) match {
@@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider(
         case None =>
           // There was no prior update, so mark this as added or updated according to its presence
           // in previous version.
-          val update =
-            if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value)
+          val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value)
           allUpdates.put(key, update)
       }
       writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
@@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider(
 
     /** Commit all the updates that have been made to the store, and return the new version. */
     override def commit(): Long = {
-      verify(state == UPDATING, "Cannot commit again after already committed or cancelled")
+      verify(state == UPDATING, "Cannot commit after already committed or cancelled")
 
       try {
         finalizeDeltaFile(tempDeltaFileStream)
@@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider(
     }
 
     /** Cancel all the updates made on this store. This store will not be usable any more. */
-    override def cancel(): Unit = {
-      state = CANCELLED
+    override def abort(): Unit = {
+      state = ABORTED
       if (tempDeltaFileStream != null) {
         tempDeltaFileStream.close()
       }
@@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider(
     }
 
     /**
-     * Get an iterator of all the store data. This can be called only after committing the
-     * updates.
+     * Get an iterator of all the store data.
+     * This can be called only after committing all the updates made in the current thread.
      */
     override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
       verify(state == COMMITTED, "Cannot get iterator of store data before comitting")
@@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider(
 
     /**
      * Get an iterator of all the updates made to the store in the current version.
-     * This can be called only after committing the updates.
+     * This can be called only after committing all the updates made in the current thread.
      */
     override def updates(): Iterator[StoreUpdate] = {
       verify(state == COMMITTED, "Cannot get iterator of updates before committing")
@@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider(
     /**
      * Whether all updates have been committed
      */
-    override def hasCommitted: Boolean = {
+    override private[state] def hasCommitted: Boolean = {
       state == COMMITTED
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index ca5c864..d60e618 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -47,12 +47,11 @@ trait StateStore {
   /** Version of the data in this store before committing updates. */
   def version: Long
 
-  /**
-   * Update the value of a key using the value generated by the update function.
-   * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
-   *       versions of the store data.
-   */
-  def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
+  /** Get the current value of a key. */
+  def get(key: UnsafeRow): Option[UnsafeRow]
+
+  /** Put a new value for a key. */
+  def put(key: UnsafeRow, value: UnsafeRow)
 
   /**
    * Remove keys that match the following condition.
@@ -65,24 +64,24 @@ trait StateStore {
   def commit(): Long
 
   /** Cancel all the updates that have been made to the store. */
-  def cancel(): Unit
+  def abort(): Unit
 
   /**
    * Iterator of store data after a set of updates have been committed.
-   * This can be called only after commitUpdates() has been called in the current thread.
+   * This can be called only after committing all the updates made in the current thread.
    */
   def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
 
   /**
    * Iterator of the updates that have been committed.
-   * This can be called only after commitUpdates() has been called in the current thread.
+   * This can be called only after committing all the updates made in the current thread.
    */
   def updates(): Iterator[StoreUpdate]
 
   /**
    * Whether all updates have been committed
    */
-  def hasCommitted: Boolean
+  private[state] def hasCommitted: Boolean
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index cca22a0..f0f1f3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state
 import org.apache.spark.sql.internal.SQLConf
 
 /** A class that contains configuration parameters for [[StateStore]]s. */
-private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
 
   def this() = this(new SQLConf)
 
@@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend
   val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
 }
 
-private[state] object StateStoreConf {
+private[streaming] object StateStoreConf {
   val empty = new StateStoreConf()
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 3318660..df3d82c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
     var store: StateStore = null
-
-    Utils.tryWithSafeFinally {
-      val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
-      store = StateStore.get(
-        storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
-      val inputIter = dataRDD.iterator(partition, ctxt)
-      val outputIter = storeUpdateFunction(store, inputIter)
-      assert(store.hasCommitted)
-      outputIter
-    } {
-      if (store != null) store.cancel()
-    }
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    store = StateStore.get(
+      storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+    val inputIter = dataRDD.iterator(partition, ctxt)
+    storeUpdateFunction(store, inputIter)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index b249e37..9b6d091 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -28,37 +28,36 @@ package object state {
   implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
 
     /** Map each partition of a RDD along with data in a [[StateStore]]. */
-    def mapPartitionWithStateStore[U: ClassTag](
-        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+    def mapPartitionsWithStateStore[U: ClassTag](
+        sqlContext: SQLContext,
         checkpointLocation: String,
         operatorId: Long,
         storeVersion: Long,
         keySchema: StructType,
-        valueSchema: StructType
-      )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
+        valueSchema: StructType)(
+        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
-      mapPartitionWithStateStore(
-        storeUpdateFunction,
+      mapPartitionsWithStateStore(
         checkpointLocation,
         operatorId,
         storeVersion,
         keySchema,
         valueSchema,
         new StateStoreConf(sqlContext.conf),
-        Some(sqlContext.streams.stateStoreCoordinator))
+        Some(sqlContext.streams.stateStoreCoordinator))(
+        storeUpdateFunction)
     }
 
     /** Map each partition of a RDD along with data in a [[StateStore]]. */
-    private[state] def mapPartitionWithStateStore[U: ClassTag](
-        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
         checkpointLocation: String,
         operatorId: Long,
         storeVersion: Long,
         keySchema: StructType,
         valueSchema: StructType,
         storeConf: StateStoreConf,
-        storeCoordinator: Option[StateStoreCoordinatorRef]
-      ): StateStoreRDD[T, U] = {
+        storeCoordinator: Option[StateStoreCoordinatorRef])(
+        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
       val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
       new StateStoreRDD(
         dataRDD,

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 0d58070..4b3091b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.{expressions, InternalRow}
 import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.internal.SessionState
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -60,14 +60,13 @@ case class ScalarSubquery(
 }
 
 /**
- * Convert the subquery from logical plan into executed plan.
+ * Plans scalar subqueries from that are present in the given [[SparkPlan]].
  */
-case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
+case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
   def apply(plan: SparkPlan): SparkPlan = {
     plan.transformAllExpressions {
       case subquery: expressions.ScalarSubquery =>
-        val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
-        val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
+        val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan
         ScalarSubquery(executedPlan, subquery.exprId)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 844f305..9cb356f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable {
       implicit bEncoder: Encoder[B],
       cEncoder: Encoder[O]): TypedColumn[I, O] = {
     val expr =
-      new AggregateExpression(
+      AggregateExpression(
         TypedAggregateExpression(this),
         Complete,
-        false)
+        isDistinct = false)
 
     new TypedColumn[I, O](expr, encoderFor[O])
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index f7fdfac..cd3d254 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) {
   /**
    * Planner that converts optimized logical plans to physical plans.
    */
-  lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods)
-
-  /**
-   * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
-   * row format conversions as needed.
-   */
-  lazy val prepareForExecution = new RuleExecutor[SparkPlan] {
-    override val batches: Seq[Batch] = Seq(
-      Batch("Subquery", Once, PlanSubqueries(SessionState.this)),
-      Batch("Add exchange", Once, EnsureRequirements(conf)),
-      Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)),
-      Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf))
-    )
-  }
+  def planner: SparkPlanner =
+    new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
 
   /**
    * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index b5be7ef..550c3c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts {
     def apply[A : Encoder](data: A*): CheckAnswerRows = {
       val encoder = encoderFor[A]
       val toExternalRow = RowEncoder(encoder.schema)
-      CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
+      CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false)
     }
 
-    def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
+    def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false)
   }
 
-  case class CheckAnswerRows(expectedAnswer: Seq[Row])
+  /**
+   * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`.
+   * This operation automatically blocks until all added data has been processed.
+   */
+  object CheckLastBatch {
+    def apply[A : Encoder](data: A*): CheckAnswerRows = {
+      val encoder = encoderFor[A]
+      val toExternalRow = RowEncoder(encoder.schema)
+      CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true)
+    }
+
+    def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true)
+  }
+
+  case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean)
       extends StreamAction with StreamMustBeRunning {
-    override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}"
+    override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}"
+    private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
   }
 
   /** Stops the stream.  It must currently be running. */
@@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts {
          """.stripMargin
 
     def verify(condition: => Boolean, message: String): Unit = {
-      try {
-        Assertions.assert(condition)
-      } catch {
-        case NonFatal(e) =>
-          failTest(message, e)
+      if (!condition) {
+        failTest(message)
       }
     }
 
@@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts {
           case a: AddData =>
             awaiting.put(a.source, a.addData())
 
-          case CheckAnswerRows(expectedAnswer) =>
+          case CheckAnswerRows(expectedAnswer, lastOnly) =>
             verify(currentStream != null, "stream not running")
 
             // Block until all data added has been processed
@@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts {
               }
             }
 
-            val allData = try sink.allData catch {
+            val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch {
               case e: Exception =>
                 failTest("Exception while getting data from sink", e)
             }
 
-            QueryTest.sameRows(expectedAnswer, allData).foreach {
+            QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach {
               error => failTest(error)
             }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/0fc4aaa7/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index ed0d3f5..3831874 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -231,10 +231,8 @@ object SparkPlanTest {
   }
 
   private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
-    // A very simple resolver to make writing tests easier. In contrast to the real resolver
-    // this is always case sensitive and does not try to handle scoping or complex type resolution.
-    val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute(
-      outputPlan transform {
+    val execution = new QueryExecution(sqlContext, null) {
+      override lazy val sparkPlan: SparkPlan = outputPlan transform {
         case plan: SparkPlan =>
           val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
           plan transformExpressions {
@@ -243,8 +241,8 @@ object SparkPlanTest {
                 sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
           }
       }
-    )
-    resolvedPlan.executeCollectPublic().toSeq
+    }
+    execution.executedPlan.executeCollectPublic().toSeq
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org