diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 34287c3e00908952ce3b65718c662233409932b4..3f139ad138c88754179e02b5d36a0a30a938a254 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -59,7 +59,7 @@ private[streaming] object StateMap {
   def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
     val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
       DELTA_CHAIN_LENGTH_THRESHOLD)
-    new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+    new OpenHashMapBasedStateMap[K, S](deltaChainThreshold)
   }
 }
 
@@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa
 /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
 private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
     @transient @volatile var parentStateMap: StateMap[K, S],
-    initialCapacity: Int = 64,
+    initialCapacity: Int = DEFAULT_INITIAL_CAPACITY,
     deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
   ) extends StateMap[K, S] { self =>
 
@@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
     deltaChainThreshold = deltaChainThreshold)
 
   def this(deltaChainThreshold: Int) = this(
-    initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+    initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold)
 
   def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
 
-  @transient @volatile private var deltaMap =
-    new OpenHashMap[K, StateInfo[S]](initialCapacity)
+  require(initialCapacity >= 1, "Invalid initial capacity")
+  require(deltaChainThreshold >= 1, "Invalid delta chain threshold")
+
+  @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity)
 
   /** Get the session data if it exists */
   override def get(key: K): Option[S] = {
@@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
     // Read the data of the parent map. Keep reading records, until the limiter is reached
     // First read the approximate number of records to expect and allocate properly size
     // OpenHashMap
-    val parentSessionStoreSizeHint = inputStream.readInt()
+    val parentStateMapSizeHint = inputStream.readInt()
+    val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY)
     val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
-      initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+      initialCapacity = newStateMapInitialCapacity, deltaChainThreshold)
 
     // Read the records until the limit marking object has been reached
     var parentSessionLoopDone = false
@@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap {
   class LimitMarker(val num: Int) extends Serializable
 
   val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+
+  val DEFAULT_INITIAL_CAPACITY = 64
 }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
index 48d3b41b66cbfec8c24854ca08d30d067530fa00..c4a01eaea739ed37895c7effa8413d7d9212897f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite {
 
   test("OpenHashMapBasedStateMap - serializing and deserializing") {
     val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+    testSerialization(map1, "error deserializing and serialized empty map")
+
     map1.put(1, 100, 1)
     map1.put(2, 200, 2)
+    testSerialization(map1, "error deserializing and serialized map with data + no delta")
 
     val map2 = map1.copy()
+    // Do not test compaction
+    assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+    testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data")
+
     map2.put(3, 300, 3)
     map2.put(4, 400, 4)
+    testSerialization(map2, "error deserializing and serialized map with 1 delta + new data")
 
     val map3 = map2.copy()
+    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+    testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data")
     map3.put(3, 600, 3)
     map3.remove(2)
-
-    // Do not test compaction
-    assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
-
-    val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
-      Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
-    assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+    testSerialization(map3, "error deserializing and serialized map with 2 delta + new data")
   }
 
   test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") {
@@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite {
     assert(map.deltaChainLength > deltaChainThreshold)
     assert(map.shouldCompact === true)
 
-    val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
-      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+    val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map")
     assert(deser_map.deltaChainLength < deltaChainThreshold)
     assert(deser_map.shouldCompact === false)
-    assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map")
   }
 
   test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") {
@@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite {
     assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
   }
 
+  private def testSerialization[MapType <: StateMap[Int, Int]](
+    map: MapType, msg: String): MapType = {
+    val deserMap = Utils.deserialize[MapType](
+      Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+    assertMap(deserMap, map, 1, msg)
+    deserMap
+  }
+
   // Assert whether all the data and operations on a state map matches that of a reference state map
   private def assertMap(
       mapToTest: StateMap[Int, Int],
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index 0feb3af1abb0fec7ee25b454ba0b0eab3e613a99..3b2d43f2ce5816c890aa4bae9c56d521869f9ac9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
       makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _)
   }
 
+  test("checkpointing empty state RDD") {
+    val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
+      sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
+    emptyStateRDD.checkpoint()
+    assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+    val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
+      emptyStateRDD.getCheckpointFile.get)
+    assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
+  }
+
   /** Assert whether the `trackStateByKey` operation generates expected results */
   private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
       testStateRDD: TrackStateRDD[K, V, S, T],