diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 34287c3e00908952ce3b65718c662233409932b4..3f139ad138c88754179e02b5d36a0a30a938a254 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -59,7 +59,7 @@ private[streaming] object StateMap { def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", DELTA_CHAIN_LENGTH_THRESHOLD) - new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold) + new OpenHashMapBasedStateMap[K, S](deltaChainThreshold) } } @@ -79,7 +79,7 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = 64, + initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD ) extends StateMap[K, S] { self => @@ -89,12 +89,14 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( deltaChainThreshold = deltaChainThreshold) def this(deltaChainThreshold: Int) = this( - initialCapacity = 64, deltaChainThreshold = deltaChainThreshold) + initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) - @transient @volatile private var deltaMap = - new OpenHashMap[K, StateInfo[S]](initialCapacity) + require(initialCapacity >= 1, "Invalid initial capacity") + require(deltaChainThreshold >= 1, "Invalid delta chain threshold") + + @transient @volatile private var deltaMap = new OpenHashMap[K, StateInfo[S]](initialCapacity) /** Get the session data if it exists */ override def get(key: K): Option[S] = { @@ -284,9 +286,10 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the parent map. Keep reading records, until the limiter is reached // First read the approximate number of records to expect and allocate properly size // OpenHashMap - val parentSessionStoreSizeHint = inputStream.readInt() + val parentStateMapSizeHint = inputStream.readInt() + val newStateMapInitialCapacity = math.max(parentStateMapSizeHint, DEFAULT_INITIAL_CAPACITY) val newParentSessionStore = new OpenHashMapBasedStateMap[K, S]( - initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold) + initialCapacity = newStateMapInitialCapacity, deltaChainThreshold) // Read the records until the limit marking object has been reached var parentSessionLoopDone = false @@ -338,4 +341,6 @@ private[streaming] object OpenHashMapBasedStateMap { class LimitMarker(val num: Int) extends Serializable val DELTA_CHAIN_LENGTH_THRESHOLD = 20 + + val DEFAULT_INITIAL_CAPACITY = 64 } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index 48d3b41b66cbfec8c24854ca08d30d067530fa00..c4a01eaea739ed37895c7effa8413d7d9212897f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -122,23 +122,27 @@ class StateMapSuite extends SparkFunSuite { test("OpenHashMapBasedStateMap - serializing and deserializing") { val map1 = new OpenHashMapBasedStateMap[Int, Int]() + testSerialization(map1, "error deserializing and serialized empty map") + map1.put(1, 100, 1) map1.put(2, 200, 2) + testSerialization(map1, "error deserializing and serialized map with data + no delta") val map2 = map1.copy() + // Do not test compaction + assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") + map2.put(3, 300, 3) map2.put(4, 400, 4) + testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") val map3 = map2.copy() + assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) - - // Do not test compaction - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) - - val deser_map3 = Utils.deserialize[StateMap[Int, Int]]( - Utils.serialize(map3), Thread.currentThread().getContextClassLoader) - assertMap(deser_map3, map3, 1, "Deserialized map not same as original map") + testSerialization(map3, "error deserializing and serialized map with 2 delta + new data") } test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") { @@ -156,11 +160,9 @@ class StateMapSuite extends SparkFunSuite { assert(map.deltaChainLength > deltaChainThreshold) assert(map.shouldCompact === true) - val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + val deser_map = testSerialization(map, "Deserialized + compacted map not same as original map") assert(deser_map.deltaChainLength < deltaChainThreshold) assert(deser_map.shouldCompact === false) - assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map") } test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") { @@ -265,6 +267,14 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } + private def testSerialization[MapType <: StateMap[Int, Int]]( + map: MapType, msg: String): MapType = { + val deserMap = Utils.deserialize[MapType]( + Utils.serialize(map), Thread.currentThread().getContextClassLoader) + assertMap(deserMap, map, 1, msg) + deserMap + } + // Assert whether all the data and operations on a state map matches that of a reference state map private def assertMap( mapToTest: StateMap[Int, Int], diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index 0feb3af1abb0fec7ee25b454ba0b0eab3e613a99..3b2d43f2ce5816c890aa4bae9c56d521869f9ac9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -332,6 +332,16 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, rddCollectFunc _) } + test("checkpointing empty state RDD") { + val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int]( + sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0)) + emptyStateRDD.checkpoint() + assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]]( + emptyStateRDD.getCheckpointFile.get) + assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty) + } + /** Assert whether the `trackStateByKey` operation generates expected results */ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( testStateRDD: TrackStateRDD[K, V, S, T],