From f48742683adf8ed18b0d25a724a13c66b3fc12e9 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Sun, 6 May 2012 20:14:40 -0700
Subject: [PATCH] Made caches dataset-aware so that they won't cyclically evict
 partitions from the same dataset.

---
 .../main/scala/spark/BoundedMemoryCache.scala | 51 +++++++++------
 core/src/main/scala/spark/Cache.scala         | 25 +++++---
 core/src/main/scala/spark/CacheTracker.scala  | 62 ++++++++++++-------
 .../main/scala/spark/DiskSpillingCache.scala  | 36 +++++------
 .../main/scala/spark/SerializingCache.scala   |  8 +--
 .../main/scala/spark/SoftReferenceCache.scala |  9 ++-
 .../main/scala/spark/WeakReferenceCache.scala | 14 -----
 .../spark/broadcast/BitTorrentBroadcast.scala |  8 +--
 .../spark/broadcast/ChainedBroadcast.scala    |  8 +--
 .../scala/spark/broadcast/DfsBroadcast.scala  |  6 +-
 .../scala/spark/broadcast/TreeBroadcast.scala |  8 +--
 11 files changed, 132 insertions(+), 103 deletions(-)
 delete mode 100644 core/src/main/scala/spark/WeakReferenceCache.scala

diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index e8e50ac360..c49be803e4 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -14,14 +14,14 @@ class BoundedMemoryCache extends Cache with Logging {
   logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
 
   private var currentBytes = 0L
-  private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true)
+  private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
 
   // An entry in our map; stores a cached object and its size in bytes
   class Entry(val value: Any, val size: Long) {}
 
-  override def get(key: Any): Any = {
+  override def get(datasetId: Any, partition: Int): Any = {
     synchronized {
-      val entry = map.get(key)
+      val entry = map.get((datasetId, partition))
       if (entry != null) {
         entry.value
       } else {
@@ -30,7 +30,8 @@ class BoundedMemoryCache extends Cache with Logging {
     }
   }
 
-  override def put(key: Any, value: Any) {
+  override def put(datasetId: Any, partition: Int, value: Any): Boolean = {
+    val key = (datasetId, partition)
     logInfo("Asked to add key " + key)
     val startTime = System.currentTimeMillis
     val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
@@ -38,11 +39,16 @@ class BoundedMemoryCache extends Cache with Logging {
     logInfo("Estimated size for key %s is %d".format(key, size))
     logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
     synchronized {
-      ensureFreeSpace(size)
-      logInfo("Adding key " + key)
-      map.put(key, new Entry(value, size))
-      currentBytes += size
-      logInfo("Number of entries is now " + map.size)
+      if (ensureFreeSpace(datasetId, size)) {
+        logInfo("Adding key " + key)
+        map.put(key, new Entry(value, size))
+        currentBytes += size
+        logInfo("Number of entries is now " + map.size)
+        return true
+      } else {
+        logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
+        return false
+      }
     }
   }
 
@@ -53,23 +59,32 @@ class BoundedMemoryCache extends Cache with Logging {
   }
 
   /**
-   * Remove least recently used entries from the map until at least space bytes are free. Assumes
+   * Remove least recently used entries from the map until at least space bytes are free, in order
+   * to make space for a partition from the given dataset ID. If this cannot be done without
+   * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
    * that a lock is held on the BoundedMemoryCache.
    */
-  private def ensureFreeSpace(space: Long) {
-    logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format(
-      space, currentBytes, maxBytes))
-    val iter = map.entrySet.iterator
+  private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
+    logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
+      datasetId, space, currentBytes, maxBytes))
+    val iter = map.entrySet.iterator   // Will give entries in LRU order
     while (maxBytes - currentBytes < space && iter.hasNext) {
       val mapEntry = iter.next()
-      dropEntry(mapEntry.getKey, mapEntry.getValue)
+      val (entryDatasetId, entryPartition) = mapEntry.getKey
+      if (entryDatasetId == datasetId) {
+        // Cannot make space without removing part of the same dataset, or a more recently used one
+        return false
+      }
+      reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
       currentBytes -= mapEntry.getValue.size
       iter.remove()
     }
+    return true
   }
 
-  protected def dropEntry(key: Any, entry: Entry) {
-    logInfo("Dropping key %s of size %d to make space".format(key, entry.size))
-    SparkEnv.get.cacheTracker.dropEntry(key)
+  protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
+    logInfo("Dropping key (%s, %d) of size %d to make space".format(
+      datasetId, partition, entry.size))
+    SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
   }
 }
diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala
index 696fff4e5e..263761bb95 100644
--- a/core/src/main/scala/spark/Cache.scala
+++ b/core/src/main/scala/spark/Cache.scala
@@ -1,10 +1,12 @@
 package spark
 
-import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.atomic.AtomicInteger
 
 /**
  * An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
- * both partitions of cached RDDs and broadcast variables on Spark executors.
+ * both partitions of cached RDDs and broadcast variables on Spark executors. Caches are also aware
+ * of which entries are part of the same dataset (for example, partitions in the same RDD). The key
+ * for each value in a cache is a (datasetID, partition) pair.
  *
  * A single Cache instance gets created on each machine and is shared by all caches (i.e. both the 
  * RDD split cache and the broadcast variable cache), to enable global replacement policies. 
@@ -17,19 +19,26 @@ import java.util.concurrent.atomic.AtomicLong
  * keys that are unique across modules.
  */
 abstract class Cache {
-  private val nextKeySpaceId = new AtomicLong(0)
+  private val nextKeySpaceId = new AtomicInteger(0)
   private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
 
   def newKeySpace() = new KeySpace(this, newKeySpaceId())
 
-  def get(key: Any): Any
-  def put(key: Any, value: Any): Unit
+  // Get the value for a given (datasetId, partition), or null if it is not found.
+  def get(datasetId: Any, partition: Int): Any
+
+  // Attempt to put a value in the cache; returns false if this was not successful (e.g. because
+  // the cache replacement policy forbids it).
+  def put(datasetId: Any, partition: Int, value: Any): Boolean
 }
 
 /**
  * A key namespace in a Cache.
  */
-class KeySpace(cache: Cache, id: Long) {
-  def get(key: Any): Any = cache.get((id, key))
-  def put(key: Any, value: Any): Unit = cache.put((id, key), value)
+class KeySpace(cache: Cache, val keySpaceId: Int) {
+  def get(datasetId: Any, partition: Int): Any =
+    cache.get((keySpaceId, datasetId), partition)
+
+  def put(datasetId: Any, partition: Int, value: Any): Boolean =
+    cache.put((keySpaceId, datasetId), partition, value)
 }
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 5b6eed743f..c399748af3 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -106,51 +106,65 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
   
   // Gets or computes an RDD split
   def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
-    val key = (rdd.id, split.index)
-    logInfo("CachedRDD partition key is " + key)
-    val cachedVal = cache.get(key)
+    logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
+    val cachedVal = cache.get(rdd.id, split.index)
     if (cachedVal != null) {
       // Split is in cache, so just return its values
       logInfo("Found partition in cache!")
       return cachedVal.asInstanceOf[Array[T]].iterator
     } else {
       // Mark the split as loading (unless someone else marks it first)
+      val key = (rdd.id, split.index)
       loading.synchronized {
-        if (loading.contains(key)) {
-          while (loading.contains(key)) {
-            try {loading.wait()} catch {case _ =>}
-          }
-          return cache.get(key).asInstanceOf[Array[T]].iterator
-        } else {
-          loading.add(key)
+        while (loading.contains(key)) {
+          // Someone else is loading it; let's wait for them
+          try { loading.wait() } catch { case _ => }
+        }
+        // See whether someone else has successfully loaded it. The main way this would fail
+        // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+        // partition but we didn't want to make space for it. However, that case is unlikely
+        // because it's unlikely that two threads would work on the same RDD partition. One
+        // downside of the current code is that threads wait serially if this does happen.
+        val cachedVal = cache.get(rdd.id, split.index)
+        if (cachedVal != null) {
+          return cachedVal.asInstanceOf[Array[T]].iterator
         }
+        // Nobody's loading it and it's not in the cache; let's load it ourselves
+        loading.add(key)
       }
       // If we got here, we have to load the split
       // Tell the master that we're doing so
       val host = System.getProperty("spark.hostname", Utils.localHostName)
-      val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
       // TODO: fetch any remote copy of the split that may be available
-      // TODO: also register a listener for when it unloads
       logInfo("Computing partition " + split)
-      val array = rdd.compute(split).toArray(m)
-      cache.put(key, array)
-      loading.synchronized {
-        loading.remove(key)
-        loading.notifyAll()
+      var array: Array[T] = null
+      var putSuccessful: Boolean = false
+      try {
+        array = rdd.compute(split).toArray(m)
+        putSuccessful = cache.put(rdd.id, split.index, array)
+      } finally {
+        // Tell other threads that we've finished our attempt to load the key (whether or not
+        // we've actually succeeded to put it in the map)
+        loading.synchronized {
+          loading.remove(key)
+          loading.notifyAll()
+        }
+      }
+      if (putSuccessful) {
+        // Tell the master that we added the entry. Don't return until it replies so it can
+        // properly schedule future tasks that use this RDD.
+        trackerActor !? AddedToCache(rdd.id, split.index, host)
       }
-      future.apply() // Wait for the reply from the cache tracker
       return array.iterator
     }
   }
 
-  // Reports that an entry has been dropped from the cache
-  def dropEntry(key: Any) {
-    key match {
-      case (keySpaceId: Long, (rddId: Int, partition: Int)) =>
+  // Called by the Cache to report that an entry has been dropped from it
+  def dropEntry(datasetId: Any, partition: Int) {
+    datasetId match {
+      case (cache.keySpaceId, rddId: Int) =>
         val host = System.getProperty("spark.hostname", Utils.localHostName)
         trackerActor !! DroppedFromCache(rddId, partition, host)
-      case _ =>
-        logWarning("Unknown key format: %s".format(key))
     }
   }
 
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
index 157e071c7f..e4d0f991aa 100644
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ b/core/src/main/scala/spark/DiskSpillingCache.scala
@@ -9,31 +9,31 @@ import java.util.UUID
 // TODO: cache into a separate directory using Utils.createTempDir
 // TODO: clean up disk cache afterwards
 class DiskSpillingCache extends BoundedMemoryCache {
-  private val diskMap = new LinkedHashMap[Any, File](32, 0.75f, true)
+  private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
 
-  override def get(key: Any): Any = {
+  override def get(datasetId: Any, partition: Int): Any = {
     synchronized {
       val ser = SparkEnv.get.serializer.newInstance()
-      super.get(key) match {
+      super.get(datasetId, partition) match {
         case bytes: Any => // found in memory
           ser.deserialize(bytes.asInstanceOf[Array[Byte]])
 
-        case _ => diskMap.get(key) match {
+        case _ => diskMap.get((datasetId, partition)) match {
           case file: Any => // found on disk
             try {
               val startTime = System.currentTimeMillis
               val bytes = new Array[Byte](file.length.toInt)
               new FileInputStream(file).read(bytes)
               val timeTaken = System.currentTimeMillis - startTime
-              logInfo("Reading key %s of size %d bytes from disk took %d ms".format(
-                key, file.length, timeTaken))
-              super.put(key, bytes)
+              logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
+                datasetId, partition, file.length, timeTaken))
+              super.put(datasetId, partition, bytes)
               ser.deserialize(bytes.asInstanceOf[Array[Byte]])
             } catch {
               case e: IOException =>
-                logWarning("Failed to read key %s from disk at %s: %s".format(
-                  key, file.getPath(), e.getMessage()))
-                diskMap.remove(key) // remove dead entry
+                logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
+                  datasetId, partition, file.getPath(), e.getMessage()))
+                diskMap.remove((datasetId, partition)) // remove dead entry
                 null
             }
 
@@ -44,18 +44,18 @@ class DiskSpillingCache extends BoundedMemoryCache {
     }
   }
 
-  override def put(key: Any, value: Any) {
+  override def put(datasetId: Any, partition: Int, value: Any): Boolean = {
     var ser = SparkEnv.get.serializer.newInstance()
-    super.put(key, ser.serialize(value))
+    super.put(datasetId, partition, ser.serialize(value))
   }
 
   /**
    * Spill the given entry to disk. Assumes that a lock is held on the
    * DiskSpillingCache.  Assumes that entry.value is a byte array.
    */
-  override protected def dropEntry(key: Any, entry: Entry) {
-    logInfo("Spilling key %s of size %d to make space".format(
-      key, entry.size))
+  override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
+    logInfo("Spilling key (%s, %d) of size %d to make space".format(
+      datasetId, partition, entry.size))
     val cacheDir = System.getProperty(
       "spark.diskSpillingCache.cacheDir",
       System.getProperty("java.io.tmpdir"))
@@ -64,11 +64,11 @@ class DiskSpillingCache extends BoundedMemoryCache {
       val stream = new FileOutputStream(file)
       stream.write(entry.value.asInstanceOf[Array[Byte]])
       stream.close()
-      diskMap.put(key, file)
+      diskMap.put((datasetId, partition), file)
     } catch {
       case e: IOException =>
-        logWarning("Failed to spill key %s to disk at %s: %s".format(
-          key, file.getPath(), e.getMessage()))
+        logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
+          datasetId, partition, file.getPath(), e.getMessage()))
         // Do nothing and let the entry be discarded
     }
   }
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
index a74922ec4c..f6964905c7 100644
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ b/core/src/main/scala/spark/SerializingCache.scala
@@ -9,13 +9,13 @@ import java.io._
 class SerializingCache extends Cache with Logging {
   val bmc = new BoundedMemoryCache
 
-  override def put(key: Any, value: Any) {
+  override def put(datasetId: Any, partition: Int, value: Any): Boolean = {
     val ser = SparkEnv.get.serializer.newInstance()
-    bmc.put(key, ser.serialize(value))
+    bmc.put(datasetId, partition, ser.serialize(value))
   }
 
-  override def get(key: Any): Any = {
-    val bytes = bmc.get(key)
+  override def get(datasetId: Any, partition: Int): Any = {
+    val bytes = bmc.get(datasetId, partition)
     if (bytes != null) {
       val ser = SparkEnv.get.serializer.newInstance()
       return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala
index e84aa57efa..c507df928b 100644
--- a/core/src/main/scala/spark/SoftReferenceCache.scala
+++ b/core/src/main/scala/spark/SoftReferenceCache.scala
@@ -8,6 +8,11 @@ import com.google.common.collect.MapMaker
 class SoftReferenceCache extends Cache {
   val map = new MapMaker().softValues().makeMap[Any, Any]()
 
-  override def get(key: Any): Any = map.get(key)
-  override def put(key: Any, value: Any) = map.put(key, value)
+  override def get(datasetId: Any, partition: Int): Any =
+    map.get((datasetId, partition))
+
+  override def put(datasetId: Any, partition: Int, value: Any): Boolean = {
+    map.put((datasetId, partition), value)
+    return true
+  }
 }
diff --git a/core/src/main/scala/spark/WeakReferenceCache.scala b/core/src/main/scala/spark/WeakReferenceCache.scala
deleted file mode 100644
index ddca065454..0000000000
--- a/core/src/main/scala/spark/WeakReferenceCache.scala
+++ /dev/null
@@ -1,14 +0,0 @@
-package spark
-
-import com.google.common.collect.MapMaker
-
-/**
- * An implementation of Cache that uses weak references.
- */
-class WeakReferenceCache extends Cache {
-  val map = new MapMaker().weakValues().makeMap[Any, Any]()
-
-  override def get(key: Any): Any = map.get(key)
-  override def put(key: Any, value: Any) = map.put(key, value)
-}
-
diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
index 6960339bf8..5a873dca3d 100644
--- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala
@@ -16,7 +16,7 @@ extends Broadcast[T] with Logging with Serializable {
   def value = value_
 
   BitTorrentBroadcast.synchronized {
-    BitTorrentBroadcast.values.put(uuid, value_)
+    BitTorrentBroadcast.values.put(uuid, 0, value_)
   }
 
   @transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -130,7 +130,7 @@ extends Broadcast[T] with Logging with Serializable {
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject
     BitTorrentBroadcast.synchronized {
-      val cachedVal = BitTorrentBroadcast.values.get(uuid)
+      val cachedVal = BitTorrentBroadcast.values.get(uuid, 0)
 
       if (cachedVal != null) {
         value_ = cachedVal.asInstanceOf[T]
@@ -152,12 +152,12 @@ extends Broadcast[T] with Logging with Serializable {
         // If does not succeed, then get from HDFS copy
         if (receptionSucceeded) {
           value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-          BitTorrentBroadcast.values.put(uuid, value_)
+          BitTorrentBroadcast.values.put(uuid, 0, value_)
         }  else {
           // TODO: This part won't work, cause HDFS writing is turned OFF
           val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
           value_ = fileIn.readObject.asInstanceOf[T]
-          BitTorrentBroadcast.values.put(uuid, value_)
+          BitTorrentBroadcast.values.put(uuid, 0, value_)
           fileIn.close()
         }
 
diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
index e33ef78e8a..64da650142 100644
--- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala
@@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable {
   def value = value_
 
   ChainedBroadcast.synchronized {
-    ChainedBroadcast.values.put(uuid, value_)
+    ChainedBroadcast.values.put(uuid, 0, value_)
   }
 
   @transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -101,7 +101,7 @@ extends Broadcast[T] with Logging with Serializable {
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject
     ChainedBroadcast.synchronized {
-      val cachedVal = ChainedBroadcast.values.get(uuid)
+      val cachedVal = ChainedBroadcast.values.get(uuid, 0)
       if (cachedVal != null) {
         value_ = cachedVal.asInstanceOf[T]
       } else {
@@ -121,11 +121,11 @@ extends Broadcast[T] with Logging with Serializable {
         // If does not succeed, then get from HDFS copy
         if (receptionSucceeded) {
           value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-          ChainedBroadcast.values.put(uuid, value_)
+          ChainedBroadcast.values.put(uuid, 0, value_)
         }  else {
           val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
           value_ = fileIn.readObject.asInstanceOf[T]
-          ChainedBroadcast.values.put(uuid, value_)
+          ChainedBroadcast.values.put(uuid, 0, value_)
           fileIn.close()
         }
 
diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
index 076f18afac..b053e2b62e 100644
--- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala
@@ -17,7 +17,7 @@ extends Broadcast[T] with Logging with Serializable {
   def value = value_
 
   DfsBroadcast.synchronized { 
-    DfsBroadcast.values.put(uuid, value_) 
+    DfsBroadcast.values.put(uuid, 0, value_) 
   }
 
   if (!isLocal) { 
@@ -34,7 +34,7 @@ extends Broadcast[T] with Logging with Serializable {
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject
     DfsBroadcast.synchronized {
-      val cachedVal = DfsBroadcast.values.get(uuid)
+      val cachedVal = DfsBroadcast.values.get(uuid, 0)
       if (cachedVal != null) {
         value_ = cachedVal.asInstanceOf[T]
       } else {
@@ -43,7 +43,7 @@ extends Broadcast[T] with Logging with Serializable {
         
         val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
         value_ = fileIn.readObject.asInstanceOf[T]
-        DfsBroadcast.values.put(uuid, value_)
+        DfsBroadcast.values.put(uuid, 0, value_)
         fileIn.close
         
         val time = (System.nanoTime - start) / 1e9
diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
index 945d8cd8a4..374389def5 100644
--- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala
@@ -15,7 +15,7 @@ extends Broadcast[T] with Logging with Serializable {
   def value = value_
 
   TreeBroadcast.synchronized {
-    TreeBroadcast.values.put(uuid, value_)
+    TreeBroadcast.values.put(uuid, 0, value_)
   }
 
   @transient var arrayOfBlocks: Array[BroadcastBlock] = null
@@ -104,7 +104,7 @@ extends Broadcast[T] with Logging with Serializable {
   private def readObject(in: ObjectInputStream): Unit = {
     in.defaultReadObject
     TreeBroadcast.synchronized {
-      val cachedVal = TreeBroadcast.values.get(uuid)
+      val cachedVal = TreeBroadcast.values.get(uuid, 0)
       if (cachedVal != null) {
         value_ = cachedVal.asInstanceOf[T]
       } else {
@@ -124,11 +124,11 @@ extends Broadcast[T] with Logging with Serializable {
         // If does not succeed, then get from HDFS copy
         if (receptionSucceeded) {
           value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-          TreeBroadcast.values.put(uuid, value_)
+          TreeBroadcast.values.put(uuid, 0, value_)
         }  else {
           val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid))
           value_ = fileIn.readObject.asInstanceOf[T]
-          TreeBroadcast.values.put(uuid, value_)
+          TreeBroadcast.values.put(uuid, 0, value_)
           fileIn.close()
         }
 
-- 
GitLab