diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index c400e4237abe3438f849dac1f8aad51c0797762b..14997c64d505ec99df1b77701249fea361167872 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -65,7 +65,7 @@ public class JavaStatefulNetworkWordCount {
     JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
     ssc.checkpoint(".");
 
-    // Initial RDD input to trackStateByKey
+    // Initial state RDD input to mapWithState
     @SuppressWarnings("unchecked")
     List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
             new Tuple2<String, Integer>("world", 1));
@@ -90,21 +90,21 @@ public class JavaStatefulNetworkWordCount {
         });
 
     // Update the cumulative count function
-    final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
-        new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {
+    final Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>> mappingFunc =
+        new Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>>() {
 
           @Override
-          public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
+          public Tuple2<String, Integer> call(String word, Optional<Integer> one, State<Integer> state) {
             int sum = one.or(0) + (state.exists() ? state.get() : 0);
             Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
             state.update(sum);
-            return Optional.of(output);
+            return output;
           }
         };
 
-    // This will give a Dstream made of state (which is the cumulative count of the words)
-    JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
-        wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
+    // DStream made of get cumulative counts that get updated in every batch
+    JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
+        wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD));
 
     stateDstream.print();
     ssc.start();
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index a4f847f118b2cd7b27ed671fc759dc534003b404..2dce1820d9734db0ebe20526006728c3243690ef 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
     val ssc = new StreamingContext(sparkConf, Seconds(1))
     ssc.checkpoint(".")
 
-    // Initial RDD input to trackStateByKey
+    // Initial state RDD for mapWithState operation
     val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
 
     // Create a ReceiverInputDStream on target ip:port and count the
@@ -58,17 +58,17 @@ object StatefulNetworkWordCount {
     val words = lines.flatMap(_.split(" "))
     val wordDstream = words.map(x => (x, 1))
 
-    // Update the cumulative count using updateStateByKey
+    // Update the cumulative count using mapWithState
     // This will give a DStream made of state (which is the cumulative count of the words)
-    val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
+    val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
       val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
       val output = (word, sum)
       state.update(sum)
-      Some(output)
+      output
     }
 
-    val stateDstream = wordDstream.trackStateByKey(
-      StateSpec.function(trackStateFunc).initialState(initialRDD))
+    val stateDstream = wordDstream.mapWithState(
+      StateSpec.function(mappingFunc).initialState(initialRDD))
     stateDstream.print()
     ssc.start()
     ssc.awaitTermination()
diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
index 4eee97bc89613e581dc66add0c7b5c01a4376d1c..89e0c7fdf7eecf42b016426679e5441fe5e0681d 100644
--- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
+++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
@@ -32,12 +32,10 @@ import org.apache.spark.Accumulator;
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.Function4;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
 
 /**
  * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
@@ -863,12 +861,12 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
   /**
    * This test is only for testing the APIs. It's not necessary to run it.
    */
-  public void testTrackStateByAPI() {
+  public void testMapWithStateAPI() {
     JavaPairRDD<String, Boolean> initialRDD = null;
     JavaPairDStream<String, Integer> wordsDstream = null;
 
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
-        wordsDstream.trackStateByKey(
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
+        wordsDstream.mapWithState(
             StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
               // Use all State's methods here
               state.exists();
@@ -884,9 +882,9 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
 
     JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
 
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
-        wordsDstream.trackStateByKey(
-            StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+        wordsDstream.mapWithState(
+            StateSpec.<String, Integer, Boolean, Double>function((key, value, state) -> {
               state.exists();
               state.get();
               state.isTimingOut();
@@ -898,6 +896,6 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
                 .partitioner(new HashPartitioner(10))
                 .timeout(Durations.seconds(10)));
 
-    JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
+    JavaPairDStream<String, Boolean> mappedDStream = stateDstream2.stateSnapshots();
   }
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 604e64fc61630faef16062a254983bd7fabd251f..b47bdda2c2137f0953a1af23b83197e7e6078aaf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental
 
 /**
  * :: Experimental ::
- * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
- * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Abstract class for getting and updating the state in mapping function used in the `mapWithState`
+ * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
+ * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
  *
  * Scala example of using `State`:
  * {{{
- *    // A tracking function that maintains an integer state and return a String
- *    def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ *    // A mapping function that maintains an integer state and returns a String
+ *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
  *      // Check if state exists
  *      if (state.exists) {
  *        val existingState = state.get  // Get the existing state
@@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
  *
  * Java example of using `State`:
  * {{{
- *    // A tracking function that maintains an integer state and return a String
- *   Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- *       new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *    // A mapping function that maintains an integer state and returns a String
+ *    Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
  *
  *         @Override
- *         public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ *         public String call(String key, Optional<Integer> value, State<Integer> state) {
  *           if (state.exists()) {
  *             int existingState = state.get(); // Get the existing state
  *             boolean shouldRemove = ...; // Decide whether to remove the state
@@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
  *         }
  *       };
  * }}}
+ *
+ * @tparam S Class of the state
  */
 @Experimental
 sealed abstract class State[S] {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index bea5b9df20b530c3117c01ffee64dbf356ca6b1b..9f6f95223f6194766bfb87b29d6a96d4a374fb4f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming
 import com.google.common.base.Optional
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.ClosureCleaner
 import org.apache.spark.{HashPartitioner, Partitioner}
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner}
 /**
  * :: Experimental ::
  * Abstract class representing all the specifications of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
  * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
  * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
  * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
@@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner}
  *
  * Example in Scala:
  * {{{
- *    def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- *      ...
+ *    // A mapping function that maintains an integer state and return a String
+ *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ *      // Use state.exists(), state.get(), state.update() and state.remove()
+ *      // to manage state, and return the necessary string
  *    }
  *
- *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
  *
- *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
  * }}}
  *
  * Example in Java:
  * {{{
- *    StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- *      StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- *                    .numPartition(10);
+ *   // A mapping function that maintains an integer state and return a string
+ *   Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ *           @Override
+ *           public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ *               // Use state.exists(), state.get(), state.update() and state.remove()
+ *               // to manage state, and return the necessary string
+ *           }
+ *       };
  *
- *    JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- *      javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
+ *    JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ *        keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
  * }}}
+ *
+ * @tparam KeyType    Class of the state key
+ * @tparam ValueType  Class of the state value
+ * @tparam StateType  Class of the state data
+ * @tparam MappedType Class of the mapped elements
  */
 @Experimental
-sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable {
 
-  /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+  /** Set the RDD containing the initial states that will be used by `mapWithState` */
   def initialState(rdd: RDD[(KeyType, StateType)]): this.type
 
-  /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+  /** Set the RDD containing the initial states that will be used by `mapWithState` */
   def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
 
   /**
-   * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+   * Set the number of partitions by which the state RDDs generated by `mapWithState`
    * will be partitioned. Hash partitioning will be used.
    */
   def numPartitions(numPartitions: Int): this.type
 
   /**
-   * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+   * Set the partitioner by which the state RDDs generated by `mapWithState` will be
    * be partitioned.
    */
   def partitioner(partitioner: Partitioner): this.type
 
   /**
    * Set the duration after which the state of an idle key will be removed. A key and its state is
-   * considered idle if it has not received any data for at least the given duration. The state
-   * tracking function will be called one final time on the idle states that are going to be
+   * considered idle if it has not received any data for at least the given duration. The
+   * mapping function will be called one final time on the idle states that are going to be
    * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
    * to `true` in that call.
    */
@@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
 /**
  * :: Experimental ::
  * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
- * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
+ * that is used for specifying the parameters of the DStream transformation `mapWithState`
  * that is used for specifying the parameters of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
  * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
  * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
  *
  * Example in Scala:
  * {{{
- *    def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- *      ...
+ *    // A mapping function that maintains an integer state and return a String
+ *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ *      // Use state.exists(), state.get(), state.update() and state.remove()
+ *      // to manage state, and return the necessary string
  *    }
  *
- *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
- *        StateSpec.function(trackingFunction).numPartitions(10))
+ *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
+ *
+ *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
  * }}}
  *
  * Example in Java:
  * {{{
- *    StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- *      StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- *                    .numPartition(10);
+ *   // A mapping function that maintains an integer state and return a string
+ *   Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ *           @Override
+ *           public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ *               // Use state.exists(), state.get(), state.update() and state.remove()
+ *               // to manage state, and return the necessary string
+ *           }
+ *       };
  *
- *    JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- *      javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
- * }}}
+ *    JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ *        keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
  */
 @Experimental
 object StateSpec {
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
-   * of the `trackStateByKey` operation on a
+   * of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
    *
-   * @param trackingFunction The function applied on every data item to manage the associated state
-   *                         and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated state
+   *                         and generate the mapped data
    * @tparam KeyType      Class of the keys
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](
-      trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
-    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
-    new StateSpecImpl(trackingFunction)
+  def function[KeyType, ValueType, StateType, MappedType](
+      mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType]
+    ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+    new StateSpecImpl(mappingFunction)
   }
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
-   * of the `trackStateByKey` operation on a
+   * of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
    *
-   * @param trackingFunction The function applied on every data item to manage the associated state
-   *                         and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated state
+   *                         and generate the mapped data
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](
-      trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
-    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+  def function[KeyType, ValueType, StateType, MappedType](
+      mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType
+    ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    ClosureCleaner.clean(mappingFunction, checkSerializable = true)
     val wrappedFunction =
-      (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
-        Some(trackingFunction(value, state))
+      (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => {
+        Some(mappingFunction(key, value, state))
       }
     new StateSpecImpl(wrappedFunction)
   }
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
-   * the specifications of the `trackStateByKey` operation on a
+   * the specifications of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
    *
-   * @param javaTrackingFunction The function applied on every data item to manage the associated
-   *                             state and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated
+   *                        state and generate the mapped data
    * @tparam KeyType      Class of the keys
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
-      JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
-    StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
-      val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+  def function[KeyType, ValueType, StateType, MappedType](mappingFunction:
+      JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]):
+    StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+      val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
       Option(t.orNull)
     }
-    StateSpec.function(trackingFunc)
+    StateSpec.function(wrappedFunc)
   }
 
   /**
    * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
-   * of the `trackStateByKey` operation on a
+   * of the `mapWithState` operation on a
    * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
    *
-   * @param javaTrackingFunction The function applied on every data item to manage the associated
-   *                             state and generate the emitted data
+   * @param mappingFunction The function applied on every data item to manage the associated
+   *                        state and generate the mapped data
    * @tparam ValueType    Class of the values
    * @tparam StateType    Class of the states data
-   * @tparam EmittedType  Class of the emitted data
+   * @tparam MappedType   Class of the mapped data
    */
-  def function[KeyType, ValueType, StateType, EmittedType](
-      javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
-    StateSpec[KeyType, ValueType, StateType, EmittedType] = {
-    val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
-      javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+  def function[KeyType, ValueType, StateType, MappedType](
+      mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]):
+    StateSpec[KeyType, ValueType, StateType, MappedType] = {
+    val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+      mappingFunction.call(k, Optional.fromNullable(v.get), s)
     }
-    StateSpec.function(trackingFunc)
+    StateSpec.function(wrappedFunc)
   }
 }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
similarity index 66%
rename from streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
index f459930d0660b782bde4062a0a69f44b82f665ac..16c0d6fff8229699d99853ac9b7cc07cb15c1f7e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
@@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.streaming.dstream.TrackStateDStream
+import org.apache.spark.streaming.dstream.MapWithStateDStream
 
 /**
  * :: Experimental ::
- * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
+ * DStream representing the stream of data generated by `mapWithState` operation on a
+ * [[JavaPairDStream]]. Additionally, it also gives access to the
  * stream of state snapshots, that is, the state data of all keys after a batch has updated them.
  *
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
- * @tparam StateType Class of the state
- * @tparam EmittedType Class of the emitted records
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped data
  */
 @Experimental
-class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
-    dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
-  extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming](
+    dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType])
+  extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) {
 
   def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
     new JavaPairDStream(dstream.stateSnapshots())(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 70e32b383e4580de0d91e7808c581230ec788af3..42ddd63f0f06c30963b512415c35298b180632f3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
 
   /**
    * :: Experimental ::
-   * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
-   * with a continuously updated per-key state. The user-provided state tracking function is
-   * applied on each keyed data item along with its corresponding state. The function can choose to
-   * update/remove the state and return a transformed data, which forms the
-   * [[JavaTrackStateDStream]].
+   * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of
+   * `this` stream, while maintaining some state data for each unique key. The mapping function
+   * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+   * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+   * as a parameter of type [[State]] in the mapping function.
    *
-   * The specifications of this transformation is made through the
-   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
-   * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
-   * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
-   *
-   * Example of using `trackStateByKey`:
+   * Example of using `mapWithState`:
    * {{{
-   *   // A tracking function that maintains an integer state and return a String
-   *   Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
-   *       new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
-   *
-   *         @Override
-   *         public Optional<String> call(Optional<Integer> one, State<Integer> state) {
-   *           // Check if state exists, accordingly update/remove state and return transformed data
-   *         }
+   *   // A mapping function that maintains an integer state and return a string
+   *   Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+   *       new Function3<String, Optional<Integer>, State<Integer>, String>() {
+   *           @Override
+   *           public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+   *               // Use state.exists(), state.get(), state.update() and state.remove()
+   *               // to manage state, and return the necessary string
+   *           }
    *       };
    *
-   *    JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
-   *        keyValueDStream.<Integer, String>trackStateByKey(
-   *                 StateSpec.function(trackStateFunc).numPartitions(10));
-   * }}}
+   *    JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+   *        keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+   *}}}
    *
    * @param spec          Specification of this transformation
-   * @tparam StateType    Class type of the state
-   * @tparam EmittedType  Class type of the tranformed data return by the tracking function
+   * @tparam StateType    Class type of the state data
+   * @tparam MappedType   Class type of the mapped data
    */
   @Experimental
-  def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
-    JavaTrackStateDStream[K, V, StateType, EmittedType] = {
-    new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+  def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]):
+    JavaMapWithStateDStream[K, V, StateType, MappedType] = {
+    new JavaMapWithStateDStream(dstream.mapWithState(spec)(
       JavaSparkContext.fakeClassTag,
       JavaSparkContext.fakeClassTag))
   }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
similarity index 72%
rename from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
index ea6213420e7abfa0ff7f1a48fe25b17d2dd91868..706465d4e25d76307e71aa1790a90e8c073d61e6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
@@ -24,53 +24,52 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.{EmptyRDD, RDD}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming._
-import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
-import org.apache.spark.streaming.dstream.InternalTrackStateDStream._
+import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord}
+import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._
 
 /**
  * :: Experimental ::
- * DStream representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a
+ * DStream representing the stream of data generated by `mapWithState` operation on a
  * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
  * Additionally, it also gives access to the stream of state snapshots, that is, the state data of
  * all keys after a batch has updated them.
  *
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
+ * @tparam KeyType Class of the key
+ * @tparam ValueType Class of the value
  * @tparam StateType Class of the state data
- * @tparam EmittedType Class of the emitted records
+ * @tparam MappedType Class of the mapped data
  */
 @Experimental
-sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
-    ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag](
+    ssc: StreamingContext) extends DStream[MappedType](ssc) {
 
   /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
   def stateSnapshots(): DStream[(KeyType, StateType)]
 }
 
-/** Internal implementation of the [[TrackStateDStream]] */
-private[streaming] class TrackStateDStreamImpl[
-    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
+/** Internal implementation of the [[MapWithStateDStream]] */
+private[streaming] class MapWithStateDStreamImpl[
+    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
     dataStream: DStream[(KeyType, ValueType)],
-    spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
-  extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
+    spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
+  extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
 
   private val internalStream =
-    new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
+    new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
 
   override def slideDuration: Duration = internalStream.slideDuration
 
   override def dependencies: List[DStream[_]] = List(internalStream)
 
-  override def compute(validTime: Time): Option[RDD[EmittedType]] = {
-    internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
+  override def compute(validTime: Time): Option[RDD[MappedType]] = {
+    internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
   }
 
   /**
    * Forward the checkpoint interval to the internal DStream that computes the state maps. This
    * to make sure that this DStream does not get checkpointed, only the internal stream.
    */
-  override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
+  override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = {
     internalStream.checkpoint(checkpointInterval)
     this
   }
@@ -87,32 +86,32 @@ private[streaming] class TrackStateDStreamImpl[
 
   def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
 
-  def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+  def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass
 }
 
 /**
  * A DStream that allows per-key state to be maintains, and arbitrary records to be generated
- * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
+ * based on updates to the state. This is the main DStream that implements the `mapWithState`
  * operation on DStreams.
  *
  * @param parent Parent (key, value) stream that is the source
- * @param spec Specifications of the trackStateByKey operation
+ * @param spec Specifications of the mapWithState operation
  * @tparam K   Key type
  * @tparam V   Value type
  * @tparam S   Type of the state maintained
- * @tparam E   Type of the emitted data
+ * @tparam E   Type of the mapped data
  */
 private[streaming]
-class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
     parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
-  extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+  extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) {
 
   persist(StorageLevel.MEMORY_ONLY)
 
   private val partitioner = spec.getPartitioner().getOrElse(
     new HashPartitioner(ssc.sc.defaultParallelism))
 
-  private val trackingFunction = spec.getFunction()
+  private val mappingFunction = spec.getFunction()
 
   override def slideDuration: Duration = parent.slideDuration
 
@@ -130,7 +129,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
   }
 
   /** Method that generates a RDD for the given time */
-  override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
+  override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
     // Get the previous state or create a new empty state RDD
     val prevStateRDD = getOrCompute(validTime - slideDuration) match {
       case Some(rdd) =>
@@ -138,13 +137,13 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
           // If the RDD is not partitioned the right way, let us repartition it using the
           // partition index as the key. This is to ensure that state RDD is always partitioned
           // before creating another state RDD using it
-          TrackStateRDD.createFromRDD[K, V, S, E](
+          MapWithStateRDD.createFromRDD[K, V, S, E](
             rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
         } else {
           rdd
         }
       case None =>
-        TrackStateRDD.createFromPairRDD[K, V, S, E](
+        MapWithStateRDD.createFromPairRDD[K, V, S, E](
           spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
           partitioner,
           validTime
@@ -161,11 +160,11 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
     val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
       (validTime - interval).milliseconds
     }
-    Some(new TrackStateRDD(
-      prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
+    Some(new MapWithStateRDD(
+      prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
   }
 }
 
-private[streaming] object InternalTrackStateDStream {
+private[streaming] object InternalMapWithStateDStream {
   private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2762309134eb191d732f1c1bdc8544f5da702623..a64a1fe93f40dcb2b292e3b2227f1bbaffbd506b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
 
   /**
    * :: Experimental ::
-   * Return a new DStream of data generated by combining the key-value data in `this` stream
-   * with a continuously updated per-key state. The user-provided state tracking function is
-   * applied on each keyed data item along with its corresponding state. The function can choose to
-   * update/remove the state and return a transformed data, which forms the
-   * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+   * Return a [[MapWithStateDStream]] by applying a function to every key-value element of
+   * `this` stream, while maintaining some state data for each unique key. The mapping function
+   * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+   * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+   * as a parameter of type [[State]] in the mapping function.
    *
-   * The specifications of this transformation is made through the
-   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
-   * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
-   * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
-   *
-   * Example of using `trackStateByKey`:
+   * Example of using `mapWithState`:
    * {{{
-   *    def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
-   *      // Check if state exists, accordingly update/remove state and return transformed data
+   *    // A mapping function that maintains an integer state and return a String
+   *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+   *      // Use state.exists(), state.get(), state.update() and state.remove()
+   *      // to manage state, and return the necessary string
    *    }
    *
-   *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+   *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
    *
-   *    val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+   *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
    * }}}
    *
    * @param spec          Specification of this transformation
-   * @tparam StateType    Class type of the state
-   * @tparam EmittedType  Class type of the tranformed data return by the tracking function
+   * @tparam StateType    Class type of the state data
+   * @tparam MappedType   Class type of the mapped data
    */
   @Experimental
-  def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
-      spec: StateSpec[K, V, StateType, EmittedType]
-    ): TrackStateDStream[K, V, StateType, EmittedType] = {
-    new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+  def mapWithState[StateType: ClassTag, MappedType: ClassTag](
+      spec: StateSpec[K, V, StateType, MappedType]
+    ): MapWithStateDStream[K, V, StateType, MappedType] = {
+    new MapWithStateDStreamImpl[K, V, StateType, MappedType](
       self,
-      spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+      spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
     )
   }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
similarity index 64%
rename from streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
index 30aafcf1460e3acd305c0eddade52e199f29fcbc..ed95171f73ee117be341d7df6b0978e74c67b783 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -29,60 +29,60 @@ import org.apache.spark.util.Utils
 import org.apache.spark._
 
 /**
- * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
- * sequence of records returned by the tracking function of `trackStateByKey`.
+ * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the mapping function of `mapWithState`.
  */
-private[streaming] case class TrackStateRDDRecord[K, S, E](
-    var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+private[streaming] case class MapWithStateRDDRecord[K, S, E](
+    var stateMap: StateMap[K, S], var mappedData: Seq[E])
 
-private[streaming] object TrackStateRDDRecord {
+private[streaming] object MapWithStateRDDRecord {
   def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-    prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+    prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
     dataIterator: Iterator[(K, V)],
-    updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
     batchTime: Time,
     timeoutThresholdTime: Option[Long],
     removeTimedoutData: Boolean
-  ): TrackStateRDDRecord[K, S, E] = {
+  ): MapWithStateRDDRecord[K, S, E] = {
     // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
     val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
 
-    val emittedRecords = new ArrayBuffer[E]
+    val mappedData = new ArrayBuffer[E]
     val wrappedState = new StateImpl[S]()
 
-    // Call the tracking function on each record in the data iterator, and accordingly
-    // update the states touched, and collect the data returned by the tracking function
+    // Call the mapping function on each record in the data iterator, and accordingly
+    // update the states touched, and collect the data returned by the mapping function
     dataIterator.foreach { case (key, value) =>
       wrappedState.wrap(newStateMap.get(key))
-      val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
+      val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
       if (wrappedState.isRemoved) {
         newStateMap.remove(key)
       } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
         newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
       }
-      emittedRecords ++= emittedRecord
+      mappedData ++= returned
     }
 
-    // Get the timed out state records, call the tracking function on each and collect the
+    // Get the timed out state records, call the mapping function on each and collect the
     // data returned
     if (removeTimedoutData && timeoutThresholdTime.isDefined) {
       newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
         wrappedState.wrapTiminoutState(state)
-        val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
-        emittedRecords ++= emittedRecord
+        val returned = mappingFunction(batchTime, key, None, wrappedState)
+        mappedData ++= returned
         newStateMap.remove(key)
       }
     }
 
-    TrackStateRDDRecord(newStateMap, emittedRecords)
+    MapWithStateRDDRecord(newStateMap, mappedData)
   }
 }
 
 /**
- * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
+ * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state
  * RDD, and a partitioned keyed-data RDD
  */
-private[streaming] class TrackStateRDDPartition(
+private[streaming] class MapWithStateRDDPartition(
     idx: Int,
     @transient private var prevStateRDD: RDD[_],
     @transient private var partitionedDataRDD: RDD[_]) extends Partition {
@@ -104,27 +104,28 @@ private[streaming] class TrackStateRDDPartition(
 
 
 /**
- * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
- * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
- * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
- * function of  `trackStateByKey`.
- * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
+ * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data.
+ * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping
+ * function of  `mapWithState`.
+ * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD
+  *                    will be created
  * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
  *                           in the `prevStateRDD` to create `this` RDD
- * @param trackingFunction The function that will be used to update state and return new data
+ * @param mappingFunction  The function that will be used to update state and return new data
  * @param batchTime        The time of the batch to which this RDD belongs to. Use to update
  * @param timeoutThresholdTime The time to indicate which keys are timeout
  */
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
-    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
+private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
     private var partitionedDataRDD: RDD[(K, V)],
-    trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
+    mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
     batchTime: Time,
     timeoutThresholdTime: Option[Long]
-  ) extends RDD[TrackStateRDDRecord[K, S, E]](
+  ) extends RDD[MapWithStateRDDRecord[K, S, E]](
     partitionedDataRDD.sparkContext,
     List(
-      new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
+      new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
       new OneToOneDependency(partitionedDataRDD))
   ) {
 
@@ -141,19 +142,19 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
   }
 
   override def compute(
-      partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
+      partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
 
-    val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+    val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
     val prevStateRDDIterator = prevStateRDD.iterator(
       stateRDDPartition.previousSessionRDDPartition, context)
     val dataIterator = partitionedDataRDD.iterator(
       stateRDDPartition.partitionedDataRDDPartition, context)
 
     val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
-    val newRecord = TrackStateRDDRecord.updateRecordWithData(
+    val newRecord = MapWithStateRDDRecord.updateRecordWithData(
       prevRecord,
       dataIterator,
-      trackingFunction,
+      mappingFunction,
       batchTime,
       timeoutThresholdTime,
       removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
@@ -163,7 +164,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
 
   override protected def getPartitions: Array[Partition] = {
     Array.tabulate(prevStateRDD.partitions.length) { i =>
-      new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+      new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
   }
 
   override def clearDependencies(): Unit = {
@@ -177,52 +178,46 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
   }
 }
 
-private[streaming] object TrackStateRDD {
+private[streaming] object MapWithStateRDD {
 
   def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
       pairRDD: RDD[(K, S)],
       partitioner: Partitioner,
-      updateTime: Time): TrackStateRDD[K, V, S, E] = {
+      updateTime: Time): MapWithStateRDD[K, V, S, E] = {
 
-    val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+    val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
       val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
       iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
-      Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+      Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
     }, preservesPartitioning = true)
 
     val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
 
     val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
 
-    new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+    new MapWithStateRDD[K, V, S, E](
+      stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
   }
 
   def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
       rdd: RDD[(K, S, Long)],
       partitioner: Partitioner,
-      updateTime: Time): TrackStateRDD[K, V, S, E] = {
+      updateTime: Time): MapWithStateRDD[K, V, S, E] = {
 
     val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
-    val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+    val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
       val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
       iterator.foreach { case (key, (state, updateTime)) =>
         stateMap.put(key, state, updateTime)
       }
-      Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+      Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
     }, preservesPartitioning = true)
 
     val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
 
     val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
 
-    new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
-  }
-}
-
-private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-    parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
-  override protected def getPartitions: Array[Partition] = parent.partitions
-  override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
-    parent.compute(partition, context).flatMap { _.emittedRecords }
+    new MapWithStateRDD[K, V, S, E](
+      stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
   }
 }
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
similarity index 80%
rename from streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
rename to streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
index 89d0bb7b617e46b31f2f1a012dea5ea391cb2845..bc4bc2eb42231d5bcd5eb4d28b656f1189829c3b 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
@@ -37,12 +37,12 @@ import org.junit.Test;
 
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.Function3;
 import org.apache.spark.api.java.function.Function4;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
 
-public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable {
+public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable {
 
   /**
    * This test is only for testing the APIs. It's not necessary to run it.
@@ -52,7 +52,7 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
     JavaPairDStream<String, Integer> wordsDstream = null;
 
     final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
-        trackStateFunc =
+        mappingFunc =
         new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
 
           @Override
@@ -68,21 +68,21 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
           }
         };
 
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
-        wordsDstream.trackStateByKey(
-            StateSpec.function(trackStateFunc)
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
+        wordsDstream.mapWithState(
+            StateSpec.function(mappingFunc)
                 .initialState(initialRDD)
                 .numPartitions(10)
                 .partitioner(new HashPartitioner(10))
                 .timeout(Durations.seconds(10)));
 
-    JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
+    JavaPairDStream<String, Boolean> stateSnapshots = stateDstream.stateSnapshots();
 
-    final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 =
-        new Function2<Optional<Integer>, State<Boolean>, Double>() {
+    final Function3<String, Optional<Integer>, State<Boolean>, Double> mappingFunc2 =
+        new Function3<String, Optional<Integer>, State<Boolean>, Double>() {
 
           @Override
-          public Double call(Optional<Integer> one, State<Boolean> state) {
+          public Double call(String key, Optional<Integer> one, State<Boolean> state) {
             // Use all State's methods here
             state.exists();
             state.get();
@@ -93,15 +93,15 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
           }
         };
 
-    JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
-        wordsDstream.trackStateByKey(
-            StateSpec.<String, Integer, Boolean, Double>function(trackStateFunc2)
+    JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+        wordsDstream.mapWithState(
+            StateSpec.<String, Integer, Boolean, Double>function(mappingFunc2)
                 .initialState(initialRDD)
                 .numPartitions(10)
                 .partitioner(new HashPartitioner(10))
                 .timeout(Durations.seconds(10)));
 
-    JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
+    JavaPairDStream<String, Boolean> stateSnapshots2 = stateDstream2.stateSnapshots();
   }
 
   @Test
@@ -148,11 +148,11 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
             new Tuple2<String, Integer>("c", 1))
     );
 
-    Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
-        new Function2<Optional<Integer>, State<Integer>, Integer>() {
+    Function3<String, Optional<Integer>, State<Integer>, Integer> mappingFunc =
+        new Function3<String, Optional<Integer>, State<Integer>, Integer>() {
 
           @Override
-          public Integer call(Optional<Integer> value, State<Integer> state) throws Exception {
+          public Integer call(String key, Optional<Integer> value, State<Integer> state) throws Exception {
             int sum = value.or(0) + (state.exists() ? state.get() : 0);
             state.update(sum);
             return sum;
@@ -160,29 +160,29 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
         };
     testOperation(
         inputData,
-        StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
+        StateSpec.<String, Integer, Integer, Integer>function(mappingFunc),
         outputData,
         stateData);
   }
 
   private <K, S, T> void testOperation(
       List<List<K>> input,
-      StateSpec<K, Integer, S, T> trackStateSpec,
+      StateSpec<K, Integer, S, T> mapWithStateSpec,
       List<Set<T>> expectedOutputs,
       List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
     int numBatches = expectedOutputs.size();
     JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
-    JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
+    JavaMapWithStateDStream<K, Integer, S, T> mapWithStateDStream =
         JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
           @Override
           public Tuple2<K, Integer> call(K x) throws Exception {
             return new Tuple2<K, Integer>(x, 1);
           }
-        })).trackStateByKey(trackStateSpec);
+        })).mapWithState(mapWithStateSpec);
 
     final List<Set<T>> collectedOutputs =
         Collections.synchronizedList(Lists.<Set<T>>newArrayList());
-    trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
+    mapWithStateDStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
       @Override
       public Void call(JavaRDD<T> rdd) throws Exception {
         collectedOutputs.add(Sets.newHashSet(rdd.collect()));
@@ -191,7 +191,7 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
     });
     final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
         Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
-    trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
+    mapWithStateDStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
       @Override
       public Void call(JavaPairRDD<K, S> rdd) throws Exception {
         collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
similarity index 77%
rename from streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
rename to streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
index 1fc320d31b18b186a3e4d2e21aa1a4732fe60d82..4b08085e09b1f05b487b1fbd8a56b789150f25ca 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
@@ -25,11 +25,11 @@ import scala.reflect.ClassTag
 import org.scalatest.PrivateMethodTester._
 import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
 
-import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl}
 import org.apache.spark.util.{ManualClock, Utils}
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 
-class TrackStateByKeySuite extends SparkFunSuite
+class MapWithStateSuite extends SparkFunSuite
   with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
 
   private var sc: SparkContext = null
@@ -49,7 +49,7 @@ class TrackStateByKeySuite extends SparkFunSuite
   }
 
   override def beforeAll(): Unit = {
-    val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
+    val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite")
     conf.set("spark.streaming.clock", classOf[ManualClock].getName())
     sc = new SparkContext(conf)
   }
@@ -129,7 +129,7 @@ class TrackStateByKeySuite extends SparkFunSuite
     testState(Some(3), shouldBeTimingOut = true)
   }
 
-  test("trackStateByKey - basic operations with simple API") {
+  test("mapWithState - basic operations with simple API") {
     val inputData =
       Seq(
         Seq(),
@@ -164,17 +164,17 @@ class TrackStateByKeySuite extends SparkFunSuite
       )
 
     // state maintains running count, and updated count is returned
-    val trackStateFunc = (value: Option[Int], state: State[Int]) => {
+    val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => {
       val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
       state.update(sum)
       sum
     }
 
     testOperation[String, Int, Int](
-      inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+      inputData, StateSpec.function(mappingFunc), outputData, stateData)
   }
 
-  test("trackStateByKey - basic operations with advanced API") {
+  test("mapWithState - basic operations with advanced API") {
     val inputData =
       Seq(
         Seq(),
@@ -209,65 +209,65 @@ class TrackStateByKeySuite extends SparkFunSuite
       )
 
     // state maintains running count, key string doubled and returned
-    val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
+    val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
       val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
       state.update(sum)
       Some(key * 2)
     }
 
-    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+    testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
   }
 
-  test("trackStateByKey - type inferencing and class tags") {
+  test("mapWithState - type inferencing and class tags") {
 
-    // Simple track state function with value as Int, state as Double and emitted type as Double
-    val simpleFunc = (value: Option[Int], state: State[Double]) => {
+    // Simple track state function with value as Int, state as Double and mapped type as Double
+    val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => {
       0L
     }
 
     // Advanced track state function with key as String, value as Int, state as Double and
-    // emitted type as Double
+    // mapped type as Double
     val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
       Some(0L)
     }
 
-    def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
-      val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
+    def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = {
+      val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]]
       assert(dstreamImpl.keyClass === classOf[String])
       assert(dstreamImpl.valueClass === classOf[Int])
       assert(dstreamImpl.stateClass === classOf[Double])
-      assert(dstreamImpl.emittedClass === classOf[Long])
+      assert(dstreamImpl.mappedClass === classOf[Long])
     }
     val ssc = new StreamingContext(sc, batchDuration)
     val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
 
-    // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
-    val simpleFunctionStateStream1 = inputStream.trackStateByKey(
+    // Defining StateSpec inline with mapWithState and simple function implicitly gets the types
+    val simpleFunctionStateStream1 = inputStream.mapWithState(
       StateSpec.function(simpleFunc).numPartitions(1))
     testTypes(simpleFunctionStateStream1)
 
     // Separately defining StateSpec with simple function requires explicitly specifying types
     val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
-    val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
+    val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec)
     testTypes(simpleFunctionStateStream2)
 
     // Separately defining StateSpec with advanced function implicitly gets the types
     val advFuncSpec1 = StateSpec.function(advancedFunc)
-    val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
+    val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1)
     testTypes(advFunctionStateStream1)
 
-    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
-    val advFunctionStateStream2 = inputStream.trackStateByKey(
+    // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
+    val advFunctionStateStream2 = inputStream.mapWithState(
       StateSpec.function(simpleFunc).numPartitions(1))
     testTypes(advFunctionStateStream2)
 
-    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
+    // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
     val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
-    val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
+    val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2)
     testTypes(advFunctionStateStream3)
   }
 
-  test("trackStateByKey - states as emitted records") {
+  test("mapWithState - states as mapped data") {
     val inputData =
       Seq(
         Seq(),
@@ -301,17 +301,17 @@ class TrackStateByKeySuite extends SparkFunSuite
         Seq(("a", 5), ("b", 3), ("c", 1))
       )
 
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
       val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
       val output = (key, sum)
       state.update(sum)
       Some(output)
     }
 
-    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+    testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
   }
 
-  test("trackStateByKey - initial states, with nothing emitted") {
+  test("mapWithState - initial states, with nothing returned as from mapping function") {
 
     val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
 
@@ -339,18 +339,18 @@ class TrackStateByKeySuite extends SparkFunSuite
         Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
       )
 
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
       val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
       val output = (key, sum)
       state.update(sum)
       None.asInstanceOf[Option[Int]]
     }
 
-    val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
-    testOperation(inputData, trackStateSpec, outputData, stateData)
+    val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState))
+    testOperation(inputData, mapWithStateSpec, outputData, stateData)
   }
 
-  test("trackStateByKey - state removing") {
+  test("mapWithState - state removing") {
     val inputData =
       Seq(
         Seq(),
@@ -388,7 +388,7 @@ class TrackStateByKeySuite extends SparkFunSuite
         Seq()
       )
 
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
       if (state.exists) {
         state.remove()
         Some(key)
@@ -399,10 +399,10 @@ class TrackStateByKeySuite extends SparkFunSuite
     }
 
     testOperation(
-      inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData)
+      inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData)
   }
 
-  test("trackStateByKey - state timing out") {
+  test("mapWithState - state timing out") {
     val inputData =
       Seq(
         Seq("a", "b", "c"),
@@ -413,7 +413,7 @@ class TrackStateByKeySuite extends SparkFunSuite
         Seq("a") // a will not time out
       ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
 
-    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+    val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
       if (value.isDefined) {
         state.update(1)
       }
@@ -425,9 +425,9 @@ class TrackStateByKeySuite extends SparkFunSuite
     }
 
     val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
-      inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
+      inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20)
 
-    // b and c should be emitted once each, when they were marked as expired
+    // b and c should be returned once each, when they were marked as expired
     assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
 
     // States for a, b, c should be defined at one point of time
@@ -439,8 +439,8 @@ class TrackStateByKeySuite extends SparkFunSuite
     assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
   }
 
-  test("trackStateByKey - checkpoint durations") {
-    val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream)
+  test("mapWithState - checkpoint durations") {
+    val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream)
 
     def testCheckpointDuration(
         batchDuration: Duration,
@@ -451,18 +451,18 @@ class TrackStateByKeySuite extends SparkFunSuite
 
       try {
         val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
-        val dummyFunc = (value: Option[Int], state: State[Int]) => 0
-        val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
-        val internalTrackStateStream = trackStateStream invokePrivate privateMethod()
+        val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0
+        val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc))
+        val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod()
 
         explicitCheckpointDuration.foreach { d =>
-          trackStateStream.checkpoint(d)
+          mapWithStateStream.checkpoint(d)
         }
-        trackStateStream.register()
+        mapWithStateStream.register()
         ssc.checkpoint(checkpointDir.toString)
         ssc.start()  // should initialize all the checkpoint durations
-        assert(trackStateStream.checkpointDuration === null)
-        assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
+        assert(mapWithStateStream.checkpointDuration === null)
+        assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration)
       } finally {
         ssc.stop(stopSparkContext = false)
       }
@@ -478,7 +478,7 @@ class TrackStateByKeySuite extends SparkFunSuite
   }
 
 
-  test("trackStateByKey - driver failure recovery") {
+  test("mapWithState - driver failure recovery") {
     val inputData =
       Seq(
         Seq(),
@@ -505,16 +505,16 @@ class TrackStateByKeySuite extends SparkFunSuite
 
       val checkpointDuration = batchDuration * (stateData.size / 2)
 
-      val runningCount = (value: Option[Int], state: State[Int]) => {
+      val runningCount = (key: String, value: Option[Int], state: State[Int]) => {
         state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
         state.get()
       }
 
-      val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
+      val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState(
         StateSpec.function(runningCount))
       // Set internval make sure there is one RDD checkpointing
-      trackStateStream.checkpoint(checkpointDuration)
-      trackStateStream.stateSnapshots()
+      mapWithStateStream.checkpoint(checkpointDuration)
+      mapWithStateStream.stateSnapshots()
     }
 
     testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
@@ -523,28 +523,28 @@ class TrackStateByKeySuite extends SparkFunSuite
 
   private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
       input: Seq[Seq[K]],
-      trackStateSpec: StateSpec[K, Int, S, T],
+      mapWithStateSpec: StateSpec[K, Int, S, T],
       expectedOutputs: Seq[Seq[T]],
       expectedStateSnapshots: Seq[Seq[(K, S)]]
     ): Unit = {
     require(expectedOutputs.size == expectedStateSnapshots.size)
 
     val (collectedOutputs, collectedStateSnapshots) =
-      getOperationOutput(input, trackStateSpec, expectedOutputs.size)
+      getOperationOutput(input, mapWithStateSpec, expectedOutputs.size)
     assert(expectedOutputs, collectedOutputs, "outputs")
     assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
   }
 
   private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
       input: Seq[Seq[K]],
-      trackStateSpec: StateSpec[K, Int, S, T],
+      mapWithStateSpec: StateSpec[K, Int, S, T],
       numBatches: Int
     ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
 
     // Setup the stream computation
     val ssc = new StreamingContext(sc, Seconds(1))
     val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
-    val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
+    val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec)
     val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
     val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
     val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
similarity index 76%
rename from streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
rename to streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
index 3b2d43f2ce5816c890aa4bae9c56d521869f9ac9..aa95bd33dda9f381d3ce993ce94f1bfc9e7eb2cc 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
@@ -30,14 +30,14 @@ import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
 import org.apache.spark.streaming.{State, Time}
 import org.apache.spark.util.Utils
 
-class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
+class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
 
   private var sc: SparkContext = null
   private var checkpointDir: File = _
 
   override def beforeAll(): Unit = {
     sc = new SparkContext(
-      new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
+      new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite"))
     checkpointDir = Utils.createTempDir()
     sc.setCheckpointDir(checkpointDir.toString)
   }
@@ -54,7 +54,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
   test("creation from pair RDD") {
     val data = Seq((1, "1"), (2, "2"), (3, "3"))
     val partitioner = new HashPartitioner(10)
-    val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int](
+    val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int](
       sc.parallelize(data), partitioner, Time(123))
     assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
     assert(rdd.partitions.size === partitioner.numPartitions)
@@ -62,7 +62,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
     assert(rdd.partitioner === Some(partitioner))
   }
 
-  test("updating state and generating emitted data in TrackStateRecord") {
+  test("updating state and generating mapped data in MapWithStateRDDRecord") {
 
     val initialTime = 1000L
     val updatedTime = 2000L
@@ -71,7 +71,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
 
     /**
      * Assert that applying given data on a prior record generates correct updated record, with
-     * correct state map and emitted data
+     * correct state map and mapped data
      */
     def assertRecordUpdate(
         initStates: Iterable[Int],
@@ -86,18 +86,18 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
       val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
       initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
       functionCalled = false
-      val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
+      val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
       val dataIterator = data.map { v => ("key", v) }.iterator
       val removedStates = new ArrayBuffer[Int]
       val timingOutStates = new ArrayBuffer[Int]
       /**
-       * Tracking function that updates/removes state based on instructions in the data, and
+       * Mapping function that updates/removes state based on instructions in the data, and
        * return state (when instructed or when state is timing out).
        */
       def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
         functionCalled = true
 
-        assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
+        assert(t.milliseconds === updatedTime, "mapping func called with wrong time")
 
         data match {
           case Some("noop") =>
@@ -120,22 +120,22 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
         }
       }
 
-      val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
+      val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int](
         Some(record), dataIterator, testFunc,
         Time(updatedTime), timeoutThreshold, removeTimedoutData)
 
       val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
       assert(updatedStateData.toSet === expectedStates.toSet,
-        "states do not match after updating the TrackStateRecord")
+        "states do not match after updating the MapWithStateRDDRecord")
 
-      assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
-        "emitted data do not match after updating the TrackStateRecord")
+      assert(updatedRecord.mappedData.toSet === expectedOutput.toSet,
+        "mapped data do not match after updating the MapWithStateRDDRecord")
 
       assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
-        "match those that were expected to do so while updating the TrackStateRecord")
+        "match those that were expected to do so while updating the MapWithStateRDDRecord")
 
       assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
-        "match those that were expected to do so while updating the TrackStateRecord")
+        "match those that were expected to do so while updating the MapWithStateRDDRecord")
 
     }
 
@@ -187,12 +187,12 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
 
   }
 
-  test("states generated by TrackStateRDD") {
+  test("states generated by MapWithStateRDD") {
     val initStates = Seq(("k1", 0), ("k2", 0))
     val initTime = 123
     val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
     val partitioner = new HashPartitioner(2)
-    val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int](
+    val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int](
       sc.parallelize(initStates), partitioner, Time(initTime)).persist()
     assertRDD(initStateRDD, initStateWthTime, Set.empty)
 
@@ -203,21 +203,21 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
      * creates a new state RDD with expected states
      */
     def testStateUpdates(
-        testStateRDD: TrackStateRDD[String, Int, Int, Int],
+        testStateRDD: MapWithStateRDD[String, Int, Int, Int],
         testData: Seq[(String, Int)],
-        expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = {
+        expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = {
 
-      // Persist the test TrackStateRDD so that its not recomputed while doing the next operation.
-      // This is to make sure that we only track which state keys are being touched in the next op.
+      // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation.
+      // This is to make sure that we only touch which state keys are being touched in the next op.
       testStateRDD.persist().count()
 
       // To track which keys are being touched
-      TrackStateRDDSuite.touchedStateKeys.clear()
+      MapWithStateRDDSuite.touchedStateKeys.clear()
 
-      val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
+      val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
 
         // Track the key that has been touched
-        TrackStateRDDSuite.touchedStateKeys += key
+        MapWithStateRDDSuite.touchedStateKeys += key
 
         // If the data is 0, do not do anything with the state
         // else if the data is 1, increment the state if it exists, or set new state to 0
@@ -236,12 +236,12 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
 
       // Assert that the new state RDD has expected state data
       val newStateRDD = assertOperation(
-        testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty)
+        testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty)
 
       // Assert that the function was called only for the keys present in the data
-      assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
+      assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size,
         "More number of keys are being touched than that is expected")
-      assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
+      assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
         "Keys not in the data are being touched unexpectedly")
 
       // Assert that the test RDD's data has not changed
@@ -289,19 +289,19 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
 
   test("checkpointing") {
     /**
-     * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs -
-     * the data RDD and the parent TrackStateRDD.
+     * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs
+     * - the data RDD and the parent MapWithStateRDD.
      */
-    def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]])
+    def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]])
       : Set[(List[(Int, Int, Long)], List[Int])] = {
-      rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) }
+      rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) }
          .collect.toSet
     }
 
-    /** Generate TrackStateRDD with data RDD having a long lineage */
+    /** Generate MapWithStateRDD with data RDD having a long lineage */
     def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
-      : TrackStateRDD[Int, Int, Int, Int] = {
-      TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
+      : MapWithStateRDD[Int, Int, Int, Int] = {
+      MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
     }
 
     testRDD(
@@ -309,15 +309,15 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
     testRDDPartitions(
       makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
 
-    /** Generate TrackStateRDD with parent state RDD having a long lineage */
+    /** Generate MapWithStateRDD with parent state RDD having a long lineage */
     def makeStateRDDWithLongLineageParenttateRDD(
-        longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = {
+        longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = {
 
-      // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage
+      // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage
       val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
 
-      // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent
-      new TrackStateRDD[Int, Int, Int, Int](
+      // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent
+      new MapWithStateRDD[Int, Int, Int, Int](
         stateRDDWithLongLineage,
         stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
         (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
@@ -333,25 +333,25 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
   }
 
   test("checkpointing empty state RDD") {
-    val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
+    val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int](
       sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
     emptyStateRDD.checkpoint()
     assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
-    val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
+    val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]](
       emptyStateRDD.getCheckpointFile.get)
     assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
   }
 
-  /** Assert whether the `trackStateByKey` operation generates expected results */
+  /** Assert whether the `mapWithState` operation generates expected results */
   private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-      testStateRDD: TrackStateRDD[K, V, S, T],
+      testStateRDD: MapWithStateRDD[K, V, S, T],
       newDataRDD: RDD[(K, V)],
-      trackStateFunc: (Time, K, Option[V], State[S]) => Option[T],
+      mappingFunction: (Time, K, Option[V], State[S]) => Option[T],
       currentTime: Long,
       expectedStates: Set[(K, S, Int)],
-      expectedEmittedRecords: Set[T],
+      expectedMappedData: Set[T],
       doFullScan: Boolean = false
-    ): TrackStateRDD[K, V, S, T] = {
+    ): MapWithStateRDD[K, V, S, T] = {
 
     val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
       newDataRDD.partitionBy(testStateRDD.partitioner.get)
@@ -359,31 +359,31 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
       newDataRDD
     }
 
-    val newStateRDD = new TrackStateRDD[K, V, S, T](
-      testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None)
+    val newStateRDD = new MapWithStateRDD[K, V, S, T](
+      testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None)
     if (doFullScan) newStateRDD.setFullScan()
 
     // Persist to make sure that it gets computed only once and we can track precisely how many
     // state keys the computing touched
     newStateRDD.persist().count()
-    assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
+    assertRDD(newStateRDD, expectedStates, expectedMappedData)
     newStateRDD
   }
 
-  /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */
+  /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */
   private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-      trackStateRDD: TrackStateRDD[K, V, S, T],
+      stateRDD: MapWithStateRDD[K, V, S, T],
       expectedStates: Set[(K, S, Int)],
-      expectedEmittedRecords: Set[T]): Unit = {
-    val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
-    val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
+      expectedMappedData: Set[T]): Unit = {
+    val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
+    val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet
     assert(states === expectedStates,
-      "states after track state operation were not as expected")
-    assert(emittedRecords === expectedEmittedRecords,
-      "emitted records after track state operation were not as expected")
+      "states after mapWithState operation were not as expected")
+    assert(mappedData === expectedMappedData,
+      "mapped data after mapWithState operation were not as expected")
   }
 }
 
-object TrackStateRDDSuite {
+object MapWithStateRDDSuite {
   private val touchedStateKeys = new ArrayBuffer[String]()
 }