diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index 02ba1c2eed0f7a46876c0366ea8dd5d8c85a4841..be2ae0b47336372be57b77fa60141233ec1d7d15 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -44,18 +44,6 @@ object StatefulNetworkWordCount {
 
     StreamingExamples.setStreamingLogLevels()
 
-    val updateFunc = (values: Seq[Int], state: Option[Int]) => {
-      val currentCount = values.sum
-
-      val previousCount = state.getOrElse(0)
-
-      Some(currentCount + previousCount)
-    }
-
-    val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
-      iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
-    }
-
     val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
     // Create the context with a 1 second batch size
     val ssc = new StreamingContext(sparkConf, Seconds(1))
@@ -71,9 +59,16 @@ object StatefulNetworkWordCount {
     val wordDstream = words.map(x => (x, 1))
 
     // Update the cumulative count using updateStateByKey
-    // This will give a Dstream made of state (which is the cumulative count of the words)
-    val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
-      new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
+    // This will give a DStream made of state (which is the cumulative count of the words)
+    val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
+      val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (word, sum)
+      state.update(sum)
+      Some(output)
+    }
+
+    val stateDstream = wordDstream.trackStateByKey(
+      StateSpec.function(trackStateFunc).initialState(initialRDD))
     stateDstream.print()
     ssc.start()
     ssc.awaitTermination()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7dd1b72f804997db4aba35d6dbefd983c307c7cf
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.language.implicitConversions
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
+ * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ *
+ * Scala example of using `State`:
+ * {{{
+ *    // A tracking function that maintains an integer state and return a String
+ *    def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ *      // Check if state exists
+ *      if (state.exists) {
+ *        val existingState = state.get  // Get the existing state
+ *        val shouldRemove = ...         // Decide whether to remove the state
+ *        if (shouldRemove) {
+ *          state.remove()     // Remove the state
+ *        } else {
+ *          val newState = ...
+ *          state.update(newState)    // Set the new state
+ *        }
+ *      } else {
+ *        val initialState = ...
+ *        state.update(initialState)  // Set the initial state
+ *      }
+ *      ... // return something
+ *    }
+ *
+ * }}}
+ *
+ * Java example:
+ * {{{
+ *      TODO(@zsxwing)
+ * }}}
+ */
+@Experimental
+sealed abstract class State[S] {
+
+  /** Whether the state already exists */
+  def exists(): Boolean
+
+  /**
+   * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`.
+   * Check with `exists()` whether the state exists or not before calling `get()`.
+   *
+   * @throws java.util.NoSuchElementException If the state does not exist.
+   */
+  def get(): S
+
+  /**
+   * Update the state with a new value.
+   *
+   * State cannot be updated if it has been already removed (that is, `remove()` has already been
+   * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`).
+   *
+   * @throws java.lang.IllegalArgumentException If the state has already been removed, or is
+   *                                            going to be removed
+   */
+  def update(newState: S): Unit
+
+  /**
+   * Remove the state if it exists.
+   *
+   * State cannot be updated if it has been already removed (that is, `remove()` has already been
+   * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`).
+   */
+  def remove(): Unit
+
+  /**
+   * Whether the state is timing out and going to be removed by the system after the current batch.
+   * This timeout can occur if timeout duration has been specified in the
+   * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data
+   * for that timeout duration.
+   */
+  def isTimingOut(): Boolean
+
+  /**
+   * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`.
+   */
+  @inline final def getOption(): Option[S] = if (exists) Some(get()) else None
+
+  @inline final override def toString(): String = {
+    getOption.map { _.toString }.getOrElse("<state not set>")
+  }
+}
+
+/** Internal implementation of the [[State]] interface */
+private[streaming] class StateImpl[S] extends State[S] {
+
+  private var state: S = null.asInstanceOf[S]
+  private var defined: Boolean = false
+  private var timingOut: Boolean = false
+  private var updated: Boolean = false
+  private var removed: Boolean = false
+
+  // ========= Public API =========
+  override def exists(): Boolean = {
+    defined
+  }
+
+  override def get(): S = {
+    if (defined) {
+      state
+    } else {
+      throw new NoSuchElementException("State is not set")
+    }
+  }
+
+  override def update(newState: S): Unit = {
+    require(!removed, "Cannot update the state after it has been removed")
+    require(!timingOut, "Cannot update the state that is timing out")
+    state = newState
+    defined = true
+    updated = true
+  }
+
+  override def isTimingOut(): Boolean = {
+    timingOut
+  }
+
+  override def remove(): Unit = {
+    require(!timingOut, "Cannot remove the state that is timing out")
+    require(!removed, "Cannot remove the state that has already been removed")
+    defined = false
+    updated = false
+    removed = true
+  }
+
+  // ========= Internal API =========
+
+  /** Whether the state has been marked for removing */
+  def isRemoved(): Boolean = {
+    removed
+  }
+
+  /** Whether the state has been been updated */
+  def isUpdated(): Boolean = {
+    updated
+  }
+
+  /**
+   * Update the internal data and flags in `this` to the given state option.
+   * This method allows `this` object to be reused across many state records.
+   */
+  def wrap(optionalState: Option[S]): Unit = {
+    optionalState match {
+      case Some(newState) =>
+        this.state = newState
+        defined = true
+
+      case None =>
+        this.state = null.asInstanceOf[S]
+        defined = false
+    }
+    timingOut = false
+    removed = false
+    updated = false
+  }
+
+  /**
+   * Update the internal data and flags in `this` to the given state that is going to be timed out.
+   * This method allows `this` object to be reused across many state records.
+   */
+  def wrapTiminoutState(newState: S): Unit = {
+    this.state = newState
+    defined = true
+    timingOut = true
+    removed = false
+    updated = false
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c9fe35e74c1c7b01976933ffacc0fffa6053ac0f
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.{HashPartitioner, Partitioner}
+
+
+/**
+ * :: Experimental ::
+ * Abstract class representing all the specifications of the DStream transformation
+ * `trackStateByKey` operation of a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
+ * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of
+ * this class.
+ *
+ * Example in Scala:
+ * {{{
+ *    def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
+ *      ...
+ *    }
+ *
+ *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * }}}
+ *
+ * Example in Java:
+ * {{{
+ *    StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
+ *      StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ *                    .numPartition(10);
+ *
+ *    JavaDStream[EmittedDataType] emittedRecordDStream =
+ *      javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * }}}
+ */
+@Experimental
+sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+
+  /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+  def initialState(rdd: RDD[(KeyType, StateType)]): this.type
+
+  /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+  def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
+
+  /**
+   * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+   * will be partitioned. Hash partitioning will be used.
+   */
+  def numPartitions(numPartitions: Int): this.type
+
+  /**
+   * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+   * be partitioned.
+   */
+  def partitioner(partitioner: Partitioner): this.type
+
+  /**
+   * Set the duration after which the state of an idle key will be removed. A key and its state is
+   * considered idle if it has not received any data for at least the given duration. The state
+   * tracking function will be called one final time on the idle states that are going to be
+   * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
+   * to `true` in that call.
+   */
+  def timeout(idleDuration: Duration): this.type
+}
+
+
+/**
+ * :: Experimental ::
+ * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
+ * that is used for specifying the parameters of the DStream transformation
+ * `trackStateByKey` operation of a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ *
+ * Example in Scala:
+ * {{{
+ *    def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
+ *      ...
+ *    }
+ *
+ *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ *    val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * }}}
+ *
+ * Example in Java:
+ * {{{
+ *    StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
+ *      StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ *                    .numPartition(10);
+ *
+ *    JavaDStream[EmittedDataType] emittedRecordDStream =
+ *      javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * }}}
+ */
+@Experimental
+object StateSpec {
+  /**
+   * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
+   * `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+   * @param trackingFunction The function applied on every data item to manage the associated state
+   *                         and generate the emitted data
+   * @tparam KeyType      Class of the keys
+   * @tparam ValueType    Class of the values
+   * @tparam StateType    Class of the states data
+   * @tparam EmittedType  Class of the emitted data
+   */
+  def function[KeyType, ValueType, StateType, EmittedType](
+      trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
+    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+    new StateSpecImpl(trackingFunction)
+  }
+
+  /**
+   * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
+   * `trackStateByKey` operation on a
+   * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+   * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+   * @param trackingFunction The function applied on every data item to manage the associated state
+   *                         and generate the emitted data
+   * @tparam ValueType    Class of the values
+   * @tparam StateType    Class of the states data
+   * @tparam EmittedType  Class of the emitted data
+   */
+  def function[KeyType, ValueType, StateType, EmittedType](
+      trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
+    ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+    ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+    val wrappedFunction =
+      (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
+        Some(trackingFunction(value, state))
+      }
+    new StateSpecImpl(wrappedFunction)
+  }
+}
+
+
+/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */
+private[streaming]
+case class StateSpecImpl[K, V, S, T](
+    function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] {
+
+  require(function != null)
+
+  @volatile private var partitioner: Partitioner = null
+  @volatile private var initialStateRDD: RDD[(K, S)] = null
+  @volatile private var timeoutInterval: Duration = null
+
+  override def initialState(rdd: RDD[(K, S)]): this.type = {
+    this.initialStateRDD = rdd
+    this
+  }
+
+  override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = {
+    this.initialStateRDD = javaPairRDD.rdd
+    this
+  }
+
+
+  override def numPartitions(numPartitions: Int): this.type = {
+    this.partitioner(new HashPartitioner(numPartitions))
+    this
+  }
+
+  override def partitioner(partitioner: Partitioner): this.type = {
+    this.partitioner = partitioner
+    this
+  }
+
+  override def timeout(interval: Duration): this.type = {
+    this.timeoutInterval = interval
+    this
+  }
+
+  // ================= Private Methods =================
+
+  private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function
+
+  private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD)
+
+  private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner)
+
+  private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval)
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 71bec96d46c8db48305f68f6d9020c844f4178f7..fb691eed27e327b68c22555450e53eed51ead7b2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -24,19 +24,19 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.mapred.{JobConf, OutputFormat}
 import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
 
-import org.apache.spark.{HashPartitioner, Partitioner}
+import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, Time}
 import org.apache.spark.streaming.StreamingContext.rddToFileName
+import org.apache.spark.streaming._
 import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf}
+import org.apache.spark.{HashPartitioner, Partitioner}
 
 /**
  * Extra functions available on DStream of (key, value) pairs through an implicit conversion.
  */
 class PairDStreamFunctions[K, V](self: DStream[(K, V)])
     (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K])
-  extends Serializable
-{
+  extends Serializable {
   private[streaming] def ssc = self.ssc
 
   private[streaming] def sparkContext = self.context.sparkContext
@@ -350,6 +350,44 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
     )
   }
 
+  /**
+   * :: Experimental ::
+   * Return a new DStream of data generated by combining the key-value data in `this` stream
+   * with a continuously updated per-key state. The user-provided state tracking function is
+   * applied on each keyed data item along with its corresponding state. The function can choose to
+   * update/remove the state and return a transformed data, which forms the
+   * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+   *
+   * The specifications of this transformation is made through the
+   * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
+   * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
+   * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
+   *
+   * Example of using `trackStateByKey`:
+   * {{{
+   *    def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
+   *      // Check if state exists, accordingly update/remove state and return transformed data
+   *    }
+   *
+   *    val spec = StateSpec.function(trackingFunction).numPartitions(10)
+   *
+   *    val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+   * }}}
+   *
+   * @param spec          Specification of this transformation
+   * @tparam StateType    Class type of the state
+   * @tparam EmittedType  Class type of the tranformed data return by the tracking function
+   */
+  @Experimental
+  def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
+      spec: StateSpec[K, V, StateType, EmittedType]
+    ): TrackStateDStream[K, V, StateType, EmittedType] = {
+    new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+      self,
+      spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+    )
+  }
+
   /**
    * Return a new "state" DStream where the state for each key is updated by applying
    * the given function on the previous state of the key and the new values of each key.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
new file mode 100644
index 0000000000000000000000000000000000000000..58d89c93bcbefecbd07c6af5d2134c4242cc5a26
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.dstream
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
+
+/**
+ * :: Experimental ::
+ * DStream representing the stream of records emitted by the tracking function in the
+ * `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ * Additionally, it also gives access to the stream of state snapshots, that is, the state data of
+ * all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam StateType Class of the state data
+ * @tparam EmittedType Class of the emitted records
+ */
+@Experimental
+sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
+    ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+
+  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
+  def stateSnapshots(): DStream[(KeyType, StateType)]
+}
+
+/** Internal implementation of the [[TrackStateDStream]] */
+private[streaming] class TrackStateDStreamImpl[
+    KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
+    dataStream: DStream[(KeyType, ValueType)],
+    spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
+  extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
+
+  private val internalStream =
+    new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
+
+  override def slideDuration: Duration = internalStream.slideDuration
+
+  override def dependencies: List[DStream[_]] = List(internalStream)
+
+  override def compute(validTime: Time): Option[RDD[EmittedType]] = {
+    internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
+  }
+
+  /**
+   * Forward the checkpoint interval to the internal DStream that computes the state maps. This
+   * to make sure that this DStream does not get checkpointed, only the internal stream.
+   */
+  override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
+    internalStream.checkpoint(checkpointInterval)
+    this
+  }
+
+  /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
+  def stateSnapshots(): DStream[(KeyType, StateType)] = {
+    internalStream.flatMap {
+      _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
+  }
+
+  def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass
+
+  def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass
+
+  def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
+
+  def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+}
+
+/**
+ * A DStream that allows per-key state to be maintains, and arbitrary records to be generated
+ * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
+ * operation on DStreams.
+ *
+ * @param parent Parent (key, value) stream that is the source
+ * @param spec Specifications of the trackStateByKey operation
+ * @tparam K   Key type
+ * @tparam V   Value type
+ * @tparam S   Type of the state maintained
+ * @tparam E   Type of the emitted data
+ */
+private[streaming]
+class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
+  extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+
+  persist(StorageLevel.MEMORY_ONLY)
+
+  private val partitioner = spec.getPartitioner().getOrElse(
+    new HashPartitioner(ssc.sc.defaultParallelism))
+
+  private val trackingFunction = spec.getFunction()
+
+  override def slideDuration: Duration = parent.slideDuration
+
+  override def dependencies: List[DStream[_]] = List(parent)
+
+  /** Enable automatic checkpointing */
+  override val mustCheckpoint = true
+
+  /** Method that generates a RDD for the given time */
+  override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
+    // Get the previous state or create a new empty state RDD
+    val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
+      TrackStateRDD.createFromPairRDD[K, V, S, E](
+        spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
+        partitioner, validTime
+      )
+    }
+
+    // Compute the new state RDD with previous state RDD and partitioned data RDD
+    parent.getOrCompute(validTime).map { dataRDD =>
+      val partitionedDataRDD = dataRDD.partitionBy(partitioner)
+      val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
+        (validTime - interval).milliseconds
+      }
+      new TrackStateRDD(
+        prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
+    }
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ed7cea26d06086a650571aa83f55b44473564f6e
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.streaming.{Time, StateImpl, State}
+import org.apache.spark.streaming.util.{EmptyStateMap, StateMap}
+import org.apache.spark.util.Utils
+import org.apache.spark._
+
+/**
+ * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the tracking function of `trackStateByKey`.
+ */
+private[streaming] case class TrackStateRDDRecord[K, S, T](
+    var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
+
+/**
+ * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
+ * RDD, and a partitioned keyed-data RDD
+ */
+private[streaming] class TrackStateRDDPartition(
+    idx: Int,
+    @transient private var prevStateRDD: RDD[_],
+    @transient private var partitionedDataRDD: RDD[_]) extends Partition {
+
+  private[rdd] var previousSessionRDDPartition: Partition = null
+  private[rdd] var partitionedDataRDDPartition: Partition = null
+
+  override def index: Int = idx
+  override def hashCode(): Int = idx
+
+  @throws(classOf[IOException])
+  private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
+    // Update the reference to parent split at the time of task serialization
+    previousSessionRDDPartition = prevStateRDD.partitions(index)
+    partitionedDataRDDPartition = partitionedDataRDD.partitions(index)
+    oos.defaultWriteObject()
+  }
+}
+
+
+/**
+ * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
+ * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
+ * function of  `trackStateByKey`.
+ * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
+ * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
+ *                           in the `prevStateRDD` to create `this` RDD
+ * @param trackingFunction The function that will be used to update state and return new data
+ * @param batchTime        The time of the batch to which this RDD belongs to. Use to update
+ */
+private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
+    private var partitionedDataRDD: RDD[(K, V)],
+    trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
+    batchTime: Time, timeoutThresholdTime: Option[Long]
+  ) extends RDD[TrackStateRDDRecord[K, S, T]](
+    partitionedDataRDD.sparkContext,
+    List(
+      new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
+      new OneToOneDependency(partitionedDataRDD))
+  ) {
+
+  @volatile private var doFullScan = false
+
+  require(prevStateRDD.partitioner.nonEmpty)
+  require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)
+
+  override val partitioner = prevStateRDD.partitioner
+
+  override def checkpoint(): Unit = {
+    super.checkpoint()
+    doFullScan = true
+  }
+
+  override def compute(
+      partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
+
+    val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+    val prevStateRDDIterator = prevStateRDD.iterator(
+      stateRDDPartition.previousSessionRDDPartition, context)
+    val dataIterator = partitionedDataRDD.iterator(
+      stateRDDPartition.partitionedDataRDDPartition, context)
+
+    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
+    val newStateMap = if (prevStateRDDIterator.hasNext) {
+      prevStateRDDIterator.next().stateMap.copy()
+    } else {
+      new EmptyStateMap[K, S]()
+    }
+
+    val emittedRecords = new ArrayBuffer[T]
+    val wrappedState = new StateImpl[S]()
+
+    // Call the tracking function on each record in the data RDD partition, and accordingly
+    // update the states touched, and the data returned by the tracking function.
+    dataIterator.foreach { case (key, value) =>
+      wrappedState.wrap(newStateMap.get(key))
+      val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
+      if (wrappedState.isRemoved) {
+        newStateMap.remove(key)
+      } else if (wrappedState.isUpdated) {
+        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+      }
+      emittedRecords ++= emittedRecord
+    }
+
+    // If the RDD is expected to be doing a full scan of all the data in the StateMap,
+    // then use this opportunity to filter out those keys that have timed out.
+    // For each of them call the tracking function.
+    if (doFullScan && timeoutThresholdTime.isDefined) {
+      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
+        wrappedState.wrapTiminoutState(state)
+        val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
+        emittedRecords ++= emittedRecord
+        newStateMap.remove(key)
+      }
+    }
+
+    Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
+  }
+
+  override protected def getPartitions: Array[Partition] = {
+    Array.tabulate(prevStateRDD.partitions.length) { i =>
+      new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+  }
+
+  override def clearDependencies(): Unit = {
+    super.clearDependencies()
+    prevStateRDD = null
+    partitionedDataRDD = null
+  }
+
+  def setFullScan(): Unit = {
+    doFullScan = true
+  }
+}
+
+private[streaming] object TrackStateRDD {
+
+  def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      pairRDD: RDD[(K, S)],
+      partitioner: Partitioner,
+      updateTime: Time): TrackStateRDD[K, V, S, T] = {
+
+    val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+      val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+      iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
+      Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
+    }, preservesPartitioning = true)
+
+    val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
+
+    val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
+
+    new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+  }
+}
+
+private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+    parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
+  override protected def getPartitions: Array[Partition] = parent.partitions
+  override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+    parent.compute(partition, context).flatMap { _.emittedRecords }
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ed622ef7bf7007c45b3f02ecf76091b8835785ec
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
+import org.apache.spark.util.collection.OpenHashMap
+
+/** Internal interface for defining the map that keeps track of sessions. */
+private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable {
+
+  /** Get the state for a key if it exists */
+  def get(key: K): Option[S]
+
+  /** Get all the keys and states whose updated time is older than the given threshold time */
+  def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)]
+
+  /** Get all the keys and states in this map. */
+  def getAll(): Iterator[(K, S, Long)]
+
+  /** Add or update state */
+  def put(key: K, state: S, updatedTime: Long): Unit
+
+  /** Remove a key */
+  def remove(key: K): Unit
+
+  /**
+   * Shallow copy `this` map to create a new state map.
+   * Updates to the new map should not mutate `this` map.
+   */
+  def copy(): StateMap[K, S]
+
+  def toDebugString(): String = toString()
+}
+
+/** Companion object for [[StateMap]], with utility methods */
+private[streaming] object StateMap {
+  def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
+
+  def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
+    val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
+      DELTA_CHAIN_LENGTH_THRESHOLD)
+    new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+  }
+}
+
+/** Implementation of StateMap interface representing an empty map */
+private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] {
+  override def put(key: K, session: S, updateTime: Long): Unit = {
+    throw new NotImplementedError("put() should not be called on an EmptyStateMap")
+  }
+  override def get(key: K): Option[S] = None
+  override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty
+  override def getAll(): Iterator[(K, S, Long)] = Iterator.empty
+  override def copy(): StateMap[K, S] = this
+  override def remove(key: K): Unit = { }
+  override def toDebugString(): String = ""
+}
+
+/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
+private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
+    @transient @volatile var parentStateMap: StateMap[K, S],
+    initialCapacity: Int = 64,
+    deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
+  ) extends StateMap[K, S] { self =>
+
+  def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
+    new EmptyStateMap[K, S],
+    initialCapacity = initialCapacity,
+    deltaChainThreshold = deltaChainThreshold)
+
+  def this(deltaChainThreshold: Int) = this(
+    initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+
+  def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
+
+  @transient @volatile private var deltaMap =
+    new OpenHashMap[K, StateInfo[S]](initialCapacity)
+
+  /** Get the session data if it exists */
+  override def get(key: K): Option[S] = {
+    val stateInfo = deltaMap(key)
+    if (stateInfo != null) {
+      if (!stateInfo.deleted) {
+        Some(stateInfo.data)
+      } else {
+        None
+      }
+    } else {
+      parentStateMap.get(key)
+    }
+  }
+
+  /** Get all the keys and states whose updated time is older than the give threshold time */
+  override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = {
+    val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) =>
+      !deltaMap.contains(key)
+    }
+
+    val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) =>
+      !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime
+    }.map { case (key, stateInfo) =>
+      (key, stateInfo.data, stateInfo.updateTime)
+    }
+    oldStates ++ updatedStates
+  }
+
+  /** Get all the keys and states in this map. */
+  override def getAll(): Iterator[(K, S, Long)] = {
+
+    val oldStates = parentStateMap.getAll().filter { case (key, _, _) =>
+      !deltaMap.contains(key)
+    }
+
+    val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) =>
+      (key, stateInfo.data, stateInfo.updateTime)
+    }
+    oldStates ++ updatedStates
+  }
+
+  /** Add or update state */
+  override def put(key: K, state: S, updateTime: Long): Unit = {
+    val stateInfo = deltaMap(key)
+    if (stateInfo != null) {
+      stateInfo.update(state, updateTime)
+    } else {
+      deltaMap.update(key, new StateInfo(state, updateTime))
+    }
+  }
+
+  /** Remove a state */
+  override def remove(key: K): Unit = {
+    val stateInfo = deltaMap(key)
+    if (stateInfo != null) {
+      stateInfo.markDeleted()
+    } else {
+      val newInfo = new StateInfo[S](deleted = true)
+      deltaMap.update(key, newInfo)
+    }
+  }
+
+  /**
+   * Shallow copy the map to create a new session store. Updates to the new map
+   * should not mutate `this` map.
+   */
+  override def copy(): StateMap[K, S] = {
+    new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold)
+  }
+
+  /** Whether the delta chain lenght is long enough that it should be compacted */
+  def shouldCompact: Boolean = {
+    deltaChainLength >= deltaChainThreshold
+  }
+
+  /** Length of the delta chains of this map */
+  def deltaChainLength: Int = parentStateMap match {
+    case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1
+    case _ => 0
+  }
+
+  /**
+   * Approximate number of keys in the map. This is an overestimation that is mainly used to
+   * reserve capacity in a new map at delta compaction time.
+   */
+  def approxSize: Int = deltaMap.size + {
+    parentStateMap match {
+      case s: OpenHashMapBasedStateMap[_, _] => s.approxSize
+      case _ => 0
+    }
+  }
+
+  /** Get all the data of this map as string formatted as a tree based on the delta depth */
+  override def toDebugString(): String = {
+    val tabs = if (deltaChainLength > 0) {
+      ("    " * (deltaChainLength - 1)) + "+--- "
+    } else ""
+    parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "")
+  }
+
+  override def toString(): String = {
+    s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]"
+  }
+
+  /**
+   * Serialize the map data. Besides serialization, this method actually compact the deltas
+   * (if needed) in a single pass over all the data in the map.
+   */
+
+  private def writeObject(outputStream: ObjectOutputStream): Unit = {
+    // Write all the non-transient fields, especially class tags, etc.
+    outputStream.defaultWriteObject()
+
+    // Write the data in the delta of this state map
+    outputStream.writeInt(deltaMap.size)
+    val deltaMapIterator = deltaMap.iterator
+    var deltaMapCount = 0
+    while (deltaMapIterator.hasNext) {
+      deltaMapCount += 1
+      val (key, stateInfo) = deltaMapIterator.next()
+      outputStream.writeObject(key)
+      outputStream.writeObject(stateInfo)
+    }
+    assert(deltaMapCount == deltaMap.size)
+
+    // Write the data in the parent state map while copying the data into a new parent map for
+    // compaction (if needed)
+    val doCompaction = shouldCompact
+    val newParentSessionStore = if (doCompaction) {
+      val initCapacity = if (approxSize > 0) approxSize else 64
+      new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold)
+    } else { null }
+
+    val iterOfActiveSessions = parentStateMap.getAll()
+
+    var parentSessionCount = 0
+
+    // First write the approximate size of the data to be written, so that readObject can
+    // allocate appropriately sized OpenHashMap.
+    outputStream.writeInt(approxSize)
+
+    while(iterOfActiveSessions.hasNext) {
+      parentSessionCount += 1
+
+      val (key, state, updateTime) = iterOfActiveSessions.next()
+      outputStream.writeObject(key)
+      outputStream.writeObject(state)
+      outputStream.writeLong(updateTime)
+
+      if (doCompaction) {
+        newParentSessionStore.deltaMap.update(
+          key, StateInfo(state, updateTime, deleted = false))
+      }
+    }
+
+    // Write the final limit marking object with the correct count of records written.
+    val limiterObj = new LimitMarker(parentSessionCount)
+    outputStream.writeObject(limiterObj)
+    if (doCompaction) {
+      parentStateMap = newParentSessionStore
+    }
+  }
+
+  /** Deserialize the map data. */
+  private def readObject(inputStream: ObjectInputStream): Unit = {
+
+    // Read the non-transient fields, especially class tags, etc.
+    inputStream.defaultReadObject()
+
+    // Read the data of the delta
+    val deltaMapSize = inputStream.readInt()
+    deltaMap = new OpenHashMap[K, StateInfo[S]]()
+    var deltaMapCount = 0
+    while (deltaMapCount < deltaMapSize) {
+      val key = inputStream.readObject().asInstanceOf[K]
+      val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]]
+      deltaMap.update(key, sessionInfo)
+      deltaMapCount += 1
+    }
+
+
+    // Read the data of the parent map. Keep reading records, until the limiter is reached
+    // First read the approximate number of records to expect and allocate properly size
+    // OpenHashMap
+    val parentSessionStoreSizeHint = inputStream.readInt()
+    val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
+      initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+
+    // Read the records until the limit marking object has been reached
+    var parentSessionLoopDone = false
+    while(!parentSessionLoopDone) {
+      val obj = inputStream.readObject()
+      if (obj.isInstanceOf[LimitMarker]) {
+        parentSessionLoopDone = true
+        val expectedCount = obj.asInstanceOf[LimitMarker].num
+        assert(expectedCount == newParentSessionStore.deltaMap.size)
+      } else {
+        val key = obj.asInstanceOf[K]
+        val state = inputStream.readObject().asInstanceOf[S]
+        val updateTime = inputStream.readLong()
+        newParentSessionStore.deltaMap.update(
+          key, StateInfo(state, updateTime, deleted = false))
+      }
+    }
+    parentStateMap = newParentSessionStore
+  }
+}
+
+/**
+ * Companion object of [[OpenHashMapBasedStateMap]] having associated helper
+ * classes and methods
+ */
+private[streaming] object OpenHashMapBasedStateMap {
+
+  /** Internal class to represent the state information */
+  case class StateInfo[S](
+      var data: S = null.asInstanceOf[S],
+      var updateTime: Long = -1,
+      var deleted: Boolean = false) {
+
+    def markDeleted(): Unit = {
+      deleted = true
+    }
+
+    def update(newData: S, newUpdateTime: Long): Unit = {
+      data = newData
+      updateTime = newUpdateTime
+      deleted = false
+    }
+  }
+
+  /**
+   * Internal class to represent a marker the demarkate the the end of all state data in the
+   * serialized bytes.
+   */
+  class LimitMarker(val num: Int) extends Serializable
+
+  val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..48d3b41b66cbfec8c24854ca08d30d067530fa00
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.collection.{immutable, mutable, Map}
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap}
+import org.apache.spark.util.Utils
+
+class StateMapSuite extends SparkFunSuite {
+
+  test("EmptyStateMap") {
+    val map = new EmptyStateMap[Int, Int]
+    intercept[scala.NotImplementedError] {
+      map.put(1, 1, 1)
+    }
+    assert(map.get(1) === None)
+    assert(map.getByTime(10000).isEmpty)
+    assert(map.getAll().isEmpty)
+    map.remove(1)   // no exception
+    assert(map.copy().eq(map))
+  }
+
+  test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") {
+    val map = new OpenHashMapBasedStateMap[Int, Int]()
+
+    map.put(1, 100, 10)
+    assert(map.get(1) === Some(100))
+    assert(map.get(2) === None)
+    assert(map.getByTime(11).toSet === Set((1, 100, 10)))
+    assert(map.getByTime(10).toSet === Set.empty)
+    assert(map.getByTime(9).toSet === Set.empty)
+    assert(map.getAll().toSet === Set((1, 100, 10)))
+
+    map.put(2, 200, 20)
+    assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20)))
+    assert(map.getByTime(11).toSet === Set((1, 100, 10)))
+    assert(map.getByTime(10).toSet === Set.empty)
+    assert(map.getByTime(9).toSet === Set.empty)
+    assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20)))
+
+    map.remove(1)
+    assert(map.get(1) === None)
+    assert(map.getAll().toSet === Set((2, 200, 20)))
+  }
+
+  test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") {
+    val parentMap = new OpenHashMapBasedStateMap[Int, Int]()
+    parentMap.put(1, 100, 1)
+    parentMap.put(2, 200, 2)
+    parentMap.remove(1)
+
+    // Create child map and make changes
+    val map = parentMap.copy()
+    assert(map.get(1) === None)
+    assert(map.get(2) === Some(200))
+    assert(map.getByTime(10).toSet === Set((2, 200, 2)))
+    assert(map.getByTime(2).toSet === Set.empty)
+    assert(map.getAll().toSet === Set((2, 200, 2)))
+
+    // Add new items
+    map.put(3, 300, 3)
+    assert(map.get(3) === Some(300))
+    map.put(4, 400, 4)
+    assert(map.get(4) === Some(400))
+    assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4)))
+    assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3)))
+    assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4)))
+    assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+    // Remove items
+    map.remove(4)
+    assert(map.get(4) === None)       // item added in this map, then removed in this map
+    map.remove(2)
+    assert(map.get(2) === None)       // item removed in parent map, then added in this map
+    assert(map.getAll().toSet === Set((3, 300, 3)))
+    assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+    // Update items
+    map.put(1, 1000, 100)
+    assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map
+    map.put(2, 2000, 200)
+    assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map
+    map.put(3, 3000, 300)
+    assert(map.get(3) === Some(3000)) // item added + updated in this map
+    map.put(4, 4000, 400)
+    assert(map.get(4) === Some(4000)) // item removed + updated in this map
+
+    assert(map.getAll().toSet ===
+      Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400)))
+    assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+    map.remove(2)         // remove item present in parent map, so that its not visible in child map
+
+    // Create child map and see availability of items
+    val childMap = map.copy()
+    assert(childMap.getAll().toSet === map.getAll().toSet)
+    assert(childMap.get(1) === Some(1000))  // item removed in grandparent, but added in parent map
+    assert(childMap.get(2) === None)        // item added in grandparent, but removed in parent map
+    assert(childMap.get(3) === Some(3000))  // item added and updated in parent map
+
+    childMap.put(2, 20000, 200)
+    assert(childMap.get(2) === Some(20000)) // item map
+  }
+
+  test("OpenHashMapBasedStateMap - serializing and deserializing") {
+    val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+    map1.put(1, 100, 1)
+    map1.put(2, 200, 2)
+
+    val map2 = map1.copy()
+    map2.put(3, 300, 3)
+    map2.put(4, 400, 4)
+
+    val map3 = map2.copy()
+    map3.put(3, 600, 3)
+    map3.remove(2)
+
+    // Do not test compaction
+    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+
+    val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
+      Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
+    assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+  }
+
+  test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") {
+    val targetDeltaLength = 10
+    val deltaChainThreshold = 5
+
+    var map = new OpenHashMapBasedStateMap[Int, Int](
+      deltaChainThreshold = deltaChainThreshold)
+
+    // Make large delta chain with length more than deltaChainThreshold
+    for(i <- 1 to targetDeltaLength) {
+      map.put(Random.nextInt(), Random.nextInt(), 1)
+      map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
+    }
+    assert(map.deltaChainLength > deltaChainThreshold)
+    assert(map.shouldCompact === true)
+
+    val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
+      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+    assert(deser_map.deltaChainLength < deltaChainThreshold)
+    assert(deser_map.shouldCompact === false)
+    assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map")
+  }
+
+  test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") {
+    /*
+     * This tests the map using all permutations of sequences operations, across multiple map
+     * copies as well as between copies. It is to ensure complete coverage, though it is
+     * kind of hard to debug this. It is set up as follows.
+     *
+     * - For any key, there can be 2 types of update ops on a state map - put or remove
+     *
+     * - These operations are done on a test map in "sets". After each set, the map is "copied"
+     *   to create a new map, and the next set of operations are done on the new one. This tests
+     *   whether the map data persistes correctly across copies.
+     *
+     * - Within each set, there are a number of operations to test whether the map correctly
+     *   updates and removes data without affecting the parent state map.
+     *
+     * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types
+     *   of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence
+     *   of operations, which we will test with different keys.
+     *
+     * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that
+     * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys.
+     * _______________________________________________
+     * |         |      Set1       |     Set2        |
+     * |         |-----------------|-----------------|
+     * |         |   Op1    Op2   |c|   Op3    Op4   |
+     * |---------|----------------|o|----------------|
+     * | key 0   |   put    put   |p|   put    put   |
+     * | key 1   |   put    put   |y|   put    rem   |
+     * | key 2   |   put    put   | |   rem    put   |
+     * | key 3   |   put    put   |t|   rem    rem   |
+     * | key 4   |   put    rem   |h|   put    put   |
+     * | key 5   |   put    rem   |e|   put    rem   |
+     * | key 6   |   put    rem   | |   rem    put   |
+     * | key 7   |   put    rem   |s|   rem    rem   |
+     * | key 8   |   rem    put   |t|   put    put   |
+     * | key 9   |   rem    put   |a|   put    rem   |
+     * | key 10  |   rem    put   |t|   rem    put   |
+     * | key 11  |   rem    put   |e|   rem    rem   |
+     * | key 12  |   rem    rem   | |   put    put   |
+     * | key 13  |   rem    rem   |m|   put    rem   |
+     * | key 14  |   rem    rem   |a|   rem    put   |
+     * | key 15  |   rem    rem   |p|   rem    rem   |
+     * |_________|________________|_|________________|
+     */
+
+    val numTypeMapOps = 2   // 0 = put a new value, 1 = remove value
+    val numSets = 3
+    val numOpsPerSet = 3    // to test seq of ops like update -> remove -> update in same set
+    val numTotalOps = numOpsPerSet * numSets
+    val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt  // to get all combinations of ops
+
+    val refMap = new mutable.HashMap[Int, (Int, Long)]()
+    var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null
+
+    var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]()
+    var prevSetStateMap: StateMap[Int, Int] = null
+
+    var time = 1L
+
+    for (setId <- 0 until numSets) {
+      for (opInSetId <- 0 until numOpsPerSet) {
+        val opId = setId * numOpsPerSet + opInSetId
+        for (keyId <- 0 until numKeys) {
+          time += 1
+          // Find the operation type that needs to be done
+          // This is similar to finding the nth bit value of a binary number
+          // E.g.  nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2
+          val opCode =
+            (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps
+          opCode match {
+            case 0 =>
+              val value = Random.nextInt()
+              stateMap.put(keyId, value, time)
+              refMap.put(keyId, (value, time))
+            case 1 =>
+              stateMap.remove(keyId)
+              refMap.remove(keyId)
+          }
+        }
+
+        // Test whether the current state map after all key updates is correct
+        assertMap(stateMap, refMap, time, "State map does not match reference map")
+
+        // Test whether the previous map before copy has not changed
+        if (prevSetStateMap != null && prevSetRefMap != null) {
+          assertMap(prevSetStateMap, prevSetRefMap, time,
+            "Parent state map somehow got modified, does not match corresponding reference map")
+        }
+      }
+
+      // Copy the map and remember the previous maps for future tests
+      prevSetStateMap = stateMap
+      prevSetRefMap = refMap.toMap
+      stateMap = stateMap.copy()
+
+      // Assert that the copied map has the same data
+      assertMap(stateMap, prevSetRefMap, time,
+        "State map does not match reference map after copying")
+    }
+    assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
+  }
+
+  // Assert whether all the data and operations on a state map matches that of a reference state map
+  private def assertMap(
+      mapToTest: StateMap[Int, Int],
+      refMapToTestWith: StateMap[Int, Int],
+      time: Long,
+      msg: String): Unit = {
+    withClue(msg) {
+      // Assert all the data is same as the reference map
+      assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet)
+
+      // Assert that get on every key returns the right value
+      for (keyId <- refMapToTestWith.getAll().map { _._1 }) {
+        assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId))
+      }
+
+      // Assert that every time threshold returns the correct data
+      for (t <- 0L to (time + 1)) {
+        assert(mapToTest.getByTime(t).toSet ===  refMapToTestWith.getByTime(t).toSet)
+      }
+    }
+  }
+
+  // Assert whether all the data and operations on a state map matches that of a reference map
+  private def assertMap(
+      mapToTest: StateMap[Int, Int],
+      refMapToTestWith: Map[Int, (Int, Long)],
+      time: Long,
+      msg: String): Unit = {
+    withClue(msg) {
+      // Assert all the data is same as the reference map
+      assert(mapToTest.getAll().toSet ===
+        refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet)
+
+      // Assert that get on every key returns the right value
+      for (keyId <- refMapToTestWith.keys) {
+        assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 })
+      }
+
+      // Assert that every time threshold returns the correct data
+      for (t <- 0L to (time + 1)) {
+        val expectedRecords =
+          refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) }
+        assert(mapToTest.getByTime(t).toSet ===  expectedRecords.toSet)
+      }
+    }
+  }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e3072b44428402f87e70d2f84dcde191446ec5f8
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.reflect.ClassTag
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
+
+  private var sc: SparkContext = null
+  private var ssc: StreamingContext = null
+  private var checkpointDir: File = null
+  private val batchDuration = Seconds(1)
+
+  before {
+    StreamingContext.getActive().foreach {
+      _.stop(stopSparkContext = false)
+    }
+    checkpointDir = Utils.createTempDir("checkpoint")
+
+    ssc = new StreamingContext(sc, batchDuration)
+    ssc.checkpoint(checkpointDir.toString)
+  }
+
+  after {
+    StreamingContext.getActive().foreach {
+      _.stop(stopSparkContext = false)
+    }
+  }
+
+  override def beforeAll(): Unit = {
+    val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
+    conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+    sc = new SparkContext(conf)
+  }
+
+  test("state - get, exists, update, remove, ") {
+    var state: StateImpl[Int] = null
+
+    def testState(
+        expectedData: Option[Int],
+        shouldBeUpdated: Boolean = false,
+        shouldBeRemoved: Boolean = false,
+        shouldBeTimingOut: Boolean = false
+      ): Unit = {
+      if (expectedData.isDefined) {
+        assert(state.exists)
+        assert(state.get() === expectedData.get)
+        assert(state.getOption() === expectedData)
+        assert(state.getOption.getOrElse(-1) === expectedData.get)
+      } else {
+        assert(!state.exists)
+        intercept[NoSuchElementException] {
+          state.get()
+        }
+        assert(state.getOption() === None)
+        assert(state.getOption.getOrElse(-1) === -1)
+      }
+
+      assert(state.isTimingOut() === shouldBeTimingOut)
+      if (shouldBeTimingOut) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+
+      assert(state.isUpdated() === shouldBeUpdated)
+
+      assert(state.isRemoved() === shouldBeRemoved)
+      if (shouldBeRemoved) {
+        intercept[IllegalArgumentException] {
+          state.remove()
+        }
+        intercept[IllegalArgumentException] {
+          state.update(-1)
+        }
+      }
+    }
+
+    state = new StateImpl[Int]()
+    testState(None)
+
+    state.wrap(None)
+    testState(None)
+
+    state.wrap(Some(1))
+    testState(Some(1))
+
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state = new StateImpl[Int]()
+    state.update(2)
+    testState(Some(2), shouldBeUpdated = true)
+
+    state.remove()
+    testState(None, shouldBeRemoved = true)
+
+    state.wrapTiminoutState(3)
+    testState(Some(3), shouldBeTimingOut = true)
+  }
+
+  test("trackStateByKey - basic operations with simple API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(1),
+        Seq(2, 1),
+        Seq(3, 2, 1),
+        Seq(4, 3),
+        Seq(5),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, and updated count is returned
+    val trackStateFunc = (value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      sum
+    }
+
+    testOperation[String, Int, Int](
+      inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+  }
+
+  test("trackStateByKey - basic operations with advanced API") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq("aa"),
+        Seq("aa", "bb"),
+        Seq("aa", "bb", "cc"),
+        Seq("aa", "bb"),
+        Seq("aa"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    // state maintains running count, key string doubled and returned
+    val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      state.update(sum)
+      Some(key * 2)
+    }
+
+    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+  }
+
+  test("trackStateByKey - type inferencing and class tags") {
+
+    // Simple track state function with value as Int, state as Double and emitted type as Double
+    val simpleFunc = (value: Option[Int], state: State[Double]) => {
+      0L
+    }
+
+    // Advanced track state function with key as String, value as Int, state as Double and
+    // emitted type as Double
+    val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
+      Some(0L)
+    }
+
+    def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
+      val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
+      assert(dstreamImpl.keyClass === classOf[String])
+      assert(dstreamImpl.valueClass === classOf[Int])
+      assert(dstreamImpl.stateClass === classOf[Double])
+      assert(dstreamImpl.emittedClass === classOf[Long])
+    }
+
+    val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
+
+    // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
+    val simpleFunctionStateStream1 = inputStream.trackStateByKey(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(simpleFunctionStateStream1)
+
+    // Separately defining StateSpec with simple function requires explicitly specifying types
+    val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
+    val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
+    testTypes(simpleFunctionStateStream2)
+
+    // Separately defining StateSpec with advanced function implicitly gets the types
+    val advFuncSpec1 = StateSpec.function(advancedFunc)
+    val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
+    testTypes(advFunctionStateStream1)
+
+    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
+    val advFunctionStateStream2 = inputStream.trackStateByKey(
+      StateSpec.function(simpleFunc).numPartitions(1))
+    testTypes(advFunctionStateStream2)
+
+    // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
+    val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
+    val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
+    testTypes(advFunctionStateStream3)
+  }
+
+  test("trackStateByKey - states as emitted records") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3)),
+        Seq(("a", 5)),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("a", 2), ("b", 1)),
+        Seq(("a", 3), ("b", 2), ("c", 1)),
+        Seq(("a", 4), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1)),
+        Seq(("a", 5), ("b", 3), ("c", 1))
+      )
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      Some(output)
+    }
+
+    testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+  }
+
+  test("trackStateByKey - initial states, with nothing emitted") {
+
+    val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
+
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
+
+    val stateData =
+      Seq(
+        Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
+        Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
+        Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
+        Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
+        Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
+      )
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+      val output = (key, sum)
+      state.update(sum)
+      None.asInstanceOf[Option[Int]]
+    }
+
+    val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
+    testOperation(inputData, trackStateSpec, outputData, stateData)
+  }
+
+  test("trackStateByKey - state removing") {
+    val inputData =
+      Seq(
+        Seq(),
+        Seq("a"),
+        Seq("a", "b"), // a will be removed
+        Seq("a", "b", "c"), // b will be removed
+        Seq("a", "b", "c"), // a and c will be removed
+        Seq("a", "b"), // b will be removed
+        Seq("a"), // a will be removed
+        Seq()
+      )
+
+    // States that were removed
+    val outputData =
+      Seq(
+        Seq(),
+        Seq(),
+        Seq("a"),
+        Seq("b"),
+        Seq("a", "c"),
+        Seq("b"),
+        Seq("a"),
+        Seq()
+      )
+
+    val stateData =
+      Seq(
+        Seq(),
+        Seq(("a", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1), ("c", 1)),
+        Seq(("b", 1)),
+        Seq(("a", 1)),
+        Seq(),
+        Seq()
+      )
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      if (state.exists) {
+        state.remove()
+        Some(key)
+      } else {
+        state.update(value.get)
+        None
+      }
+    }
+
+    testOperation(
+      inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData)
+  }
+
+  test("trackStateByKey - state timing out") {
+    val inputData =
+      Seq(
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq(), // c will time out
+        Seq(), // b will time out
+        Seq("a") // a will not time out
+      ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
+
+    val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+      if (value.isDefined) {
+        state.update(1)
+      }
+      if (state.isTimingOut) {
+        Some(key)
+      } else {
+        None
+      }
+    }
+
+    val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
+      inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
+
+    // b and c should be emitted once each, when they were marked as expired
+    assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
+
+    // States for a, b, c should be defined at one point of time
+    assert(collectedStateSnapshots.exists {
+      _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
+    })
+
+    // Finally state should be defined only for a
+    assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
+  }
+
+
+  private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      trackStateSpec: StateSpec[K, Int, S, T],
+      expectedOutputs: Seq[Seq[T]],
+      expectedStateSnapshots: Seq[Seq[(K, S)]]
+    ): Unit = {
+    require(expectedOutputs.size == expectedStateSnapshots.size)
+
+    val (collectedOutputs, collectedStateSnapshots) =
+      getOperationOutput(input, trackStateSpec, expectedOutputs.size)
+    assert(expectedOutputs, collectedOutputs, "outputs")
+    assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
+  }
+
+  private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
+      input: Seq[Seq[K]],
+      trackStateSpec: StateSpec[K, Int, S, T],
+      numBatches: Int
+    ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
+
+    // Setup the stream computation
+    val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+    val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
+    val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
+    val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
+    val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
+    val stateSnapshotStream = new TestOutputStream(
+      trackeStateStream.stateSnapshots(), collectedStateSnapshots)
+    outputStream.register()
+    stateSnapshotStream.register()
+
+    val batchCounter = new BatchCounter(ssc)
+    ssc.start()
+
+    val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+    clock.advance(batchDuration.milliseconds * numBatches)
+
+    batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+    (collectedOutputs, collectedStateSnapshots)
+  }
+
+  private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) {
+    val debugString = "\nExpected:\n" + expected.mkString("\n") +
+      "\nCollected:\n" + collected.mkString("\n")
+    assert(expected.size === collected.size,
+      s"number of collected $typ (${collected.size}) different from expected (${expected.size})" +
+        debugString)
+    expected.zip(collected).foreach { case (c, e) =>
+      assert(c.toSet === e.toSet,
+        s"collected $typ is different from expected $debugString"
+      )
+    }
+  }
+}
+
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..fc5f26607ef98f74e4cdfadbbdda8922580d4093
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Time, State}
+import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
+
+class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+  private var sc = new SparkContext(
+    new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
+
+  override def afterAll(): Unit = {
+    sc.stop()
+  }
+
+  test("creation from pair RDD") {
+    val data = Seq((1, "1"), (2, "2"), (3, "3"))
+    val partitioner = new HashPartitioner(10)
+    val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int](
+      sc.parallelize(data), partitioner, Time(123))
+    assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
+    assert(rdd.partitions.size === partitioner.numPartitions)
+
+    assert(rdd.partitioner === Some(partitioner))
+  }
+
+  test("states generated by TrackStateRDD") {
+    val initStates = Seq(("k1", 0), ("k2", 0))
+    val initTime = 123
+    val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
+    val partitioner = new HashPartitioner(2)
+    val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int](
+      sc.parallelize(initStates), partitioner, Time(initTime)).persist()
+    assertRDD(initStateRDD, initStateWthTime, Set.empty)
+
+    val updateTime = 345
+
+    /**
+     * Test that the test state RDD, when operated with new data,
+     * creates a new state RDD with expected states
+     */
+    def testStateUpdates(
+        testStateRDD: TrackStateRDD[String, Int, Int, Int],
+        testData: Seq[(String, Int)],
+        expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = {
+
+      // Persist the test TrackStateRDD so that its not recomputed while doing the next operation.
+      // This is to make sure that we only track which state keys are being touched in the next op.
+      testStateRDD.persist().count()
+
+      // To track which keys are being touched
+      TrackStateRDDSuite.touchedStateKeys.clear()
+
+      val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
+
+        // Track the key that has been touched
+        TrackStateRDDSuite.touchedStateKeys += key
+
+        // If the data is 0, do not do anything with the state
+        // else if the data is 1, increment the state if it exists, or set new state to 0
+        // else if the data is 2, remove the state if it exists
+        data match {
+          case Some(1) =>
+            if (state.exists()) { state.update(state.get + 1) }
+            else state.update(0)
+          case Some(2) =>
+            state.remove()
+          case _ =>
+        }
+        None.asInstanceOf[Option[Int]]  // Do not return anything, not being tested
+      }
+      val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get)
+
+      // Assert that the new state RDD has expected state data
+      val newStateRDD = assertOperation(
+        testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty)
+
+      // Assert that the function was called only for the keys present in the data
+      assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
+        "More number of keys are being touched than that is expected")
+      assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
+        "Keys not in the data are being touched unexpectedly")
+
+      // Assert that the test RDD's data has not changed
+      assertRDD(initStateRDD, initStateWthTime, Set.empty)
+      newStateRDD
+    }
+
+    // Test no-op, no state should change
+    testStateUpdates(initStateRDD, Seq(), initStateWthTime)   // should not scan any state
+    testStateUpdates(
+      initStateRDD, Seq(("k1", 0)), initStateWthTime)         // should not update existing state
+    testStateUpdates(
+      initStateRDD, Seq(("k3", 0)), initStateWthTime)         // should not create new state
+
+    // Test creation of new state
+    val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0
+      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime)))
+
+    val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)),         // should create k4's state as 0
+      Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))
+
+    // Test updating of state
+    val rdd3 = testStateUpdates(
+      initStateRDD, Seq(("k1", 1)),                   // should increment k1's state 0 -> 1
+      Set(("k1", 1, updateTime), ("k2", 0, initTime)))
+
+    val rdd4 = testStateUpdates(rdd3,
+      Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)),  // should update k2, 0 -> 2 and create k3, 0
+      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime)))
+
+    val rdd5 = testStateUpdates(
+      rdd4, Seq(("k3", 1)),                           // should update k3's state 0 -> 2
+      Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime)))
+
+    // Test removing of state
+    val rdd6 = testStateUpdates(                      // should remove k1's state
+      initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime)))
+
+    val rdd7 = testStateUpdates(                      // should remove k2's state
+      rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
+
+    val rdd8 = testStateUpdates(
+      rdd7, Seq(("k3", 2)), Set()                     //
+    )
+  }
+
+  /** Assert whether the `trackStateByKey` operation generates expected results */
+  private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      testStateRDD: TrackStateRDD[K, V, S, T],
+      newDataRDD: RDD[(K, V)],
+      trackStateFunc: (Time, K, Option[V], State[S]) => Option[T],
+      currentTime: Long,
+      expectedStates: Set[(K, S, Int)],
+      expectedEmittedRecords: Set[T],
+      doFullScan: Boolean = false
+    ): TrackStateRDD[K, V, S, T] = {
+
+    val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
+      newDataRDD.partitionBy(testStateRDD.partitioner.get)
+    } else {
+      newDataRDD
+    }
+
+    val newStateRDD = new TrackStateRDD[K, V, S, T](
+      testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None)
+    if (doFullScan) newStateRDD.setFullScan()
+
+    // Persist to make sure that it gets computed only once and we can track precisely how many
+    // state keys the computing touched
+    newStateRDD.persist()
+    assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
+    newStateRDD
+  }
+
+  /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */
+  private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+      trackStateRDD: TrackStateRDD[K, V, S, T],
+      expectedStates: Set[(K, S, Int)],
+      expectedEmittedRecords: Set[T]): Unit = {
+    val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
+    val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
+    assert(states === expectedStates, "states after track state operation were not as expected")
+    assert(emittedRecords === expectedEmittedRecords,
+      "emitted records after track state operation were not as expected")
+  }
+}
+
+object TrackStateRDDSuite {
+  private val touchedStateKeys = new ArrayBuffer[String]()
+}