You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2015/12/10 05:47:20 UTC

[1/2] spark git commit: [SPARK-12244][SPARK-12245][STREAMING] Rename trackStateByKey to mapWithState and change tracking function signature

Repository: spark
Updated Branches:
  refs/heads/master 2166c2a75 -> bd2cd4f53


http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
----------------------------------------------------------------------
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
deleted file mode 100644
index 89d0bb7..0000000
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Set;
-
-import scala.Tuple2;
-
-import com.google.common.base.Optional;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.util.ManualClock;
-import org.junit.Assert;
-import org.junit.Test;
-
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.Function4;
-import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
-
-public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable {
-
-  /**
-   * This test is only for testing the APIs. It's not necessary to run it.
-   */
-  public void testAPI() {
-    JavaPairRDD<String, Boolean> initialRDD = null;
-    JavaPairDStream<String, Integer> wordsDstream = null;
-
-    final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
-        trackStateFunc =
-        new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
-
-          @Override
-          public Optional<Double> call(
-              Time time, String word, Optional<Integer> one, State<Boolean> state) {
-            // Use all State's methods here
-            state.exists();
-            state.get();
-            state.isTimingOut();
-            state.remove();
-            state.update(true);
-            return Optional.of(2.0);
-          }
-        };
-
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
-        wordsDstream.trackStateByKey(
-            StateSpec.function(trackStateFunc)
-                .initialState(initialRDD)
-                .numPartitions(10)
-                .partitioner(new HashPartitioner(10))
-                .timeout(Durations.seconds(10)));
-
-    JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
-
-    final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 =
-        new Function2<Optional<Integer>, State<Boolean>, Double>() {
-
-          @Override
-          public Double call(Optional<Integer> one, State<Boolean> state) {
-            // Use all State's methods here
-            state.exists();
-            state.get();
-            state.isTimingOut();
-            state.remove();
-            state.update(true);
-            return 2.0;
-          }
-        };
-
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
-        wordsDstream.trackStateByKey(
-            StateSpec.<String, Integer, Boolean, Double>function(trackStateFunc2)
-                .initialState(initialRDD)
-                .numPartitions(10)
-                .partitioner(new HashPartitioner(10))
-                .timeout(Durations.seconds(10)));
-
-    JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
-  }
-
-  @Test
-  public void testBasicFunction() {
-    List<List<String>> inputData = Arrays.asList(
-        Collections.<String>emptyList(),
-        Arrays.asList("a"),
-        Arrays.asList("a", "b"),
-        Arrays.asList("a", "b", "c"),
-        Arrays.asList("a", "b"),
-        Arrays.asList("a"),
-        Collections.<String>emptyList()
-    );
-
-    List<Set<Integer>> outputData = Arrays.asList(
-        Collections.<Integer>emptySet(),
-        Sets.newHashSet(1),
-        Sets.newHashSet(2, 1),
-        Sets.newHashSet(3, 2, 1),
-        Sets.newHashSet(4, 3),
-        Sets.newHashSet(5),
-        Collections.<Integer>emptySet()
-    );
-
-    List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
-        Collections.<Tuple2<String, Integer>>emptySet(),
-        Sets.newHashSet(new Tuple2<String, Integer>("a", 1)),
-        Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 3),
-            new Tuple2<String, Integer>("b", 2),
-            new Tuple2<String, Integer>("c", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 4),
-            new Tuple2<String, Integer>("b", 3),
-            new Tuple2<String, Integer>("c", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 5),
-            new Tuple2<String, Integer>("b", 3),
-            new Tuple2<String, Integer>("c", 1)),
-        Sets.newHashSet(
-            new Tuple2<String, Integer>("a", 5),
-            new Tuple2<String, Integer>("b", 3),
-            new Tuple2<String, Integer>("c", 1))
-    );
-
-    Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
-        new Function2<Optional<Integer>, State<Integer>, Integer>() {
-
-          @Override
-          public Integer call(Optional<Integer> value, State<Integer> state) throws Exception {
-            int sum = value.or(0) + (state.exists() ? state.get() : 0);
-            state.update(sum);
-            return sum;
-          }
-        };
-    testOperation(
-        inputData,
-        StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
-        outputData,
-        stateData);
-  }
-
-  private <K, S, T> void testOperation(
-      List<List<K>> input,
-      StateSpec<K, Integer, S, T> trackStateSpec,
-      List<Set<T>> expectedOutputs,
-      List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
-    int numBatches = expectedOutputs.size();
-    JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
-    JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
-        JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
-          @Override
-          public Tuple2<K, Integer> call(K x) throws Exception {
-            return new Tuple2<K, Integer>(x, 1);
-          }
-        })).trackStateByKey(trackStateSpec);
-
-    final List<Set<T>> collectedOutputs =
-        Collections.synchronizedList(Lists.<Set<T>>newArrayList());
-    trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
-      @Override
-      public Void call(JavaRDD<T> rdd) throws Exception {
-        collectedOutputs.add(Sets.newHashSet(rdd.collect()));
-        return null;
-      }
-    });
-    final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
-        Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
-    trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
-      @Override
-      public Void call(JavaPairRDD<K, S> rdd) throws Exception {
-        collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
-        return null;
-      }
-    });
-    BatchCounter batchCounter = new BatchCounter(ssc.ssc());
-    ssc.start();
-    ((ManualClock) ssc.ssc().scheduler().clock())
-        .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
-    batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
-
-    Assert.assertEquals(expectedOutputs, collectedOutputs);
-    Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
new file mode 100644
index 0000000..4b08085
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
@@ -0,0 +1,581 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.reflect.ClassTag
+
+import org.scalatest.PrivateMethodTester._
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl}
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class MapWithStateSuite extends SparkFunSuite
+  with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
+
+  private var sc: SparkContext = null
+  protected var checkpointDir: File = null
+  protected val batchDuration = Seconds(1)
+
+  before {
+    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+    checkpointDir = Utils.createTempDir("checkpoint")
+  }
+
+  after {
+    if (checkpointDir != null) {
+      Utils.deleteRecursively(checkpointDir)
+    }
+    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+  }
+
+  override def beforeAll(): Unit = {
+    val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite")
+    conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+    sc = new SparkContext(conf)
+  }
+
+  override def afterAll(): Unit = {
+    if (sc != null) {
+      sc.stop()
+    }
+  }
+
+  test("state - get, exists, update, remove, ") {
+    var state: StateImpl[Int] = null
+
+    def testState(
+        expectedData: Option[Int],
+        shouldBeUpdated: Boolean = false,
+        shouldBeRemoved: Boolean = false,
+        shouldBeTimingOut: Boolean = false
+      ): Unit = {
+      if (expectedData.isDefined) {
+        assert(state.exists)
+        assert(state.get() === expectedData.get)
+        assert(state.getOption() === expectedData)
+        assert(state.getOption.getOrElse(-1) === expectedData.get)
+      } else {
+        assert(!state.exists)
+        intercept[NoSuchElementException] {
+          state.get()
+        }
+        assert(state.getOption() === None)
+        assert(state.getOption.getOrElse(-1) === -1)
+      }
+
+      assert(state.isTimingOut() === shouldBeTimingOut)
+      if (shouldBeTimingOut) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+
+      assert(state.isUpdated() === shouldBeUpdated)
+
+      assert(state.isRemoved() === shouldBeRemoved)
+      if (shouldBeRemoved) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+    }
+
+    state = new StateImpl[Int]()
+    testState(None)
+
+    state.wrap(None)
+    testState(None)
+
+    state.wrap(Some(1))
+    testState(Some(1))
+
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state = new StateImpl[Int]()
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state.remove()
+    testState(None, shouldBeRemoved = true)
+
+    state.wrapTiminoutState(3)
+    testState(Some(3), shouldBeTimingOut = true)
+  }
+
+  test("mapWithState - basic operations with simple API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(1),
+        Seq(2, 1),
+        Seq(3, 2, 1),
+        Seq(4, 3),
+        Seq(5),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, and updated count is returned
+    val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      sum
+    }
+
+    testOperation[String, Int, Int](
+      inputData, StateSpec.function(mappingFunc), outputData, stateData)
+  }
+
+  test("mapWithState - basic operations with advanced API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq("aa"),
+        Seq("aa", "bb"),
+        Seq("aa", "bb", "cc"),
+        Seq("aa", "bb"),
+        Seq("aa"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, key string doubled and returned
+    val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      Some(key * 2)
+    }
+
+    testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
+  }
+
+  test("mapWithState - type inferencing and class tags") {
+
+    // Simple track state function with value as Int, state as Double and mapped type as Double
+    val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => {
+      0L
+    }
+
+    // Advanced track state function with key as String, value as Int, state as Double and
+    // mapped type as Double
+    val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
+      Some(0L)
+    }
+
+    def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = {
+      val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]]
+      assert(dstreamImpl.keyClass === classOf[String])
+      assert(dstreamImpl.valueClass === classOf[Int])
+      assert(dstreamImpl.stateClass === classOf[Double])
+      assert(dstreamImpl.mappedClass === classOf[Long])
+    }
+    val ssc = new StreamingContext(sc, batchDuration)
+    val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
+
+    // Defining StateSpec inline with mapWithState and simple function implicitly gets the types
+    val simpleFunctionStateStream1 = inputStream.mapWithState(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(simpleFunctionStateStream1)
+
+    // Separately defining StateSpec with simple function requires explicitly specifying types
+    val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
+    val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec)
+    testTypes(simpleFunctionStateStream2)
+
+    // Separately defining StateSpec with advanced function implicitly gets the types
+    val advFuncSpec1 = StateSpec.function(advancedFunc)
+    val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1)
+    testTypes(advFunctionStateStream1)
+
+    // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
+    val advFunctionStateStream2 = inputStream.mapWithState(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(advFunctionStateStream2)
+
+    // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
+    val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
+    val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2)
+    testTypes(advFunctionStateStream3)
+  }
+
+  test("mapWithState - states as mapped data") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3)),
+        Seq(("a", 5)),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      Some(output)
+    }
+
+    testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
+  }
+
+  test("mapWithState - initial states, with nothing returned as from mapping function") {
+
+    val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
+
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
+
+    val stateData =
+      Seq(
+        Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
+        Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
+        Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
+      )
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      None.asInstanceOf[Option[Int]]
+    }
+
+    val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState))
+    testOperation(inputData, mapWithStateSpec, outputData, stateData)
+  }
+
+  test("mapWithState - state removing") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"), // a will be removed
+        Seq("a", "b", "c"), // b will be removed
+        Seq("a", "b", "c"), // a and c will be removed
+        Seq("a", "b"), // b will be removed
+        Seq("a"), // a will be removed
+        Seq()
+      )
+
+    // States that were removed
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(),
+        Seq("a"),
+        Seq("b"),
+        Seq("a", "c"),
+        Seq("b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1), ("c", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1)),
+        Seq(),
+        Seq()
+      )
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      if (state.exists) {
+        state.remove()
+        Some(key)
+      } else {
+        state.update(value.get)
+        None
+      }
+    }
+
+    testOperation(
+      inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData)
+  }
+
+  test("mapWithState - state timing out") {
+    val inputData =
+      Seq(
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq(), // c will time out
+        Seq(), // b will time out
+        Seq("a") // a will not time out
+      ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
+
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      if (value.isDefined) {
+        state.update(1)
+      }
+      if (state.isTimingOut) {
+        Some(key)
+      } else {
+        None
+      }
+    }
+
+    val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
+      inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20)
+
+    // b and c should be returned once each, when they were marked as expired
+    assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
+
+    // States for a, b, c should be defined at one point of time
+    assert(collectedStateSnapshots.exists {
+      _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
+    })
+
+    // Finally state should be defined only for a
+    assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
+  }
+
+  test("mapWithState - checkpoint durations") {
+    val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream)
+
+    def testCheckpointDuration(
+        batchDuration: Duration,
+        expectedCheckpointDuration: Duration,
+        explicitCheckpointDuration: Option[Duration] = None
+      ): Unit = {
+      val ssc = new StreamingContext(sc, batchDuration)
+
+      try {
+        val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
+        val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0
+        val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc))
+        val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod()
+
+        explicitCheckpointDuration.foreach { d =>
+          mapWithStateStream.checkpoint(d)
+        }
+        mapWithStateStream.register()
+        ssc.checkpoint(checkpointDir.toString)
+        ssc.start()  // should initialize all the checkpoint durations
+        assert(mapWithStateStream.checkpointDuration === null)
+        assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration)
+      } finally {
+        ssc.stop(stopSparkContext = false)
+      }
+    }
+
+    testCheckpointDuration(Milliseconds(100), Seconds(1))
+    testCheckpointDuration(Seconds(1), Seconds(10))
+    testCheckpointDuration(Seconds(10), Seconds(100))
+
+    testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2)))
+    testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2)))
+    testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
+  }
+
+
+  test("mapWithState - driver failure recovery") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    def operation(dstream: DStream[String]): DStream[(String, Int)] = {
+
+      val checkpointDuration = batchDuration * (stateData.size / 2)
+
+      val runningCount = (key: String, value: Option[Int], state: State[Int]) => {
+        state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
+        state.get()
+      }
+
+      val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState(
+        StateSpec.function(runningCount))
+      // Set internval make sure there is one RDD checkpointing
+      mapWithStateStream.checkpoint(checkpointDuration)
+      mapWithStateStream.stateSnapshots()
+    }
+
+    testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
+      batchDuration = batchDuration, stopSparkContextAfterTest = false)
+  }
+
+  private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      mapWithStateSpec: StateSpec[K, Int, S, T],
+      expectedOutputs: Seq[Seq[T]],
+      expectedStateSnapshots: Seq[Seq[(K, S)]]
+    ): Unit = {
+    require(expectedOutputs.size == expectedStateSnapshots.size)
+
+    val (collectedOutputs, collectedStateSnapshots) =
+      getOperationOutput(input, mapWithStateSpec, expectedOutputs.size)
+    assert(expectedOutputs, collectedOutputs, "outputs")
+    assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
+  }
+
+  private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      mapWithStateSpec: StateSpec[K, Int, S, T],
+      numBatches: Int
+    ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
+
+    // Setup the stream computation
+    val ssc = new StreamingContext(sc, Seconds(1))
+    val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+    val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec)
+    val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
+    val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
+    val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
+    val stateSnapshotStream = new TestOutputStream(
+      trackeStateStream.stateSnapshots(), collectedStateSnapshots)
+    outputStream.register()
+    stateSnapshotStream.register()
+
+    val batchCounter = new BatchCounter(ssc)
+    ssc.checkpoint(checkpointDir.toString)
+    ssc.start()
+
+    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+    clock.advance(batchDuration.milliseconds * numBatches)
+
+    batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+    ssc.stop(stopSparkContext = false)
+    (collectedOutputs, collectedStateSnapshots)
+  }
+
+  private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) {
+    val debugString = "\nExpected:\n" + expected.mkString("\n") +
+      "\nCollected:\n" + collected.mkString("\n")
+    assert(expected.size === collected.size,
+      s"number of collected $typ (${collected.size}) different from expected (${expected.size})" +
+        debugString)
+    expected.zip(collected).foreach { case (c, e) =>
+      assert(c.toSet === e.toSet,
+        s"collected $typ is different from expected $debugString"
+      )
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
deleted file mode 100644
index 1fc320d..0000000
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ /dev/null
@@ -1,581 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming
-
-import java.io.File
-
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
-import scala.reflect.ClassTag
-
-import org.scalatest.PrivateMethodTester._
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-
-import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
-import org.apache.spark.util.{ManualClock, Utils}
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-
-class TrackStateByKeySuite extends SparkFunSuite
-  with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
-
-  private var sc: SparkContext = null
-  protected var checkpointDir: File = null
-  protected val batchDuration = Seconds(1)
-
-  before {
-    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
-    checkpointDir = Utils.createTempDir("checkpoint")
-  }
-
-  after {
-    if (checkpointDir != null) {
-      Utils.deleteRecursively(checkpointDir)
-    }
-    StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
-  }
-
-  override def beforeAll(): Unit = {
-    val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
-    conf.set("spark.streaming.clock", classOf[ManualClock].getName())
-    sc = new SparkContext(conf)
-  }
-
-  override def afterAll(): Unit = {
-    if (sc != null) {
-      sc.stop()
-    }
-  }
-
-  test("state - get, exists, update, remove, ") {
-    var state: StateImpl[Int] = null
-
-    def testState(
-        expectedData: Option[Int],
-        shouldBeUpdated: Boolean = false,
-        shouldBeRemoved: Boolean = false,
-        shouldBeTimingOut: Boolean = false
-      ): Unit = {
-      if (expectedData.isDefined) {
-        assert(state.exists)
-        assert(state.get() === expectedData.get)
-        assert(state.getOption() === expectedData)
-        assert(state.getOption.getOrElse(-1) === expectedData.get)
-      } else {
-        assert(!state.exists)
-        intercept[NoSuchElementException] {
-          state.get()
-        }
-        assert(state.getOption() === None)
-        assert(state.getOption.getOrElse(-1) === -1)
-      }
-
-      assert(state.isTimingOut() === shouldBeTimingOut)
-      if (shouldBeTimingOut) {
-        intercept[IllegalArgumentException] {
-          state.remove()
-        }
-        intercept[IllegalArgumentException] {
-          state.update(-1)
-        }
-      }
-
-      assert(state.isUpdated() === shouldBeUpdated)
-
-      assert(state.isRemoved() === shouldBeRemoved)
-      if (shouldBeRemoved) {
-        intercept[IllegalArgumentException] {
-          state.remove()
-        }
-        intercept[IllegalArgumentException] {
-          state.update(-1)
-        }
-      }
-    }
-
-    state = new StateImpl[Int]()
-    testState(None)
-
-    state.wrap(None)
-    testState(None)
-
-    state.wrap(Some(1))
-    testState(Some(1))
-
-    state.update(2)
-    testState(Some(2), shouldBeUpdated = true)
-
-    state = new StateImpl[Int]()
-    state.update(2)
-    testState(Some(2), shouldBeUpdated = true)
-
-    state.remove()
-    testState(None, shouldBeRemoved = true)
-
-    state.wrapTiminoutState(3)
-    testState(Some(3), shouldBeTimingOut = true)
-  }
-
-  test("trackStateByKey - basic operations with simple API") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData =
-      Seq(
-        Seq(),
-        Seq(1),
-        Seq(2, 1),
-        Seq(3, 2, 1),
-        Seq(4, 3),
-        Seq(5),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    // state maintains running count, and updated count is returned
-    val trackStateFunc = (value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      state.update(sum)
-      sum
-    }
-
-    testOperation[String, Int, Int](
-      inputData, StateSpec.function(trackStateFunc), outputData, stateData)
-  }
-
-  test("trackStateByKey - basic operations with advanced API") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData =
-      Seq(
-        Seq(),
-        Seq("aa"),
-        Seq("aa", "bb"),
-        Seq("aa", "bb", "cc"),
-        Seq("aa", "bb"),
-        Seq("aa"),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    // state maintains running count, key string doubled and returned
-    val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      state.update(sum)
-      Some(key * 2)
-    }
-
-    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
-  }
-
-  test("trackStateByKey - type inferencing and class tags") {
-
-    // Simple track state function with value as Int, state as Double and emitted type as Double
-    val simpleFunc = (value: Option[Int], state: State[Double]) => {
-      0L
-    }
-
-    // Advanced track state function with key as String, value as Int, state as Double and
-    // emitted type as Double
-    val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
-      Some(0L)
-    }
-
-    def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
-      val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
-      assert(dstreamImpl.keyClass === classOf[String])
-      assert(dstreamImpl.valueClass === classOf[Int])
-      assert(dstreamImpl.stateClass === classOf[Double])
-      assert(dstreamImpl.emittedClass === classOf[Long])
-    }
-    val ssc = new StreamingContext(sc, batchDuration)
-    val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
-
-    // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
-    val simpleFunctionStateStream1 = inputStream.trackStateByKey(
-      StateSpec.function(simpleFunc).numPartitions(1))
-    testTypes(simpleFunctionStateStream1)
-
-    // Separately defining StateSpec with simple function requires explicitly specifying types
-    val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
-    val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
-    testTypes(simpleFunctionStateStream2)
-
-    // Separately defining StateSpec with advanced function implicitly gets the types
-    val advFuncSpec1 = StateSpec.function(advancedFunc)
-    val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
-    testTypes(advFunctionStateStream1)
-
-    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
-    val advFunctionStateStream2 = inputStream.trackStateByKey(
-      StateSpec.function(simpleFunc).numPartitions(1))
-    testTypes(advFunctionStateStream2)
-
-    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
-    val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
-    val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
-    testTypes(advFunctionStateStream3)
-  }
-
-  test("trackStateByKey - states as emitted records") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3)),
-        Seq(("a", 5)),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      val output = (key, sum)
-      state.update(sum)
-      Some(output)
-    }
-
-    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
-  }
-
-  test("trackStateByKey - initial states, with nothing emitted") {
-
-    val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
-
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
-
-    val stateData =
-      Seq(
-        Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
-        Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
-        Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
-        Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
-        Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
-        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
-        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
-      )
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
-      val output = (key, sum)
-      state.update(sum)
-      None.asInstanceOf[Option[Int]]
-    }
-
-    val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
-    testOperation(inputData, trackStateSpec, outputData, stateData)
-  }
-
-  test("trackStateByKey - state removing") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"), // a will be removed
-        Seq("a", "b", "c"), // b will be removed
-        Seq("a", "b", "c"), // a and c will be removed
-        Seq("a", "b"), // b will be removed
-        Seq("a"), // a will be removed
-        Seq()
-      )
-
-    // States that were removed
-    val outputData =
-      Seq(
-        Seq(),
-        Seq(),
-        Seq("a"),
-        Seq("b"),
-        Seq("a", "c"),
-        Seq("b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("b", 1)),
-        Seq(("a", 1), ("c", 1)),
-        Seq(("b", 1)),
-        Seq(("a", 1)),
-        Seq(),
-        Seq()
-      )
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      if (state.exists) {
-        state.remove()
-        Some(key)
-      } else {
-        state.update(value.get)
-        None
-      }
-    }
-
-    testOperation(
-      inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData)
-  }
-
-  test("trackStateByKey - state timing out") {
-    val inputData =
-      Seq(
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq(), // c will time out
-        Seq(), // b will time out
-        Seq("a") // a will not time out
-      ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
-
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
-      if (value.isDefined) {
-        state.update(1)
-      }
-      if (state.isTimingOut) {
-        Some(key)
-      } else {
-        None
-      }
-    }
-
-    val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
-      inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
-
-    // b and c should be emitted once each, when they were marked as expired
-    assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
-
-    // States for a, b, c should be defined at one point of time
-    assert(collectedStateSnapshots.exists {
-      _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
-    })
-
-    // Finally state should be defined only for a
-    assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
-  }
-
-  test("trackStateByKey - checkpoint durations") {
-    val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream)
-
-    def testCheckpointDuration(
-        batchDuration: Duration,
-        expectedCheckpointDuration: Duration,
-        explicitCheckpointDuration: Option[Duration] = None
-      ): Unit = {
-      val ssc = new StreamingContext(sc, batchDuration)
-
-      try {
-        val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
-        val dummyFunc = (value: Option[Int], state: State[Int]) => 0
-        val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
-        val internalTrackStateStream = trackStateStream invokePrivate privateMethod()
-
-        explicitCheckpointDuration.foreach { d =>
-          trackStateStream.checkpoint(d)
-        }
-        trackStateStream.register()
-        ssc.checkpoint(checkpointDir.toString)
-        ssc.start()  // should initialize all the checkpoint durations
-        assert(trackStateStream.checkpointDuration === null)
-        assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
-      } finally {
-        ssc.stop(stopSparkContext = false)
-      }
-    }
-
-    testCheckpointDuration(Milliseconds(100), Seconds(1))
-    testCheckpointDuration(Seconds(1), Seconds(10))
-    testCheckpointDuration(Seconds(10), Seconds(100))
-
-    testCheckpointDuration(Milliseconds(100), Seconds(2), Some(Seconds(2)))
-    testCheckpointDuration(Seconds(1), Seconds(2), Some(Seconds(2)))
-    testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
-  }
-
-
-  test("trackStateByKey - driver failure recovery") {
-    val inputData =
-      Seq(
-        Seq(),
-        Seq("a"),
-        Seq("a", "b"),
-        Seq("a", "b", "c"),
-        Seq("a", "b"),
-        Seq("a"),
-        Seq()
-      )
-
-    val stateData =
-      Seq(
-        Seq(),
-        Seq(("a", 1)),
-        Seq(("a", 2), ("b", 1)),
-        Seq(("a", 3), ("b", 2), ("c", 1)),
-        Seq(("a", 4), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1)),
-        Seq(("a", 5), ("b", 3), ("c", 1))
-      )
-
-    def operation(dstream: DStream[String]): DStream[(String, Int)] = {
-
-      val checkpointDuration = batchDuration * (stateData.size / 2)
-
-      val runningCount = (value: Option[Int], state: State[Int]) => {
-        state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
-        state.get()
-      }
-
-      val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
-        StateSpec.function(runningCount))
-      // Set internval make sure there is one RDD checkpointing
-      trackStateStream.checkpoint(checkpointDuration)
-      trackStateStream.stateSnapshots()
-    }
-
-    testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
-      batchDuration = batchDuration, stopSparkContextAfterTest = false)
-  }
-
-  private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
-      input: Seq[Seq[K]],
-      trackStateSpec: StateSpec[K, Int, S, T],
-      expectedOutputs: Seq[Seq[T]],
-      expectedStateSnapshots: Seq[Seq[(K, S)]]
-    ): Unit = {
-    require(expectedOutputs.size == expectedStateSnapshots.size)
-
-    val (collectedOutputs, collectedStateSnapshots) =
-      getOperationOutput(input, trackStateSpec, expectedOutputs.size)
-    assert(expectedOutputs, collectedOutputs, "outputs")
-    assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
-  }
-
-  private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
-      input: Seq[Seq[K]],
-      trackStateSpec: StateSpec[K, Int, S, T],
-      numBatches: Int
-    ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
-
-    // Setup the stream computation
-    val ssc = new StreamingContext(sc, Seconds(1))
-    val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
-    val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
-    val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
-    val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
-    val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
-    val stateSnapshotStream = new TestOutputStream(
-      trackeStateStream.stateSnapshots(), collectedStateSnapshots)
-    outputStream.register()
-    stateSnapshotStream.register()
-
-    val batchCounter = new BatchCounter(ssc)
-    ssc.checkpoint(checkpointDir.toString)
-    ssc.start()
-
-    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
-    clock.advance(batchDuration.milliseconds * numBatches)
-
-    batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
-    ssc.stop(stopSparkContext = false)
-    (collectedOutputs, collectedStateSnapshots)
-  }
-
-  private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) {
-    val debugString = "\nExpected:\n" + expected.mkString("\n") +
-      "\nCollected:\n" + collected.mkString("\n")
-    assert(expected.size === collected.size,
-      s"number of collected $typ (${collected.size}) different from expected (${expected.size})" +
-        debugString)
-    expected.zip(collected).foreach { case (c, e) =>
-      assert(c.toSet === e.toSet,
-        s"collected $typ is different from expected $debugString"
-      )
-    }
-  }
-}
-

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
new file mode 100644
index 0000000..aa95bd3
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
@@ -0,0 +1,389 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
+import org.apache.spark.streaming.{State, Time}
+import org.apache.spark.util.Utils
+
+class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
+
+  private var sc: SparkContext = null
+  private var checkpointDir: File = _
+
+  override def beforeAll(): Unit = {
+    sc = new SparkContext(
+      new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite"))
+    checkpointDir = Utils.createTempDir()
+    sc.setCheckpointDir(checkpointDir.toString)
+  }
+
+  override def afterAll(): Unit = {
+    if (sc != null) {
+      sc.stop()
+    }
+    Utils.deleteRecursively(checkpointDir)
+  }
+
+  override def sparkContext: SparkContext = sc
+
+  test("creation from pair RDD") {
+    val data = Seq((1, "1"), (2, "2"), (3, "3"))
+    val partitioner = new HashPartitioner(10)
+    val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int](
+      sc.parallelize(data), partitioner, Time(123))
+    assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
+    assert(rdd.partitions.size === partitioner.numPartitions)
+
+    assert(rdd.partitioner === Some(partitioner))
+  }
+
+  test("updating state and generating mapped data in MapWithStateRDDRecord") {
+
+    val initialTime = 1000L
+    val updatedTime = 2000L
+    val thresholdTime = 1500L
+    @volatile var functionCalled = false
+
+    /**
+     * Assert that applying given data on a prior record generates correct updated record, with
+     * correct state map and mapped data
+     */
+    def assertRecordUpdate(
+        initStates: Iterable[Int],
+        data: Iterable[String],
+        expectedStates: Iterable[(Int, Long)],
+        timeoutThreshold: Option[Long] = None,
+        removeTimedoutData: Boolean = false,
+        expectedOutput: Iterable[Int] = None,
+        expectedTimingOutStates: Iterable[Int] = None,
+        expectedRemovedStates: Iterable[Int] = None
+      ): Unit = {
+      val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
+      initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
+      functionCalled = false
+      val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
+      val dataIterator = data.map { v => ("key", v) }.iterator
+      val removedStates = new ArrayBuffer[Int]
+      val timingOutStates = new ArrayBuffer[Int]
+      /**
+       * Mapping function that updates/removes state based on instructions in the data, and
+       * return state (when instructed or when state is timing out).
+       */
+      def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
+        functionCalled = true
+
+        assert(t.milliseconds === updatedTime, "mapping func called with wrong time")
+
+        data match {
+          case Some("noop") =>
+            None
+          case Some("get-state") =>
+            Some(state.getOption().getOrElse(-1))
+          case Some("update-state") =>
+            if (state.exists) state.update(state.get + 1) else state.update(0)
+            None
+          case Some("remove-state") =>
+            removedStates += state.get()
+            state.remove()
+            None
+          case None =>
+            assert(state.isTimingOut() === true, "State is not timing out when data = None")
+            timingOutStates += state.get()
+            None
+          case _ =>
+            fail("Unexpected test data")
+        }
+      }
+
+      val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int](
+        Some(record), dataIterator, testFunc,
+        Time(updatedTime), timeoutThreshold, removeTimedoutData)
+
+      val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
+      assert(updatedStateData.toSet === expectedStates.toSet,
+        "states do not match after updating the MapWithStateRDDRecord")
+
+      assert(updatedRecord.mappedData.toSet === expectedOutput.toSet,
+        "mapped data do not match after updating the MapWithStateRDDRecord")
+
+      assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
+        "match those that were expected to do so while updating the MapWithStateRDDRecord")
+
+      assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
+        "match those that were expected to do so while updating the MapWithStateRDDRecord")
+
+    }
+
+    // No data, no state should be changed, function should not be called,
+    assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
+    assert(functionCalled === false)
+    assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === false)
+
+    // Data present, function should be called irrespective of whether state exists
+    assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
+      expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === true)
+    assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
+    assert(functionCalled === true)
+
+    // Function called with right state data
+    assertRecordUpdate(initStates = None, data = Seq("get-state"),
+      expectedStates = None, expectedOutput = Seq(-1))
+    assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
+      expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
+
+    // Update state and timestamp, when timeout not present
+    assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
+      expectedStates = Seq((0, updatedTime)))
+    assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
+      expectedStates = Seq((1, updatedTime)))
+
+    // Remove state
+    assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
+      expectedStates = Nil, expectedRemovedStates = Seq(345))
+
+    // State strictly older than timeout threshold should be timed out
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime), removeTimedoutData = true,
+      expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
+
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Seq(123))
+
+    // State should not be timed out after it has received data
+    assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
+    assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
+
+  }
+
+  test("states generated by MapWithStateRDD") {
+    val initStates = Seq(("k1", 0), ("k2", 0))
+    val initTime = 123
+    val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
+    val partitioner = new HashPartitioner(2)
+    val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int](
+      sc.parallelize(initStates), partitioner, Time(initTime)).persist()
+    assertRDD(initStateRDD, initStateWthTime, Set.empty)
+
+    val updateTime = 345
+
+    /**
+     * Test that the test state RDD, when operated with new data,
+     * creates a new state RDD with expected states
+     */
+    def testStateUpdates(
+        testStateRDD: MapWithStateRDD[String, Int, Int, Int],
+        testData: Seq[(String, Int)],
+        expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = {
+
+      // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation.
+      // This is to make sure that we only touch which state keys are being touched in the next op.
+      testStateRDD.persist().count()
+
+      // To track which keys are being touched
+      MapWithStateRDDSuite.touchedStateKeys.clear()
+
+      val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
+
+        // Track the key that has been touched
+        MapWithStateRDDSuite.touchedStateKeys += key
+
+        // If the data is 0, do not do anything with the state
+        // else if the data is 1, increment the state if it exists, or set new state to 0
+        // else if the data is 2, remove the state if it exists
+        data match {
+          case Some(1) =>
+            if (state.exists()) { state.update(state.get + 1) }
+            else state.update(0)
+          case Some(2) =>
+            state.remove()
+          case _ =>
+        }
+        None.asInstanceOf[Option[Int]]  // Do not return anything, not being tested
+      }
+      val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get)
+
+      // Assert that the new state RDD has expected state data
+      val newStateRDD = assertOperation(
+        testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty)
+
+      // Assert that the function was called only for the keys present in the data
+      assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size,
+        "More number of keys are being touched than that is expected")
+      assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
+        "Keys not in the data are being touched unexpectedly")
+
+      // Assert that the test RDD's data has not changed
+      assertRDD(initStateRDD, initStateWthTime, Set.empty)
+      newStateRDD
+    }
+
+    // Test no-op, no state should change
+    testStateUpdates(initStateRDD, Seq(), initStateWthTime)   // should not scan any state
+    testStateUpdates(
+      initStateRDD, Seq(("k1", 0)), initStateWthTime)         // should not update existing state
+    testStateUpdates(
+      initStateRDD, Seq(("k3", 0)), initStateWthTime)         // should not create new state
+
+    // Test creation of new state
+    val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0
+      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime)))
+
+    val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)),         // should create k4's state as 0
+      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))
+
+    // Test updating of state
+    val rdd3 = testStateUpdates(
+      initStateRDD, Seq(("k1", 1)),                   // should increment k1's state 0 -> 1
+      Set(("k1", 1, updateTime), ("k2", 0, initTime)))
+
+    val rdd4 = testStateUpdates(rdd3,
+      Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)),  // should update k2, 0 -> 2 and create k3, 0
+      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime)))
+
+    val rdd5 = testStateUpdates(
+      rdd4, Seq(("k3", 1)),                           // should update k3's state 0 -> 2
+      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime)))
+
+    // Test removing of state
+    val rdd6 = testStateUpdates(                      // should remove k1's state
+      initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime)))
+
+    val rdd7 = testStateUpdates(                      // should remove k2's state
+      rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
+
+    val rdd8 = testStateUpdates(                      // should remove k3's state
+      rdd7, Seq(("k3", 2)), Set())
+  }
+
+  test("checkpointing") {
+    /**
+     * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs
+     * - the data RDD and the parent MapWithStateRDD.
+     */
+    def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]])
+      : Set[(List[(Int, Int, Long)], List[Int])] = {
+      rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) }
+         .collect.toSet
+    }
+
+    /** Generate MapWithStateRDD with data RDD having a long lineage */
+    def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
+      : MapWithStateRDD[Int, Int, Int, Int] = {
+      MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
+    }
+
+    testRDD(
+      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
+    testRDDPartitions(
+      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
+
+    /** Generate MapWithStateRDD with parent state RDD having a long lineage */
+    def makeStateRDDWithLongLineageParenttateRDD(
+        longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = {
+
+      // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage
+      val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
+
+      // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent
+      new MapWithStateRDD[Int, Int, Int, Int](
+        stateRDDWithLongLineage,
+        stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
+        (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
+        Time(10),
+        None
+      )
+    }
+
+    testRDD(
+      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
+    testRDDPartitions(
+      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
+  }
+
+  test("checkpointing empty state RDD") {
+    val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int](
+      sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
+    emptyStateRDD.checkpoint()
+    assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+    val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]](
+      emptyStateRDD.getCheckpointFile.get)
+    assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+  }
+
+  /** Assert whether the `mapWithState` operation generates expected results */
+  private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      testStateRDD: MapWithStateRDD[K, V, S, T],
+      newDataRDD: RDD[(K, V)],
+      mappingFunction: (Time, K, Option[V], State[S]) => Option[T],
+      currentTime: Long,
+      expectedStates: Set[(K, S, Int)],
+      expectedMappedData: Set[T],
+      doFullScan: Boolean = false
+    ): MapWithStateRDD[K, V, S, T] = {
+
+    val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
+      newDataRDD.partitionBy(testStateRDD.partitioner.get)
+    } else {
+      newDataRDD
+    }
+
+    val newStateRDD = new MapWithStateRDD[K, V, S, T](
+      testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None)
+    if (doFullScan) newStateRDD.setFullScan()
+
+    // Persist to make sure that it gets computed only once and we can track precisely how many
+    // state keys the computing touched
+    newStateRDD.persist().count()
+    assertRDD(newStateRDD, expectedStates, expectedMappedData)
+    newStateRDD
+  }
+
+  /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */
+  private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      stateRDD: MapWithStateRDD[K, V, S, T],
+      expectedStates: Set[(K, S, Int)],
+      expectedMappedData: Set[T]): Unit = {
+    val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
+    val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet
+    assert(states === expectedStates,
+      "states after mapWithState operation were not as expected")
+    assert(mappedData === expectedMappedData,
+      "mapped data after mapWithState operation were not as expected")
+  }
+}
+
+object MapWithStateRDDSuite {
+  private val touchedStateKeys = new ArrayBuffer[String]()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
----------------------------------------------------------------------
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
deleted file mode 100644
index 3b2d43f..0000000
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ /dev/null
@@ -1,389 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.rdd
-
-import java.io.File
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
-import org.apache.spark.streaming.{State, Time}
-import org.apache.spark.util.Utils
-
-class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
-
-  private var sc: SparkContext = null
-  private var checkpointDir: File = _
-
-  override def beforeAll(): Unit = {
-    sc = new SparkContext(
-      new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
-    checkpointDir = Utils.createTempDir()
-    sc.setCheckpointDir(checkpointDir.toString)
-  }
-
-  override def afterAll(): Unit = {
-    if (sc != null) {
-      sc.stop()
-    }
-    Utils.deleteRecursively(checkpointDir)
-  }
-
-  override def sparkContext: SparkContext = sc
-
-  test("creation from pair RDD") {
-    val data = Seq((1, "1"), (2, "2"), (3, "3"))
-    val partitioner = new HashPartitioner(10)
-    val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int](
-      sc.parallelize(data), partitioner, Time(123))
-    assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
-    assert(rdd.partitions.size === partitioner.numPartitions)
-
-    assert(rdd.partitioner === Some(partitioner))
-  }
-
-  test("updating state and generating emitted data in TrackStateRecord") {
-
-    val initialTime = 1000L
-    val updatedTime = 2000L
-    val thresholdTime = 1500L
-    @volatile var functionCalled = false
-
-    /**
-     * Assert that applying given data on a prior record generates correct updated record, with
-     * correct state map and emitted data
-     */
-    def assertRecordUpdate(
-        initStates: Iterable[Int],
-        data: Iterable[String],
-        expectedStates: Iterable[(Int, Long)],
-        timeoutThreshold: Option[Long] = None,
-        removeTimedoutData: Boolean = false,
-        expectedOutput: Iterable[Int] = None,
-        expectedTimingOutStates: Iterable[Int] = None,
-        expectedRemovedStates: Iterable[Int] = None
-      ): Unit = {
-      val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
-      initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
-      functionCalled = false
-      val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
-      val dataIterator = data.map { v => ("key", v) }.iterator
-      val removedStates = new ArrayBuffer[Int]
-      val timingOutStates = new ArrayBuffer[Int]
-      /**
-       * Tracking function that updates/removes state based on instructions in the data, and
-       * return state (when instructed or when state is timing out).
-       */
-      def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
-        functionCalled = true
-
-        assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
-
-        data match {
-          case Some("noop") =>
-            None
-          case Some("get-state") =>
-            Some(state.getOption().getOrElse(-1))
-          case Some("update-state") =>
-            if (state.exists) state.update(state.get + 1) else state.update(0)
-            None
-          case Some("remove-state") =>
-            removedStates += state.get()
-            state.remove()
-            None
-          case None =>
-            assert(state.isTimingOut() === true, "State is not timing out when data = None")
-            timingOutStates += state.get()
-            None
-          case _ =>
-            fail("Unexpected test data")
-        }
-      }
-
-      val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
-        Some(record), dataIterator, testFunc,
-        Time(updatedTime), timeoutThreshold, removeTimedoutData)
-
-      val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
-      assert(updatedStateData.toSet === expectedStates.toSet,
-        "states do not match after updating the TrackStateRecord")
-
-      assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
-        "emitted data do not match after updating the TrackStateRecord")
-
-      assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
-        "match those that were expected to do so while updating the TrackStateRecord")
-
-      assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
-        "match those that were expected to do so while updating the TrackStateRecord")
-
-    }
-
-    // No data, no state should be changed, function should not be called,
-    assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
-    assert(functionCalled === false)
-    assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
-    assert(functionCalled === false)
-
-    // Data present, function should be called irrespective of whether state exists
-    assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
-      expectedStates = Seq((0, initialTime)))
-    assert(functionCalled === true)
-    assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
-    assert(functionCalled === true)
-
-    // Function called with right state data
-    assertRecordUpdate(initStates = None, data = Seq("get-state"),
-      expectedStates = None, expectedOutput = Seq(-1))
-    assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
-      expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
-
-    // Update state and timestamp, when timeout not present
-    assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
-      expectedStates = Seq((0, updatedTime)))
-    assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
-      expectedStates = Seq((1, updatedTime)))
-
-    // Remove state
-    assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
-      expectedStates = Nil, expectedRemovedStates = Seq(345))
-
-    // State strictly older than timeout threshold should be timed out
-    assertRecordUpdate(initStates = Seq(123), data = Nil,
-      timeoutThreshold = Some(initialTime), removeTimedoutData = true,
-      expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
-
-    assertRecordUpdate(initStates = Seq(123), data = Nil,
-      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
-      expectedStates = Nil, expectedTimingOutStates = Seq(123))
-
-    // State should not be timed out after it has received data
-    assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
-      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
-      expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
-    assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
-      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
-      expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
-
-  }
-
-  test("states generated by TrackStateRDD") {
-    val initStates = Seq(("k1", 0), ("k2", 0))
-    val initTime = 123
-    val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
-    val partitioner = new HashPartitioner(2)
-    val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int](
-      sc.parallelize(initStates), partitioner, Time(initTime)).persist()
-    assertRDD(initStateRDD, initStateWthTime, Set.empty)
-
-    val updateTime = 345
-
-    /**
-     * Test that the test state RDD, when operated with new data,
-     * creates a new state RDD with expected states
-     */
-    def testStateUpdates(
-        testStateRDD: TrackStateRDD[String, Int, Int, Int],
-        testData: Seq[(String, Int)],
-        expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = {
-
-      // Persist the test TrackStateRDD so that its not recomputed while doing the next operation.
-      // This is to make sure that we only track which state keys are being touched in the next op.
-      testStateRDD.persist().count()
-
-      // To track which keys are being touched
-      TrackStateRDDSuite.touchedStateKeys.clear()
-
-      val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
-
-        // Track the key that has been touched
-        TrackStateRDDSuite.touchedStateKeys += key
-
-        // If the data is 0, do not do anything with the state
-        // else if the data is 1, increment the state if it exists, or set new state to 0
-        // else if the data is 2, remove the state if it exists
-        data match {
-          case Some(1) =>
-            if (state.exists()) { state.update(state.get + 1) }
-            else state.update(0)
-          case Some(2) =>
-            state.remove()
-          case _ =>
-        }
-        None.asInstanceOf[Option[Int]]  // Do not return anything, not being tested
-      }
-      val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get)
-
-      // Assert that the new state RDD has expected state data
-      val newStateRDD = assertOperation(
-        testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty)
-
-      // Assert that the function was called only for the keys present in the data
-      assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
-        "More number of keys are being touched than that is expected")
-      assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
-        "Keys not in the data are being touched unexpectedly")
-
-      // Assert that the test RDD's data has not changed
-      assertRDD(initStateRDD, initStateWthTime, Set.empty)
-      newStateRDD
-    }
-
-    // Test no-op, no state should change
-    testStateUpdates(initStateRDD, Seq(), initStateWthTime)   // should not scan any state
-    testStateUpdates(
-      initStateRDD, Seq(("k1", 0)), initStateWthTime)         // should not update existing state
-    testStateUpdates(
-      initStateRDD, Seq(("k3", 0)), initStateWthTime)         // should not create new state
-
-    // Test creation of new state
-    val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0
-      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime)))
-
-    val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)),         // should create k4's state as 0
-      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))
-
-    // Test updating of state
-    val rdd3 = testStateUpdates(
-      initStateRDD, Seq(("k1", 1)),                   // should increment k1's state 0 -> 1
-      Set(("k1", 1, updateTime), ("k2", 0, initTime)))
-
-    val rdd4 = testStateUpdates(rdd3,
-      Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)),  // should update k2, 0 -> 2 and create k3, 0
-      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime)))
-
-    val rdd5 = testStateUpdates(
-      rdd4, Seq(("k3", 1)),                           // should update k3's state 0 -> 2
-      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime)))
-
-    // Test removing of state
-    val rdd6 = testStateUpdates(                      // should remove k1's state
-      initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime)))
-
-    val rdd7 = testStateUpdates(                      // should remove k2's state
-      rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
-
-    val rdd8 = testStateUpdates(                      // should remove k3's state
-      rdd7, Seq(("k3", 2)), Set())
-  }
-
-  test("checkpointing") {
-    /**
-     * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs -
-     * the data RDD and the parent TrackStateRDD.
-     */
-    def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]])
-      : Set[(List[(Int, Int, Long)], List[Int])] = {
-      rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) }
-         .collect.toSet
-    }
-
-    /** Generate TrackStateRDD with data RDD having a long lineage */
-    def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
-      : TrackStateRDD[Int, Int, Int, Int] = {
-      TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
-    }
-
-    testRDD(
-      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
-    testRDDPartitions(
-      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
-
-    /** Generate TrackStateRDD with parent state RDD having a long lineage */
-    def makeStateRDDWithLongLineageParenttateRDD(
-        longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = {
-
-      // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage
-      val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
-
-      // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent
-      new TrackStateRDD[Int, Int, Int, Int](
-        stateRDDWithLongLineage,
-        stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
-        (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
-        Time(10),
-        None
-      )
-    }
-
-    testRDD(
-      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
-    testRDDPartitions(
-      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
-  }
-
-  test("checkpointing empty state RDD") {
-    val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
-      sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
-    emptyStateRDD.checkpoint()
-    assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
-    val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
-      emptyStateRDD.getCheckpointFile.get)
-    assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
-  }
-
-  /** Assert whether the `trackStateByKey` operation generates expected results */
-  private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-      testStateRDD: TrackStateRDD[K, V, S, T],
-      newDataRDD: RDD[(K, V)],
-      trackStateFunc: (Time, K, Option[V], State[S]) => Option[T],
-      currentTime: Long,
-      expectedStates: Set[(K, S, Int)],
-      expectedEmittedRecords: Set[T],
-      doFullScan: Boolean = false
-    ): TrackStateRDD[K, V, S, T] = {
-
-    val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
-      newDataRDD.partitionBy(testStateRDD.partitioner.get)
-    } else {
-      newDataRDD
-    }
-
-    val newStateRDD = new TrackStateRDD[K, V, S, T](
-      testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None)
-    if (doFullScan) newStateRDD.setFullScan()
-
-    // Persist to make sure that it gets computed only once and we can track precisely how many
-    // state keys the computing touched
-    newStateRDD.persist().count()
-    assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
-    newStateRDD
-  }
-
-  /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */
-  private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-      trackStateRDD: TrackStateRDD[K, V, S, T],
-      expectedStates: Set[(K, S, Int)],
-      expectedEmittedRecords: Set[T]): Unit = {
-    val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
-    val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
-    assert(states === expectedStates,
-      "states after track state operation were not as expected")
-    assert(emittedRecords === expectedEmittedRecords,
-      "emitted records after track state operation were not as expected")
-  }
-}
-
-object TrackStateRDDSuite {
-  private val touchedStateKeys = new ArrayBuffer[String]()
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-12244][SPARK-12245][STREAMING] Rename trackStateByKey to mapWithState and change tracking function signature

Posted by zs...@apache.org.
[SPARK-12244][SPARK-12245][STREAMING] Rename trackStateByKey to mapWithState and change tracking function signature

SPARK-12244:

Based on feedback from early users and personal experience attempting to explain it, the name trackStateByKey had two problem.
"trackState" is a completely new term which really does not give any intuition on what the operation is
the resultant data stream of objects returned by the function is called in docs as the "emitted" data for the lack of a better.
"mapWithState" makes sense because the API is like a mapping function like (Key, Value) => T with State as an additional parameter. The resultant data stream is "mapped data". So both problems are solved.

SPARK-12245:

>From initial experiences, not having the key in the function makes it hard to return mapped stuff, as the whole information of the records is not there. Basically the user is restricted to doing something like mapValue() instead of map(). So adding the key as a parameter.

Author: Tathagata Das <ta...@gmail.com>

Closes #10224 from tdas/rename.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bd2cd4f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bd2cd4f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bd2cd4f5

Branch: refs/heads/master
Commit: bd2cd4f53d1ca10f4896bd39b0e180d4929867a2
Parents: 2166c2a
Author: Tathagata Das <ta...@gmail.com>
Authored: Wed Dec 9 20:47:15 2015 -0800
Committer: Shixiong Zhu <sh...@databricks.com>
Committed: Wed Dec 9 20:47:15 2015 -0800

----------------------------------------------------------------------
 .../streaming/JavaStatefulNetworkWordCount.java |  16 +-
 .../streaming/StatefulNetworkWordCount.scala    |  12 +-
 .../apache/spark/streaming/Java8APISuite.java   |  18 +-
 .../org/apache/spark/streaming/State.scala      |  20 +-
 .../org/apache/spark/streaming/StateSpec.scala  | 160 ++---
 .../api/java/JavaMapWithStateDStream.scala      |  44 ++
 .../streaming/api/java/JavaPairDStream.scala    |  50 +-
 .../api/java/JavaTrackStateDStream.scala        |  44 --
 .../streaming/dstream/MapWithStateDStream.scala | 170 ++++++
 .../dstream/PairDStreamFunctions.scala          |  41 +-
 .../streaming/dstream/TrackStateDStream.scala   | 171 ------
 .../spark/streaming/rdd/MapWithStateRDD.scala   | 223 +++++++
 .../spark/streaming/rdd/TrackStateRDD.scala     | 228 --------
 .../spark/streaming/JavaMapWithStateSuite.java  | 210 +++++++
 .../streaming/JavaTrackStateByKeySuite.java     | 210 -------
 .../spark/streaming/MapWithStateSuite.scala     | 581 +++++++++++++++++++
 .../spark/streaming/TrackStateByKeySuite.scala  | 581 -------------------
 .../streaming/rdd/MapWithStateRDDSuite.scala    | 389 +++++++++++++
 .../streaming/rdd/TrackStateRDDSuite.scala      | 389 -------------
 19 files changed, 1782 insertions(+), 1775 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index c400e42..14997c6 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -65,7 +65,7 @@ public class JavaStatefulNetworkWordCount {
     JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
     ssc.checkpoint(".");
 
-    // Initial RDD input to trackStateByKey
+    // Initial state RDD input to mapWithState
     @SuppressWarnings("unchecked")
     List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
             new Tuple2<String, Integer>("world", 1));
@@ -90,21 +90,21 @@ public class JavaStatefulNetworkWordCount {
         });
 
     // Update the cumulative count function
-    final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
-        new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {
+    final Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>> mappingFunc =
+        new Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>>() {
 
           @Override
-          public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
+          public Tuple2<String, Integer> call(String word, Optional<Integer> one, State<Integer> state) {
             int sum = one.or(0) + (state.exists() ? state.get() : 0);
             Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
             state.update(sum);
-            return Optional.of(output);
+            return output;
           }
         };
 
-    // This will give a Dstream made of state (which is the cumulative count of the words)
-    JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
-        wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
+    // DStream made of get cumulative counts that get updated in every batch
+    JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
+        wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD));
 
     stateDstream.print();
     ssc.start();

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index a4f847f..2dce182 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
     val ssc = new StreamingContext(sparkConf, Seconds(1))
     ssc.checkpoint(".")
 
-    // Initial RDD input to trackStateByKey
+    // Initial state RDD for mapWithState operation
     val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
 
     // Create a ReceiverInputDStream on target ip:port and count the
@@ -58,17 +58,17 @@ object StatefulNetworkWordCount {
     val words = lines.flatMap(_.split(" "))
     val wordDstream = words.map(x => (x, 1))
 
-    // Update the cumulative count using updateStateByKey
+    // Update the cumulative count using mapWithState
     // This will give a DStream made of state (which is the cumulative count of the words)
-    val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
+    val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
       val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
       val output = (word, sum)
       state.update(sum)
-      Some(output)
+      output
     }
 
-    val stateDstream = wordDstream.trackStateByKey(
-      StateSpec.function(trackStateFunc).initialState(initialRDD))
+    val stateDstream = wordDstream.mapWithState(
+      StateSpec.function(mappingFunc).initialState(initialRDD))
     stateDstream.print()
     ssc.start()
     ssc.awaitTermination()

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
----------------------------------------------------------------------
diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
index 4eee97b..89e0c7f 100644
--- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
+++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
@@ -32,12 +32,10 @@ import org.apache.spark.Accumulator;
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.Function4;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
 
 /**
  * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
@@ -863,12 +861,12 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
   /**
    * This test is only for testing the APIs. It's not necessary to run it.
    */
-  public void testTrackStateByAPI() {
+  public void testMapWithStateAPI() {
     JavaPairRDD<String, Boolean> initialRDD = null;
     JavaPairDStream<String, Integer> wordsDstream = null;
 
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
-        wordsDstream.trackStateByKey(
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
+        wordsDstream.mapWithState(
             StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
               // Use all State's methods here
               state.exists();
@@ -884,9 +882,9 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
 
     JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
 
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
-        wordsDstream.trackStateByKey(
-            StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+        wordsDstream.mapWithState(
+            StateSpec.<String, Integer, Boolean, Double>function((key, value, state) -> {
               state.exists();
               state.get();
               state.isTimingOut();
@@ -898,6 +896,6 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
                 .partitioner(new HashPartitioner(10))
                 .timeout(Durations.seconds(10)));
 
-    JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
+    JavaPairDStream<String, Boolean> mappedDStream = stateDstream2.stateSnapshots();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/State.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 604e64f..b47bdda 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental
 
 /**
  * :: Experimental ::
- * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
- * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Abstract class for getting and updating the state in mapping function used in the `mapWithState`
+ * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
+ * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
  *
  * Scala example of using `State`:
  * {{{
- *    // A tracking function that maintains an integer state and return a String
- *    def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ *    // A mapping function that maintains an integer state and returns a String
+ *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
  *      // Check if state exists
  *      if (state.exists) {
  *        val existingState = state.get  // Get the existing state
@@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
  *
  * Java example of using `State`:
  * {{{
- *    // A tracking function that maintains an integer state and return a String
- *   Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- *       new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *    // A mapping function that maintains an integer state and returns a String
+ *    Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
  *
  *         @Override
- *         public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ *         public String call(String key, Optional<Integer> value, State<Integer> state) {
  *           if (state.exists()) {
  *             int existingState = state.get(); // Get the existing state
  *             boolean shouldRemove = ...; // Decide whether to remove the state
@@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
  *         }
  *       };
  * }}}
+ *
+ * @tparam S Class of the state
  */
 @Experimental
 sealed abstract class State[S] {

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index bea5b9d..9f6f952 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming
 import com.google.common.base.Optional
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.ClosureCleaner
 import org.apache.spark.{HashPartitioner, Partitioner}
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner}
 /**
  * :: Experimental ::
  * Abstract class representing all the specifications of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
  * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
  * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
  * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
@@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner}
  *
  * Example in Scala:
  * {{{
- *    def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- *      ...
+ *    // A mapping function that maintains an integer state and return a String
+ *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ *      // Use state.exists(), state.get(), state.update() and state.remove()
+ *      // to manage state, and return the necessary string
  *    }
  *
- *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
  *
- *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
  * }}}
  *
  * Example in Java:
  * {{{
- *    StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- *      StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- *                    .numPartition(10);
+ *   // A mapping function that maintains an integer state and return a string
+ *   Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ *           @Override
+ *           public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ *               // Use state.exists(), state.get(), state.update() and state.remove()
+ *               // to manage state, and return the necessary string
+ *           }
+ *       };
  *
- *    JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- *      javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
+ *    JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ *        keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
  * }}}
+ *
+ * @tparam KeyType    Class of the state key
+ * @tparam ValueType  Class of the state value
+ * @tparam StateType  Class of the state data
+ * @tparam MappedType Class of the mapped elements
  */
 @Experimental
-sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable {
 
-  /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+  /** Set the RDD containing the initial states that will be used by `mapWithState` */
   def initialState(rdd: RDD[(KeyType, StateType)]): this.type
 
-  /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+  /** Set the RDD containing the initial states that will be used by `mapWithState` */
   def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
 
   /**
-   * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+   * Set the number of partitions by which the state RDDs generated by `mapWithState`
    * will be partitioned. Hash partitioning will be used.
    */
   def numPartitions(numPartitions: Int): this.type
 
   /**
-   * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+   * Set the partitioner by which the state RDDs generated by `mapWithState` will be
    * be partitioned.
    */
   def partitioner(partitioner: Partitioner): this.type
 
   /**
    * Set the duration after which the state of an idle key will be removed. A key and its state is
-   * considered idle if it has not received any data for at least the given duration. The state
-   * tracking function will be called one final time on the idle states that are going to be
+   * considered idle if it has not received any data for at least the given duration. The
+   * mapping function will be called one final time on the idle states that are going to be
    * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
    * to `true` in that call.
    */
@@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
 /**
  * :: Experimental ::
  * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
- * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
+ * that is used for specifying the parameters of the DStream transformation `mapWithState`
  * that is used for specifying the parameters of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
  * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
  * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
  *
  * Example in Scala:
  * {{{
- *    def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- *      ...
+ *    // A mapping function that maintains an integer state and return a String
+ *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ *      // Use state.exists(), state.get(), state.update() and state.remove()
+ *      // to manage state, and return the necessary string
  *    }
  *
- *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
- *        StateSpec.function(trackingFunction).numPartitions(10))
+ *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
+ *
+ *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
  * }}}
  *
  * Example in Java:
  * {{{
- *    StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- *      StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- *                    .numPartition(10);
+ *   // A mapping function that maintains an integer state and return a string
+ *   Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ *           @Override
+ *           public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ *               // Use state.exists(), state.get(), state.update() and state.remove()
+ *               // to manage state, and return the necessary string
+ *           }
+ *       };
  *
- *    JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- *      javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
- * }}}
+ *    JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ *        keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
  */
 @Experimental
 object StateSpec {
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
-   * of the `trackStateByKey` operation on a
+   * of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
    *
-   * @param trackingFunction The function applied on every data item to manage the associated state
-   *                         and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated state
+   *                         and generate the mapped data
    * @tparam KeyType      Class of the keys
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](
-      trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
-    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
-    new StateSpecImpl(trackingFunction)
+  def function[KeyType, ValueType, StateType, MappedType](
+      mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType]
+    ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+    new StateSpecImpl(mappingFunction)
   }
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
-   * of the `trackStateByKey` operation on a
+   * of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
    *
-   * @param trackingFunction The function applied on every data item to manage the associated state
-   *                         and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated state
+   *                         and generate the mapped data
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](
-      trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
-    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+  def function[KeyType, ValueType, StateType, MappedType](
+      mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType
+    ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    ClosureCleaner.clean(mappingFunction, checkSerializable = true)
     val wrappedFunction =
-      (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
-        Some(trackingFunction(value, state))
+      (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => {
+        Some(mappingFunction(key, value, state))
       }
     new StateSpecImpl(wrappedFunction)
   }
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
-   * the specifications of the `trackStateByKey` operation on a
+   * the specifications of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
    *
-   * @param javaTrackingFunction The function applied on every data item to manage the associated
-   *                             state and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated
+   *                        state and generate the mapped data
    * @tparam KeyType      Class of the keys
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
-      JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
-    StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
-      val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+  def function[KeyType, ValueType, StateType, MappedType](mappingFunction:
+      JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]):
+    StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+      val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
       Option(t.orNull)
     }
-    StateSpec.function(trackingFunc)
+    StateSpec.function(wrappedFunc)
   }
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
-   * of the `trackStateByKey` operation on a
+   * of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
    *
-   * @param javaTrackingFunction The function applied on every data item to manage the associated
-   *                             state and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated
+   *                        state and generate the mapped data
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](
-      javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
-    StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
-      javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+  def function[KeyType, ValueType, StateType, MappedType](
+      mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]):
+    StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+      mappingFunction.call(k, Optional.fromNullable(v.get), s)
     }
-    StateSpec.function(trackingFunc)
+    StateSpec.function(wrappedFunc)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
new file mode 100644
index 0000000..16c0d6f
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.java
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.streaming.dstream.MapWithStateDStream
+
+/**
+ * :: Experimental ::
+ * DStream representing the stream of data generated by `mapWithState` operation on a
+ * [[JavaPairDStream]]. Additionally, it also gives access to the
+ * stream of state snapshots, that is, the state data of all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped data
+ */
+@Experimental
+class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming](
+    dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType])
+  extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) {
+
+  def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
+    new JavaPairDStream(dstream.stateSnapshots())(
+      JavaSparkContext.fakeClassTag,
+      JavaSparkContext.fakeClassTag)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 70e32b3..42ddd63 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
 
   /**
    * :: Experimental ::
-   * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
-   * with a continuously updated per-key state. The user-provided state tracking function is
-   * applied on each keyed data item along with its corresponding state. The function can choose to
-   * update/remove the state and return a transformed data, which forms the
-   * [[JavaTrackStateDStream]].
+   * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of
+   * `this` stream, while maintaining some state data for each unique key. The mapping function
+   * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+   * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+   * as a parameter of type [[State]] in the mapping function.
    *
-   * The specifications of this transformation is made through the
-   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
-   * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
-   * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
-   *
-   * Example of using `trackStateByKey`:
+   * Example of using `mapWithState`:
    * {{{
-   *   // A tracking function that maintains an integer state and return a String
-   *   Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
-   *       new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
-   *
-   *         @Override
-   *         public Optional<String> call(Optional<Integer> one, State<Integer> state) {
-   *           // Check if state exists, accordingly update/remove state and return transformed data
-   *         }
+   *   // A mapping function that maintains an integer state and return a string
+   *   Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+   *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
+   *           @Override
+   *           public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+   *               // Use state.exists(), state.get(), state.update() and state.remove()
+   *               // to manage state, and return the necessary string
+   *           }
    *       };
    *
-   *    JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
-   *        keyValueDStream.<Integer, String>trackStateByKey(
-   *                 StateSpec.function(trackStateFunc).numPartitions(10));
-   * }}}
+   *    JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+   *        keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+   *}}}
    *
    * @param spec          Specification of this transformation
-   * @tparam StateType    Class type of the state
-   * @tparam EmittedType  Class type of the tranformed data return by the tracking function
+   * @tparam StateType    Class type of the state data
+   * @tparam MappedType   Class type of the mapped data
    */
   @Experimental
-  def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
-    JavaTrackStateDStream[K, V, StateType, EmittedType] = {
-    new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+  def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]):
+    JavaMapWithStateDStream[K, V, StateType, MappedType] = {
+    new JavaMapWithStateDStream(dstream.mapWithState(spec)(
       JavaSparkContext.fakeClassTag,
       JavaSparkContext.fakeClassTag))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
deleted file mode 100644
index f459930..0000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.api.java
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.streaming.dstream.TrackStateDStream
-
-/**
- * :: Experimental ::
- * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
- * stream of state snapshots, that is, the state data of all keys after a batch has updated them.
- *
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
- * @tparam StateType Class of the state
- * @tparam EmittedType Class of the emitted records
- */
-@Experimental
-class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
-    dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
-  extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
-
-  def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
-    new JavaPairDStream(dstream.stateSnapshots())(
-      JavaSparkContext.fakeClassTag,
-      JavaSparkContext.fakeClassTag)
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
new file mode 100644
index 0000000..706465d
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.dstream
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord}
+import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._
+
+/**
+ * :: Experimental ::
+ * DStream representing the stream of data generated by `mapWithState` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ * Additionally, it also gives access to the stream of state snapshots, that is, the state data of
+ * all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the key
+ * @tparam ValueType Class of the value
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped data
+ */
+@Experimental
+sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag](
+    ssc: StreamingContext) extends DStream[MappedType](ssc) {
+
+  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
+  def stateSnapshots(): DStream[(KeyType, StateType)]
+}
+
+/** Internal implementation of the [[MapWithStateDStream]] */
+private[streaming] class MapWithStateDStreamImpl[
+    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
+    dataStream: DStream[(KeyType, ValueType)],
+    spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
+  extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
+
+  private val internalStream =
+    new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
+
+  override def slideDuration: Duration = internalStream.slideDuration
+
+  override def dependencies: List[DStream[_]] = List(internalStream)
+
+  override def compute(validTime: Time): Option[RDD[MappedType]] = {
+    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
+  }
+
+  /**
+   * Forward the checkpoint interval to the internal DStream that computes the state maps. This
+   * to make sure that this DStream does not get checkpointed, only the internal stream.
+   */
+  override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = {
+    internalStream.checkpoint(checkpointInterval)
+    this
+  }
+
+  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
+  def stateSnapshots(): DStream[(KeyType, StateType)] = {
+    internalStream.flatMap {
+      _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
+  }
+
+  def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass
+
+  def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass
+
+  def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
+
+  def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass
+}
+
+/**
+ * A DStream that allows per-key state to be maintains, and arbitrary records to be generated
+ * based on updates to the state. This is the main DStream that implements the `mapWithState`
+ * operation on DStreams.
+ *
+ * @param parent Parent (key, value) stream that is the source
+ * @param spec Specifications of the mapWithState operation
+ * @tparam K   Key type
+ * @tparam V   Value type
+ * @tparam S   Type of the state maintained
+ * @tparam E   Type of the mapped data
+ */
+private[streaming]
+class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
+  extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) {
+
+  persist(StorageLevel.MEMORY_ONLY)
+
+  private val partitioner = spec.getPartitioner().getOrElse(
+    new HashPartitioner(ssc.sc.defaultParallelism))
+
+  private val mappingFunction = spec.getFunction()
+
+  override def slideDuration: Duration = parent.slideDuration
+
+  override def dependencies: List[DStream[_]] = List(parent)
+
+  /** Enable automatic checkpointing */
+  override val mustCheckpoint = true
+
+  /** Override the default checkpoint duration */
+  override def initialize(time: Time): Unit = {
+    if (checkpointDuration == null) {
+      checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER
+    }
+    super.initialize(time)
+  }
+
+  /** Method that generates a RDD for the given time */
+  override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
+    // Get the previous state or create a new empty state RDD
+    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
+      case Some(rdd) =>
+        if (rdd.partitioner != Some(partitioner)) {
+          // If the RDD is not partitioned the right way, let us repartition it using the
+          // partition index as the key. This is to ensure that state RDD is always partitioned
+          // before creating another state RDD using it
+          MapWithStateRDD.createFromRDD[K, V, S, E](
+            rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
+        } else {
+          rdd
+        }
+      case None =>
+        MapWithStateRDD.createFromPairRDD[K, V, S, E](
+          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
+          partitioner,
+          validTime
+        )
+    }
+
+
+    // Compute the new state RDD with previous state RDD and partitioned data RDD
+    // Even if there is no data RDD, use an empty one to create a new state RDD
+    val dataRDD = parent.getOrCompute(validTime).getOrElse {
+      context.sparkContext.emptyRDD[(K, V)]
+    }
+    val partitionedDataRDD = dataRDD.partitionBy(partitioner)
+    val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
+      (validTime - interval).milliseconds
+    }
+    Some(new MapWithStateRDD(
+      prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
+  }
+}
+
+private[streaming] object InternalMapWithStateDStream {
+  private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2762309..a64a1fe 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
 
   /**
    * :: Experimental ::
-   * Return a new DStream of data generated by combining the key-value data in `this` stream
-   * with a continuously updated per-key state. The user-provided state tracking function is
-   * applied on each keyed data item along with its corresponding state. The function can choose to
-   * update/remove the state and return a transformed data, which forms the
-   * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+   * Return a [[MapWithStateDStream]] by applying a function to every key-value element of
+   * `this` stream, while maintaining some state data for each unique key. The mapping function
+   * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+   * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+   * as a parameter of type [[State]] in the mapping function.
    *
-   * The specifications of this transformation is made through the
-   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
-   * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
-   * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
-   *
-   * Example of using `trackStateByKey`:
+   * Example of using `mapWithState`:
    * {{{
-   *    def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
-   *      // Check if state exists, accordingly update/remove state and return transformed data
+   *    // A mapping function that maintains an integer state and return a String
+   *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+   *      // Use state.exists(), state.get(), state.update() and state.remove()
+   *      // to manage state, and return the necessary string
    *    }
    *
-   *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+   *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
    *
-   *    val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+   *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
    * }}}
    *
    * @param spec          Specification of this transformation
-   * @tparam StateType    Class type of the state
-   * @tparam EmittedType  Class type of the tranformed data return by the tracking function
+   * @tparam StateType    Class type of the state data
+   * @tparam MappedType   Class type of the mapped data
    */
   @Experimental
-  def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
-      spec: StateSpec[K, V, StateType, EmittedType]
-    ): TrackStateDStream[K, V, StateType, EmittedType] = {
-    new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+  def mapWithState[StateType: ClassTag, MappedType: ClassTag](
+      spec: StateSpec[K, V, StateType, MappedType]
+    ): MapWithStateDStream[K, V, StateType, MappedType] = {
+    new MapWithStateDStreamImpl[K, V, StateType, MappedType](
       self,
-      spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+      spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
deleted file mode 100644
index ea62134..0000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ /dev/null
@@ -1,171 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.dstream
-
-import scala.reflect.ClassTag
-
-import org.apache.spark._
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.rdd.{EmptyRDD, RDD}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming._
-import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
-import org.apache.spark.streaming.dstream.InternalTrackStateDStream._
-
-/**
- * :: Experimental ::
- * DStream representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a
- * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
- * Additionally, it also gives access to the stream of state snapshots, that is, the state data of
- * all keys after a batch has updated them.
- *
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
- * @tparam StateType Class of the state data
- * @tparam EmittedType Class of the emitted records
- */
-@Experimental
-sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
-    ssc: StreamingContext) extends DStream[EmittedType](ssc) {
-
-  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
-  def stateSnapshots(): DStream[(KeyType, StateType)]
-}
-
-/** Internal implementation of the [[TrackStateDStream]] */
-private[streaming] class TrackStateDStreamImpl[
-    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
-    dataStream: DStream[(KeyType, ValueType)],
-    spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
-  extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
-
-  private val internalStream =
-    new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
-
-  override def slideDuration: Duration = internalStream.slideDuration
-
-  override def dependencies: List[DStream[_]] = List(internalStream)
-
-  override def compute(validTime: Time): Option[RDD[EmittedType]] = {
-    internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
-  }
-
-  /**
-   * Forward the checkpoint interval to the internal DStream that computes the state maps. This
-   * to make sure that this DStream does not get checkpointed, only the internal stream.
-   */
-  override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
-    internalStream.checkpoint(checkpointInterval)
-    this
-  }
-
-  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
-  def stateSnapshots(): DStream[(KeyType, StateType)] = {
-    internalStream.flatMap {
-      _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
-  }
-
-  def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass
-
-  def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass
-
-  def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
-
-  def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
-}
-
-/**
- * A DStream that allows per-key state to be maintains, and arbitrary records to be generated
- * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
- * operation on DStreams.
- *
- * @param parent Parent (key, value) stream that is the source
- * @param spec Specifications of the trackStateByKey operation
- * @tparam K   Key type
- * @tparam V   Value type
- * @tparam S   Type of the state maintained
- * @tparam E   Type of the emitted data
- */
-private[streaming]
-class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-    parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
-  extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
-
-  persist(StorageLevel.MEMORY_ONLY)
-
-  private val partitioner = spec.getPartitioner().getOrElse(
-    new HashPartitioner(ssc.sc.defaultParallelism))
-
-  private val trackingFunction = spec.getFunction()
-
-  override def slideDuration: Duration = parent.slideDuration
-
-  override def dependencies: List[DStream[_]] = List(parent)
-
-  /** Enable automatic checkpointing */
-  override val mustCheckpoint = true
-
-  /** Override the default checkpoint duration */
-  override def initialize(time: Time): Unit = {
-    if (checkpointDuration == null) {
-      checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER
-    }
-    super.initialize(time)
-  }
-
-  /** Method that generates a RDD for the given time */
-  override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
-    // Get the previous state or create a new empty state RDD
-    val prevStateRDD = getOrCompute(validTime - slideDuration) match {
-      case Some(rdd) =>
-        if (rdd.partitioner != Some(partitioner)) {
-          // If the RDD is not partitioned the right way, let us repartition it using the
-          // partition index as the key. This is to ensure that state RDD is always partitioned
-          // before creating another state RDD using it
-          TrackStateRDD.createFromRDD[K, V, S, E](
-            rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
-        } else {
-          rdd
-        }
-      case None =>
-        TrackStateRDD.createFromPairRDD[K, V, S, E](
-          spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
-          partitioner,
-          validTime
-        )
-    }
-
-
-    // Compute the new state RDD with previous state RDD and partitioned data RDD
-    // Even if there is no data RDD, use an empty one to create a new state RDD
-    val dataRDD = parent.getOrCompute(validTime).getOrElse {
-      context.sparkContext.emptyRDD[(K, V)]
-    }
-    val partitionedDataRDD = dataRDD.partitionBy(partitioner)
-    val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
-      (validTime - interval).milliseconds
-    }
-    Some(new TrackStateRDD(
-      prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
-  }
-}
-
-private[streaming] object InternalTrackStateDStream {
-  private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
new file mode 100644
index 0000000..ed95171
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.streaming.{Time, StateImpl, State}
+import org.apache.spark.streaming.util.{EmptyStateMap, StateMap}
+import org.apache.spark.util.Utils
+import org.apache.spark._
+
+/**
+ * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the mapping function of `mapWithState`.
+ */
+private[streaming] case class MapWithStateRDDRecord[K, S, E](
+    var stateMap: StateMap[K, S], var mappedData: Seq[E])
+
+private[streaming] object MapWithStateRDDRecord {
+  def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
+    dataIterator: Iterator[(K, V)],
+    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
+    batchTime: Time,
+    timeoutThresholdTime: Option[Long],
+    removeTimedoutData: Boolean
+  ): MapWithStateRDDRecord[K, S, E] = {
+    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
+    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
+
+    val mappedData = new ArrayBuffer[E]
+    val wrappedState = new StateImpl[S]()
+
+    // Call the mapping function on each record in the data iterator, and accordingly
+    // update the states touched, and collect the data returned by the mapping function
+    dataIterator.foreach { case (key, value) =>
+      wrappedState.wrap(newStateMap.get(key))
+      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
+      if (wrappedState.isRemoved) {
+        newStateMap.remove(key)
+      } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
+        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+      }
+      mappedData ++= returned
+    }
+
+    // Get the timed out state records, call the mapping function on each and collect the
+    // data returned
+    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
+      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
+        wrappedState.wrapTiminoutState(state)
+        val returned = mappingFunction(batchTime, key, None, wrappedState)
+        mappedData ++= returned
+        newStateMap.remove(key)
+      }
+    }
+
+    MapWithStateRDDRecord(newStateMap, mappedData)
+  }
+}
+
+/**
+ * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state
+ * RDD, and a partitioned keyed-data RDD
+ */
+private[streaming] class MapWithStateRDDPartition(
+    idx: Int,
+    @transient private var prevStateRDD: RDD[_],
+    @transient private var partitionedDataRDD: RDD[_]) extends Partition {
+
+  private[rdd] var previousSessionRDDPartition: Partition = null
+  private[rdd] var partitionedDataRDDPartition: Partition = null
+
+  override def index: Int = idx
+  override def hashCode(): Int = idx
+
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
+    // Update the reference to parent split at the time of task serialization
+    previousSessionRDDPartition = prevStateRDD.partitions(index)
+    partitionedDataRDDPartition = partitionedDataRDD.partitions(index)
+    oos.defaultWriteObject()
+  }
+}
+
+
+/**
+ * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data.
+ * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping
+ * function of  `mapWithState`.
+ * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD
+  *                    will be created
+ * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
+ *                           in the `prevStateRDD` to create `this` RDD
+ * @param mappingFunction  The function that will be used to update state and return new data
+ * @param batchTime        The time of the batch to which this RDD belongs to. Use to update
+ * @param timeoutThresholdTime The time to indicate which keys are timeout
+ */
+private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
+    private var partitionedDataRDD: RDD[(K, V)],
+    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
+    batchTime: Time,
+    timeoutThresholdTime: Option[Long]
+  ) extends RDD[MapWithStateRDDRecord[K, S, E]](
+    partitionedDataRDD.sparkContext,
+    List(
+      new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
+      new OneToOneDependency(partitionedDataRDD))
+  ) {
+
+  @volatile private var doFullScan = false
+
+  require(prevStateRDD.partitioner.nonEmpty)
+  require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)
+
+  override val partitioner = prevStateRDD.partitioner
+
+  override def checkpoint(): Unit = {
+    super.checkpoint()
+    doFullScan = true
+  }
+
+  override def compute(
+      partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
+
+    val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
+    val prevStateRDDIterator = prevStateRDD.iterator(
+      stateRDDPartition.previousSessionRDDPartition, context)
+    val dataIterator = partitionedDataRDD.iterator(
+      stateRDDPartition.partitionedDataRDDPartition, context)
+
+    val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
+    val newRecord = MapWithStateRDDRecord.updateRecordWithData(
+      prevRecord,
+      dataIterator,
+      mappingFunction,
+      batchTime,
+      timeoutThresholdTime,
+      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
+    )
+    Iterator(newRecord)
+  }
+
+  override protected def getPartitions: Array[Partition] = {
+    Array.tabulate(prevStateRDD.partitions.length) { i =>
+      new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+  }
+
+  override def clearDependencies(): Unit = {
+    super.clearDependencies()
+    prevStateRDD = null
+    partitionedDataRDD = null
+  }
+
+  def setFullScan(): Unit = {
+    doFullScan = true
+  }
+}
+
+private[streaming] object MapWithStateRDD {
+
+  def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+      pairRDD: RDD[(K, S)],
+      partitioner: Partitioner,
+      updateTime: Time): MapWithStateRDD[K, V, S, E] = {
+
+    val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+      iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
+      Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
+    }, preservesPartitioning = true)
+
+    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
+
+    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
+
+    new MapWithStateRDD[K, V, S, E](
+      stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
+  }
+
+  def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+      rdd: RDD[(K, S, Long)],
+      partitioner: Partitioner,
+      updateTime: Time): MapWithStateRDD[K, V, S, E] = {
+
+    val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
+    val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+      iterator.foreach { case (key, (state, updateTime)) =>
+        stateMap.put(key, state, updateTime)
+      }
+      Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
+    }, preservesPartitioning = true)
+
+    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
+
+    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
+
+    new MapWithStateRDD[K, V, S, E](
+      stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
----------------------------------------------------------------------
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
deleted file mode 100644
index 30aafcf..0000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.rdd
-
-import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
-
-import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
-import org.apache.spark.streaming.{Time, StateImpl, State}
-import org.apache.spark.streaming.util.{EmptyStateMap, StateMap}
-import org.apache.spark.util.Utils
-import org.apache.spark._
-
-/**
- * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
- * sequence of records returned by the tracking function of `trackStateByKey`.
- */
-private[streaming] case class TrackStateRDDRecord[K, S, E](
-    var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
-
-private[streaming] object TrackStateRDDRecord {
-  def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-    prevRecord: Option[TrackStateRDDRecord[K, S, E]],
-    dataIterator: Iterator[(K, V)],
-    updateFunction: (Time, K, Option[V], State[S]) => Option[E],
-    batchTime: Time,
-    timeoutThresholdTime: Option[Long],
-    removeTimedoutData: Boolean
-  ): TrackStateRDDRecord[K, S, E] = {
-    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
-    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
-
-    val emittedRecords = new ArrayBuffer[E]
-    val wrappedState = new StateImpl[S]()
-
-    // Call the tracking function on each record in the data iterator, and accordingly
-    // update the states touched, and collect the data returned by the tracking function
-    dataIterator.foreach { case (key, value) =>
-      wrappedState.wrap(newStateMap.get(key))
-      val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
-      if (wrappedState.isRemoved) {
-        newStateMap.remove(key)
-      } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
-        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
-      }
-      emittedRecords ++= emittedRecord
-    }
-
-    // Get the timed out state records, call the tracking function on each and collect the
-    // data returned
-    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
-      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
-        wrappedState.wrapTiminoutState(state)
-        val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
-        emittedRecords ++= emittedRecord
-        newStateMap.remove(key)
-      }
-    }
-
-    TrackStateRDDRecord(newStateMap, emittedRecords)
-  }
-}
-
-/**
- * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
- * RDD, and a partitioned keyed-data RDD
- */
-private[streaming] class TrackStateRDDPartition(
-    idx: Int,
-    @transient private var prevStateRDD: RDD[_],
-    @transient private var partitionedDataRDD: RDD[_]) extends Partition {
-
-  private[rdd] var previousSessionRDDPartition: Partition = null
-  private[rdd] var partitionedDataRDDPartition: Partition = null
-
-  override def index: Int = idx
-  override def hashCode(): Int = idx
-
-  @throws(classOf[IOException])
-  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
-    // Update the reference to parent split at the time of task serialization
-    previousSessionRDDPartition = prevStateRDD.partitions(index)
-    partitionedDataRDDPartition = partitionedDataRDD.partitions(index)
-    oos.defaultWriteObject()
-  }
-}
-
-
-/**
- * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
- * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
- * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
- * function of  `trackStateByKey`.
- * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
- * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
- *                           in the `prevStateRDD` to create `this` RDD
- * @param trackingFunction The function that will be used to update state and return new data
- * @param batchTime        The time of the batch to which this RDD belongs to. Use to update
- * @param timeoutThresholdTime The time to indicate which keys are timeout
- */
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
-    private var partitionedDataRDD: RDD[(K, V)],
-    trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
-    batchTime: Time,
-    timeoutThresholdTime: Option[Long]
-  ) extends RDD[TrackStateRDDRecord[K, S, E]](
-    partitionedDataRDD.sparkContext,
-    List(
-      new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
-      new OneToOneDependency(partitionedDataRDD))
-  ) {
-
-  @volatile private var doFullScan = false
-
-  require(prevStateRDD.partitioner.nonEmpty)
-  require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)
-
-  override val partitioner = prevStateRDD.partitioner
-
-  override def checkpoint(): Unit = {
-    super.checkpoint()
-    doFullScan = true
-  }
-
-  override def compute(
-      partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
-
-    val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
-    val prevStateRDDIterator = prevStateRDD.iterator(
-      stateRDDPartition.previousSessionRDDPartition, context)
-    val dataIterator = partitionedDataRDD.iterator(
-      stateRDDPartition.partitionedDataRDDPartition, context)
-
-    val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
-    val newRecord = TrackStateRDDRecord.updateRecordWithData(
-      prevRecord,
-      dataIterator,
-      trackingFunction,
-      batchTime,
-      timeoutThresholdTime,
-      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
-    )
-    Iterator(newRecord)
-  }
-
-  override protected def getPartitions: Array[Partition] = {
-    Array.tabulate(prevStateRDD.partitions.length) { i =>
-      new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
-  }
-
-  override def clearDependencies(): Unit = {
-    super.clearDependencies()
-    prevStateRDD = null
-    partitionedDataRDD = null
-  }
-
-  def setFullScan(): Unit = {
-    doFullScan = true
-  }
-}
-
-private[streaming] object TrackStateRDD {
-
-  def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-      pairRDD: RDD[(K, S)],
-      partitioner: Partitioner,
-      updateTime: Time): TrackStateRDD[K, V, S, E] = {
-
-    val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
-      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
-      iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
-      Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
-    }, preservesPartitioning = true)
-
-    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
-
-    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
-
-    new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
-  }
-
-  def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-      rdd: RDD[(K, S, Long)],
-      partitioner: Partitioner,
-      updateTime: Time): TrackStateRDD[K, V, S, E] = {
-
-    val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
-    val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
-      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
-      iterator.foreach { case (key, (state, updateTime)) =>
-        stateMap.put(key, state, updateTime)
-      }
-      Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
-    }, preservesPartitioning = true)
-
-    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
-
-    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
-
-    new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
-  }
-}
-
-private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-    parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
-  override protected def getPartitions: Array[Partition] = parent.partitions
-  override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
-    parent.compute(partition, context).flatMap { _.emittedRecords }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/bd2cd4f5/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
----------------------------------------------------------------------
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
new file mode 100644
index 0000000..bc4bc2e
--- /dev/null
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+import scala.Tuple2;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.util.ManualClock;
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.Function3;
+import org.apache.spark.api.java.function.Function4;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
+
+public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable {
+
+  /**
+   * This test is only for testing the APIs. It's not necessary to run it.
+   */
+  public void testAPI() {
+    JavaPairRDD<String, Boolean> initialRDD = null;
+    JavaPairDStream<String, Integer> wordsDstream = null;
+
+    final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
+        mappingFunc =
+        new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
+
+          @Override
+          public Optional<Double> call(
+              Time time, String word, Optional<Integer> one, State<Boolean> state) {
+            // Use all State's methods here
+            state.exists();
+            state.get();
+            state.isTimingOut();
+            state.remove();
+            state.update(true);
+            return Optional.of(2.0);
+          }
+        };
+
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
+        wordsDstream.mapWithState(
+            StateSpec.function(mappingFunc)
+                .initialState(initialRDD)
+                .numPartitions(10)
+                .partitioner(new HashPartitioner(10))
+                .timeout(Durations.seconds(10)));
+
+    JavaPairDStream<String, Boolean> stateSnapshots = stateDstream.stateSnapshots();
+
+    final Function3<String, Optional<Integer>, State<Boolean>, Double> mappingFunc2 =
+        new Function3<String, Optional<Integer>, State<Boolean>, Double>() {
+
+          @Override
+          public Double call(String key, Optional<Integer> one, State<Boolean> state) {
+            // Use all State's methods here
+            state.exists();
+            state.get();
+            state.isTimingOut();
+            state.remove();
+            state.update(true);
+            return 2.0;
+          }
+        };
+
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+        wordsDstream.mapWithState(
+            StateSpec.<String, Integer, Boolean, Double>function(mappingFunc2)
+                .initialState(initialRDD)
+                .numPartitions(10)
+                .partitioner(new HashPartitioner(10))
+                .timeout(Durations.seconds(10)));
+
+    JavaPairDStream<String, Boolean> stateSnapshots2 = stateDstream2.stateSnapshots();
+  }
+
+  @Test
+  public void testBasicFunction() {
+    List<List<String>> inputData = Arrays.asList(
+        Collections.<String>emptyList(),
+        Arrays.asList("a"),
+        Arrays.asList("a", "b"),
+        Arrays.asList("a", "b", "c"),
+        Arrays.asList("a", "b"),
+        Arrays.asList("a"),
+        Collections.<String>emptyList()
+    );
+
+    List<Set<Integer>> outputData = Arrays.asList(
+        Collections.<Integer>emptySet(),
+        Sets.newHashSet(1),
+        Sets.newHashSet(2, 1),
+        Sets.newHashSet(3, 2, 1),
+        Sets.newHashSet(4, 3),
+        Sets.newHashSet(5),
+        Collections.<Integer>emptySet()
+    );
+
+    List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
+        Collections.<Tuple2<String, Integer>>emptySet(),
+        Sets.newHashSet(new Tuple2<String, Integer>("a", 1)),
+        Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 3),
+            new Tuple2<String, Integer>("b", 2),
+            new Tuple2<String, Integer>("c", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 4),
+            new Tuple2<String, Integer>("b", 3),
+            new Tuple2<String, Integer>("c", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 5),
+            new Tuple2<String, Integer>("b", 3),
+            new Tuple2<String, Integer>("c", 1)),
+        Sets.newHashSet(
+            new Tuple2<String, Integer>("a", 5),
+            new Tuple2<String, Integer>("b", 3),
+            new Tuple2<String, Integer>("c", 1))
+    );
+
+    Function3<String, Optional<Integer>, State<Integer>, Integer> mappingFunc =
+        new Function3<String, Optional<Integer>, State<Integer>, Integer>() {
+
+          @Override
+          public Integer call(String key, Optional<Integer> value, State<Integer> state) throws Exception {
+            int sum = value.or(0) + (state.exists() ? state.get() : 0);
+            state.update(sum);
+            return sum;
+          }
+        };
+    testOperation(
+        inputData,
+        StateSpec.<String, Integer, Integer, Integer>function(mappingFunc),
+        outputData,
+        stateData);
+  }
+
+  private <K, S, T> void testOperation(
+      List<List<K>> input,
+      StateSpec<K, Integer, S, T> mapWithStateSpec,
+      List<Set<T>> expectedOutputs,
+      List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
+    int numBatches = expectedOutputs.size();
+    JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
+    JavaMapWithStateDStream<K, Integer, S, T> mapWithStateDStream =
+        JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
+          @Override
+          public Tuple2<K, Integer> call(K x) throws Exception {
+            return new Tuple2<K, Integer>(x, 1);
+          }
+        })).mapWithState(mapWithStateSpec);
+
+    final List<Set<T>> collectedOutputs =
+        Collections.synchronizedList(Lists.<Set<T>>newArrayList());
+    mapWithStateDStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
+      @Override
+      public Void call(JavaRDD<T> rdd) throws Exception {
+        collectedOutputs.add(Sets.newHashSet(rdd.collect()));
+        return null;
+      }
+    });
+    final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
+        Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
+    mapWithStateDStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
+      @Override
+      public Void call(JavaPairRDD<K, S> rdd) throws Exception {
+        collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
+        return null;
+      }
+    });
+    BatchCounter batchCounter = new BatchCounter(ssc.ssc());
+    ssc.start();
+    ((ManualClock) ssc.ssc().scheduler().clock())
+        .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
+    batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
+
+    Assert.assertEquals(expectedOutputs, collectedOutputs);
+    Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org