diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index 7084ff97d90d28e20785483ca8d450db73e7eced..4c18cb913442b71463bfbea3cb32c00da31d2bf9 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -11,6 +11,7 @@ import scala.xml.{XML,NodeSeq}
 import scala.collection.mutable.ArrayBuffer
 
 import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
 
 object WikipediaPageRankStandalone {
   def main(args: Array[String]) {
@@ -118,23 +119,23 @@ class WPRSerializer extends spark.Serializer {
 }
 
 class WPRSerializerInstance extends SerializerInstance {
-  def serialize[T](t: T): Array[Byte] = {
+  def serialize[T](t: T): ByteBuffer = {
     throw new UnsupportedOperationException()
   }
 
-  def deserialize[T](bytes: Array[Byte]): T = {
+  def deserialize[T](bytes: ByteBuffer): T = {
     throw new UnsupportedOperationException()
   }
 
-  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+  def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
     throw new UnsupportedOperationException()
   }
 
-  def outputStream(s: OutputStream): SerializationStream = {
+  def serializeStream(s: OutputStream): SerializationStream = {
     new WPRSerializationStream(s)
   }
 
-  def inputStream(s: InputStream): DeserializationStream = {
+  def deserializeStream(s: InputStream): DeserializationStream = {
     new WPRDeserializationStream(s)
   }
 }
diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e00a0d80fa25a15e4bf884912613566acba5ab63
--- /dev/null
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -0,0 +1,70 @@
+package spark
+
+import java.io.EOFException
+import java.net.URL
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import spark.storage.BlockException
+import spark.storage.BlockManagerId
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+
+class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
+  def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
+    logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+    val ser = SparkEnv.get.serializer.newInstance()
+    val blockManager = SparkEnv.get.blockManager
+    
+    val startTime = System.currentTimeMillis
+    val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
+    logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
+      shuffleId, reduceId, System.currentTimeMillis - startTime))
+    
+    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
+    for ((address, index) <- addresses.zipWithIndex) {
+      splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
+    }
+
+    val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
+      case (address, splits) =>
+        (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
+    }
+
+    try {
+      val blockOptions = blockManager.get(blocksByAddress)
+      logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format(
+        shuffleId, reduceId, System.currentTimeMillis - startTime))
+      blockOptions.foreach(x => {
+        val (blockId, blockOption) = x 
+        blockOption match {
+          case Some(block) => {
+            val values = block.asInstanceOf[Iterator[Any]]
+            for(value <- values) {
+              val v = value.asInstanceOf[(K, V)]
+              func(v._1, v._2)
+            }
+          }
+          case None => {
+            throw new BlockException(blockId, "Did not get block " + blockId)         
+          }
+        }
+      })
+    } catch {
+      case be: BlockException => {
+        val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r
+        be.blockId match {
+          case regex(sId, mId, rId) => { 
+            val address = addresses(mId.toInt)
+            throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be)
+          }
+          case _ => {
+            throw be
+          }
+        }
+      }
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
index 1162e34ab03340c763e943b696a611ba9cb5d8d8..fa5dcee7bbf0c4cd66a1d2f0bd363799e4c9eaff 100644
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ b/core/src/main/scala/spark/BoundedMemoryCache.scala
@@ -90,7 +90,8 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
 
   protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
     logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
-    SparkEnv.get.cacheTracker.dropEntry(datasetId, partition)
+    // TODO: remove BoundedMemoryCache
+    SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition)
   }
 }
 
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
index 4867829c17ac6519dba55aed2a47af78f54fe85f..64b4af0ae20e327b90abd36df6ea9a33969a64ed 100644
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ b/core/src/main/scala/spark/CacheTracker.scala
@@ -1,11 +1,17 @@
 package spark
 
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
+import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 import scala.collection.mutable.HashSet
 
+import spark.storage.BlockManager
+import spark.storage.StorageLevel
+
 sealed trait CacheTrackerMessage
 case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
   extends CacheTrackerMessage
@@ -18,8 +24,8 @@ case object GetCacheStatus extends CacheTrackerMessage
 case object GetCacheLocations extends CacheTrackerMessage
 case object StopCacheTracker extends CacheTrackerMessage
 
-
-class CacheTrackerActor extends DaemonActor with Logging {
+class CacheTrackerActor extends Actor with Logging {
+  // TODO: Should probably store (String, CacheType) tuples
   private val locs = new HashMap[Int, Array[List[String]]]
 
   /**
@@ -28,109 +34,93 @@ class CacheTrackerActor extends DaemonActor with Logging {
   private val slaveCapacity = new HashMap[String, Long]
   private val slaveUsage = new HashMap[String, Long]
 
-  // TODO: Should probably store (String, CacheType) tuples
-
   private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
   private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
   private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
   
-  def act() {
-    val port = System.getProperty("spark.master.port").toInt
-    RemoteActor.alive(port)
-    RemoteActor.register('CacheTracker, self)
-    logInfo("Registered actor on port " + port)
-    
-    loop {
-      react {
-        case SlaveCacheStarted(host: String, size: Long) =>
-          logInfo("Started slave cache (size %s) on %s".format(
-            Utils.memoryBytesToString(size), host))
-          slaveCapacity.put(host, size)
-          slaveUsage.put(host, 0)
-          reply('OK)
-
-        case RegisterRDD(rddId: Int, numPartitions: Int) =>
-          logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
-          locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
-          reply('OK)
-        
-        case AddedToCache(rddId, partition, host, size) =>
-          if (size > 0) {
-            slaveUsage.put(host, getCacheUsage(host) + size)
-            logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
-              rddId, partition, host, Utils.memoryBytesToString(size),
-              Utils.memoryBytesToString(getCacheAvailable(host))))
-          } else {
-            logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host))
-          }
-          locs(rddId)(partition) = host :: locs(rddId)(partition)
-          reply('OK)
-          
-        case DroppedFromCache(rddId, partition, host, size) =>
-          if (size > 0) {
-            logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
-              rddId, partition, host, Utils.memoryBytesToString(size),
-              Utils.memoryBytesToString(getCacheAvailable(host))))
-            slaveUsage.put(host, getCacheUsage(host) - size)
-
-            // Do a sanity check to make sure usage is greater than 0.
-            val usage = getCacheUsage(host)
-            if (usage < 0) {
-              logError("Cache usage on %s is negative (%d)".format(host, usage))
-            }
-          } else {
-            logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host))
-          }
-          locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
-          reply('OK)
+  def receive = {
+    case SlaveCacheStarted(host: String, size: Long) =>
+      logInfo("Started slave cache (size %s) on %s".format(
+        Utils.memoryBytesToString(size), host))
+      slaveCapacity.put(host, size)
+      slaveUsage.put(host, 0)
+      self.reply(true)
 
-        case MemoryCacheLost(host) =>
-          logInfo("Memory cache lost on " + host)
-          // TODO: Drop host from the memory locations list of all RDDs
-        
-        case GetCacheLocations =>
-          logInfo("Asked for current cache locations")
-          reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
-
-        case GetCacheStatus =>
-          val status = slaveCapacity.map { case (host,capacity) =>
-            (host, capacity, getCacheUsage(host))
-          }.toSeq
-          reply(status)
-
-        case StopCacheTracker =>
-          reply('OK)
-          exit()
+    case RegisterRDD(rddId: Int, numPartitions: Int) =>
+      logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
+      locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
+      self.reply(true)
+
+    case AddedToCache(rddId, partition, host, size) =>
+      slaveUsage.put(host, getCacheUsage(host) + size)
+      logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
+        rddId, partition, host, Utils.memoryBytesToString(size),
+        Utils.memoryBytesToString(getCacheAvailable(host))))
+      locs(rddId)(partition) = host :: locs(rddId)(partition)
+      self.reply(true)
+
+    case DroppedFromCache(rddId, partition, host, size) =>
+      logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
+        rddId, partition, host, Utils.memoryBytesToString(size),
+        Utils.memoryBytesToString(getCacheAvailable(host))))
+      slaveUsage.put(host, getCacheUsage(host) - size)
+      // Do a sanity check to make sure usage is greater than 0.
+      val usage = getCacheUsage(host)
+      if (usage < 0) {
+        logError("Cache usage on %s is negative (%d)".format(host, usage))
       }
-    }
-  }
-}
+      locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
+      self.reply(true)
 
+    case MemoryCacheLost(host) =>
+      logInfo("Memory cache lost on " + host)
+      for ((id, locations) <- locs) {
+        for (i <- 0 until locations.length) {
+          locations(i) = locations(i).filterNot(_ == host)
+        }
+      }
+      self.reply(true)
 
-class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
-  // Tracker actor on the master, or remote reference to it on workers
-  var trackerActor: AbstractActor = null
+    case GetCacheLocations =>
+      logInfo("Asked for current cache locations")
+      self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())})
 
-  val registeredRddIds = new HashSet[Int]
+    case GetCacheStatus =>
+      val status = slaveCapacity.map { case (host, capacity) =>
+        (host, capacity, getCacheUsage(host))
+      }.toSeq
+      self.reply(status)
 
-  // Stores map results for various splits locally
-  val cache = theCache.newKeySpace()
+    case StopCacheTracker =>
+      logInfo("CacheTrackerActor Server stopped!")
+      self.reply(true)
+      self.exit()
+  }
+}
 
+class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging {
+  // Tracker actor on the master, or remote reference to it on workers
+  val ip: String = System.getProperty("spark.master.host", "localhost")
+  val port: Int = System.getProperty("spark.master.port", "7077").toInt
+  val aName: String = "CacheTracker"
+  
   if (isMaster) {
-    val tracker = new CacheTrackerActor
-    tracker.start()
-    trackerActor = tracker
+  }
+  
+  var trackerActor: ActorRef = if (isMaster) {
+    val actor = actorOf(new CacheTrackerActor)
+    remote.register(aName, actor)
+    actor.start()
+    logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port)
+    actor
   } else {
-    val host = System.getProperty("spark.master.host")
-    val port = System.getProperty("spark.master.port").toInt
-    trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker)
+    remote.actorFor(aName, ip, port)
   }
 
-  // Report the cache being started.
-  trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity)
+  val registeredRddIds = new HashSet[Int]
 
   // Remembers which splits are currently being loaded (on worker nodes)
-  val loading = new HashSet[(Int, Int)]
+  val loading = new HashSet[String]
   
   // Registers an RDD (on master only)
   def registerRDD(rddId: Int, numPartitions: Int) {
@@ -138,24 +128,33 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
       if (!registeredRddIds.contains(rddId)) {
         logInfo("Registering RDD ID " + rddId + " with cache")
         registeredRddIds += rddId
-        trackerActor !? RegisterRDD(rddId, numPartitions)
+        (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match {
+          case Some(true) =>
+            logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.")
+          case Some(oops) =>
+            logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops)
+          case None => 
+            logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions))
+            throw new SparkException("Internal error: CacheTracker registerRDD None.")
+        }
       }
     }
   }
-
-  // Get a snapshot of the currently known locations
-  def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
-    (trackerActor !? GetCacheLocations) match {
-      case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]]
-
-      case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+  
+  // For BlockManager.scala only
+  def cacheLost(host: String) {
+    (trackerActor ? MemoryCacheLost(host)).as[Any] match {
+       case Some(true) =>
+         logInfo("CacheTracker successfully removed entries on " + host)
+       case _ =>
+         logError("CacheTracker did not reply to MemoryCacheLost")
     }
   }
 
   // Get the usage status of slave caches. Each tuple in the returned sequence
   // is in the form of (host name, capacity, usage).
   def getCacheStatus(): Seq[(String, Long, Long)] = {
-    (trackerActor !? GetCacheStatus) match {
+    (trackerActor ? GetCacheStatus) match {
       case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]]
 
       case _ =>
@@ -164,75 +163,94 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging {
     }
   }
   
+  // For BlockManager.scala only
+  def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
+    (trackerActor ? t).as[Any] match {
+      case Some(true) =>
+        logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.")
+      case Some(oops) =>
+        logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops)
+      case None => 
+        logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.")
+    }
+  }
+  
+  // Get a snapshot of the currently known locations
+  def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
+    (trackerActor ? GetCacheLocations).as[Any] match {
+      case Some(h: HashMap[_, _]) =>
+        h.asInstanceOf[HashMap[Int, Array[List[String]]]]
+        
+      case _ => 
+        throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap")
+    }
+  }
+  
   // Gets or computes an RDD split
-  def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = {
-    logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index))
-    val cachedVal = cache.get(rdd.id, split.index)
-    if (cachedVal != null) {
-      // Split is in cache, so just return its values
-      logInfo("Found partition in cache!")
-      return cachedVal.asInstanceOf[Array[T]].iterator
-    } else {
-      // Mark the split as loading (unless someone else marks it first)
-      val key = (rdd.id, split.index)
-      loading.synchronized {
-        while (loading.contains(key)) {
-          // Someone else is loading it; let's wait for them
-          try { loading.wait() } catch { case _ => }
-        }
-        // See whether someone else has successfully loaded it. The main way this would fail
-        // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
-        // partition but we didn't want to make space for it. However, that case is unlikely
-        // because it's unlikely that two threads would work on the same RDD partition. One
-        // downside of the current code is that threads wait serially if this does happen.
-        val cachedVal = cache.get(rdd.id, split.index)
-        if (cachedVal != null) {
-          return cachedVal.asInstanceOf[Array[T]].iterator
-        }
-        // Nobody's loading it and it's not in the cache; let's load it ourselves
-        loading.add(key)
-      }
-      // If we got here, we have to load the split
-      // Tell the master that we're doing so
-
-      // TODO: fetch any remote copy of the split that may be available
-      logInfo("Computing partition " + split)
-      var array: Array[T] = null
-      var putResponse: CachePutResponse = null
-      try {
-        array = rdd.compute(split).toArray(m)
-        putResponse = cache.put(rdd.id, split.index, array)
-      } finally {
-        // Tell other threads that we've finished our attempt to load the key (whether or not
-        // we've actually succeeded to put it in the map)
+  def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
+    val key = "rdd:%d:%d".format(rdd.id, split.index)
+    logInfo("Cache key is " + key)
+    blockManager.get(key) match {
+      case Some(cachedValues) =>
+        // Split is in cache, so just return its values
+        logInfo("Found partition in cache!")
+        return cachedValues.asInstanceOf[Iterator[T]]
+
+      case None =>
+        // Mark the split as loading (unless someone else marks it first)
         loading.synchronized {
-          loading.remove(key)
-          loading.notifyAll()
+          if (loading.contains(key)) {
+            logInfo("Loading contains " + key + ", waiting...")
+            while (loading.contains(key)) {
+              try {loading.wait()} catch {case _ =>}
+            }
+            logInfo("Loading no longer contains " + key + ", so returning cached result")
+            // See whether someone else has successfully loaded it. The main way this would fail
+            // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+            // partition but we didn't want to make space for it. However, that case is unlikely
+            // because it's unlikely that two threads would work on the same RDD partition. One
+            // downside of the current code is that threads wait serially if this does happen.
+            blockManager.get(key) match {
+              case Some(values) =>
+                return values.asInstanceOf[Iterator[T]]
+              case None =>
+                logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+                loading.add(key)
+            }
+          } else {
+            loading.add(key)
+          }
         }
-      }
-
-      putResponse match {
-        case CachePutSuccess(size) => {
-          // Tell the master that we added the entry. Don't return until it
-          // replies so it can properly schedule future tasks that use this RDD.
-          trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size)
+        // If we got here, we have to load the split
+        // Tell the master that we're doing so
+        //val host = System.getProperty("spark.hostname", Utils.localHostName)
+        //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
+        // TODO: fetch any remote copy of the split that may be available
+        // TODO: also register a listener for when it unloads
+        logInfo("Computing partition " + split)
+        try {
+          val values = new ArrayBuffer[Any]
+          values ++= rdd.compute(split)
+          blockManager.put(key, values.iterator, storageLevel, false)
+          //future.apply() // Wait for the reply from the cache tracker
+          return values.iterator.asInstanceOf[Iterator[T]]
+        } finally {
+          loading.synchronized {
+            loading.remove(key)
+            loading.notifyAll()
+          }
         }
-        case _ => null
-      }
-      return array.iterator
     }
   }
 
   // Called by the Cache to report that an entry has been dropped from it
-  def dropEntry(datasetId: Any, partition: Int) {
-    datasetId match {
-      //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
-      case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost)
-    }
+  def dropEntry(rddId: Int, partition: Int) {
+    //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here.
+    trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName())
   }
 
   def stop() {
-    trackerActor !? StopCacheTracker
+    trackerActor !! StopCacheTracker
     registeredRddIds.clear()
     trackerActor = null
   }
diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala
index 93f453bc5e4341bcf74de43ee22d332cdeaf4e1a..3543c8afa8a081f201f630df60dcb6f915c01115 100644
--- a/core/src/main/scala/spark/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/CoGroupedRDD.scala
@@ -22,11 +22,12 @@ class CoGroupAggregator
     { (b1, b2) => b1 ++ b2 })
   with Serializable
 
-class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
+class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
   extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
   
   val aggr = new CoGroupAggregator
   
+  @transient
   override val dependencies = {
     val deps = new ArrayBuffer[Dependency[_]]
     for ((rdd, index) <- rdds.zipWithIndex) {
@@ -67,9 +68,10 @@ class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner)
   
   override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = {
     val split = s.asInstanceOf[CoGroupSplit]
+    val numRdds = split.deps.size
     val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
     def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
-      map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any]))
+      map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
     }
     for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
       case NarrowCoGroupSplitDep(rdd, itsSplit) => {
diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala
deleted file mode 100644
index 1b4af9d84c6d2159eb05084e2587ddef62a6bed1..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/DAGScheduler.scala
+++ /dev/null
@@ -1,374 +0,0 @@
-package spark
-
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
-
-/**
- * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation.
- */
-abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] {
-  val gen = SparkEnv.get.mapOutputTracker.getGeneration
-  override def generation: Option[Long] = Some(gen)
-}
-
-/**
- * A completion event passed by the underlying task scheduler to the DAG scheduler.
- */
-case class CompletionEvent(
-    task: DAGTask[_],
-    reason: TaskEndReason,
-    result: Any,
-    accumUpdates: Map[Long, Any])
-
-/**
- * Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry
- * tasks several times for "ephemeral" failures, and only report back failures that require some
- * old stages to be resubmitted, such as shuffle map fetch failures.
- */
-sealed trait TaskEndReason
-case object Success extends TaskEndReason
-case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
-case class ExceptionFailure(exception: Throwable) extends TaskEndReason
-case class OtherFailure(message: String) extends TaskEndReason
-
-/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for 
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal 
- * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
- * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
- */
-private trait DAGScheduler extends Scheduler with Logging {
-  // Must be implemented by subclasses to start running a set of tasks. The subclass should also
-  // attempt to run different sets of tasks in the order given by runId (lower values first).
-  def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit
-
-  // Must be called by subclasses to report task completions or failures.
-  def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) {
-    lock.synchronized {
-      val dagTask = task.asInstanceOf[DAGTask[_]]
-      eventQueues.get(dagTask.runId) match {
-        case Some(queue) =>
-          queue += CompletionEvent(dagTask, reason, result, accumUpdates)
-          lock.notifyAll()
-        case None =>
-          logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone")
-      }
-    }
-  }
-
-  // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
-  // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
-  // as more failure events come in
-  val RESUBMIT_TIMEOUT = 2000L
-
-  // The time, in millis, to wake up between polls of the completion queue in order to potentially
-  // resubmit failed stages
-  val POLL_TIMEOUT = 500L
-
-  private val lock = new Object          // Used for access to the entire DAGScheduler
-
-  private val eventQueues = new HashMap[Int, Queue[CompletionEvent]]   // Indexed by run ID
-
-  val nextRunId = new AtomicInteger(0)
-
-  val nextStageId = new AtomicInteger(0)
-
-  val idToStage = new HashMap[Int, Stage]
-
-  val shuffleToMapStage = new HashMap[Int, Stage]
-
-  var cacheLocs = new HashMap[Int, Array[List[String]]]
-
-  val env = SparkEnv.get
-  val cacheTracker = env.cacheTracker
-  val mapOutputTracker = env.mapOutputTracker
-
-  def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
-    cacheLocs(rdd.id)
-  }
-  
-  def updateCacheLocs() {
-    cacheLocs = cacheTracker.getLocationsSnapshot()
-  }
-
-  def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = {
-    shuffleToMapStage.get(shuf.shuffleId) match {
-      case Some(stage) => stage
-      case None =>
-        val stage = newStage(shuf.rdd, Some(shuf))
-        shuffleToMapStage(shuf.shuffleId) = stage
-        stage
-    }
-  }
-
-  def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = {
-    // Kind of ugly: need to register RDDs with the cache and map output tracker here
-    // since we can't do it in the RDD constructor because # of splits is unknown
-    cacheTracker.registerRDD(rdd.id, rdd.splits.size)
-    if (shuffleDep != None) {
-      mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
-    }
-    val id = nextStageId.getAndIncrement()
-    val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd))
-    idToStage(id) = stage
-    stage
-  }
-
-  def getParentStages(rdd: RDD[_]): List[Stage] = {
-    val parents = new HashSet[Stage]
-    val visited = new HashSet[RDD[_]]
-    def visit(r: RDD[_]) {
-      if (!visited(r)) {
-        visited += r
-        // Kind of ugly: need to register RDDs with the cache here since
-        // we can't do it in its constructor because # of splits is unknown
-        cacheTracker.registerRDD(r.id, r.splits.size)
-        for (dep <- r.dependencies) {
-          dep match {
-            case shufDep: ShuffleDependency[_,_,_] =>
-              parents += getShuffleMapStage(shufDep)
-            case _ =>
-              visit(dep.rdd)
-          }
-        }
-      }
-    }
-    visit(rdd)
-    parents.toList
-  }
-
-  def getMissingParentStages(stage: Stage): List[Stage] = {
-    val missing = new HashSet[Stage]
-    val visited = new HashSet[RDD[_]]
-    def visit(rdd: RDD[_]) {
-      if (!visited(rdd)) {
-        visited += rdd
-        val locs = getCacheLocs(rdd)
-        for (p <- 0 until rdd.splits.size) {
-          if (locs(p) == Nil) {
-            for (dep <- rdd.dependencies) {
-              dep match {
-                case shufDep: ShuffleDependency[_,_,_] =>
-                  val stage = getShuffleMapStage(shufDep)
-                  if (!stage.isAvailable) {
-                    missing += stage
-                  }
-                case narrowDep: NarrowDependency[_] =>
-                  visit(narrowDep.rdd)
-              }
-            }
-          }
-        }
-      }
-    }
-    visit(stage.rdd)
-    missing.toList
-  }
-
-  override def runJob[T, U](
-      finalRdd: RDD[T],
-      func: (TaskContext, Iterator[T]) => U,
-      partitions: Seq[Int],
-      allowLocal: Boolean)
-      (implicit m: ClassManifest[U]): Array[U] = {
-    lock.synchronized {
-      val runId = nextRunId.getAndIncrement()
-      
-      val outputParts = partitions.toArray
-      val numOutputParts: Int = partitions.size
-      val finalStage = newStage(finalRdd, None)
-      val results = new Array[U](numOutputParts)
-      val finished = new Array[Boolean](numOutputParts)
-      var numFinished = 0
-  
-      val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done
-      val running = new HashSet[Stage] // stages we are running right now
-      val failed = new HashSet[Stage]  // stages that must be resubmitted due to fetch failures
-      val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
-      var lastFetchFailureTime: Long = 0  // used to wait a bit to avoid repeated resubmits
-  
-      SparkEnv.set(env)
-  
-      updateCacheLocs()
-      
-      logInfo("Final stage: " + finalStage)
-      logInfo("Parents of final stage: " + finalStage.parents)
-      logInfo("Missing parents: " + getMissingParentStages(finalStage))
-  
-      // Optimization for short actions like first() and take() that can be computed locally
-      // without shipping tasks to the cluster.
-      if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) {
-        logInfo("Computing the requested partition locally")
-        val split = finalRdd.splits(outputParts(0))
-        val taskContext = new TaskContext(finalStage.id, outputParts(0), 0)
-        return Array(func(taskContext, finalRdd.iterator(split)))
-      }
-
-      // Register the job ID so that we can get completion events for it
-      eventQueues(runId) = new Queue[CompletionEvent]
-  
-      def submitStage(stage: Stage) {
-        if (!waiting(stage) && !running(stage)) {
-          val missing = getMissingParentStages(stage)
-          if (missing == Nil) {
-            logInfo("Submitting " + stage + ", which has no missing parents")
-            submitMissingTasks(stage)
-            running += stage
-          } else {
-            for (parent <- missing) {
-              submitStage(parent)
-            }
-            waiting += stage
-          }
-        }
-      }
-  
-      def submitMissingTasks(stage: Stage) {
-        // Get our pending tasks and remember them in our pendingTasks entry
-        val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
-        var tasks = ArrayBuffer[Task[_]]()
-        if (stage == finalStage) {
-          for (id <- 0 until numOutputParts if (!finished(id))) {
-            val part = outputParts(id)
-            val locs = getPreferredLocs(finalRdd, part)
-            tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id)
-          }
-        } else {
-          for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
-            val locs = getPreferredLocs(stage.rdd, p)
-            tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
-          }
-        }
-        myPending ++= tasks
-        submitTasks(tasks, runId)
-      }
-  
-      submitStage(finalStage)
-  
-      while (numFinished != numOutputParts) {
-        val eventOption = waitForEvent(runId, POLL_TIMEOUT)
-        val time = System.currentTimeMillis // TODO: use a pluggable clock for testability
-  
-        // If we got an event off the queue, mark the task done or react to a fetch failure
-        if (eventOption != None) {
-          val evt = eventOption.get
-          val stage = idToStage(evt.task.stageId)
-          pendingTasks(stage) -= evt.task
-          if (evt.reason == Success) {
-            // A task ended
-            logInfo("Completed " + evt.task)
-            Accumulators.add(evt.accumUpdates)
-            evt.task match {
-              case rt: ResultTask[_, _] =>
-                results(rt.outputId) = evt.result.asInstanceOf[U]
-                finished(rt.outputId) = true
-                numFinished += 1
-              case smt: ShuffleMapTask =>
-                val stage = idToStage(smt.stageId)
-                stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String])
-                if (running.contains(stage) && pendingTasks(stage).isEmpty) {
-                  logInfo(stage + " finished; looking for newly runnable stages")
-                  running -= stage
-                  if (stage.shuffleDep != None) {
-                    mapOutputTracker.registerMapOutputs(
-                      stage.shuffleDep.get.shuffleId,
-                      stage.outputLocs.map(_.head).toArray)
-                  }
-                  updateCacheLocs()
-                  val newlyRunnable = new ArrayBuffer[Stage]
-                  for (stage <- waiting if getMissingParentStages(stage) == Nil) {
-                    newlyRunnable += stage
-                  }
-                  waiting --= newlyRunnable
-                  running ++= newlyRunnable
-                  for (stage <- newlyRunnable) {
-                    submitMissingTasks(stage)
-                  }
-                }
-            }
-          } else {
-            evt.reason match {
-              case FetchFailed(serverUri, shuffleId, mapId, reduceId) =>
-                // Mark the stage that the reducer was in as unrunnable
-                val failedStage = idToStage(evt.task.stageId)
-                running -= failedStage
-                failed += failedStage
-                // TODO: Cancel running tasks in the stage
-                logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
-                // Mark the map whose fetch failed as broken in the map stage
-                val mapStage = shuffleToMapStage(shuffleId)
-                mapStage.removeOutputLoc(mapId, serverUri)
-                mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri)
-                logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
-                failed += mapStage
-                // Remember that a fetch failed now; this is used to resubmit the broken
-                // stages later, after a small wait (to give other tasks the chance to fail)
-                lastFetchFailureTime = time
-                // TODO: If there are a lot of fetch failures on the same node, maybe mark all
-                // outputs on the node as dead.
-              case _ =>
-                // Non-fetch failure -- probably a bug in the job, so bail out
-                throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason)
-                // TODO: Cancel all tasks that are still running
-            }
-          }
-        } // end if (evt != null)
-  
-        // If fetches have failed recently and we've waited for the right timeout,
-        // resubmit all the failed stages
-        if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
-          logInfo("Resubmitting failed stages")
-          updateCacheLocs()
-          for (stage <- failed) {
-            submitStage(stage)
-          }
-          failed.clear()
-        }
-      }
-  
-      eventQueues -= runId
-      return results
-    }
-  }
-
-  def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
-    // If the partition is cached, return the cache locations
-    val cached = getCacheLocs(rdd)(partition)
-    if (cached != Nil) {
-      return cached
-    }
-    // If the RDD has some placement preferences (as is the case for input RDDs), get those
-    val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
-    if (rddPrefs != Nil) {
-      return rddPrefs
-    }
-    // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
-    // that has any placement preferences. Ideally we would choose based on transfer sizes,
-    // but this will do for now.
-    rdd.dependencies.foreach(_ match {
-      case n: NarrowDependency[_] =>
-        for (inPart <- n.getParents(partition)) {
-          val locs = getPreferredLocs(n.rdd, inPart)
-          if (locs != Nil)
-            return locs;
-        }
-      case _ =>
-    })
-    return Nil
-  }
-
-  // Assumes that lock is held on entrance, but will release it to wait for the next event.
-  def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = {
-    val endTime = System.currentTimeMillis() + timeout   // TODO: Use pluggable clock for testing
-    while (eventQueues(runId).isEmpty) {
-      val time = System.currentTimeMillis()
-      if (time >= endTime) {
-        return None
-      } else {
-        lock.wait(endTime - time)
-      }
-    }
-    return Some(eventQueues(runId).dequeue())
-  }
-}
diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala
index d93c84924a5038fb202157b907092591b1343ac8..c0ff94acc6266b3e25e1988d700680100affec24 100644
--- a/core/src/main/scala/spark/Dependency.scala
+++ b/core/src/main/scala/spark/Dependency.scala
@@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
 
 class ShuffleDependency[K, V, C](
     val shuffleId: Int,
-    rdd: RDD[(K, V)],
+    @transient rdd: RDD[(K, V)],
     val aggregator: Aggregator[K, V, C],
     val partitioner: Partitioner)
   extends Dependency(rdd, true)
diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala
deleted file mode 100644
index e11466eb64eec01e923bde295867653d88bb7706..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/DiskSpillingCache.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-package spark
-
-import java.io.File
-import java.io.{FileOutputStream,FileInputStream}
-import java.io.IOException
-import java.util.LinkedHashMap
-import java.util.UUID
-
-// TODO: cache into a separate directory using Utils.createTempDir
-// TODO: clean up disk cache afterwards
-class DiskSpillingCache extends BoundedMemoryCache {
-  private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true)
-
-  override def get(datasetId: Any, partition: Int): Any = {
-    synchronized {
-      val ser = SparkEnv.get.serializer.newInstance()
-      super.get(datasetId, partition) match {
-        case bytes: Any => // found in memory
-          ser.deserialize(bytes.asInstanceOf[Array[Byte]])
-
-        case _ => diskMap.get((datasetId, partition)) match {
-          case file: Any => // found on disk
-            try {
-              val startTime = System.currentTimeMillis
-              val bytes = new Array[Byte](file.length.toInt)
-              new FileInputStream(file).read(bytes)
-              val timeTaken = System.currentTimeMillis - startTime
-              logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format(
-                datasetId, partition, file.length, timeTaken))
-              super.put(datasetId, partition, bytes)
-              ser.deserialize(bytes.asInstanceOf[Array[Byte]])
-            } catch {
-              case e: IOException =>
-                logWarning("Failed to read key (%s, %d) from disk at %s: %s".format(
-                  datasetId, partition, file.getPath(), e.getMessage()))
-                diskMap.remove((datasetId, partition)) // remove dead entry
-                null
-            }
-
-          case _ => // not found
-            null
-        }
-      }
-    }
-  }
-
-  override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
-    var ser = SparkEnv.get.serializer.newInstance()
-    super.put(datasetId, partition, ser.serialize(value))
-  }
-
-  /**
-   * Spill the given entry to disk. Assumes that a lock is held on the
-   * DiskSpillingCache.  Assumes that entry.value is a byte array.
-   */
-  override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
-    logInfo("Spilling key (%s, %d) of size %d to make space".format(
-      datasetId, partition, entry.size))
-    val cacheDir = System.getProperty(
-      "spark.diskSpillingCache.cacheDir",
-      System.getProperty("java.io.tmpdir"))
-    val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString)
-    try {
-      val stream = new FileOutputStream(file)
-      stream.write(entry.value.asInstanceOf[Array[Byte]])
-      stream.close()
-      diskMap.put((datasetId, partition), file)
-    } catch {
-      case e: IOException =>
-        logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format(
-          datasetId, partition, file.getPath(), e.getMessage()))
-        // Do nothing and let the entry be discarded
-    }
-  }
-}
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1fbf66b7ded3c2e16ed708159be075e12ea0e8e3
--- /dev/null
+++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala
@@ -0,0 +1,39 @@
+package spark
+
+import spark.partial.BoundedDouble
+import spark.partial.MeanEvaluator
+import spark.partial.PartialResult
+import spark.partial.SumEvaluator
+
+import spark.util.StatCounter
+
+/**
+ * Extra functions available on RDDs of Doubles through an implicit conversion.
+ */
+class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
+  def sum(): Double = {
+    self.reduce(_ + _)
+  }
+
+  def stats(): StatCounter = {
+    self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
+  }
+
+  def mean(): Double = stats().mean
+
+  def variance(): Double = stats().variance
+
+  def stdev(): Double = stats().stdev
+
+  def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+    val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+    val evaluator = new MeanEvaluator(self.splits.size, confidence)
+    self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+  }
+
+  def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+    val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+    val evaluator = new SumEvaluator(self.splits.size, confidence)
+    self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+  }
+}
diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala
index c795b6c3519332a6ea3fe0a9193918a32ec69b99..af9eb9c878ede5fd39441c413bf72c56524b0b5f 100644
--- a/core/src/main/scala/spark/Executor.scala
+++ b/core/src/main/scala/spark/Executor.scala
@@ -10,9 +10,10 @@ import scala.collection.mutable.ArrayBuffer
 import com.google.protobuf.ByteString
 
 import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
 
 import spark.broadcast._
+import spark.scheduler._
 
 /**
  * The Mesos executor for Spark.
@@ -29,6 +30,9 @@ class Executor extends org.apache.mesos.Executor with Logging {
       executorInfo: ExecutorInfo,
       frameworkInfo: FrameworkInfo,
       slaveInfo: SlaveInfo) {
+    // Make sure the local hostname we report matches Mesos's name for this host
+    Utils.setCustomHostname(slaveInfo.getHostname())
+
     // Read spark.* system properties from executor arg
     val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
     for ((key, value) <- props) {
@@ -39,7 +43,7 @@ class Executor extends org.apache.mesos.Executor with Logging {
     RemoteActor.classLoader = getClass.getClassLoader
 
     // Initialize Spark environment (using system properties read above)
-    env = SparkEnv.createFromSystemProperties(false)
+    env = SparkEnv.createFromSystemProperties(false, false)
     SparkEnv.set(env)
     // Old stuff that isn't yet using env
     Broadcast.initialize(false)
@@ -57,11 +61,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
 
   override def reregistered(d: ExecutorDriver, s: SlaveInfo) {}
   
-  override def launchTask(d: ExecutorDriver, task: TaskInfo) {
+  override def launchTask(d: ExecutorDriver, task: MTaskInfo) {
     threadPool.execute(new TaskRunner(task, d))
   }
 
-  class TaskRunner(info: TaskInfo, d: ExecutorDriver)
+  class TaskRunner(info: MTaskInfo, d: ExecutorDriver)
   extends Runnable {
     override def run() = {
       val tid = info.getTaskId.getValue
@@ -74,11 +78,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
           .setState(TaskState.TASK_RUNNING)
           .build())
       try {
+        SparkEnv.set(env)
+        Thread.currentThread.setContextClassLoader(classLoader)
         Accumulators.clear
-        val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader)
-        for (gen <- task.generation) {// Update generation if any is set
-          env.mapOutputTracker.updateGeneration(gen)
-        }
+        val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader)
+        env.mapOutputTracker.updateGeneration(task.generation)
         val value = task.run(tid.toInt)
         val accumUpdates = Accumulators.values
         val result = new TaskResult(value, accumUpdates)
@@ -105,9 +109,11 @@ class Executor extends org.apache.mesos.Executor with Logging {
               .setData(ByteString.copyFrom(ser.serialize(reason)))
               .build())
 
-          // TODO: Handle errors in tasks less dramatically
+          // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+          // have left some weird state around depending on when the exception was thrown, but on
+          // the other hand, maybe we could detect that when future tasks fail and exit then.
           logError("Exception in task ID " + tid, t)
-          System.exit(1)
+          //System.exit(1)
         }
       }
     }
diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala
index a3c4e7873d7ac5b11468320008252c1d2b84a549..55512f4481af231aa13c7c4b629ccdcc6bd556b5 100644
--- a/core/src/main/scala/spark/FetchFailedException.scala
+++ b/core/src/main/scala/spark/FetchFailedException.scala
@@ -1,7 +1,9 @@
 package spark
 
+import spark.storage.BlockManagerId
+
 class FetchFailedException(
-    val serverUri: String,
+    val bmAddress: BlockManagerId,
     val shuffleId: Int,
     val mapId: Int,
     val reduceId: Int,
@@ -9,10 +11,10 @@ class FetchFailedException(
   extends Exception {
   
   override def getMessage(): String = 
-    "Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId)
+    "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
 
   override def getCause(): Throwable = cause
 
   def toTaskEndReason: TaskEndReason =
-    FetchFailed(serverUri, shuffleId, mapId, reduceId)
+    FetchFailed(bmAddress, shuffleId, mapId, reduceId)
 }
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index 80f615eeb0a942f183d63128a3adec119101fcbe..ec5c33d1df0f639289401f0c9d5891f9bc57d9be 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -1,6 +1,7 @@
 package spark
 
 import java.io._
+import java.nio.ByteBuffer
 
 class JavaSerializationStream(out: OutputStream) extends SerializationStream {
   val objOut = new ObjectOutputStream(out)
@@ -9,10 +10,11 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream {
   def close() { objOut.close() }
 }
 
-class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
+class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
+extends DeserializationStream {
   val objIn = new ObjectInputStream(in) {
     override def resolveClass(desc: ObjectStreamClass) =
-      Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
+      Class.forName(desc.getName, false, loader)
   }
 
   def readObject[T](): T = objIn.readObject().asInstanceOf[T]
@@ -20,35 +22,36 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream {
 }
 
 class JavaSerializerInstance extends SerializerInstance {
-  def serialize[T](t: T): Array[Byte] = {
+  def serialize[T](t: T): ByteBuffer = {
     val bos = new ByteArrayOutputStream()
-    val out = outputStream(bos)
+    val out = serializeStream(bos)
     out.writeObject(t)
     out.close()
-    bos.toByteArray
+    ByteBuffer.wrap(bos.toByteArray)
   }
 
-  def deserialize[T](bytes: Array[Byte]): T = {
-    val bis = new ByteArrayInputStream(bytes)
-    val in = inputStream(bis)
+  def deserialize[T](bytes: ByteBuffer): T = {
+    val bis = new ByteArrayInputStream(bytes.array())
+    val in = deserializeStream(bis)
     in.readObject().asInstanceOf[T]
   }
 
-  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
-    val bis = new ByteArrayInputStream(bytes)
-    val ois = new ObjectInputStream(bis) {
-      override def resolveClass(desc: ObjectStreamClass) =
-        Class.forName(desc.getName, false, loader)
-    }
-    return ois.readObject.asInstanceOf[T]
+  def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+    val bis = new ByteArrayInputStream(bytes.array())
+    val in = deserializeStream(bis, loader)
+    in.readObject().asInstanceOf[T]
   }
 
-  def outputStream(s: OutputStream): SerializationStream = {
+  def serializeStream(s: OutputStream): SerializationStream = {
     new JavaSerializationStream(s)
   }
 
-  def inputStream(s: InputStream): DeserializationStream = {
-    new JavaDeserializationStream(s)
+  def deserializeStream(s: InputStream): DeserializationStream = {
+    new JavaDeserializationStream(s, currentThread.getContextClassLoader)
+  }
+
+  def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
+    new JavaDeserializationStream(s, loader)
   }
 }
 
diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala
deleted file mode 100644
index b7b0361c62c34c0377737b0328fe131a35d772e7..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/Job.scala
+++ /dev/null
@@ -1,16 +0,0 @@
-package spark
-
-import org.apache.mesos._
-import org.apache.mesos.Protos._
-
-/**
- * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various 
- * callbacks.
- */
-abstract class Job(val runId: Int, val jobId: Int) {
-  def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo]
-
-  def statusUpdate(t: TaskStatus): Unit
-
-  def error(message: String): Unit
-}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 5693613d6d45804767aeeab09c8990cb43babf43..65d0532bd58dddaea498fd4d9169eecfc4dea470 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -12,6 +12,8 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer}
 import com.esotericsoftware.kryo.serialize.ClassSerializer
 import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
 
+import spark.storage._
+
 /**
  * Zig-zag encoder used to write object sizes to serialization streams.
  * Based on Kryo's integer encoder.
@@ -64,57 +66,90 @@ object ZigZag {
   }
 }
 
-class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream)
+class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
 extends SerializationStream {
   val channel = Channels.newChannel(out)
 
   def writeObject[T](t: T) {
-    kryo.writeClassAndObject(buf, t)
-    ZigZag.writeInt(buf.position(), out)
-    buf.flip()
-    channel.write(buf)
-    buf.clear()
+    kryo.writeClassAndObject(threadBuffer, t)
+    ZigZag.writeInt(threadBuffer.position(), out)
+    threadBuffer.flip()
+    channel.write(threadBuffer)
+    threadBuffer.clear()
   }
 
   def flush() { out.flush() }
   def close() { out.close() }
 }
 
-class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream)
+class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
 extends DeserializationStream {
   def readObject[T](): T = {
     val len = ZigZag.readInt(in)
-    buf.readClassAndObject(in, len).asInstanceOf[T]
+    objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
   }
 
   def close() { in.close() }
 }
 
 class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
-  val buf = ks.threadBuf.get()
+  val kryo = ks.kryo
+  val threadBuffer = ks.threadBuffer.get()
+  val objectBuffer = ks.objectBuffer.get()
 
-  def serialize[T](t: T): Array[Byte] = {
-    buf.writeClassAndObject(t)
+  def serialize[T](t: T): ByteBuffer = {
+    // Write it to our thread-local scratch buffer first to figure out the size, then return a new
+    // ByteBuffer of the appropriate size
+    threadBuffer.clear()
+    kryo.writeClassAndObject(threadBuffer, t)
+    val newBuf = ByteBuffer.allocate(threadBuffer.position)
+    threadBuffer.flip()
+    newBuf.put(threadBuffer)
+    newBuf.flip()
+    newBuf
   }
 
-  def deserialize[T](bytes: Array[Byte]): T = {
-    buf.readClassAndObject(bytes).asInstanceOf[T]
+  def deserialize[T](bytes: ByteBuffer): T = {
+    kryo.readClassAndObject(bytes).asInstanceOf[T]
   }
 
-  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
-    val oldClassLoader = ks.kryo.getClassLoader
-    ks.kryo.setClassLoader(loader)
-    val obj = buf.readClassAndObject(bytes).asInstanceOf[T]
-    ks.kryo.setClassLoader(oldClassLoader)
+  def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+    val oldClassLoader = kryo.getClassLoader
+    kryo.setClassLoader(loader)
+    val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+    kryo.setClassLoader(oldClassLoader)
     obj
   }
 
-  def outputStream(s: OutputStream): SerializationStream = {
-    new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s)
+  def serializeStream(s: OutputStream): SerializationStream = {
+    threadBuffer.clear()
+    new KryoSerializationStream(kryo, threadBuffer, s)
+  }
+
+  def deserializeStream(s: InputStream): DeserializationStream = {
+    new KryoDeserializationStream(objectBuffer, s)
   }
 
-  def inputStream(s: InputStream): DeserializationStream = {
-    new KryoDeserializationStream(buf, s)
+  override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+    threadBuffer.clear()
+    while (iterator.hasNext) {
+      val element = iterator.next()
+      // TODO: Do we also want to write the object's size? Doesn't seem necessary.
+      kryo.writeClassAndObject(threadBuffer, element)
+    }
+    val newBuf = ByteBuffer.allocate(threadBuffer.position)
+    threadBuffer.flip()
+    newBuf.put(threadBuffer)
+    newBuf.flip()
+    newBuf
+  }
+
+  override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+    buffer.rewind()
+    new Iterator[Any] {
+      override def hasNext: Boolean = buffer.remaining > 0
+      override def next(): Any = kryo.readClassAndObject(buffer)
+    }
   }
 }
 
@@ -126,20 +161,17 @@ trait KryoRegistrator {
 class KryoSerializer extends Serializer with Logging {
   val kryo = createKryo()
 
-  val bufferSize = 
-    System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 
+  val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 
 
-  val threadBuf = new ThreadLocal[ObjectBuffer] {
+  val objectBuffer = new ThreadLocal[ObjectBuffer] {
     override def initialValue = new ObjectBuffer(kryo, bufferSize)
   }
 
-  val threadByteBuf = new ThreadLocal[ByteBuffer] {
+  val threadBuffer = new ThreadLocal[ByteBuffer] {
     override def initialValue = ByteBuffer.allocate(bufferSize)
   }
 
   def createKryo(): Kryo = {
-    // This is used so we can serialize/deserialize objects without a zero-arg
-    // constructor.
     val kryo = new KryoReflectionFactorySupport()
 
     // Register some commonly used classes
@@ -148,14 +180,20 @@ class KryoSerializer extends Serializer with Logging {
       Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")),
       Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'),
       // Specialized Tuple2s
-      ("", ""), (1, 1), (1.0, 1.0), (1L, 1L),
+      ("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L),
       (1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1),
       // Scala collections
       List(1), mutable.ArrayBuffer(1),
       // Options and Either
       Some(1), Left(1), Right(1),
       // Higher-dimensional tuples
-      (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1)
+      (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
+      None,
+      ByteBuffer.allocate(1),
+      StorageLevel.MEMORY_ONLY_DESER,
+      PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
+      GotBlock("1", ByteBuffer.allocate(1)),
+      GetBlock("1")
     )
     for (obj <- toRegister) {
       kryo.register(obj.getClass)
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 0d11ab9cbd836a5495f5392b942cb39ffd60e385..54bd57f6d3c94d2c17160f3ddaf38b1485f12e50 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -28,9 +28,11 @@ trait Logging {
   }
 
   // Log methods that take only a String
-  def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg)
+  def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg)
 
   def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg)
+  
+  def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg)
 
   def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg)
 
@@ -43,6 +45,9 @@ trait Logging {
   def logDebug(msg: => String, throwable: Throwable) =
     if (log.isDebugEnabled) log.debug(msg)
 
+  def logTrace(msg: => String, throwable: Throwable) =
+    if (log.isTraceEnabled) log.trace(msg)
+
   def logWarning(msg: => String, throwable: Throwable) =
     if (log.isWarnEnabled) log.warn(msg, throwable)
 
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index a934c5a02fe30706ddb9d6ce7194743c91c40ca1..d938a6eb629867b0a45c9a4abbe24233a5947b5b 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -2,80 +2,80 @@ package spark
 
 import java.util.concurrent.ConcurrentHashMap
 
-import scala.actors._
-import scala.actors.Actor._
-import scala.actors.remote._
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
 import scala.collection.mutable.HashSet
 
+import spark.storage.BlockManagerId
+
 sealed trait MapOutputTrackerMessage
 case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage 
 case object StopMapOutputTracker extends MapOutputTrackerMessage
 
-class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]])
-extends DaemonActor with Logging {
-  def act() {
-    val port = System.getProperty("spark.master.port").toInt
-    RemoteActor.alive(port)
-    RemoteActor.register('MapOutputTracker, self)
-    logInfo("Registered actor on port " + port)
-    
-    loop {
-      react {
-        case GetMapOutputLocations(shuffleId: Int) =>
-          logInfo("Asked to get map output locations for shuffle " + shuffleId)
-          reply(serverUris.get(shuffleId))
-          
-        case StopMapOutputTracker =>
-          reply('OK)
-          exit()
-      }
-    }
+class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) 
+extends Actor with Logging {
+  def receive = {
+    case GetMapOutputLocations(shuffleId: Int) =>
+      logInfo("Asked to get map output locations for shuffle " + shuffleId)
+      self.reply(bmAddresses.get(shuffleId))
+
+    case StopMapOutputTracker =>
+      logInfo("MapOutputTrackerActor stopped!")
+      self.reply(true)
+      self.exit()
   }
 }
 
 class MapOutputTracker(isMaster: Boolean) extends Logging {
-  var trackerActor: AbstractActor = null
+  val ip: String = System.getProperty("spark.master.host", "localhost")
+  val port: Int = System.getProperty("spark.master.port", "7077").toInt
+  val aName: String = "MapOutputTracker"
 
-  private var serverUris = new ConcurrentHashMap[Int, Array[String]]
+  private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
 
   // Incremented every time a fetch fails so that client nodes know to clear
   // their cache of map output locations if this happens.
   private var generation: Long = 0
   private var generationLock = new java.lang.Object
-  
-  if (isMaster) {
-    val tracker = new MapOutputTrackerActor(serverUris)
-    tracker.start()
-    trackerActor = tracker
+
+  var trackerActor: ActorRef = if (isMaster) {
+    val actor = actorOf(new MapOutputTrackerActor(bmAddresses))
+    remote.register(aName, actor)
+    logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port)
+    actor
   } else {
-    val host = System.getProperty("spark.master.host")
-    val port = System.getProperty("spark.master.port").toInt
-    trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker)
+    remote.actorFor(aName, ip, port)
   }
 
   def registerShuffle(shuffleId: Int, numMaps: Int) {
-    if (serverUris.get(shuffleId) != null) {
+    if (bmAddresses.get(shuffleId) != null) {
       throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
     }
-    serverUris.put(shuffleId, new Array[String](numMaps))
+    bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
   }
   
-  def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
-    var array = serverUris.get(shuffleId)
+  def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+    var array = bmAddresses.get(shuffleId)
     array.synchronized {
-      array(mapId) = serverUri
+      array(mapId) = bmAddress
     }
   }
   
-  def registerMapOutputs(shuffleId: Int, locs: Array[String]) {
-    serverUris.put(shuffleId, Array[String]() ++ locs)
+  def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
+    bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
+    if (changeGeneration) {
+      incrementGeneration()
+    }
   }
 
-  def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) {
-    var array = serverUris.get(shuffleId)
+  def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+    var array = bmAddresses.get(shuffleId)
     if (array != null) {
       array.synchronized {
-        if (array(mapId) == serverUri) {
+        if (array(mapId) == bmAddress) {
           array(mapId) = null
         }
       }
@@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
   val fetching = new HashSet[Int]
   
   // Called on possibly remote nodes to get the server URIs for a given shuffle
-  def getServerUris(shuffleId: Int): Array[String] = {
-    val locs = serverUris.get(shuffleId)
+  def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
+    val locs = bmAddresses.get(shuffleId)
     if (locs == null) {
-      logInfo("Don't have map outputs for " + shuffleId + ", fetching them")
+      logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them")
       fetching.synchronized {
         if (fetching.contains(shuffleId)) {
           // Someone else is fetching it; wait for them to be done
@@ -103,15 +103,17 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
               case _ =>
             }
           }
-          return serverUris.get(shuffleId)
+          return bmAddresses.get(shuffleId)
         } else {
           fetching += shuffleId
         }
       }
       // We won the race to fetch the output locs; do so
       logInfo("Doing the fetch; tracker actor = " + trackerActor)
-      val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]]
-      serverUris.put(shuffleId, fetched)
+      val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get
+      
+      logInfo("Got the output locations")
+      bmAddresses.put(shuffleId, fetched)
       fetching.synchronized {
         fetching -= shuffleId
         fetching.notifyAll()
@@ -121,14 +123,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
       return locs
     }
   }
-  
-  def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
-    "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
-  }
 
   def stop() {
-    trackerActor !? StopMapOutputTracker
-    serverUris.clear()
+    trackerActor !! StopMapOutputTracker
+    bmAddresses.clear()
     trackerActor = null
   }
 
@@ -153,7 +151,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging {
     generationLock.synchronized {
       if (newGen > generation) {
         logInfo("Updating generation to " + newGen + " and clearing cache")
-        serverUris = new ConcurrentHashMap[Int, Array[String]]
+        bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
         generation = newGen
       }
     }
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 8b63d1aba1eeff4fd9a0c1fc99f37a87d0a9a7ec..ff6764e0a21d6f84d4dd1f0b7581451c12969565 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -4,14 +4,14 @@ import java.io.EOFException
 import java.net.URL
 import java.io.ObjectInputStream
 import java.util.concurrent.atomic.AtomicLong
-import java.util.HashSet
-import java.util.Random
+import java.util.{HashMap => JHashMap}
 import java.util.Date
 import java.text.SimpleDateFormat
 
+import scala.collection.Map
 import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.Map
 import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
 
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.io.BytesWritable
@@ -34,7 +34,9 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
 import org.apache.hadoop.mapreduce.TaskAttemptID
 import org.apache.hadoop.mapreduce.TaskAttemptContext
 
-import SparkContext._
+import spark.SparkContext._
+import spark.partial.BoundedDouble
+import spark.partial.PartialResult
 
 /**
  * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -43,19 +45,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     self: RDD[(K, V)])
   extends Logging
   with Serializable {
-  
-  def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
-    def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
-      for ((k, v) <- m2) {
-        m1.get(k) match {
-          case None => m1(k) = v
-          case Some(w) => m1(k) = func(w, v)
-        }
-      }
-      return m1
-    }
-    self.map(pair => HashMap(pair)).reduce(mergeMaps)
-  }
 
   def combineByKey[C](createCombiner: V => C,
       mergeValue: (C, V) => C,
@@ -77,6 +66,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
   def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
     combineByKey[V]((v: V) => v, func, func, numSplits)
   }
+  
+  def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+    def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
+      val map = new JHashMap[K, V]
+      for ((k, v) <- iter) {
+        val old = map.get(k)
+        map.put(k, if (old == null) v else func(old, v))
+      }
+      Iterator(map)
+    }
+
+    def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
+      for ((k, v) <- m2) {
+        val old = m1.get(k)
+        m1.put(k, if (old == null) v else func(old, v))
+      }
+      return m1
+    }
+
+    self.mapPartitions(reducePartition).reduce(mergeMaps)
+  }
+
+  // Alias for backwards compatibility
+  def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
+
+  // TODO: This should probably be a distributed version
+  def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+
+  // TODO: This should probably be a distributed version
+  def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
+      : PartialResult[Map[K, BoundedDouble]] = {
+    self.map(_._1).countByValueApprox(timeout, confidence)
+  }
 
   def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
     def createCombiner(v: V) = ArrayBuffer(v)
diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala
deleted file mode 100644
index 19eb288e8460e599b501091f75b178cea388501a..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/ParallelShuffleFetcher.scala
+++ /dev/null
@@ -1,119 +0,0 @@
-package spark
-
-import java.io.ByteArrayInputStream
-import java.io.EOFException
-import java.net.URL
-import java.util.concurrent.LinkedBlockingQueue
-import java.util.concurrent.TimeUnit
-import java.util.concurrent.atomic.AtomicBoolean
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.atomic.AtomicReference
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-
-class ParallelShuffleFetcher extends ShuffleFetcher with Logging {
-  val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt
-
-  def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
-    logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-    
-    // Figure out a list of input IDs (mapper IDs) for each server
-    val ser = SparkEnv.get.serializer.newInstance()
-    val inputsByUri = new HashMap[String, ArrayBuffer[Int]]
-    val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
-    for ((serverUri, index) <- serverUris.zipWithIndex) {
-      inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
-    }
-    
-    // Randomize them and put them in a LinkedBlockingQueue
-    val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])]
-    for (pair <- Utils.randomize(inputsByUri)) {
-      serverQueue.put(pair)
-    }
-
-    // Create a queue to hold the fetched data
-    val resultQueue = new LinkedBlockingQueue[Array[Byte]]
-
-    // Atomic variables to communicate failures and # of fetches done
-    var failure = new AtomicReference[FetchFailedException](null)
-
-    // Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously)
-    for (i <- 0 until parallelFetches) {
-      new Thread("Fetch thread " + i + " for reduce " + reduceId) {
-        override def run() {
-          while (true) {
-            val pair = serverQueue.poll()
-            if (pair == null)
-              return
-            val (serverUri, inputIds) = pair
-            //logInfo("Pulled out server URI " + serverUri)
-            for (i <- inputIds) {
-              if (failure.get != null)
-                return
-              val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
-              logInfo("Starting HTTP request for " + url)
-              try {
-                val conn = new URL(url).openConnection()
-                conn.connect()
-                val len = conn.getContentLength()
-                if (len == -1) {
-                  throw new SparkException("Content length was not specified by server")
-                }
-                val buf = new Array[Byte](len)
-                val in = new FastBufferedInputStream(conn.getInputStream())
-                var pos = 0
-                while (pos < len) {
-                  val n = in.read(buf, pos, len-pos)
-                  if (n == -1) {
-                    throw new SparkException("EOF before reading the expected " + len + " bytes")
-                  } else {
-                    pos += n
-                  }
-                }
-                // Done reading everything
-                resultQueue.put(buf)
-                in.close()
-              } catch {
-                case e: Exception =>
-                  logError("Fetch failed from " + url, e)
-                  failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e))
-                  return
-              }
-            }
-            //logInfo("Done with server URI " + serverUri)
-          }
-        }
-      }.start()
-    }
-
-    // Wait for results from the threads (either a failure or all servers done)
-    var resultsDone = 0
-    var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum
-    while (failure.get == null && resultsDone < totalResults) {
-      try {
-        val result = resultQueue.poll(100, TimeUnit.MILLISECONDS)
-        if (result != null) {
-          //logInfo("Pulled out a result")
-          val in = ser.inputStream(new ByteArrayInputStream(result))
-            try {
-            while (true) {
-              val pair = in.readObject().asInstanceOf[(K, V)]
-              func(pair._1, pair._2)
-            }
-          } catch {
-            case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel
-          }
-          resultsDone += 1
-          //logInfo("Results done = " + resultsDone)
-        }
-      } catch { case e: InterruptedException => {} }
-    }
-    if (failure.get != null) {
-      throw failure.get
-    }
-  }
-}
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index ac61fe3b54526da22a0d812a485da167651a686e..8f3f0f5e15beca7c662feebb9f084120c0cd553f 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -70,4 +70,3 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
       false
   }
 }
-
diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala
index 8a5de3d7e96055ca839b476e661b0a9ed10035ad..9e0a01b5f9fb0357ca5eb0f599ccc2e567aef83b 100644
--- a/core/src/main/scala/spark/PipedRDD.scala
+++ b/core/src/main/scala/spark/PipedRDD.scala
@@ -3,6 +3,7 @@ package spark
 import java.io.PrintWriter
 import java.util.StringTokenizer
 
+import scala.collection.Map
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index fa53d9be2c045de5bd0ba15a0597fdeb75761b74..22dcc27bad5ea303bc2a649a8dc6fc511334f91f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -4,11 +4,14 @@ import java.io.EOFException
 import java.net.URL
 import java.io.ObjectInputStream
 import java.util.concurrent.atomic.AtomicLong
-import java.util.HashSet
 import java.util.Random
 import java.util.Date
+import java.util.{HashMap => JHashMap}
 
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
 
 import org.apache.hadoop.io.BytesWritable
 import org.apache.hadoop.io.NullWritable
@@ -22,6 +25,14 @@ import org.apache.hadoop.mapred.OutputFormat
 import org.apache.hadoop.mapred.SequenceFileOutputFormat
 import org.apache.hadoop.mapred.TextOutputFormat
 
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+import spark.partial.BoundedDouble
+import spark.partial.CountEvaluator
+import spark.partial.GroupedCountEvaluator
+import spark.partial.PartialResult
+import spark.storage.StorageLevel
+
 import SparkContext._
 
 /**
@@ -61,19 +72,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
   // Get a unique ID for this RDD
   val id = sc.newRddId()
   
-  // Variables relating to caching
-  private var shouldCache = false
+  // Variables relating to persistence
+  private var storageLevel: StorageLevel = StorageLevel.NONE
   
-  // Change this RDD's caching
-  def cache(): RDD[T] = {
-    shouldCache = true
+  // Change this RDD's storage level
+  def persist(newLevel: StorageLevel): RDD[T] = {
+    // TODO: Handle changes of StorageLevel
+    if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+      throw new UnsupportedOperationException(
+        "Cannot change storage level of an RDD after it was already assigned a level")
+    }
+    storageLevel = newLevel
     this
   }
+
+  // Turn on the default caching level for this RDD
+  def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
+  
+  // Turn on the default caching level for this RDD
+  def cache(): RDD[T] = persist()
+
+  def getStorageLevel = storageLevel
   
   // Read this RDD; will read from cache if applicable, or otherwise compute
   final def iterator(split: Split): Iterator[T] = {
-    if (shouldCache) {
-      SparkEnv.get.cacheTracker.getOrCompute[T](this, split)
+    if (storageLevel != StorageLevel.NONE) {
+      SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
     } else {
       compute(split)
     }
@@ -162,6 +186,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
     Array.concat(results: _*)
   }
 
+  def toArray(): Array[T] = collect()
+
   def reduce(f: (T, T) => T): T = {
     val cleanF = sc.clean(f)
     val reducePartition: Iterator[T] => Option[T] = iter => {
@@ -222,7 +248,67 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
     }).sum
   }
 
-  def toArray(): Array[T] = collect()
+  /**
+   * Approximate version of count() that returns a potentially incomplete result after a timeout.
+   */
+  def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+    val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
+      var result = 0L
+      while (iter.hasNext) {
+        result += 1L
+        iter.next
+      }
+      result
+    }
+    val evaluator = new CountEvaluator(splits.size, confidence)
+    sc.runApproximateJob(this, countElements, evaluator, timeout)
+  }
+
+  /**
+   * Count elements equal to each value, returning a map of (value, count) pairs. The final combine
+   * step happens locally on the master, equivalent to running a single reduce task.
+   *
+   * TODO: This should perhaps be distributed by default.
+   */
+  def countByValue(): Map[T, Long] = {
+    def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
+      val map = new OLMap[T]
+      while (iter.hasNext) {
+        val v = iter.next()
+        map.put(v, map.getLong(v) + 1L)
+      }
+      Iterator(map)
+    }
+    def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
+      val iter = m2.object2LongEntrySet.fastIterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
+      }
+      return m1
+    }
+    val myResult = mapPartitions(countPartition).reduce(mergeMaps)
+    myResult.asInstanceOf[java.util.Map[T, Long]]   // Will be wrapped as a Scala mutable Map
+  }
+
+  /**
+   * Approximate version of countByValue().
+   */
+  def countByValueApprox(
+      timeout: Long,
+      confidence: Double = 0.95
+      ): PartialResult[Map[T, BoundedDouble]] = {
+    val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
+      val map = new OLMap[T]
+      while (iter.hasNext) {
+        val v = iter.next()
+        map.put(v, map.getLong(v) + 1L)
+      }
+      map
+    }
+    val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
+    sc.runApproximateJob(this, countPartition, evaluator, timeout)
+  }
   
   /**
    * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala
deleted file mode 100644
index 6c7e569313b9f6a325b39c1606700715b90c56d9..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/Scheduler.scala
+++ /dev/null
@@ -1,27 +0,0 @@
-package spark
-
-/**
- * Scheduler trait, implemented by both MesosScheduler and LocalScheduler.
- */
-private trait Scheduler {
-  def start()
-
-  // Wait for registration with Mesos.
-  def waitForRegister()
-
-  /**
-   * Run a function on some partitions of an RDD, returning an array of results. The allowLocal
-   * flag specifies whether the scheduler is allowed to run the job on the master machine rather
-   * than shipping it to the cluster, for actions that create short jobs such as first() and take().
-   */
-  def runJob[T, U: ClassManifest](
-      rdd: RDD[T],
-      func: (TaskContext, Iterator[T]) => U,
-      partitions: Seq[Int],
-      allowLocal: Boolean): Array[U]
-
-  def stop()
-
-  // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
-  def defaultParallelism(): Int
-}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index b213ca9dcbde6c70ad6ef03ca4c2150a84a1390f..9da73c4b028c8f70a085f5ec22a5891516c575d7 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
       }
        // TODO: use something like WritableConverter to avoid reflection
     }
-    c.asInstanceOf[Class[ _ <: Writable]]
+    c.asInstanceOf[Class[_ <: Writable]]
   }
 
   def saveAsSequenceFile(path: String) {
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 2429bbfeb927445e887359465d54a8c8aafcade8..61a70beaf1fd73566443f8cf7e05c2317eceafd4 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -1,6 +1,12 @@
 package spark
 
-import java.io.{InputStream, OutputStream}
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.Channels
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import spark.util.ByteBufferInputStream
 
 /**
  * A serializer. Because some serialization libraries are not thread safe, this class is used to 
@@ -14,11 +20,31 @@ trait Serializer {
  * An instance of the serializer, for use by one thread at a time.
  */
 trait SerializerInstance {
-  def serialize[T](t: T): Array[Byte]
-  def deserialize[T](bytes: Array[Byte]): T
-  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T
-  def outputStream(s: OutputStream): SerializationStream
-  def inputStream(s: InputStream): DeserializationStream
+  def serialize[T](t: T): ByteBuffer
+
+  def deserialize[T](bytes: ByteBuffer): T
+
+  def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
+
+  def serializeStream(s: OutputStream): SerializationStream
+
+  def deserializeStream(s: InputStream): DeserializationStream
+
+  def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+    // Default implementation uses serializeStream
+    val stream = new FastByteArrayOutputStream()
+    serializeStream(stream).writeAll(iterator)
+    val buffer = ByteBuffer.allocate(stream.position.toInt)
+    buffer.put(stream.array, 0, stream.position.toInt)
+    buffer.flip()
+    buffer
+  }
+
+  def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+    // Default implementation uses deserializeStream
+    buffer.rewind()
+    deserializeStream(new ByteBufferInputStream(buffer)).toIterator
+  }
 }
 
 /**
@@ -28,6 +54,13 @@ trait SerializationStream {
   def writeObject[T](t: T): Unit
   def flush(): Unit
   def close(): Unit
+
+  def writeAll[T](iter: Iterator[T]): SerializationStream = {
+    while (iter.hasNext) {
+      writeObject(iter.next())
+    }
+    this
+  }
 }
 
 /**
@@ -36,4 +69,45 @@ trait SerializationStream {
 trait DeserializationStream {
   def readObject[T](): T
   def close(): Unit
+
+  /**
+   * Read the elements of this stream through an iterator. This can only be called once, as
+   * reading each element will consume data from the input source.
+   */
+  def toIterator: Iterator[Any] = new Iterator[Any] {
+    var gotNext = false
+    var finished = false
+    var nextValue: Any = null
+
+    private def getNext() {
+      try {
+        nextValue = readObject[Any]()
+      } catch {
+        case eof: EOFException =>
+          finished = true
+      }
+      gotNext = true
+    }
+    
+    override def hasNext: Boolean = {
+      if (!gotNext) {
+        getNext()
+      }
+      if (finished) {
+        close()
+      }
+      !finished
+    }
+
+    override def next(): Any = {
+      if (!gotNext) {
+        getNext()
+      }
+      if (finished) {
+        throw new NoSuchElementException("End of stream")
+      }
+      gotNext = false
+      nextValue
+    }
+  }
 }
diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala
deleted file mode 100644
index 3d192f24034a0f5a59a7247bf2850ba29efbbc80..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/SerializingCache.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-package spark
-
-import java.io._
-
-/**
- * Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to 
- * reduce storage cost and GC overhead
- */
-class SerializingCache extends Cache with Logging {
-  val bmc = new BoundedMemoryCache
-
-  override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
-    val ser = SparkEnv.get.serializer.newInstance()
-    bmc.put(datasetId, partition, ser.serialize(value))
-  }
-
-  override def get(datasetId: Any, partition: Int): Any = {
-    val bytes = bmc.get(datasetId, partition)
-    if (bytes != null) {
-      val ser = SparkEnv.get.serializer.newInstance()
-      return ser.deserialize(bytes.asInstanceOf[Array[Byte]])
-    } else {
-      return null
-    }
-  }
-}
diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala
deleted file mode 100644
index 5fc59af06c039f6d74638c63cea13ad824058e40..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/ShuffleMapTask.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-package spark
-
-import java.io.BufferedOutputStream
-import java.io.FileOutputStream
-import java.io.ObjectOutputStream
-import java.util.{HashMap => JHashMap}
-
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-class ShuffleMapTask(
-    runId: Int,
-    stageId: Int,
-    rdd: RDD[_], 
-    dep: ShuffleDependency[_,_,_],
-    val partition: Int, 
-    locs: Seq[String])
-  extends DAGTask[String](runId, stageId)
-  with Logging {
-  
-  val split = rdd.splits(partition)
-
-  override def run (attemptId: Int): String = {
-    val numOutputSplits = dep.partitioner.numPartitions
-    val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
-    val partitioner = dep.partitioner.asInstanceOf[Partitioner]
-    val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
-    for (elem <- rdd.iterator(split)) {
-      val (k, v) = elem.asInstanceOf[(Any, Any)]
-      var bucketId = partitioner.getPartition(k)
-      val bucket = buckets(bucketId)
-      var existing = bucket.get(k)
-      if (existing == null) {
-        bucket.put(k, aggregator.createCombiner(v))
-      } else {
-        bucket.put(k, aggregator.mergeValue(existing, v))
-      }
-    }
-    val ser = SparkEnv.get.serializer.newInstance()
-    for (i <- 0 until numOutputSplits) {
-      val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
-      val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
-      val iter = buckets(i).entrySet().iterator()
-      while (iter.hasNext()) {
-        val entry = iter.next()
-        out.writeObject((entry.getKey, entry.getValue))
-      }
-      // TODO: have some kind of EOF marker
-      out.close()
-    }
-    return SparkEnv.get.shuffleManager.getServerUri
-  }
-
-  override def preferredLocations: Seq[String] = locs
-
-  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
-}
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index 5efc8cf50b8ef27154c59a2bf00bd7a3d2220114..5434197ecad3330fb000b6c5a3238453e16a3b19 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
 }
 
 class ShuffledRDD[K, V, C](
-    parent: RDD[(K, V)],
+    @transient parent: RDD[(K, V)],
     aggregator: Aggregator[K, V, C],
     part : Partitioner) 
   extends RDD[(K, C)](parent.context) {
diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala
deleted file mode 100644
index 196c64cf1fb76758c9d1251dc296ddcb58d863cd..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-package spark
-
-import java.io.EOFException
-import java.net.URL
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-
-class SimpleShuffleFetcher extends ShuffleFetcher with Logging {
-  def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
-    logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-    val ser = SparkEnv.get.serializer.newInstance()
-    val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
-    val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId)
-    for ((serverUri, index) <- serverUris.zipWithIndex) {
-      splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index
-    }
-    for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) {
-      for (i <- inputIds) {
-        try {
-          val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId)
-          // TODO: multithreaded fetch
-          // TODO: would be nice to retry multiple times
-          val inputStream = ser.inputStream(
-              new FastBufferedInputStream(new URL(url).openStream()))
-          try {
-            while (true) {
-              val pair = inputStream.readObject().asInstanceOf[(K, V)]
-              func(pair._1, pair._2)
-            }
-          } finally {
-            inputStream.close()
-          }
-        } catch {
-          case e: EOFException => {} // We currently assume EOF means we read the whole thing
-          case other: Exception => {
-            logError("Fetch failed", other)
-            throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other)
-          }
-        }
-      }
-    }
-  }
-}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 6e019d6e7f10c345bb79a7452124384e46a8c12b..7a9a70fee0111475ab02993b921405b4bea63af9 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,6 +3,9 @@ package spark
 import java.io._
 import java.util.concurrent.atomic.AtomicInteger
 
+import akka.actor.Actor
+import akka.actor.Actor._
+
 import scala.actors.remote.RemoteActor
 import scala.collection.mutable.ArrayBuffer
 
@@ -32,6 +35,15 @@ import org.apache.mesos.MesosNativeLibrary
 
 import spark.broadcast._
 
+import spark.partial.ApproximateEvaluator
+import spark.partial.PartialResult
+
+import spark.scheduler.DAGScheduler
+import spark.scheduler.TaskScheduler
+import spark.scheduler.local.LocalScheduler
+import spark.scheduler.mesos.MesosScheduler
+import spark.scheduler.mesos.CoarseMesosScheduler
+
 class SparkContext(
     master: String,
     frameworkName: String,
@@ -54,14 +66,19 @@ class SparkContext(
   if (RemoteActor.classLoader == null) {
     RemoteActor.classLoader = getClass.getClassLoader
   }
+
+  remote.start(System.getProperty("spark.master.host"), 
+               System.getProperty("spark.master.port").toInt)
   
+  private val isLocal = master.startsWith("local") // TODO: better check for local
+
   // Create the Spark execution environment (cache, map output tracker, etc)
-  val env = SparkEnv.createFromSystemProperties(true)
+  val env = SparkEnv.createFromSystemProperties(true, isLocal) 
   SparkEnv.set(env)
   Broadcast.initialize(true)
 
   // Create and start the scheduler
-  private var scheduler: Scheduler = {
+  private var taskScheduler: TaskScheduler = {
     // Regular expression used for local[N] master format
     val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
     // Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -74,13 +91,17 @@ class SparkContext(
       case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
         new LocalScheduler(threads.toInt, maxFailures.toInt)
       case _ =>
-        MesosNativeLibrary.load()
-        new MesosScheduler(this, master, frameworkName)
+        System.loadLibrary("mesos")
+        if (System.getProperty("spark.mesos.coarse", "false") == "true") {
+          new CoarseMesosScheduler(this, master, frameworkName)
+        } else {
+          new MesosScheduler(this, master, frameworkName)
+        }
     }
   }
-  scheduler.start()
+  taskScheduler.start()
 
-  private val isLocal = scheduler.isInstanceOf[LocalScheduler]
+  private var dagScheduler = new DAGScheduler(taskScheduler)
 
   // Methods for creating RDDs
 
@@ -237,19 +258,21 @@ class SparkContext(
 
   // Stop the SparkContext
   def stop() {
-    scheduler.stop()
-    scheduler = null
+    dagScheduler.stop()
+    dagScheduler = null
+    taskScheduler = null
     // TODO: Broadcast.stop(), Cache.stop()?
     env.mapOutputTracker.stop()
     env.cacheTracker.stop()
     env.shuffleFetcher.stop()
     env.shuffleManager.stop()
+    env.connectionManager.stop()
     SparkEnv.set(null)
   }
 
-  // Wait for the scheduler to be registered
+  // Wait for the scheduler to be registered with the cluster manager
   def waitForRegister() {
-    scheduler.waitForRegister()
+    taskScheduler.waitForRegister()
   }
 
   // Get Spark's home location from either a value set through the constructor,
@@ -281,7 +304,7 @@ class SparkContext(
       ): Array[U] = {
     logInfo("Starting job...")
     val start = System.nanoTime
-    val result = scheduler.runJob(rdd, func, partitions, allowLocal)
+    val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
     logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
     result
   }
@@ -306,6 +329,22 @@ class SparkContext(
     runJob(rdd, func, 0 until rdd.splits.size, false)
   }
 
+  /**
+   * Run a job that can return approximate results.
+   */
+  def runApproximateJob[T, U, R](
+      rdd: RDD[T],
+      func: (TaskContext, Iterator[T]) => U,
+      evaluator: ApproximateEvaluator[U, R],
+      timeout: Long
+      ): PartialResult[R] = {
+    logInfo("Starting job...")
+    val start = System.nanoTime
+    val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
+    logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
+    result
+  }
+
   // Clean a closure to make it ready to serialized and send to tasks
   // (removes unreferenced variables in $outer's, updates REPL variables)
   private[spark] def clean[F <: AnyRef](f: F): F = {
@@ -314,7 +353,7 @@ class SparkContext(
   }
 
   // Default level of parallelism to use when not given by user (e.g. for reduce tasks)
-  def defaultParallelism: Int = scheduler.defaultParallelism
+  def defaultParallelism: Int = taskScheduler.defaultParallelism
 
   // Default min number of splits for Hadoop RDDs when not given by user
   def defaultMinSplits: Int = math.min(defaultParallelism, 2)
@@ -349,15 +388,23 @@ object SparkContext {
   }
 
   // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+
   implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
     new PairRDDFunctions(rdd)
-
-  implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) =
+  
+  implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+      rdd: RDD[(K, V)]) =
     new SequenceFileRDDFunctions(rdd)
 
-  implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+  implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+      rdd: RDD[(K, V)]) =
     new OrderedRDDFunctions(rdd)
 
+  implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+
+  implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+    new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+
   // Implicit conversions to common Writable types, for saveAsSequenceFile
 
   implicit def intToIntWritable(i: Int) = new IntWritable(i)
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index cd752f8b6597e6feb97a1d1e582070dae745f628..897a5ef82d0913cf3d263d0d7db4e6986c4387d9 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,14 +1,26 @@
 package spark
 
+import akka.actor.Actor
+
+import spark.storage.BlockManager
+import spark.storage.BlockManagerMaster
+import spark.network.ConnectionManager
+
 class SparkEnv (
-  val cache: Cache,
-  val serializer: Serializer,
-  val closureSerializer: Serializer,
-  val cacheTracker: CacheTracker,
-  val mapOutputTracker: MapOutputTracker,
-  val shuffleFetcher: ShuffleFetcher,
-  val shuffleManager: ShuffleManager
-)
+    val cache: Cache,
+    val serializer: Serializer,
+    val closureSerializer: Serializer,
+    val cacheTracker: CacheTracker,
+    val mapOutputTracker: MapOutputTracker,
+    val shuffleFetcher: ShuffleFetcher,
+    val shuffleManager: ShuffleManager,
+    val blockManager: BlockManager,
+    val connectionManager: ConnectionManager
+  ) {
+
+  /** No-parameter constructor for unit tests. */
+  def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null)
+}
 
 object SparkEnv {
   private val env = new ThreadLocal[SparkEnv]
@@ -21,36 +33,55 @@ object SparkEnv {
     env.get()
   }
 
-  def createFromSystemProperties(isMaster: Boolean): SparkEnv = {
-    val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
-    val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
-    
-    val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
+  def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = {
+    val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
     val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+    
+    BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal)
+    
+    var blockManager = new BlockManager(serializer)
+    
+    val connectionManager = blockManager.connectionManager 
+    
+    val shuffleManager = new ShuffleManager()
 
     val closureSerializerClass =
       System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
     val closureSerializer =
       Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
+    val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
+    val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
 
-    val cacheTracker = new CacheTracker(isMaster, cache)
+    val cacheTracker = new CacheTracker(isMaster, blockManager)
+    blockManager.cacheTracker = cacheTracker
 
     val mapOutputTracker = new MapOutputTracker(isMaster)
 
     val shuffleFetcherClass = 
-      System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher")
+      System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
     val shuffleFetcher = 
       Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
 
-    val shuffleMgr = new ShuffleManager()
+    /*
+    if (System.getProperty("spark.stream.distributed", "false") == "true") {
+      val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] 
+      if (isLocal || !isMaster) { 
+        (new Thread() {
+          override def run() {
+            println("Wait started") 
+            Thread.sleep(60000)
+            println("Wait ended")
+            val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
+            val constructor = receiverClass.getConstructor(blockManagerClass)
+            val receiver = constructor.newInstance(blockManager)
+            receiver.asInstanceOf[Thread].start()
+          }
+        }).start()
+      }
+    }
+    */
 
-    new SparkEnv(
-      cache,
-      serializer,
-      closureSerializer,
-      cacheTracker,
-      mapOutputTracker,
-      shuffleFetcher,
-      shuffleMgr)
+    new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher,
+        shuffleManager, blockManager, connectionManager)
   }
 }
diff --git a/core/src/main/scala/spark/Stage.scala b/core/src/main/scala/spark/Stage.scala
deleted file mode 100644
index 9452ea3a8e57db93c4cc31744a80bef8b3dfbd15..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/Stage.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-package spark
-
-class Stage(
-    val id: Int,
-    val rdd: RDD[_],
-    val shuffleDep: Option[ShuffleDependency[_,_,_]],
-    val parents: List[Stage]) {
-  
-  val isShuffleMap = shuffleDep != None
-  val numPartitions = rdd.splits.size
-  val outputLocs = Array.fill[List[String]](numPartitions)(Nil)
-  var numAvailableOutputs = 0
-
-  def isAvailable: Boolean = {
-    if (parents.size == 0 && !isShuffleMap) {
-      true
-    } else {
-      numAvailableOutputs == numPartitions
-    }
-  }
-
-  def addOutputLoc(partition: Int, host: String) {
-    val prevList = outputLocs(partition)
-    outputLocs(partition) = host :: prevList
-    if (prevList == Nil)
-      numAvailableOutputs += 1
-  }
-
-  def removeOutputLoc(partition: Int, host: String) {
-    val prevList = outputLocs(partition)
-    val newList = prevList.filterNot(_ == host)
-    outputLocs(partition) = newList
-    if (prevList != Nil && newList == Nil) {
-      numAvailableOutputs -= 1
-    }
-  }
-
-  override def toString = "Stage " + id
-
-  override def hashCode(): Int = id
-}
diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala
deleted file mode 100644
index bc3b3743447bda9d887bbbe970beb2ef52dbf38e..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/Task.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
-
-abstract class Task[T] extends Serializable {
-  def run(id: Int): T
-  def preferredLocations: Seq[String] = Nil
-  def generation: Option[Long] = None
-}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7a6214aab6648f6e7f5670b9839f3582dbe628bb
--- /dev/null
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -0,0 +1,3 @@
+package spark
+
+class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable
diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6e4eb25ed44ff07e94085ebaa0d01c736a2839ed
--- /dev/null
+++ b/core/src/main/scala/spark/TaskEndReason.scala
@@ -0,0 +1,16 @@
+package spark
+
+import spark.storage.BlockManagerId
+
+/**
+ * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
+ * tasks several times for "ephemeral" failures, and only report back failures that require some
+ * old stages to be resubmitted, such as shuffle map fetch failures.
+ */
+sealed trait TaskEndReason
+
+case object Success extends TaskEndReason
+case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
+case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
+case class ExceptionFailure(exception: Throwable) extends TaskEndReason
+case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala
deleted file mode 100644
index 2b7fd1a4b225e74dae4da46ad14d8b2cba0a87e9..0000000000000000000000000000000000000000
--- a/core/src/main/scala/spark/TaskResult.scala
+++ /dev/null
@@ -1,8 +0,0 @@
-package spark
-
-import scala.collection.mutable.Map
-
-// Task result. Also contains updates to accumulator variables.
-// TODO: Use of distributed cache to return result is a hack to get around
-// what seems to be a bug with messages over 60KB in libprocess; fix it
-private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable
diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala
index 4c0f255e6bb767e61ed3864f3e3600f237692247..17522e2bbb6d1077d4d8caefc778753229d820d2 100644
--- a/core/src/main/scala/spark/UnionRDD.scala
+++ b/core/src/main/scala/spark/UnionRDD.scala
@@ -33,7 +33,8 @@ class UnionRDD[T: ClassManifest](
 
   override def splits = splits_
 
-  @transient override val dependencies = {
+  @transient
+  override val dependencies = {
     val deps = new ArrayBuffer[Dependency[_]]
     var pos = 0
     for ((rdd, index) <- rdds.zipWithIndex) {
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index cfd6dc8b2aa3550e0f47dfdfbcc85732a72cd050..742e60b176f3b8103491ee01dfdc5219810a0fba 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -118,6 +118,23 @@ object Utils {
    * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
    */
   def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress
+
+  private var customHostname: Option[String] = None
+
+  /**
+   * Allow setting a custom host name because when we run on Mesos we need to use the same
+   * hostname it reports to the master.
+   */
+  def setCustomHostname(hostname: String) {
+    customHostname = Some(hostname)
+  }
+
+  /**
+   * Get the local machine's hostname
+   */
+  def localHostName(): String = {
+    customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
+  }
   
   /**
    * Returns a standard ThreadFactory except all threads are daemons.
@@ -142,6 +159,14 @@ object Utils {
 
     return threadPool
   }
+  
+  /**
+   * Return the string to tell how long has passed in seconds. The passing parameter should be in 
+   * millisecond. 
+   */
+  def getUsedTimeMs(startTimeMs: Long): String = {
+    return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+  }
 
   /**
    * Wrapper over newFixedThreadPool.
@@ -154,16 +179,6 @@ object Utils {
     return threadPool
   }
 
-  /**
-   * Get the local machine's hostname.
-   */
-  def localHostName(): String = InetAddress.getLocalHost.getHostName
-
-  /**
-   * Get current host
-   */
-  def getHost = System.getProperty("spark.hostname", localHostName())
-
   /**
    * Delete a file or directory and its contents recursively.
    */
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4546dfa0fac1b6c7f07d708a42abac2f4cedbdaa
--- /dev/null
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -0,0 +1,364 @@
+package spark.network
+
+import spark._
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+
+
+abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
+
+  channel.configureBlocking(false)
+  channel.socket.setTcpNoDelay(true)
+  channel.socket.setReuseAddress(true)
+  channel.socket.setKeepAlive(true)
+  /*channel.socket.setReceiveBufferSize(32768) */
+
+  var onCloseCallback: Connection => Unit = null
+  var onExceptionCallback: (Connection, Exception) => Unit = null
+  var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+  lazy val remoteAddress = getRemoteAddress() 
+  lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) 
+
+  def key() = channel.keyFor(selector)
+
+  def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+  def read() { 
+    throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) 
+  }
+  
+  def write() { 
+    throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) 
+  }
+
+  def close() {
+    key.cancel()
+    channel.close()
+    callOnCloseCallback()
+  }
+
+  def onClose(callback: Connection => Unit) {onCloseCallback = callback}
+
+  def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
+
+  def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
+
+  def callOnExceptionCallback(e: Exception) {
+    if (onExceptionCallback != null) {
+      onExceptionCallback(this, e)
+    } else {
+      logError("Error in connection to " + remoteConnectionManagerId + 
+        " and OnExceptionCallback not registered", e)
+    }
+  }
+  
+  def callOnCloseCallback() {
+    if (onCloseCallback != null) {
+      onCloseCallback(this)
+    } else {
+      logWarning("Connection to " + remoteConnectionManagerId + 
+        " closed and OnExceptionCallback not registered")
+    }
+
+  }
+
+  def changeConnectionKeyInterest(ops: Int) {
+    if (onKeyInterestChangeCallback != null) {
+      onKeyInterestChangeCallback(this, ops) 
+    } else {
+      throw new Exception("OnKeyInterestChangeCallback not registered")
+    }
+  }
+
+  def printRemainingBuffer(buffer: ByteBuffer) {
+    val bytes = new Array[Byte](buffer.remaining)
+    val curPosition = buffer.position
+    buffer.get(bytes)
+    bytes.foreach(x => print(x + " "))
+    buffer.position(curPosition)
+    print(" (" + bytes.size + ")")
+  }
+
+  def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+    val bytes = new Array[Byte](length)
+    val curPosition = buffer.position
+    buffer.position(position)
+    buffer.get(bytes)
+    bytes.foreach(x => print(x + " "))
+    print(" (" + position + ", " + length + ")")
+    buffer.position(curPosition)
+  }
+
+}
+
+
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector) 
+extends Connection(SocketChannel.open, selector_) {
+
+  class Outbox(fair: Int = 0) {
+    val messages = new Queue[Message]()
+    val defaultChunkSize = 65536  //32768 //16384 
+    var nextMessageToBeUsed = 0
+
+    def addMessage(message: Message): Unit = {
+      messages.synchronized{ 
+        /*messages += message*/
+        messages.enqueue(message)
+        logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
+      }
+    }
+
+    def getChunk(): Option[MessageChunk] = {
+      fair match {
+        case 0 => getChunkFIFO()
+        case 1 => getChunkRR()
+        case _ => throw new Exception("Unexpected fairness policy in outbox")
+      }
+    }
+
+    private def getChunkFIFO(): Option[MessageChunk] = {
+      /*logInfo("Using FIFO")*/
+      messages.synchronized {
+        while (!messages.isEmpty) {
+          val message = messages(0)
+          val chunk = message.getChunkForSending(defaultChunkSize)
+          if (chunk.isDefined) {
+            messages += message  // this is probably incorrect, it wont work as fifo
+            if (!message.started) logDebug("Starting to send [" + message + "]")
+            message.started = true
+            return chunk 
+          }
+          /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
+          logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in "  + message.timeTaken )
+        }
+      }
+      None
+    }
+    
+    private def getChunkRR(): Option[MessageChunk] = {
+      messages.synchronized {
+        while (!messages.isEmpty) {
+          /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+          /*val message = messages(nextMessageToBeUsed)*/
+          val message = messages.dequeue
+          val chunk = message.getChunkForSending(defaultChunkSize)
+          if (chunk.isDefined) {
+            messages.enqueue(message)
+            nextMessageToBeUsed = nextMessageToBeUsed + 1
+            if (!message.started) {
+              logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
+              message.started = true
+              message.startTime = System.currentTimeMillis
+            }
+            logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
+            return chunk 
+          } 
+          /*messages -= message*/
+          message.finishTime = System.currentTimeMillis
+          logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in "  + message.timeTaken )
+        }
+      }
+      None
+    }
+  }
+  
+  val outbox = new Outbox(1) 
+  val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+  /*channel.socket.setSendBufferSize(256 * 1024)*/
+
+  override def getRemoteAddress() = address 
+
+  def send(message: Message) {
+    outbox.synchronized {
+      outbox.addMessage(message)
+      if (channel.isConnected) {
+        changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+      }
+    }
+  }
+
+  def connect() {
+    try{
+      channel.connect(address)
+      channel.register(selector, SelectionKey.OP_CONNECT)
+      logInfo("Initiating connection to [" + address + "]")
+    } catch {
+      case e: Exception => {
+        logError("Error connecting to " + address, e)
+        callOnExceptionCallback(e)
+      }
+    }
+  }
+
+  def finishConnect() {
+    try {
+      channel.finishConnect
+      changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+      logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+    } catch {
+      case e: Exception => {
+        logWarning("Error finishing connection to " + address, e)
+        callOnExceptionCallback(e)
+      }
+    }
+  }
+
+  override def write() {
+    try{
+      while(true) {
+        if (currentBuffers.size == 0) {
+          outbox.synchronized {
+            outbox.getChunk match {
+              case Some(chunk) => {
+                currentBuffers ++= chunk.buffers 
+              }
+              case None => {
+                changeConnectionKeyInterest(0)
+                /*key.interestOps(0)*/
+                return
+              }
+            }
+          }
+        }
+        
+        if (currentBuffers.size > 0) {
+          val buffer = currentBuffers(0)
+          val remainingBytes = buffer.remaining
+          val writtenBytes = channel.write(buffer)
+          if (buffer.remaining == 0) {
+            currentBuffers -= buffer
+          }
+          if (writtenBytes < remainingBytes) {
+            return
+          }
+        }
+      }
+    } catch {
+      case e: Exception => { 
+        logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
+        callOnExceptionCallback(e)
+        close()
+      }
+    }
+  }
+}
+
+
+class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) 
+extends Connection(channel_, selector_) {
+  
+  class Inbox() {
+    val messages = new HashMap[Int, BufferMessage]()
+    
+    def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+      
+      def createNewMessage: BufferMessage = {
+        val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+        newMessage.started = true
+        newMessage.startTime = System.currentTimeMillis
+        logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") 
+        messages += ((newMessage.id, newMessage))
+        newMessage
+      }
+      
+      val message = messages.getOrElseUpdate(header.id, createNewMessage)
+      logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
+      message.getChunkForReceiving(header.chunkSize)
+    }
+    
+    def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+      messages.get(chunk.header.id) 
+    }
+
+    def removeMessage(message: Message) {
+      messages -= message.id
+    }
+  }
+  
+  val inbox = new Inbox()
+  val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+  var onReceiveCallback: (Connection , Message) => Unit = null
+  var currentChunk: MessageChunk = null
+
+  channel.register(selector, SelectionKey.OP_READ)
+
+  override def read() {
+    try {
+      while (true) {
+        if (currentChunk == null) {
+          val headerBytesRead = channel.read(headerBuffer)
+          if (headerBytesRead == -1) {
+            close()
+            return
+          }
+          if (headerBuffer.remaining > 0) {
+            return
+          }
+          headerBuffer.flip
+          if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+            throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+          }
+          val header = MessageChunkHeader.create(headerBuffer)
+          headerBuffer.clear()
+          header.typ match {
+            case Message.BUFFER_MESSAGE => {
+              if (header.totalSize == 0) {
+                if (onReceiveCallback != null) {
+                  onReceiveCallback(this, Message.create(header))
+                }
+                currentChunk = null
+                return
+              } else {
+                currentChunk = inbox.getChunk(header).orNull
+              }
+            }
+            case _ => throw new Exception("Message of unknown type received")
+          }
+        }
+        
+        if (currentChunk == null) throw new Exception("No message chunk to receive data")
+       
+        val bytesRead = channel.read(currentChunk.buffer)
+        if (bytesRead == 0) {
+          return
+        } else if (bytesRead == -1) {
+          close()
+          return
+        }
+
+        /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+        
+        if (currentChunk.buffer.remaining == 0) {
+          /*println("Filled buffer at " + System.currentTimeMillis)*/
+          val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+          if (bufferMessage.isCompletelyReceived) {
+            bufferMessage.flip
+            bufferMessage.finishTime = System.currentTimeMillis
+            logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) 
+            if (onReceiveCallback != null) {
+              onReceiveCallback(this, bufferMessage)
+            }
+            inbox.removeMessage(bufferMessage)
+          }
+          currentChunk = null
+        }
+      }
+    } catch {
+      case e: Exception  => { 
+        logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
+        callOnExceptionCallback(e)
+        close()
+      }
+    }
+  }
+  
+  def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e9f254d0f3b9624cedddfa8698f90eebadbc561d
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -0,0 +1,467 @@
+package spark.network
+
+import spark._
+
+import scala.actors.Future
+import scala.actors.Futures.future
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.SynchronizedMap
+import scala.collection.mutable.SynchronizedQueue
+import scala.collection.mutable.Queue
+import scala.collection.mutable.ArrayBuffer
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.Executors
+
+case class ConnectionManagerId(val host: String, val port: Int) {
+  def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+object ConnectionManagerId {
+  def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+    new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+  }
+}
+  
+class ConnectionManager(port: Int) extends Logging {
+
+  case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) {
+    var ackMessage: Option[Message] = None
+    var attempted = false
+    var acked = false
+  }
+  
+  val selector = SelectorProvider.provider.openSelector()
+  /*val handleMessageExecutor = new ThreadPoolExecutor(4, 4, 600, TimeUnit.SECONDS, new LinkedBlockingQueue()) */
+  val handleMessageExecutor = Executors.newFixedThreadPool(4) 
+  val serverChannel = ServerSocketChannel.open()
+  val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] 
+  val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+  val messageStatuses = new HashMap[Int, MessageStatus] 
+  val connectionRequests = new SynchronizedQueue[SendingConnection]
+  val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+  val sendMessageRequests = new Queue[(Message, SendingConnection)]
+  
+  var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+
+  serverChannel.configureBlocking(false)
+  serverChannel.socket.setReuseAddress(true)
+  serverChannel.socket.setReceiveBufferSize(256 * 1024) 
+
+  serverChannel.socket.bind(new InetSocketAddress(port))
+  serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+  val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
+  logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+  
+  val thisInstance = this
+  var selectorThread = new Thread("connection-manager-thread") {
+    override def run() {
+      thisInstance.run()
+    }
+  }
+  selectorThread.setDaemon(true)
+  selectorThread.start()
+
+  def run() {
+    try {
+      var interrupted = false 
+      while(!interrupted) {
+        while(!connectionRequests.isEmpty) {
+          val sendingConnection = connectionRequests.dequeue
+          sendingConnection.connect() 
+          addConnection(sendingConnection)
+        }
+        sendMessageRequests.synchronized {
+          while(!sendMessageRequests.isEmpty) {
+            val (message, connection) = sendMessageRequests.dequeue
+            connection.send(message)
+          }
+        }
+
+        while(!keyInterestChangeRequests.isEmpty) {
+          val (key, ops) = keyInterestChangeRequests.dequeue
+          val connection = connectionsByKey(key)
+          val lastOps = key.interestOps()
+          key.interestOps(ops)
+          
+          def intToOpStr(op: Int): String = {
+            val opStrs = ArrayBuffer[String]()
+            if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+            if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+            if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+            if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+            if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+          }
+          
+          logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId  + 
+            "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+          
+        }
+
+        val selectedKeysCount = selector.select()
+        if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+        
+        interrupted = selectorThread.isInterrupted
+
+        val selectedKeys = selector.selectedKeys().iterator()
+        while (selectedKeys.hasNext()) {
+          val key = selectedKeys.next.asInstanceOf[SelectionKey]
+          selectedKeys.remove()
+          if (key.isValid) {
+            if (key.isAcceptable) {
+              acceptConnection(key)
+            } else 
+            if (key.isConnectable) {
+              connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
+            } else 
+            if (key.isReadable) {
+              connectionsByKey(key).read()
+            } else 
+            if (key.isWritable) {
+              connectionsByKey(key).write()
+            }
+          }
+        }
+      }
+    } catch {
+      case e: Exception => logError("Error in select loop", e)
+    }
+  }
+  
+  def acceptConnection(key: SelectionKey) {
+    val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+    val newChannel = serverChannel.accept()
+    val newConnection = new ReceivingConnection(newChannel, selector)
+    newConnection.onReceive(receiveMessage)
+    newConnection.onClose(removeConnection)
+    addConnection(newConnection)
+    logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+  }
+
+  def addConnection(connection: Connection) {
+    connectionsByKey += ((connection.key, connection))
+    if (connection.isInstanceOf[SendingConnection]) {
+      val sendingConnection = connection.asInstanceOf[SendingConnection]
+      connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
+    }
+    connection.onKeyInterestChange(changeConnectionKeyInterest)
+    connection.onException(handleConnectionError)
+    connection.onClose(removeConnection)
+  }
+
+  def removeConnection(connection: Connection) {
+    /*logInfo("Removing connection")*/
+    connectionsByKey -= connection.key
+    if (connection.isInstanceOf[SendingConnection]) {
+      val sendingConnection = connection.asInstanceOf[SendingConnection]
+      val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
+      logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+      
+      connectionsById -= sendingConnectionManagerId
+
+      messageStatuses.synchronized {
+        messageStatuses
+          .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+            logInfo("Notifying " + status)
+            status.synchronized {
+            status.attempted = true 
+             status.acked = false
+             status.notifyAll
+            }
+          })
+
+        messageStatuses.retain((i, status) => { 
+          status.connectionManagerId != sendingConnectionManagerId 
+        })
+      }
+    } else if (connection.isInstanceOf[ReceivingConnection]) {
+      val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+      val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
+      logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+      
+      val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
+      if (sendingConnectionManagerId == null) {
+        logError("Corresponding SendingConnectionManagerId not found")
+        return
+      }
+      logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
+      
+      val sendingConnection = connectionsById(sendingConnectionManagerId)
+      sendingConnection.close()
+      connectionsById -= sendingConnectionManagerId
+      
+      messageStatuses.synchronized {
+        messageStatuses
+          .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+            logInfo("Notifying " + status)
+            status.synchronized {
+            status.attempted = true 
+             status.acked = false
+             status.notifyAll
+            }
+          })
+
+        messageStatuses.retain((i, status) => { 
+          status.connectionManagerId != sendingConnectionManagerId 
+        })
+      }
+    }
+  }
+
+  def handleConnectionError(connection: Connection, e: Exception) {
+    logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
+    removeConnection(connection)
+  }
+
+  def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+    keyInterestChangeRequests += ((connection.key, ops))  
+  }
+
+  def receiveMessage(connection: Connection, message: Message) {
+    val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
+    logInfo("Received [" + message + "] from [" + connectionManagerId + "]") 
+    val runnable = new Runnable() {
+      val creationTime = System.currentTimeMillis
+      def run() {
+        logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+        handleMessage(connectionManagerId, message)
+        logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+      }
+    }
+    handleMessageExecutor.execute(runnable)
+    /*handleMessage(connection, message)*/
+  }
+
+  private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+    logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") 
+    message match {
+      case bufferMessage: BufferMessage => {
+        if (bufferMessage.hasAckId) {
+          val sentMessageStatus = messageStatuses.synchronized {
+            messageStatuses.get(bufferMessage.ackId) match {
+              case Some(status) => { 
+                messageStatuses -= bufferMessage.ackId 
+                status
+              }
+              case None => { 
+                throw new Exception("Could not find reference for received ack message " + message.id)
+                null
+              }
+            }
+          }
+          sentMessageStatus.synchronized {
+            sentMessageStatus.ackMessage = Some(message)
+            sentMessageStatus.attempted = true
+            sentMessageStatus.acked = true
+            sentMessageStatus.notifyAll
+          }
+        } else {
+          val ackMessage = if (onReceiveCallback != null) {
+            logDebug("Calling back")
+            onReceiveCallback(bufferMessage, connectionManagerId)
+          } else {
+            logWarning("Not calling back as callback is null")
+            None
+          }
+          
+          if (ackMessage.isDefined) {
+            if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+              logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+            } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+              logWarning("Response to " + bufferMessage + " does not have ack id set")
+              ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+            }
+          }
+
+          sendMessage(connectionManagerId, ackMessage.getOrElse { 
+            Message.createBufferMessage(bufferMessage.id)
+          })
+        }
+      }
+      case _ => throw new Exception("Unknown type message received")
+    }
+  }
+
+  private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+    def startNewConnection(): SendingConnection = {
+      val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
+      val newConnection = new SendingConnection(inetSocketAddress, selector)
+      connectionRequests += newConnection
+      newConnection   
+    }
+    val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection) 
+    message.senderAddress = id.toSocketAddress()
+    logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") 
+    /*connection.send(message)*/
+    sendMessageRequests.synchronized {
+      sendMessageRequests += ((message, connection))
+    }
+    selector.wakeup()
+  }
+
+  def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = {
+    val messageStatus = new MessageStatus(message, connectionManagerId) 
+    messageStatuses.synchronized {
+      messageStatuses += ((message.id, messageStatus))
+    }
+    sendMessage(connectionManagerId, message)
+    future {
+      messageStatus.synchronized {
+        if (!messageStatus.attempted) {
+          logTrace("Waiting, " + messageStatuses.size + " statuses" )
+          messageStatus.wait()
+          logTrace("Done waiting")
+        }
+      }
+      messageStatus.ackMessage 
+    }
+  }
+
+  def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
+    sendMessageReliably(connectionManagerId, message)()
+  }
+
+  def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
+    onReceiveCallback = callback
+  }
+
+  def stop() {
+    selectorThread.interrupt()
+    selectorThread.join()
+    selector.close()
+    val connections = connectionsByKey.values
+    connections.foreach(_.close())
+    if (connectionsByKey.size != 0) {
+      logWarning("All connections not cleaned up")
+    }
+    handleMessageExecutor.shutdown()
+    logInfo("ConnectionManager stopped")
+  }
+}
+
+
+object ConnectionManager {
+
+  def main(args: Array[String]) {
+  
+    val manager = new ConnectionManager(9999)
+    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { 
+      println("Received [" + msg + "] from [" + id + "]")
+      None
+    })
+    
+    /*testSequentialSending(manager)*/
+    /*System.gc()*/
+
+    /*testParallelSending(manager)*/
+    /*System.gc()*/
+    
+    /*testParallelDecreasingSending(manager)*/
+    /*System.gc()*/
+
+    testContinuousSending(manager)
+    System.gc()
+  }
+
+  def testSequentialSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Sequential Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024 
+    val count = 10
+    
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+    buffer.flip
+
+    (0 until count).map(i => {
+      val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+      manager.sendMessageReliablySync(manager.id, bufferMessage)
+    })
+    println("--------------------------")
+    println()
+  }
+
+  def testParallelSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Parallel Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024 
+    val count = 10
+
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+    buffer.flip
+
+    val startTime = System.currentTimeMillis
+    (0 until count).map(i => {
+      val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+      manager.sendMessageReliably(manager.id, bufferMessage)
+    }).foreach(f => {if (!f().isDefined) println("Failed")})
+    val finishTime = System.currentTimeMillis
+    
+    val mb = size * count / 1024.0 / 1024.0
+    val ms = finishTime - startTime
+    val tput = mb * 1000.0 / ms
+    println("--------------------------")
+    println("Started at " + startTime + ", finished at " + finishTime) 
+    println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
+    println("--------------------------")
+    println()
+  }
+
+  def testParallelDecreasingSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Parallel Decreasing Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024 
+    val count = 10
+    val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
+    buffers.foreach(_.flip)
+    val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+    val startTime = System.currentTimeMillis
+    (0 until count).map(i => {
+      val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
+      manager.sendMessageReliably(manager.id, bufferMessage)
+    }).foreach(f => {if (!f().isDefined) println("Failed")})
+    val finishTime = System.currentTimeMillis
+    
+    val ms = finishTime - startTime
+    val tput = mb * 1000.0 / ms
+    println("--------------------------")
+    /*println("Started at " + startTime + ", finished at " + finishTime) */
+    println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+    println("--------------------------")
+    println()
+  }
+
+  def testContinuousSending(manager: ConnectionManager) {
+    println("--------------------------")
+    println("Continuous Sending")
+    println("--------------------------")
+    val size = 10 * 1024 * 1024 
+    val count = 10
+
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+    buffer.flip
+
+    val startTime = System.currentTimeMillis
+    while(true) {
+      (0 until count).map(i => {
+          val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+          manager.sendMessageReliably(manager.id, bufferMessage)
+        }).foreach(f => {if (!f().isDefined) println("Failed")})
+      val finishTime = System.currentTimeMillis
+      Thread.sleep(1000)
+      val mb = size * count / 1024.0 / 1024.0
+      val ms = finishTime - startTime
+      val tput = mb * 1000.0 / ms
+      println("--------------------------")
+      println()
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5d21bb793f3dcefce2af736edeb602c47ff0c56f
--- /dev/null
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -0,0 +1,74 @@
+package spark.network
+
+import spark._
+import spark.SparkContext._
+
+import scala.io.Source
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object ConnectionManagerTest extends Logging{
+  def main(args: Array[String]) {
+    if (args.length < 2) {
+      println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
+      System.exit(1)
+    }
+    
+    if (args(0).startsWith("local")) {
+      println("This runs only on a mesos cluster")
+    }
+    
+    val sc = new SparkContext(args(0), "ConnectionManagerTest")
+    val slavesFile = Source.fromFile(args(1))
+    val slaves = slavesFile.mkString.split("\n")
+    slavesFile.close()
+
+    /*println("Slaves")*/
+    /*slaves.foreach(println)*/
+   
+    val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
+        i => SparkEnv.get.connectionManager.id).collect()
+    println("\nSlave ConnectionManagerIds")
+    slaveConnManagerIds.foreach(println)
+    println
+
+    val count = 10
+    (0 until count).foreach(i => {
+      val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
+        val connManager = SparkEnv.get.connectionManager
+        val thisConnManagerId = connManager.id 
+        connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { 
+          logInfo("Received [" + msg + "] from [" + id + "]")
+          None
+        })
+
+        val size =  100 * 1024  * 1024 
+        val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+        buffer.flip
+        
+        val startTime = System.currentTimeMillis  
+        val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
+          val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+          logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
+          connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
+        })
+        val results = futures.map(f => f())
+        val finishTime = System.currentTimeMillis
+        Thread.sleep(5000)
+        
+        val mb = size * results.size / 1024.0 / 1024.0
+        val ms = finishTime - startTime
+        val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+        logInfo(resultStr)
+        resultStr
+      }).collect()
+      
+      println("---------------------") 
+      println("Run " + i) 
+      resultStrs.foreach(println)
+      println("---------------------") 
+    })
+  }
+}
+
diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2e858036791d2e5e80020c7527d6cfecf6bd9f07
--- /dev/null
+++ b/core/src/main/scala/spark/network/Message.scala
@@ -0,0 +1,219 @@
+package spark.network
+
+import spark._
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+import java.net.InetSocketAddress
+
+class MessageChunkHeader(
+    val typ: Long,
+    val id: Int,
+    val totalSize: Int,
+    val chunkSize: Int,
+    val other: Int,
+    val address: InetSocketAddress) {
+  lazy val buffer = {
+    val ip = address.getAddress.getAddress() 
+    val port = address.getPort()
+    ByteBuffer.
+      allocate(MessageChunkHeader.HEADER_SIZE).
+      putLong(typ).
+      putInt(id).
+      putInt(totalSize).
+      putInt(chunkSize).
+      putInt(other).
+      putInt(ip.size).
+      put(ip).
+      putInt(port).
+      position(MessageChunkHeader.HEADER_SIZE).
+      flip.asInstanceOf[ByteBuffer]
+  }
+
+  override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + 
+      " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+  val size = if (buffer == null) 0 else buffer.remaining
+  lazy val buffers = {
+    val ab = new ArrayBuffer[ByteBuffer]()
+    ab += header.buffer
+    if (buffer != null) { 
+      ab += buffer
+    }
+    ab
+  }
+
+  override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+}
+
+abstract class Message(val typ: Long, val id: Int) {
+  var senderAddress: InetSocketAddress = null
+  var started = false
+  var startTime = -1L
+  var finishTime = -1L
+
+  def size: Int
+  
+  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+  
+  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+ 
+  def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+  override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+}
+
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) 
+extends Message(Message.BUFFER_MESSAGE, id_) {
+  
+  val initialSize = currentSize() 
+  var gotChunkForSendingOnce = false
+  
+  def size = initialSize 
+
+  def currentSize() = {
+    if (buffers == null || buffers.isEmpty) {
+      0 
+    } else {
+      buffers.map(_.remaining).reduceLeft(_ + _)
+    }
+  }
+  
+  def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+    if (maxChunkSize <= 0) {
+      throw new Exception("Max chunk size is " + maxChunkSize)
+    }
+
+    if (size == 0 && gotChunkForSendingOnce == false) {
+      val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+      gotChunkForSendingOnce = true
+      return Some(newChunk)
+    }
+
+    while(!buffers.isEmpty) {
+      val buffer = buffers(0)
+      if (buffer.remaining == 0) {
+        buffers -= buffer
+      } else {
+        val newBuffer = if (buffer.remaining <= maxChunkSize) {
+          buffer.duplicate
+        } else {
+          buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+        }
+        buffer.position(buffer.position + newBuffer.remaining)
+        val newChunk = new MessageChunk(new MessageChunkHeader(
+            typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+        gotChunkForSendingOnce = true
+        return Some(newChunk)
+      }
+    }
+    None
+  }
+
+  def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+    // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+    if (buffers.size > 1) {
+      throw new Exception("Attempting to get chunk from message with multiple data buffers")
+    }
+    val buffer = buffers(0)
+    if (buffer.remaining > 0) {
+      if (buffer.remaining < chunkSize) {
+        throw new Exception("Not enough space in data buffer for receiving chunk")
+      }
+      val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+      buffer.position(buffer.position + newBuffer.remaining)
+      val newChunk = new MessageChunk(new MessageChunkHeader(
+          typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+      return Some(newChunk)
+    }
+    None 
+  }
+
+  def flip() {
+    buffers.foreach(_.flip)
+  }
+
+  def hasAckId() = (ackId != 0)
+
+  def isCompletelyReceived() = !buffers(0).hasRemaining
+  
+  override def toString = {
+    if (hasAckId) {
+      "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+    } else {
+      "BufferMessage(id = " + id + ", size = " + size + ")"
+    }
+
+  }
+}
+
+object MessageChunkHeader {
+  val HEADER_SIZE = 40 
+  
+  def create(buffer: ByteBuffer): MessageChunkHeader = {
+    if (buffer.remaining != HEADER_SIZE) {
+      throw new IllegalArgumentException("Cannot convert buffer data to Message")
+    }
+    val typ = buffer.getLong()
+    val id = buffer.getInt()
+    val totalSize = buffer.getInt()
+    val chunkSize = buffer.getInt()
+    val other = buffer.getInt()
+    val ipSize = buffer.getInt()
+    val ipBytes = new Array[Byte](ipSize)
+    buffer.get(ipBytes)
+    val ip = InetAddress.getByAddress(ipBytes)
+    val port = buffer.getInt()
+    new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+  }
+}
+
+object Message {
+  val BUFFER_MESSAGE = 1111111111L
+
+  var lastId = 1
+
+  def getNewId() = synchronized {
+    lastId += 1
+    if (lastId == 0) lastId += 1
+    lastId
+  }
+
+  def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
+    if (dataBuffers == null) {
+      return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+    } 
+    if (dataBuffers.exists(_ == null)) {
+      throw new Exception("Attempting to create buffer message with null buffer")
+    }
+    return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+  }
+
+  def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+    createBufferMessage(dataBuffers, 0)
+  
+  def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
+    if (dataBuffer == null) {
+      return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+    } else {
+      return createBufferMessage(Array(dataBuffer), ackId)
+    }
+  }
+ 
+  def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = 
+    createBufferMessage(dataBuffer, 0)
+  
+  def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
+
+  def create(header: MessageChunkHeader): Message = {
+    val newMessage: Message = header.typ match {
+      case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+    }
+    newMessage.senderAddress = header.address
+    newMessage
+  }
+}
diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala
new file mode 100644
index 0000000000000000000000000000000000000000..e1ba7c06c04dfd615ef5f23ae710fc73faaf6e11
--- /dev/null
+++ b/core/src/main/scala/spark/network/ReceiverTest.scala
@@ -0,0 +1,20 @@
+package spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object ReceiverTest {
+
+  def main(args: Array[String]) {
+    val manager = new ConnectionManager(9999)
+    println("Started connection manager with id = " + manager.id)
+    
+    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { 
+      /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
+      val buffer = ByteBuffer.wrap("response".getBytes())
+      Some(Message.createBufferMessage(buffer, msg.id))
+    })
+    Thread.currentThread.join()  
+  }
+}
+
diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4ab6dd34140992fdc9d6b1642b5b4d6ae1e69e2c
--- /dev/null
+++ b/core/src/main/scala/spark/network/SenderTest.scala
@@ -0,0 +1,53 @@
+package spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+object SenderTest {
+
+  def main(args: Array[String]) {
+    
+    if (args.length < 2) {
+      println("Usage: SenderTest <target host> <target port>")
+      System.exit(1)
+    }
+   
+    val targetHost = args(0)
+    val targetPort = args(1).toInt
+    val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
+
+    val manager = new ConnectionManager(0)
+    println("Started connection manager with id = " + manager.id)
+
+    manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { 
+      println("Received [" + msg + "] from [" + id + "]")
+      None
+    })
+  
+    val size =  100 * 1024  * 1024 
+    val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+    buffer.flip
+
+    val targetServer = args(0)
+
+    val count = 100
+    (0 until count).foreach(i => {
+      val dataMessage = Message.createBufferMessage(buffer.duplicate)
+      val startTime = System.currentTimeMillis  
+      /*println("Started timer at " + startTime)*/
+      val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
+        case Some(response) =>
+          val buffer = response.asInstanceOf[BufferMessage].buffers(0)
+          new String(buffer.array)
+        case None => "none"
+      }
+      val finishTime = System.currentTimeMillis
+      val mb = size / 1024.0 / 1024.0
+      val ms = finishTime - startTime
+      /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
+      val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" +  (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
+      println(resultStr)
+    })
+  }
+}
+
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
new file mode 100644
index 0000000000000000000000000000000000000000..260547902bb4a743e7a48ec1fb2d5a8b3b56da9c
--- /dev/null
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -0,0 +1,66 @@
+package spark.partial
+
+import spark._
+import spark.scheduler.JobListener
+
+/**
+ * A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
+ * This listener waits up to timeout milliseconds and will return a partial answer even if the
+ * complete answer is not available by then.
+ *
+ * This class assumes that the action is performed on an entire RDD[T] via a function that computes
+ * a result of type U for each partition, and that the action returns a partial or complete result
+ * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
+ */
+class ApproximateActionListener[T, U, R](
+    rdd: RDD[T],
+    func: (TaskContext, Iterator[T]) => U,
+    evaluator: ApproximateEvaluator[U, R],
+    timeout: Long)
+  extends JobListener {
+
+  val startTime = System.currentTimeMillis()
+  val totalTasks = rdd.splits.size
+  var finishedTasks = 0
+  var failure: Option[Exception] = None             // Set if the job has failed (permanently)
+  var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
+
+  override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+    evaluator.merge(index, result.asInstanceOf[U])
+    finishedTasks += 1
+    if (finishedTasks == totalTasks) {
+      // If we had already returned a PartialResult, set its final value
+      resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
+      // Notify any waiting thread that may have called getResult
+      this.notifyAll()
+    }
+  }
+
+  override def jobFailed(exception: Exception): Unit = synchronized {
+    failure = Some(exception)
+    this.notifyAll()
+  }
+
+  /**
+   * Waits for up to timeout milliseconds since the listener was created and then returns a
+   * PartialResult with the result so far. This may be complete if the whole job is done.
+   */
+  def getResult(): PartialResult[R] = synchronized {
+    val finishTime = startTime + timeout
+    while (true) {
+      val time = System.currentTimeMillis()
+      if (failure != None) {
+        throw failure.get
+      } else if (finishedTasks == totalTasks) {
+        return new PartialResult(evaluator.currentResult(), true)
+      } else if (time >= finishTime) {
+        resultObject = Some(new PartialResult(evaluator.currentResult(), false))
+        return resultObject.get
+      } else {
+        this.wait(finishTime - time)
+      }
+    }
+    // Should never be reached, but required to keep the compiler happy
+    return null
+  }
+}
diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4772e43ef04118cc25a2555ca3c250268496264f
--- /dev/null
+++ b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala
@@ -0,0 +1,10 @@
+package spark.partial
+
+/**
+ * An object that computes a function incrementally by merging in results of type U from multiple
+ * tasks. Allows partial evaluation at any point by calling currentResult().
+ */
+trait ApproximateEvaluator[U, R] {
+  def merge(outputId: Int, taskResult: U): Unit
+  def currentResult(): R
+}
diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala
new file mode 100644
index 0000000000000000000000000000000000000000..463c33d6e238ebc688390accd0b66e4b4ef10cf5
--- /dev/null
+++ b/core/src/main/scala/spark/partial/BoundedDouble.scala
@@ -0,0 +1,8 @@
+package spark.partial
+
+/**
+ * A Double with error bars on it.
+ */
+class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+  override def toString(): String = "[%.3f, %.3f]".format(low, high)
+}
diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..1bc90d6b3930aab7b870cbca4a2b0731723be1e8
--- /dev/null
+++ b/core/src/main/scala/spark/partial/CountEvaluator.scala
@@ -0,0 +1,38 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * An ApproximateEvaluator for counts.
+ *
+ * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
+ * be best to make this a special case of GroupedCountEvaluator with one group.
+ */
+class CountEvaluator(totalOutputs: Int, confidence: Double)
+  extends ApproximateEvaluator[Long, BoundedDouble] {
+
+  var outputsMerged = 0
+  var sum: Long = 0
+
+  override def merge(outputId: Int, taskResult: Long) {
+    outputsMerged += 1
+    sum += taskResult
+  }
+
+  override def currentResult(): BoundedDouble = {
+    if (outputsMerged == totalOutputs) {
+      new BoundedDouble(sum, 1.0, sum, sum)
+    } else if (outputsMerged == 0) {
+      new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+    } else {
+      val p = outputsMerged.toDouble / totalOutputs
+      val mean = (sum + 1 - p) / p
+      val variance = (sum + 1) * (1 - p) / (p * p)
+      val stdev = math.sqrt(variance)
+      val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+      val low = mean - confFactor * stdev
+      val high = mean + confFactor * stdev
+      new BoundedDouble(mean, confidence, low, high)
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..3e631c0efc5517c184126ff4602988d1e79297e6
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala
@@ -0,0 +1,62 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import cern.jet.stat.Probability
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+/**
+ * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
+ */
+class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
+  extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
+
+  var outputsMerged = 0
+  var sums = new OLMap[T]   // Sum of counts for each key
+
+  override def merge(outputId: Int, taskResult: OLMap[T]) {
+    outputsMerged += 1
+    val iter = taskResult.object2LongEntrySet.fastIterator()
+    while (iter.hasNext) {
+      val entry = iter.next()
+      sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
+    }
+  }
+
+  override def currentResult(): Map[T, BoundedDouble] = {
+    if (outputsMerged == totalOutputs) {
+      val result = new JHashMap[T, BoundedDouble](sums.size)
+      val iter = sums.object2LongEntrySet.fastIterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val sum = entry.getLongValue()
+        result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+      }
+      result
+    } else if (outputsMerged == 0) {
+      new HashMap[T, BoundedDouble]
+    } else {
+      val p = outputsMerged.toDouble / totalOutputs
+      val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+      val result = new JHashMap[T, BoundedDouble](sums.size)
+      val iter = sums.object2LongEntrySet.fastIterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val sum = entry.getLongValue
+        val mean = (sum + 1 - p) / p
+        val variance = (sum + 1) * (1 - p) / (p * p)
+        val stdev = math.sqrt(variance)
+        val low = mean - confFactor * stdev
+        val high = mean + confFactor * stdev
+        result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+      }
+      result
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2a9ccba2055efc5121de8789b225a9808bb475b9
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala
@@ -0,0 +1,65 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
+ */
+class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
+  extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+  var outputsMerged = 0
+  var sums = new JHashMap[T, StatCounter]   // Sum of counts for each key
+
+  override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+    outputsMerged += 1
+    val iter = taskResult.entrySet.iterator()
+    while (iter.hasNext) {
+      val entry = iter.next()
+      val old = sums.get(entry.getKey)
+      if (old != null) {
+        old.merge(entry.getValue)
+      } else {
+        sums.put(entry.getKey, entry.getValue)
+      }
+    }
+  }
+
+  override def currentResult(): Map[T, BoundedDouble] = {
+    if (outputsMerged == totalOutputs) {
+      val result = new JHashMap[T, BoundedDouble](sums.size)
+      val iter = sums.entrySet.iterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val mean = entry.getValue.mean
+        result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean)
+      }
+      result
+    } else if (outputsMerged == 0) {
+      new HashMap[T, BoundedDouble]
+    } else {
+      val p = outputsMerged.toDouble / totalOutputs
+      val studentTCacher = new StudentTCacher(confidence)
+      val result = new JHashMap[T, BoundedDouble](sums.size)
+      val iter = sums.entrySet.iterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val counter = entry.getValue
+        val mean = counter.mean
+        val stdev = math.sqrt(counter.sampleVariance / counter.count)
+        val confFactor = studentTCacher.get(counter.count)
+        val low = mean - confFactor * stdev
+        val high = mean + confFactor * stdev
+        result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+      }
+      result
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6a2ec7a7bd30e53bf4844ff1f4382f3118bbc635
--- /dev/null
+++ b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala
@@ -0,0 +1,72 @@
+package spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
+ */
+class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
+  extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+  var outputsMerged = 0
+  var sums = new JHashMap[T, StatCounter]   // Sum of counts for each key
+
+  override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+    outputsMerged += 1
+    val iter = taskResult.entrySet.iterator()
+    while (iter.hasNext) {
+      val entry = iter.next()
+      val old = sums.get(entry.getKey)
+      if (old != null) {
+        old.merge(entry.getValue)
+      } else {
+        sums.put(entry.getKey, entry.getValue)
+      }
+    }
+  }
+
+  override def currentResult(): Map[T, BoundedDouble] = {
+    if (outputsMerged == totalOutputs) {
+      val result = new JHashMap[T, BoundedDouble](sums.size)
+      val iter = sums.entrySet.iterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val sum = entry.getValue.sum
+        result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+      }
+      result
+    } else if (outputsMerged == 0) {
+      new HashMap[T, BoundedDouble]
+    } else {
+      val p = outputsMerged.toDouble / totalOutputs
+      val studentTCacher = new StudentTCacher(confidence)
+      val result = new JHashMap[T, BoundedDouble](sums.size)
+      val iter = sums.entrySet.iterator()
+      while (iter.hasNext) {
+        val entry = iter.next()
+        val counter = entry.getValue
+        val meanEstimate = counter.mean
+        val meanVar = counter.sampleVariance / counter.count
+        val countEstimate = (counter.count + 1 - p) / p
+        val countVar = (counter.count + 1) * (1 - p) / (p * p)
+        val sumEstimate = meanEstimate * countEstimate
+        val sumVar = (meanEstimate * meanEstimate * countVar) +
+                     (countEstimate * countEstimate * meanVar) +
+                     (meanVar * countVar)
+        val sumStdev = math.sqrt(sumVar)
+        val confFactor = studentTCacher.get(counter.count)
+        val low = sumEstimate - confFactor * sumStdev
+        val high = sumEstimate + confFactor * sumStdev
+        result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high)
+      }
+      result
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b8c7cb8863539096ec9577e1c43ec1831c545423
--- /dev/null
+++ b/core/src/main/scala/spark/partial/MeanEvaluator.scala
@@ -0,0 +1,41 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means.
+ */
+class MeanEvaluator(totalOutputs: Int, confidence: Double)
+  extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+  var outputsMerged = 0
+  var counter = new StatCounter
+
+  override def merge(outputId: Int, taskResult: StatCounter) {
+    outputsMerged += 1
+    counter.merge(taskResult)
+  }
+
+  override def currentResult(): BoundedDouble = {
+    if (outputsMerged == totalOutputs) {
+      new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
+    } else if (outputsMerged == 0) {
+      new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+    } else {
+      val mean = counter.mean
+      val stdev = math.sqrt(counter.sampleVariance / counter.count)
+      val confFactor = {
+        if (counter.count > 100) {
+          Probability.normalInverse(1 - (1 - confidence) / 2)
+        } else {
+          Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+        }
+      }
+      val low = mean - confFactor * stdev
+      val high = mean + confFactor * stdev
+      new BoundedDouble(mean, confidence, low, high)
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7095bc8ca1bbf4d134a3ce01b3cd1826e3a93722
--- /dev/null
+++ b/core/src/main/scala/spark/partial/PartialResult.scala
@@ -0,0 +1,86 @@
+package spark.partial
+
+class PartialResult[R](initialVal: R, isFinal: Boolean) {
+  private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
+  private var failure: Option[Exception] = None
+  private var completionHandler: Option[R => Unit] = None
+  private var failureHandler: Option[Exception => Unit] = None
+
+  def initialValue: R = initialVal
+
+  def isInitialValueFinal: Boolean = isFinal
+
+  /**
+   * Blocking method to wait for and return the final value.
+   */
+  def getFinalValue(): R = synchronized {
+    while (finalValue == None && failure == None) {
+      this.wait()
+    }
+    if (finalValue != None) {
+      return finalValue.get
+    } else {
+      throw failure.get
+    }
+  }
+
+  /** 
+   * Set a handler to be called when this PartialResult completes. Only one completion handler
+   * is supported per PartialResult.
+   */
+  def onComplete(handler: R => Unit): PartialResult[R] = synchronized {
+    if (completionHandler != None) {
+      throw new UnsupportedOperationException("onComplete cannot be called twice")
+    }
+    completionHandler = Some(handler)
+    if (finalValue != None) {
+      // We already have a final value, so let's call the handler
+      handler(finalValue.get)
+    }
+    return this
+  }
+
+  /** 
+   * Set a handler to be called if this PartialResult's job fails. Only one failure handler
+   * is supported per PartialResult.
+   */
+  def onFail(handler: Exception => Unit): Unit = synchronized {
+    if (failureHandler != None) {
+      throw new UnsupportedOperationException("onFail cannot be called twice")
+    }
+    failureHandler = Some(handler)
+    if (failure != None) {
+      // We already have a failure, so let's call the handler
+      handler(failure.get)
+    }
+  }
+
+  private[spark] def setFinalValue(value: R): Unit = synchronized {
+    if (finalValue != None) {
+      throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
+    }
+    finalValue = Some(value)
+    // Call the completion handler if it was set
+    completionHandler.foreach(h => h(value))
+    // Notify any threads that may be calling getFinalValue()
+    this.notifyAll()
+  }
+
+  private[spark] def setFailure(exception: Exception): Unit = synchronized {
+    if (failure != None) {
+      throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
+    }
+    failure = Some(exception)
+    // Call the failure handler if it was set
+    failureHandler.foreach(h => h(exception))
+    // Notify any threads that may be calling getFinalValue()
+    this.notifyAll()
+  }
+
+  override def toString: String = synchronized {
+    finalValue match {
+      case Some(value) => "(final: " + value + ")"
+      case None => "(partial: " + initialValue + ")"
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6263ee3518d8c21beb081d4c26dd0aa837f683d5
--- /dev/null
+++ b/core/src/main/scala/spark/partial/StudentTCacher.scala
@@ -0,0 +1,26 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * A utility class for caching Student's T distribution values for a given confidence level
+ * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
+ * confidence intervals for many keys.
+ */
+class StudentTCacher(confidence: Double) {
+  val NORMAL_APPROX_SAMPLE_SIZE = 100  // For samples bigger than this, use Gaussian approximation
+  val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+  val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
+
+  def get(sampleSize: Long): Double = {
+    if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
+      normalApprox
+    } else {
+      val size = sampleSize.toInt
+      if (cache(size) < 0) {
+        cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+      }
+      cache(size)
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0357a6bff860a78729f759d44ff63feae76236fa
--- /dev/null
+++ b/core/src/main/scala/spark/partial/SumEvaluator.scala
@@ -0,0 +1,51 @@
+package spark.partial
+
+import cern.jet.stat.Probability
+
+import spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them
+ * together, then uses the formula for the variance of two independent random variables to get
+ * a variance for the result and compute a confidence interval.
+ */
+class SumEvaluator(totalOutputs: Int, confidence: Double)
+  extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+  var outputsMerged = 0
+  var counter = new StatCounter
+
+  override def merge(outputId: Int, taskResult: StatCounter) {
+    outputsMerged += 1
+    counter.merge(taskResult)
+  }
+
+  override def currentResult(): BoundedDouble = {
+    if (outputsMerged == totalOutputs) {
+      new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
+    } else if (outputsMerged == 0) {
+      new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+    } else {
+      val p = outputsMerged.toDouble / totalOutputs
+      val meanEstimate = counter.mean
+      val meanVar = counter.sampleVariance / counter.count
+      val countEstimate = (counter.count + 1 - p) / p
+      val countVar = (counter.count + 1) * (1 - p) / (p * p)
+      val sumEstimate = meanEstimate * countEstimate
+      val sumVar = (meanEstimate * meanEstimate * countVar) +
+                   (countEstimate * countEstimate * meanVar) +
+                   (meanVar * countVar)
+      val sumStdev = math.sqrt(sumVar)
+      val confFactor = {
+        if (counter.count > 100) {
+          Probability.normalInverse(1 - (1 - confidence) / 2)
+        } else {
+          Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+        }
+      }
+      val low = sumEstimate - confFactor * sumStdev
+      val high = sumEstimate + confFactor * sumStdev
+      new BoundedDouble(sumEstimate, confidence, low, high)
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0ecff9ce77ea773c30d9947a342327d2bf88fa29
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala
@@ -0,0 +1,18 @@
+package spark.scheduler
+
+import spark.TaskContext
+
+/**
+ * Tracks information about an active job in the DAGScheduler.
+ */
+class ActiveJob(
+    val runId: Int,
+    val finalStage: Stage,
+    val func: (TaskContext, Iterator[_]) => _,
+    val partitions: Array[Int],
+    val listener: JobListener) {
+
+  val numPartitions = partitions.length
+  val finished = Array.fill[Boolean](numPartitions)(false)
+  var numFinished = 0
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f31e2c65a050d59302810be769a28a6c9bed67aa
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -0,0 +1,532 @@
+package spark.scheduler
+
+import java.net.URI
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.Future
+import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+
+import spark._
+import spark.partial.ApproximateActionListener
+import spark.partial.ApproximateEvaluator
+import spark.partial.PartialResult
+import spark.storage.BlockManagerMaster
+import spark.storage.BlockManagerId
+
+/**
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for 
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal 
+ * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
+ * and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
+ */
+class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging {
+  taskSched.setListener(this)
+
+  // Called by TaskScheduler to report task completions or failures.
+  override def taskEnded(
+      task: Task[_],
+      reason: TaskEndReason,
+      result: Any,
+      accumUpdates: Map[Long, Any]) {
+    eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
+  }
+
+  // Called by TaskScheduler when a host fails.
+  override def hostLost(host: String) {
+    eventQueue.put(HostLost(host))
+  }
+
+  // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
+  // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
+  // as more failure events come in
+  val RESUBMIT_TIMEOUT = 50L
+
+  // The time, in millis, to wake up between polls of the completion queue in order to potentially
+  // resubmit failed stages
+  val POLL_TIMEOUT = 10L
+
+  private val lock = new Object          // Used for access to the entire DAGScheduler
+
+  private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
+
+  val nextRunId = new AtomicInteger(0)
+
+  val nextStageId = new AtomicInteger(0)
+
+  val idToStage = new HashMap[Int, Stage]
+
+  val shuffleToMapStage = new HashMap[Int, Stage]
+
+  var cacheLocs = new HashMap[Int, Array[List[String]]]
+
+  val env = SparkEnv.get
+  val cacheTracker = env.cacheTracker
+  val mapOutputTracker = env.mapOutputTracker
+
+  val deadHosts = new HashSet[String]  // TODO: The code currently assumes these can't come back;
+                                       // that's not going to be a realistic assumption in general
+  
+  val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
+  val running = new HashSet[Stage] // Stages we are running right now
+  val failed = new HashSet[Stage]  // Stages that must be resubmitted due to fetch failures
+  val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+  var lastFetchFailureTime: Long = 0  // Used to wait a bit to avoid repeated resubmits
+
+  val activeJobs = new HashSet[ActiveJob]
+  val resultStageToJob = new HashMap[Stage, ActiveJob]
+
+  // Start a thread to run the DAGScheduler event loop
+  new Thread("DAGScheduler") {
+    setDaemon(true)
+    override def run() {
+      DAGScheduler.this.run()
+    }
+  }.start()
+
+  def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+    cacheLocs(rdd.id)
+  }
+  
+  def updateCacheLocs() {
+    cacheLocs = cacheTracker.getLocationsSnapshot()
+  }
+
+  /**
+   * Get or create a shuffle map stage for the given shuffle dependency's map side.
+   * The priority value passed in will be used if the stage doesn't already exist with
+   * a lower priority (we assume that priorities always increase across jobs for now).
+   */
+  def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
+    shuffleToMapStage.get(shuffleDep.shuffleId) match {
+      case Some(stage) => stage
+      case None =>
+        val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority)
+        shuffleToMapStage(shuffleDep.shuffleId) = stage
+        stage
+    }
+  }
+
+  /**
+   * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
+   * as a result stage for the final RDD used directly in an action. The stage will also be given
+   * the provided priority.
+   */
+  def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
+    // Kind of ugly: need to register RDDs with the cache and map output tracker here
+    // since we can't do it in the RDD constructor because # of splits is unknown
+    logInfo("Registering RDD " + rdd.id + ": " + rdd)
+    cacheTracker.registerRDD(rdd.id, rdd.splits.size)
+    if (shuffleDep != None) {
+      mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
+    }
+    val id = nextStageId.getAndIncrement()
+    val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
+    idToStage(id) = stage
+    stage
+  }
+
+  /**
+   * Get or create the list of parent stages for a given RDD. The stages will be assigned the
+   * provided priority if they haven't already been created with a lower priority.
+   */
+  def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+    val parents = new HashSet[Stage]
+    val visited = new HashSet[RDD[_]]
+    def visit(r: RDD[_]) {
+      if (!visited(r)) {
+        visited += r
+        // Kind of ugly: need to register RDDs with the cache here since
+        // we can't do it in its constructor because # of splits is unknown
+        logInfo("Registering parent RDD " + r.id + ": " + r)
+        cacheTracker.registerRDD(r.id, r.splits.size)
+        for (dep <- r.dependencies) {
+          dep match {
+            case shufDep: ShuffleDependency[_,_,_] =>
+              parents += getShuffleMapStage(shufDep, priority)
+            case _ =>
+              visit(dep.rdd)
+          }
+        }
+      }
+    }
+    visit(rdd)
+    parents.toList
+  }
+
+  def getMissingParentStages(stage: Stage): List[Stage] = {
+    val missing = new HashSet[Stage]
+    val visited = new HashSet[RDD[_]]
+    def visit(rdd: RDD[_]) {
+      if (!visited(rdd)) {
+        visited += rdd
+        val locs = getCacheLocs(rdd)
+        for (p <- 0 until rdd.splits.size) {
+          if (locs(p) == Nil) {
+            for (dep <- rdd.dependencies) {
+              dep match {
+                case shufDep: ShuffleDependency[_,_,_] =>
+                  val mapStage = getShuffleMapStage(shufDep, stage.priority)
+                  if (!mapStage.isAvailable) {
+                    missing += mapStage
+                  }
+                case narrowDep: NarrowDependency[_] =>
+                  visit(narrowDep.rdd)
+              }
+            }
+          }
+        }
+      }
+    }
+    visit(stage.rdd)
+    missing.toList
+  }
+
+  def runJob[T, U](
+      finalRdd: RDD[T],
+      func: (TaskContext, Iterator[T]) => U,
+      partitions: Seq[Int],
+      allowLocal: Boolean)
+      (implicit m: ClassManifest[U]): Array[U] =
+  {
+    val waiter = new JobWaiter(partitions.size)
+    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+    eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter))
+    waiter.getResult() match {
+      case JobSucceeded(results: Seq[_]) =>
+        return results.asInstanceOf[Seq[U]].toArray
+      case JobFailed(exception: Exception) =>
+        throw exception
+    }
+  }
+
+  def runApproximateJob[T, U, R](
+      rdd: RDD[T],
+      func: (TaskContext, Iterator[T]) => U,
+      evaluator: ApproximateEvaluator[U, R],
+      timeout: Long
+      ): PartialResult[R] =
+  {
+    val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
+    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+    val partitions = (0 until rdd.splits.size).toArray
+    eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener))
+    return listener.getResult()    // Will throw an exception if the job fails
+  }
+
+  /**
+   * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
+   * events and responds by launching tasks. This runs in a dedicated thread and receives events
+   * via the eventQueue.
+   */
+  def run() = {
+    SparkEnv.set(env)
+
+    while (true) {
+      val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
+      val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
+      if (event != null) {
+        logDebug("Got event of type " + event.getClass.getName)
+      }
+
+      event match {
+        case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) =>
+          val runId = nextRunId.getAndIncrement()
+          val finalStage = newStage(finalRDD, None, runId)
+          val job = new ActiveJob(runId, finalStage, func, partitions, listener)
+          updateCacheLocs()
+          logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions")
+          logInfo("Final stage: " + finalStage)
+          logInfo("Parents of final stage: " + finalStage.parents)
+          logInfo("Missing parents: " + getMissingParentStages(finalStage))
+          if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+            // Compute very short actions like first() or take() with no parent stages locally.
+            runLocally(job)
+          } else {
+            activeJobs += job
+            resultStageToJob(finalStage) = job
+            submitStage(finalStage)
+          }
+
+        case HostLost(host) =>
+          handleHostLost(host)
+
+        case completion: CompletionEvent =>
+          handleTaskCompletion(completion)
+
+        case null =>
+          // queue.poll() timed out, ignore it
+      }
+
+      // Periodically resubmit failed stages if some map output fetches have failed and we have
+      // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
+      // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
+      // the same time, so we want to make sure we've identified all the reduce tasks that depend
+      // on the failed node.
+      if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
+        logInfo("Resubmitting failed stages")
+        updateCacheLocs()
+        val failed2 = failed.toArray
+        failed.clear()
+        for (stage <- failed2.sortBy(_.priority)) {
+          submitStage(stage)
+        }
+      } else {
+        // TODO: We might want to run this less often, when we are sure that something has become
+        // runnable that wasn't before.
+        logDebug("Checking for newly runnable parent stages")
+        logDebug("running: " + running)
+        logDebug("waiting: " + waiting)
+        logDebug("failed: " + failed)
+        val waiting2 = waiting.toArray
+        waiting.clear()
+        for (stage <- waiting2.sortBy(_.priority)) {
+          submitStage(stage)
+        }
+      }
+    }
+  }
+
+  /**
+   * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
+   * We run the operation in a separate thread just in case it takes a bunch of time, so that we
+   * don't block the DAGScheduler event loop or other concurrent jobs.
+   */
+  def runLocally(job: ActiveJob) {
+    logInfo("Computing the requested partition locally")
+    new Thread("Local computation of job " + job.runId) {
+      override def run() {
+        try {
+          SparkEnv.set(env)
+          val rdd = job.finalStage.rdd
+          val split = rdd.splits(job.partitions(0))
+          val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+          val result = job.func(taskContext, rdd.iterator(split))
+          job.listener.taskSucceeded(0, result)
+        } catch {
+          case e: Exception =>
+            job.listener.jobFailed(e)
+        }
+      }
+    }.start()
+  }
+
+  def submitStage(stage: Stage) {
+    logDebug("submitStage(" + stage + ")")
+    if (!waiting(stage) && !running(stage) && !failed(stage)) {
+      val missing = getMissingParentStages(stage).sortBy(_.id)
+      logDebug("missing: " + missing)
+      if (missing == Nil) {
+        logInfo("Submitting " + stage + ", which has no missing parents")
+        submitMissingTasks(stage)
+        running += stage
+      } else {
+        for (parent <- missing) {
+          submitStage(parent)
+        }
+        waiting += stage
+      }
+    }
+  }
+  
+  def submitMissingTasks(stage: Stage) {
+    logDebug("submitMissingTasks(" + stage + ")")
+    // Get our pending tasks and remember them in our pendingTasks entry
+    val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
+    myPending.clear()
+    var tasks = ArrayBuffer[Task[_]]()
+    if (stage.isShuffleMap) {
+      for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
+        val locs = getPreferredLocs(stage.rdd, p)
+        tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+      }
+    } else {
+      // This is a final stage; figure out its job's missing partitions
+      val job = resultStageToJob(stage)
+      for (id <- 0 until job.numPartitions if (!job.finished(id))) {
+        val partition = job.partitions(id)
+        val locs = getPreferredLocs(stage.rdd, partition)
+        tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
+      }
+    }
+    if (tasks.size > 0) {
+      logInfo("Submitting " + tasks.size + " missing tasks from " + stage)
+      myPending ++= tasks
+      logDebug("New pending tasks: " + myPending)
+      taskSched.submitTasks(
+        new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+    } else {
+      logDebug("Stage " + stage + " is actually done; %b %d %d".format(
+        stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
+      running -= stage
+    }
+  }
+
+  /**
+   * Responds to a task finishing. This is called inside the event loop so it assumes that it can
+   * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
+   */
+  def handleTaskCompletion(event: CompletionEvent) {
+    val task = event.task
+    val stage = idToStage(task.stageId)
+    event.reason match {
+      case Success =>  
+        logInfo("Completed " + task)
+        if (event.accumUpdates != null) {
+          Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
+        }
+        pendingTasks(stage) -= task
+        task match {
+          case rt: ResultTask[_, _] =>
+            resultStageToJob.get(stage) match {
+              case Some(job) =>
+                if (!job.finished(rt.outputId)) {
+                  job.finished(rt.outputId) = true
+                  job.numFinished += 1
+                  job.listener.taskSucceeded(rt.outputId, event.result)
+                  // If the whole job has finished, remove it
+                  if (job.numFinished == job.numPartitions) {
+                    activeJobs -= job
+                    resultStageToJob -= stage
+                    running -= stage
+                  }
+                }
+              case None =>
+                logInfo("Ignoring result from " + rt + " because its job has finished")
+            }
+
+          case smt: ShuffleMapTask =>
+            val stage = idToStage(smt.stageId)
+            val bmAddress = event.result.asInstanceOf[BlockManagerId]
+            val host = bmAddress.ip
+            logInfo("ShuffleMapTask finished with host " + host)
+            if (!deadHosts.contains(host)) {   // TODO: Make sure hostnames are consistent with Mesos
+              stage.addOutputLoc(smt.partition, bmAddress)
+            }
+            if (running.contains(stage) && pendingTasks(stage).isEmpty) {
+              logInfo(stage + " finished; looking for newly runnable stages")
+              running -= stage
+              logInfo("running: " + running)
+              logInfo("waiting: " + waiting)
+              logInfo("failed: " + failed)
+              if (stage.shuffleDep != None) {
+                mapOutputTracker.registerMapOutputs(
+                  stage.shuffleDep.get.shuffleId,
+                  stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+              }
+              updateCacheLocs()
+              if (stage.outputLocs.count(_ == Nil) != 0) {
+                // Some tasks had failed; let's resubmit this stage
+                // TODO: Lower-level scheduler should also deal with this
+                logInfo("Resubmitting " + stage + " because some of its tasks had failed: " +
+                  stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
+                submitStage(stage)
+              } else {
+                val newlyRunnable = new ArrayBuffer[Stage]
+                for (stage <- waiting) {
+                  logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
+                }
+                for (stage <- waiting if getMissingParentStages(stage) == Nil) {
+                  newlyRunnable += stage
+                }
+                waiting --= newlyRunnable
+                running ++= newlyRunnable
+                for (stage <- newlyRunnable.sortBy(_.id)) {
+                  submitMissingTasks(stage)
+                }
+              }
+            }
+          }
+
+      case Resubmitted =>
+        logInfo("Resubmitted " + task + ", so marking it as still running")
+        pendingTasks(stage) += task
+
+      case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+        // Mark the stage that the reducer was in as unrunnable
+        val failedStage = idToStage(task.stageId)
+        running -= failedStage
+        failed += failedStage
+        // TODO: Cancel running tasks in the stage
+        logInfo("Marking " + failedStage + " for resubmision due to a fetch failure")
+        // Mark the map whose fetch failed as broken in the map stage
+        val mapStage = shuffleToMapStage(shuffleId)
+        mapStage.removeOutputLoc(mapId, bmAddress)
+        mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+        logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission")
+        failed += mapStage
+        // Remember that a fetch failed now; this is used to resubmit the broken
+        // stages later, after a small wait (to give other tasks the chance to fail)
+        lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
+        // TODO: mark the host as failed only if there were lots of fetch failures on it
+        if (bmAddress != null) {
+          handleHostLost(bmAddress.ip)
+        }
+
+      case _ =>
+        // Non-fetch failure -- probably a bug in the job, so bail out
+        // TODO: Cancel all tasks that are still running
+        resultStageToJob.get(stage) match {
+          case Some(job) =>
+            val error = new SparkException("Task failed: " + task + ", reason: " + event.reason)
+            job.listener.jobFailed(error)
+            activeJobs -= job
+            resultStageToJob -= stage
+          case None =>
+            logInfo("Ignoring result from " + task + " because its job has finished")
+        }
+    }
+  }
+
+  /**
+   * Responds to a host being lost. This is called inside the event loop so it assumes that it can
+   * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+   */
+  def handleHostLost(host: String) {
+    if (!deadHosts.contains(host)) {
+      logInfo("Host lost: " + host)
+      deadHosts += host
+      BlockManagerMaster.notifyADeadHost(host)
+      // TODO: This will be really slow if we keep accumulating shuffle map stages
+      for ((shuffleId, stage) <- shuffleToMapStage) {
+        stage.removeOutputsOnHost(host)
+        val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+        mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
+      }
+      cacheTracker.cacheLost(host)
+      updateCacheLocs()
+    }
+  }
+
+  def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+    // If the partition is cached, return the cache locations
+    val cached = getCacheLocs(rdd)(partition)
+    if (cached != Nil) {
+      return cached
+    }
+    // If the RDD has some placement preferences (as is the case for input RDDs), get those
+    val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
+    if (rddPrefs != Nil) {
+      return rddPrefs
+    }
+    // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
+    // that has any placement preferences. Ideally we would choose based on transfer sizes,
+    // but this will do for now.
+    rdd.dependencies.foreach(_ match {
+      case n: NarrowDependency[_] =>
+        for (inPart <- n.getParents(partition)) {
+          val locs = getPreferredLocs(n.rdd, inPart)
+          if (locs != Nil)
+            return locs;
+        }
+      case _ =>
+    })
+    return Nil
+  }
+
+  def stop() {
+    // TODO: Put a stop event on our queue and break the event loop
+    taskSched.stop()
+  }
+}
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c10abc92028993d9200676d60139493ee5df5f62
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -0,0 +1,30 @@
+package spark.scheduler
+
+import scala.collection.mutable.Map
+
+import spark._
+
+/**
+ * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
+ * architecture where any thread can post an event (e.g. a task finishing or a new job being
+ * submitted) but there is a single "logic" thread that reads these events and takes decisions.
+ * This greatly simplifies synchronization.
+ */
+sealed trait DAGSchedulerEvent
+
+case class JobSubmitted(
+    finalRDD: RDD[_],
+    func: (TaskContext, Iterator[_]) => _,
+    partitions: Array[Int],
+    allowLocal: Boolean,
+    listener: JobListener)
+  extends DAGSchedulerEvent
+
+case class CompletionEvent(
+    task: Task[_],
+    reason: TaskEndReason,
+    result: Any,
+    accumUpdates: Map[Long, Any])
+  extends DAGSchedulerEvent
+
+case class HostLost(host: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala
new file mode 100644
index 0000000000000000000000000000000000000000..d4dd536a7de553f92d3c8a506df39805bb89d77f
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobListener.scala
@@ -0,0 +1,11 @@
+package spark.scheduler
+
+/**
+ * Interface used to listen for job completion or failure events after submitting a job to the
+ * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
+ * job fails (and no further taskSucceeded events will happen).
+ */
+trait JobListener {
+  def taskSucceeded(index: Int, result: Any)
+  def jobFailed(exception: Exception)
+}
diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala
new file mode 100644
index 0000000000000000000000000000000000000000..62b458eccbd22822592b236ba2c67ad15c4a2b4b
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobResult.scala
@@ -0,0 +1,9 @@
+package spark.scheduler
+
+/**
+ * A result of a job in the DAGScheduler.
+ */
+sealed trait JobResult
+
+case class JobSucceeded(results: Seq[_]) extends JobResult
+case class JobFailed(exception: Exception) extends JobResult
diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala
new file mode 100644
index 0000000000000000000000000000000000000000..be8ec9bd7b07e9d8ac8e986ae9a20b575b9bbd0c
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala
@@ -0,0 +1,43 @@
+package spark.scheduler
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An object that waits for a DAGScheduler job to complete.
+ */
+class JobWaiter(totalTasks: Int) extends JobListener {
+  private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null)
+  private var finishedTasks = 0
+
+  private var jobFinished = false          // Is the job as a whole finished (succeeded or failed)?
+  private var jobResult: JobResult = null  // If the job is finished, this will be its result
+
+  override def taskSucceeded(index: Int, result: Any) = synchronized {
+    if (jobFinished) {
+      throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+    }
+    taskResults(index) = result
+    finishedTasks += 1
+    if (finishedTasks == totalTasks) {
+      jobFinished = true
+      jobResult = JobSucceeded(taskResults)
+      this.notifyAll()
+    }
+  }
+
+  override def jobFailed(exception: Exception) = synchronized {
+    if (jobFinished) {
+      throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
+    }
+    jobFinished = true
+    jobResult = JobFailed(exception)
+    this.notifyAll()
+  }
+
+  def getResult(): JobResult = synchronized {
+    while (!jobFinished) {
+      this.wait()
+    }
+    return jobResult
+  }
+}
diff --git a/core/src/main/scala/spark/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
similarity index 71%
rename from core/src/main/scala/spark/ResultTask.scala
rename to core/src/main/scala/spark/scheduler/ResultTask.scala
index 3952bf85b2cdb89f83aaed4bbca8c73086e08f5d..d2fab55b5e8a1aa3af9d0ea4f1f9607449dc5b2a 100644
--- a/core/src/main/scala/spark/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,14 +1,15 @@
-package spark
+package spark.scheduler
+
+import spark._
 
 class ResultTask[T, U](
-    runId: Int,
-    stageId: Int, 
-    rdd: RDD[T], 
+    stageId: Int,
+    rdd: RDD[T],
     func: (TaskContext, Iterator[T]) => U,
-    val partition: Int, 
-    locs: Seq[String],
+    val partition: Int,
+    @transient locs: Seq[String],
     val outputId: Int)
-  extends DAGTask[U](runId, stageId) {
+  extends Task[U](stageId) {
   
   val split = rdd.splits(partition)
 
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
new file mode 100644
index 0000000000000000000000000000000000000000..317faa08510c9d9969f60d13978165080d761715
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -0,0 +1,135 @@
+package spark.scheduler
+
+import java.io._
+import java.util.HashMap
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import com.ning.compress.lzf.LZFInputStream
+import com.ning.compress.lzf.LZFOutputStream
+
+import spark._
+import spark.storage._
+
+object ShuffleMapTask {
+  val serializedInfoCache = new HashMap[Int, Array[Byte]]
+  val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])]
+
+  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
+    synchronized {
+      val old = serializedInfoCache.get(stageId)
+      if (old != null) {
+        return old
+      } else {
+        val out = new ByteArrayOutputStream
+        val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+        objOut.writeObject(rdd)
+        objOut.writeObject(dep)
+        objOut.close()
+        val bytes = out.toByteArray
+        serializedInfoCache.put(stageId, bytes)
+        return bytes
+      }
+    }
+  }
+
+  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
+    synchronized {
+      val old = deserializedInfoCache.get(stageId)
+      if (old != null) {
+        return old
+      } else {
+        val loader = currentThread.getContextClassLoader
+        val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+        val objIn = new ObjectInputStream(in) {
+          override def resolveClass(desc: ObjectStreamClass) =
+            Class.forName(desc.getName, false, loader)
+        }
+        val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+        val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]]
+        val tuple = (rdd, dep)
+        deserializedInfoCache.put(stageId, tuple)
+        return tuple
+      }
+    }
+  }
+}
+
+class ShuffleMapTask(
+    stageId: Int,
+    var rdd: RDD[_], 
+    var dep: ShuffleDependency[_,_,_],
+    var partition: Int, 
+    @transient var locs: Seq[String])
+  extends Task[BlockManagerId](stageId)
+  with Externalizable
+  with Logging {
+
+  def this() = this(0, null, null, 0, null)
+  
+  var split = if (rdd == null) {
+    null 
+  } else { 
+    rdd.splits(partition)
+  }
+
+  override def writeExternal(out: ObjectOutput) {
+    out.writeInt(stageId)
+    val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+    out.writeInt(bytes.length)
+    out.write(bytes)
+    out.writeInt(partition)
+    out.writeObject(split)
+  }
+
+  override def readExternal(in: ObjectInput) {
+    val stageId = in.readInt()
+    val numBytes = in.readInt()
+    val bytes = new Array[Byte](numBytes)
+    in.readFully(bytes)
+    val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
+    rdd = rdd_
+    dep = dep_
+    partition = in.readInt()
+    split = in.readObject().asInstanceOf[Split]
+  }
+
+  override def run(attemptId: Int): BlockManagerId = {
+    val numOutputSplits = dep.partitioner.numPartitions
+    val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
+    val partitioner = dep.partitioner.asInstanceOf[Partitioner]
+    val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
+    for (elem <- rdd.iterator(split)) {
+      val (k, v) = elem.asInstanceOf[(Any, Any)]
+      var bucketId = partitioner.getPartition(k)
+      val bucket = buckets(bucketId)
+      var existing = bucket.get(k)
+      if (existing == null) {
+        bucket.put(k, aggregator.createCombiner(v))
+      } else {
+        bucket.put(k, aggregator.mergeValue(existing, v))
+      }
+    }
+    val ser = SparkEnv.get.serializer.newInstance()
+    val blockManager = SparkEnv.get.blockManager
+    for (i <- 0 until numOutputSplits) {
+      val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
+      val arr = new ArrayBuffer[Any]
+      val iter = buckets(i).entrySet().iterator()
+      while (iter.hasNext()) {
+        val entry = iter.next()
+        arr += ((entry.getKey(), entry.getValue()))
+      }
+      // TODO: This should probably be DISK_ONLY
+      blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false)
+    }
+    return SparkEnv.get.blockManager.blockManagerId
+  }
+
+  override def preferredLocations: Seq[String] = locs
+
+  override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+}
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
new file mode 100644
index 0000000000000000000000000000000000000000..cd660c9085a751193bcc99cc93c3499276b7b72a
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -0,0 +1,86 @@
+package spark.scheduler
+
+import java.net.URI
+
+import spark._
+import spark.storage.BlockManagerId
+
+/**
+ * A stage is a set of independent tasks all computing the same function that need to run as part
+ * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
+ * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the
+ * DAGScheduler runs these stages in topological order.
+ *
+ * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for
+ * another stage, or a result stage, in which case its tasks directly compute the action that
+ * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes
+ * that each output partition is on.
+ *
+ * Each Stage also has a priority, which is (by default) based on the job it was submitted in.
+ * This allows Stages from earlier jobs to be computed first or recovered faster on failure.
+ */
+class Stage(
+    val id: Int,
+    val rdd: RDD[_],
+    val shuffleDep: Option[ShuffleDependency[_,_,_]],  // Output shuffle if stage is a map stage
+    val parents: List[Stage],
+    val priority: Int)
+  extends Logging {
+  
+  val isShuffleMap = shuffleDep != None
+  val numPartitions = rdd.splits.size
+  val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil)
+  var numAvailableOutputs = 0
+
+  private var nextAttemptId = 0
+
+  def isAvailable: Boolean = {
+    if (/*parents.size == 0 &&*/ !isShuffleMap) {
+      true
+    } else {
+      numAvailableOutputs == numPartitions
+    }
+  }
+
+  def addOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+    val prevList = outputLocs(partition)
+    outputLocs(partition) = bmAddress :: prevList
+    if (prevList == Nil)
+      numAvailableOutputs += 1
+  }
+
+  def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+    val prevList = outputLocs(partition)
+    val newList = prevList.filterNot(_ == bmAddress)
+    outputLocs(partition) = newList
+    if (prevList != Nil && newList == Nil) {
+      numAvailableOutputs -= 1
+    }
+  }
+ 
+  def removeOutputsOnHost(host: String) {
+    var becameUnavailable = false
+    for (partition <- 0 until numPartitions) {
+      val prevList = outputLocs(partition)
+      val newList = prevList.filterNot(_.ip == host)
+      outputLocs(partition) = newList
+      if (prevList != Nil && newList == Nil) {
+        becameUnavailable = true
+        numAvailableOutputs -= 1
+      }
+    }
+    if (becameUnavailable) {
+      logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable))
+    }
+  }
+
+  def newAttemptId(): Int = {
+    val id = nextAttemptId
+    nextAttemptId += 1
+    return id
+  }
+
+  override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]"
+
+  override def hashCode(): Int = id
+}
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
new file mode 100644
index 0000000000000000000000000000000000000000..42325956baa51cf1681799ad9a2b82531a7ef4ce
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -0,0 +1,11 @@
+package spark.scheduler
+
+/**
+ * A task to execute on a worker node.
+ */
+abstract class Task[T](val stageId: Int) extends Serializable {
+  def run(attemptId: Int): T
+  def preferredLocations: Seq[String] = Nil
+
+  var generation: Long = -1   // Map output tracker generation. Will be set by TaskScheduler.
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
new file mode 100644
index 0000000000000000000000000000000000000000..868ddb237c0a23ca8f55d443df8a2473f1604ddd
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -0,0 +1,34 @@
+package spark.scheduler
+
+import java.io._
+
+import scala.collection.mutable.Map
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
+  def this() = this(null.asInstanceOf[T], null)
+
+  override def writeExternal(out: ObjectOutput) {
+    out.writeObject(value)
+    out.writeInt(accumUpdates.size)
+    for ((key, value) <- accumUpdates) {
+      out.writeLong(key)
+      out.writeObject(value)
+    }
+  }
+
+  override def readExternal(in: ObjectInput) {
+    value = in.readObject().asInstanceOf[T]
+    val numUpdates = in.readInt
+    if (numUpdates == 0) {
+      accumUpdates = null
+    } else {
+      accumUpdates = Map()
+      for (i <- 0 until numUpdates) {
+        accumUpdates(in.readLong()) = in.readObject()
+      }
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
new file mode 100644
index 0000000000000000000000000000000000000000..cb7c375d97e09e07c022fc3dcca238971efbf425
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala
@@ -0,0 +1,27 @@
+package spark.scheduler
+
+/**
+ * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler.
+ * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
+ * and are responsible for sending the tasks to the cluster, running them, retrying if there
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler through
+ * the TaskSchedulerListener interface.
+ */
+trait TaskScheduler {
+  def start(): Unit
+
+  // Wait for registration with Mesos.
+  def waitForRegister(): Unit
+
+  // Disconnect from the cluster.
+  def stop(): Unit
+
+  // Submit a sequence of tasks to run.
+  def submitTasks(taskSet: TaskSet): Unit
+
+  // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
+  def setListener(listener: TaskSchedulerListener): Unit
+
+  // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+  def defaultParallelism(): Int
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a647eec9e477831f5c77b84f05344efaaa7ec2d5
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -0,0 +1,16 @@
+package spark.scheduler
+
+import scala.collection.mutable.Map
+
+import spark.TaskEndReason
+
+/**
+ * Interface for getting events back from the TaskScheduler.
+ */
+trait TaskSchedulerListener {
+  // A task has finished or failed.
+  def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
+
+  // A node was lost from the cluster.
+  def hostLost(host: String): Unit
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6f29dd2e9d6dd0688c3a9ac4a38f3fae4fcddb4e
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -0,0 +1,9 @@
+package spark.scheduler
+
+/**
+ * A set of tasks submitted together to the low-level TaskScheduler, usually representing
+ * missing partitions of a particular stage.
+ */
+class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
+  val id: String = stageId + "." + attempt
+}
diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
similarity index 57%
rename from core/src/main/scala/spark/LocalScheduler.scala
rename to core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 3910c7b09e915c173c41c8d6b96bc427d2b6aea1..8339c0ae9025aab942f26f97a078d31235f99613 100644
--- a/core/src/main/scala/spark/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,16 +1,21 @@
-package spark
+package spark.scheduler.local
 
 import java.util.concurrent.Executors
 import java.util.concurrent.atomic.AtomicInteger
 
+import spark._
+import spark.scheduler._
+
 /**
- * A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the 
- * scheduler also allows each task to fail up to maxFailures times, which is useful for testing
- * fault recovery.
+ * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * the scheduler also allows each task to fail up to maxFailures times, which is useful for
+ * testing fault recovery.
  */
-private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging {
+class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
   var attemptId = new AtomicInteger(0)
   var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+  val env = SparkEnv.get
+  var listener: TaskSchedulerListener = null
 
   // TODO: Need to take into account stage priority in scheduling
 
@@ -18,7 +23,12 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
   
   override def waitForRegister() {}
 
-  override def submitTasks(tasks: Seq[Task[_]], runId: Int) {
+  override def setListener(listener: TaskSchedulerListener) { 
+    this.listener = listener
+  }
+
+  override def submitTasks(taskSet: TaskSet) {
+    val tasks = taskSet.tasks
     val failCount = new Array[Int](tasks.size)
 
     def submitTask(task: Task[_], idInJob: Int) {
@@ -38,23 +48,14 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
         // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
         // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
         Accumulators.clear
-        val ser = SparkEnv.get.closureSerializer.newInstance()
-        val startTime = System.currentTimeMillis
-        val bytes = ser.serialize(task)
-        val timeTaken = System.currentTimeMillis - startTime
-        logInfo("Size of task %d is %d bytes and took %d ms to serialize".format(
-            idInJob, bytes.size, timeTaken))
-        val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader)
+        val bytes = Utils.serialize(task)
+        logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
+        val deserializedTask = Utils.deserialize[Task[_]](
+            bytes, Thread.currentThread.getContextClassLoader)
         val result: Any = deserializedTask.run(attemptId)
-
-        // Serialize and deserialize the result to emulate what the mesos
-        // executor does. This is useful to catch serialization errors early
-        // on in development (so when users move their local Spark programs
-        // to the cluster, they don't get surprised by serialization errors).
-        val resultToReturn = ser.deserialize[Any](ser.serialize(result))
         val accumUpdates = Accumulators.values
         logInfo("Finished task " + idInJob)
-        taskEnded(task, Success, resultToReturn, accumUpdates)
+        listener.taskEnded(task, Success, result, accumUpdates)
       } catch {
         case t: Throwable => {
           logError("Exception in task " + idInJob, t)
@@ -64,7 +65,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
               submitTask(task, idInJob)
             } else {
               // TODO: Do something nicer here to return all the way to the user
-              taskEnded(task, new ExceptionFailure(t), null, null)
+              listener.taskEnded(task, new ExceptionFailure(t), null, null)
             }
           }
         }
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8182901ce3abb6d80b5f8bbcf1008098fd44b304
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala
@@ -0,0 +1,364 @@
+package spark.scheduler.mesos
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.util.{ArrayList => JArrayList}
+import java.util.{List => JList}
+import java.util.{HashMap => JHashMap}
+import java.util.concurrent._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Map
+import scala.collection.mutable.PriorityQueue
+import scala.collection.JavaConversions._
+import scala.math.Ordering
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.actor.Channel
+import akka.serialization.RemoteActorSerialization._
+
+import com.google.protobuf.ByteString
+
+import org.apache.mesos.{Scheduler => MScheduler}
+import org.apache.mesos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
+
+sealed trait CoarseMesosSchedulerMessage
+case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage
+case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage
+case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage
+case class ReviveOffers() extends CoarseMesosSchedulerMessage
+
+case class FakeOffer(slaveId: String, host: String, cores: Int)
+
+/**
+ * Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside
+ * them using Akka actors for messaging. Clients should first call start(), then submit task sets
+ * through the runTasks method.
+ *
+ * TODO: This is a pretty big hack for now.
+ */
+class CoarseMesosScheduler(
+    sc: SparkContext,
+    master: String,
+    frameworkName: String)
+  extends MesosScheduler(sc, master, frameworkName) {
+
+  val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt
+
+  class MasterActor extends Actor {
+    val slaveActor = new HashMap[String, ActorRef]
+    val slaveHost = new HashMap[String, String]
+    val freeCores = new HashMap[String, Int]
+   
+    def receive = {
+      case RegisterSlave(slaveId, host, port) =>
+        slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port)
+        logInfo("Slave actor: " + slaveActor(slaveId))
+        slaveHost(slaveId) = host
+        freeCores(slaveId) = CORES_PER_SLAVE
+        makeFakeOffers()
+
+      case StatusUpdate(slaveId, status) =>
+        fakeStatusUpdate(status)
+        if (isFinished(status.getState)) {
+          freeCores(slaveId) += 1
+          makeFakeOffers(slaveId)
+        }
+
+      case LaunchTask(slaveId, task) =>
+        freeCores(slaveId) -= 1
+        slaveActor(slaveId) ! LaunchTask(slaveId, task)
+
+      case ReviveOffers() =>
+        logInfo("Reviving offers")
+        makeFakeOffers()
+    }
+
+    // Make fake resource offers for all slaves
+    def makeFakeOffers() {
+      fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))})
+    }
+
+    // Make fake resource offers for all slaves
+    def makeFakeOffers(slaveId: String) {
+      fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId))))
+    }
+  }
+
+  val masterActor: ActorRef = actorOf(new MasterActor)
+  remote.register("MasterActor", masterActor)
+  masterActor.start()
+
+  val taskIdsOnSlave = new HashMap[String, HashSet[String]]
+
+  /**
+   * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets 
+   * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+   * tasks are balanced across the cluster.
+   */
+  override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
+    synchronized {
+      val tasks = offers.map(o => new JArrayList[MTaskInfo])
+      for (i <- 0 until offers.size) {
+        val o = offers.get(i)
+        val slaveId = o.getSlaveId.getValue
+        if (!slaveIdToHost.contains(slaveId)) {
+          slaveIdToHost(slaveId) = o.getHostname
+          hostsAlive += o.getHostname
+          taskIdsOnSlave(slaveId) = new HashSet[String]
+          // Launch an infinite task on the node that will talk to the MasterActor to get fake tasks
+          val cpuRes = Resource.newBuilder()
+              .setName("cpus")
+              .setType(Value.Type.SCALAR)
+              .setScalar(Value.Scalar.newBuilder().setValue(1).build())
+              .build()
+          val task = new WorkerTask(slaveId, o.getHostname)
+          val serializedTask = Utils.serialize(task)
+          tasks(i).add(MTaskInfo.newBuilder()
+              .setTaskId(newTaskId())
+              .setSlaveId(o.getSlaveId)
+              .setExecutor(executorInfo)
+              .setName("worker task")
+              .addResources(cpuRes)
+              .setData(ByteString.copyFrom(serializedTask))
+              .build())
+        }
+      }
+      val filters = Filters.newBuilder().setRefuseSeconds(10).build()
+      for (i <- 0 until offers.size) {
+        d.launchTasks(offers(i).getId(), tasks(i), filters)
+      }
+    }
+  }
+
+  override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+    val tid = status.getTaskId.getValue
+    var taskSetToUpdate: Option[TaskSetManager] = None
+    var taskFailed = false
+    synchronized {
+      try {
+        taskIdToTaskSetId.get(tid) match {
+          case Some(taskSetId) =>
+            if (activeTaskSets.contains(taskSetId)) {
+              //activeTaskSets(taskSetId).statusUpdate(status)
+              taskSetToUpdate = Some(activeTaskSets(taskSetId))
+            }
+            if (isFinished(status.getState)) {
+              taskIdToTaskSetId.remove(tid)
+              if (taskSetTaskIds.contains(taskSetId)) {
+                taskSetTaskIds(taskSetId) -= tid
+              }
+              val slaveId = taskIdToSlaveId(tid)
+              taskIdToSlaveId -= tid
+              taskIdsOnSlave(slaveId) -= tid
+            }
+            if (status.getState == TaskState.TASK_FAILED) {
+              taskFailed = true
+            }
+          case None =>
+            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+        }
+      } catch {
+        case e: Exception => logError("Exception in statusUpdate", e)
+      }
+    }
+    // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+    if (taskSetToUpdate != None) {
+      taskSetToUpdate.get.statusUpdate(status)
+    }
+    if (taskFailed) {
+      // Revive offers if a task had failed for some reason other than host lost
+      reviveOffers()
+    }
+  }
+
+  override def slaveLost(d: SchedulerDriver, s: SlaveID) {
+    logInfo("Slave lost: " + s.getValue)
+    var failedHost: Option[String] = None
+    var lostTids: Option[HashSet[String]] = None
+    synchronized {
+      val slaveId = s.getValue
+      val host = slaveIdToHost(slaveId)
+      if (hostsAlive.contains(host)) {
+        slaveIdsWithExecutors -= slaveId
+        hostsAlive -= host
+        failedHost = Some(host)
+        lostTids = Some(taskIdsOnSlave(slaveId))
+        logInfo("failedHost: " + host)
+        logInfo("lostTids: " + lostTids)
+        taskIdsOnSlave -= slaveId
+        activeTaskSetsQueue.foreach(_.hostLost(host))
+      }
+    }
+    if (failedHost != None) {
+      // Report all the tasks on the failed host as lost, without holding a lock on this
+      for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) {
+        // TODO: Maybe call our statusUpdate() instead to clean our internal data structures
+        activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder()
+          .setTaskId(TaskID.newBuilder().setValue(tid).build())
+          .setState(TaskState.TASK_LOST)
+          .build())
+      }
+      // Also report the loss to the DAGScheduler
+      listener.hostLost(failedHost.get)
+      reviveOffers();
+    }
+  }
+
+  override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+  // Check for speculatable tasks in all our active jobs.
+  override def checkSpeculatableTasks() {
+    var shouldRevive = false
+    synchronized {
+      for (ts <- activeTaskSetsQueue) {
+        shouldRevive |= ts.checkSpeculatableTasks()
+      }
+    }
+    if (shouldRevive) {
+      reviveOffers()
+    }
+  }
+
+
+  val lock2 = new Object
+  var firstWait = true
+
+  override def waitForRegister() {
+    lock2.synchronized {
+      if (firstWait) {
+        super.waitForRegister()
+        Thread.sleep(5000)
+        firstWait = false
+      }
+    }
+  }
+
+  def fakeStatusUpdate(status: TaskStatus) {
+    statusUpdate(driver, status)
+  }
+
+  def fakeResourceOffers(offers: Seq[FakeOffer]) {
+    logDebug("fakeResourceOffers: " + offers)
+    val availableCpus = offers.map(_.cores.toDouble).toArray
+    var launchedTask = false
+    for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
+      do {
+        launchedTask = false
+        for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) {
+          manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match {
+            case Some(task) => 
+              val tid = task.getTaskId.getValue
+              val sid = offers(i).slaveId
+              taskIdToTaskSetId(tid) = manager.taskSet.id
+              taskSetTaskIds(manager.taskSet.id) += tid
+              taskIdToSlaveId(tid) = sid
+              taskIdsOnSlave(sid) += tid
+              slaveIdsWithExecutors += sid
+              availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
+              launchedTask = true
+              masterActor ! LaunchTask(sid, task)
+              
+            case None => {}
+          }
+        }
+      } while (launchedTask)
+    }
+  }
+
+  override def reviveOffers() {
+    masterActor ! ReviveOffers()
+  }
+}
+
+class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) {
+  generation = 0
+
+  def run(id: Int): Unit = {
+    val actor = actorOf(new WorkerActor(slaveId, host))
+    if (!remote.isRunning) {
+      remote.start(Utils.localIpAddress, 7078)
+    }
+    remote.register("WorkerActor", actor)
+    actor.start()
+    while (true) {
+      Thread.sleep(10000)
+    }
+  }
+}
+
+class WorkerActor(slaveId: String, host: String) extends Actor with Logging {
+  val env = SparkEnv.get
+  val classLoader = currentThread.getContextClassLoader
+  val threadPool = new ThreadPoolExecutor(
+    1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+
+  val masterIp: String = System.getProperty("spark.master.host", "localhost")
+  val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+  val masterActor = remote.actorFor("MasterActor", masterIp, masterPort)
+
+  class TaskRunner(desc: MTaskInfo)
+  extends Runnable {
+    override def run() = {
+      val tid = desc.getTaskId.getValue
+      logInfo("Running task ID " + tid)
+      try {
+        SparkEnv.set(env)
+        Thread.currentThread.setContextClassLoader(classLoader)
+        Accumulators.clear
+        val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader)
+        env.mapOutputTracker.updateGeneration(task.generation)
+        val value = task.run(tid.toInt)
+        val accumUpdates = Accumulators.values
+        val result = new TaskResult(value, accumUpdates)
+        masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+            .setTaskId(desc.getTaskId)
+            .setState(TaskState.TASK_FINISHED)
+            .setData(ByteString.copyFrom(Utils.serialize(result)))
+            .build())
+        logInfo("Finished task ID " + tid)
+      } catch {
+        case ffe: FetchFailedException => {
+          val reason = ffe.toTaskEndReason
+          masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+              .setTaskId(desc.getTaskId)
+              .setState(TaskState.TASK_FAILED)
+              .setData(ByteString.copyFrom(Utils.serialize(reason)))
+              .build())
+        }
+        case t: Throwable => {
+          val reason = ExceptionFailure(t)
+          masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder()
+              .setTaskId(desc.getTaskId)
+              .setState(TaskState.TASK_FAILED)
+              .setData(ByteString.copyFrom(Utils.serialize(reason)))
+              .build())
+
+          // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+          // have left some weird state around depending on when the exception was thrown, but on
+          // the other hand, maybe we could detect that when future tasks fail and exit then.
+          logError("Exception in task ID " + tid, t)
+          //System.exit(1)
+        }
+      }
+    }
+  }
+
+  override def preStart {
+    val ref = toRemoteActorRefProtocol(self).toByteArray
+    logInfo("Registering with master")
+    masterActor ! RegisterSlave(slaveId, host, remote.address.getPort)
+  }
+
+  override def receive = {
+    case LaunchTask(slaveId, task) =>
+      threadPool.execute(new TaskRunner(task))    
+  }
+}
diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
similarity index 58%
rename from core/src/main/scala/spark/MesosScheduler.scala
rename to core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
index a7711e0d352f04c004aa3030413f1593f4a76849..f72618c03fc8a1b996f32c86678b19de6ecf31cd 100644
--- a/core/src/main/scala/spark/MesosScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.scheduler.mesos
 
 import java.io.{File, FileInputStream, FileOutputStream}
 import java.util.{ArrayList => JArrayList}
@@ -17,20 +17,23 @@ import com.google.protobuf.ByteString
 
 import org.apache.mesos.{Scheduler => MScheduler}
 import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
 
 /**
- * The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(),
- * then submit tasks through the runTasks method.
+ * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call
+ * start(), then submit task sets through the runTasks method.
  */
-private class MesosScheduler(
+class MesosScheduler(
     sc: SparkContext,
     master: String,
     frameworkName: String)
-  extends MScheduler
-  with DAGScheduler
+  extends TaskScheduler
+  with MScheduler
   with Logging {
-  
+
   // Environment variables to pass to our executors
   val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
     "SPARK_MEM",
@@ -49,55 +52,60 @@ private class MesosScheduler(
     }
   }
 
+  // How often to check for speculative tasks
+  val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
   // Lock used to wait for scheduler to be registered
-  private var isRegistered = false
-  private val registeredLock = new Object()
+  var isRegistered = false
+  val registeredLock = new Object()
 
-  private val activeJobs = new HashMap[Int, Job]
-  private var activeJobsQueue = new ArrayBuffer[Job]
+  val activeTaskSets = new HashMap[String, TaskSetManager]
+  var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
 
-  private val taskIdToJobId = new HashMap[String, Int]
-  private val taskIdToSlaveId = new HashMap[String, String]
-  private val jobTasks = new HashMap[Int, HashSet[String]]
+  val taskIdToTaskSetId = new HashMap[String, String]
+  val taskIdToSlaveId = new HashMap[String, String]
+  val taskSetTaskIds = new HashMap[String, HashSet[String]]
 
-  // Incrementing job and task IDs
-  private var nextJobId = 0
-  private var nextTaskId = 0
+  // Incrementing Mesos task IDs
+  var nextTaskId = 0
 
   // Driver for talking to Mesos
   var driver: SchedulerDriver = null
 
-  // Which nodes we have executors on
-  private val slavesWithExecutors = new HashSet[String]
+  // Which hosts in the cluster are alive (contains hostnames)
+  val hostsAlive = new HashSet[String]
+
+  // Which slave IDs we have executors on
+  val slaveIdsWithExecutors = new HashSet[String]
+
+  val slaveIdToHost = new HashMap[String, String]
 
   // JAR server, if any JARs were added by the user to the SparkContext
   var jarServer: HttpServer = null
 
   // URIs of JARs to pass to executor
   var jarUris: String = ""
-
+  
   // Create an ExecutorInfo for our tasks
   val executorInfo = createExecutorInfo()
 
-  // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first)
-  private val jobOrdering = new Ordering[Job] {
-    override def compare(j1: Job, j2: Job): Int =  j2.runId - j1.runId
-  }
-  
-  def newJobId(): Int = this.synchronized {
-    val id = nextJobId
-    nextJobId += 1
-    return id
+  // Listener object to pass upcalls into
+  var listener: TaskSchedulerListener = null
+
+  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
+  override def setListener(listener: TaskSchedulerListener) { 
+    this.listener = listener
   }
 
   def newTaskId(): TaskID = {
-    val id = "" + nextTaskId;
-    nextTaskId += 1;
-    return TaskID.newBuilder().setValue(id).build()
+    val id = TaskID.newBuilder().setValue("" + nextTaskId).build()
+    nextTaskId += 1
+    return id
   }
   
   override def start() {
-    new Thread("Spark scheduler") {
+    new Thread("MesosScheduler driver") {
       setDaemon(true)
       override def run {
         val sched = MesosScheduler.this
@@ -110,12 +118,27 @@ private class MesosScheduler(
           case e: Exception => logError("driver.run() failed", e)
         }
       }
-    }.start
+    }.start()
+    if (System.getProperty("spark.speculation", "false") == "true") {
+      new Thread("MesosScheduler speculation check") {
+        setDaemon(true)
+        override def run {
+          waitForRegister()
+          while (true) {
+            try {
+              Thread.sleep(SPECULATION_INTERVAL)
+            } catch { case e: InterruptedException => {} }
+            checkSpeculatableTasks()
+          }
+        }
+      }.start()
+    }
   }
 
   def createExecutorInfo(): ExecutorInfo = {
     val sparkHome = sc.getSparkHome match {
-      case Some(path) => path
+      case Some(path) =>
+        path
       case None =>
         throw new SparkException("Spark home is not set; set it through the spark.home system " +
             "property, the SPARK_HOME environment variable or the SparkContext constructor")
@@ -151,27 +174,26 @@ private class MesosScheduler(
       .build()
   }
   
-  def submitTasks(tasks: Seq[Task[_]], runId: Int) {
-    logInfo("Got a job with " + tasks.size + " tasks")
+  def submitTasks(taskSet: TaskSet) {
+    val tasks = taskSet.tasks
+    logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks")
     waitForRegister()
     this.synchronized {
-      val jobId = newJobId()
-      val myJob = new SimpleJob(this, tasks, runId, jobId)
-      activeJobs(jobId) = myJob
-      activeJobsQueue += myJob
-      logInfo("Adding job with ID " + jobId)
-      jobTasks(jobId) = HashSet.empty[String]
+      val manager = new TaskSetManager(this, taskSet)
+      activeTaskSets(taskSet.id) = manager
+      activeTaskSetsQueue += manager
+      taskSetTaskIds(taskSet.id) = new HashSet()
     }
-    driver.reviveOffers();
+    reviveOffers();
   }
   
-  def jobFinished(job: Job) {
+  def taskSetFinished(manager: TaskSetManager) {
     this.synchronized {
-      activeJobs -= job.jobId
-      activeJobsQueue -= job
-      taskIdToJobId --= jobTasks(job.jobId)
-      taskIdToSlaveId --= jobTasks(job.jobId)
-      jobTasks.remove(job.jobId)
+      activeTaskSets -= manager.taskSet.id
+      activeTaskSetsQueue -= manager
+      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+      taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id)
+      taskSetTaskIds.remove(manager.taskSet.id)
     }
   }
 
@@ -196,33 +218,40 @@ private class MesosScheduler(
   override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
 
   /**
-   * Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for 
-   * tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are
-   * balanced across the cluster.
+   * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets 
+   * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+   * tasks are balanced across the cluster.
    */
   override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
     synchronized {
-      val tasks = offers.map(o => new JArrayList[TaskInfo])
+      // Mark each slave as alive and remember its hostname
+      for (o <- offers) {
+        slaveIdToHost(o.getSlaveId.getValue) = o.getHostname
+        hostsAlive += o.getHostname
+      }
+      // Build a list of tasks to assign to each slave
+      val tasks = offers.map(o => new JArrayList[MTaskInfo])
       val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus"))
       val enoughMem = offers.map(o => {
         val mem = getResource(o.getResourcesList(), "mem")
         val slaveId = o.getSlaveId.getValue
-        mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId)
+        mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId)
       })
       var launchedTask = false
-      for (job <- activeJobsQueue.sorted(jobOrdering)) {
+      for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
         do {
           launchedTask = false
           for (i <- 0 until offers.size if enoughMem(i)) {
-            job.slaveOffer(offers(i), availableCpus(i)) match {
+            val sid = offers(i).getSlaveId.getValue
+            val host = offers(i).getHostname
+            manager.slaveOffer(sid, host, availableCpus(i)) match {
               case Some(task) => 
                 tasks(i).add(task)
                 val tid = task.getTaskId.getValue
-                val sid = offers(i).getSlaveId.getValue
-                taskIdToJobId(tid) = job.jobId
-                jobTasks(job.jobId) += tid
+                taskIdToTaskSetId(tid) = manager.taskSet.id
+                taskSetTaskIds(manager.taskSet.id) += tid
                 taskIdToSlaveId(tid) = sid
-                slavesWithExecutors += sid
+                slaveIdsWithExecutors += sid
                 availableCpus(i) -= getResource(task.getResourcesList(), "cpus")
                 launchedTask = true
                 
@@ -256,53 +285,74 @@ private class MesosScheduler(
   }
 
   override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
-    var jobToUpdate: Option[Job] = None
+    val tid = status.getTaskId.getValue
+    var taskSetToUpdate: Option[TaskSetManager] = None
+    var failedHost: Option[String] = None
+    var taskFailed = false
     synchronized {
       try {
-        val tid = status.getTaskId.getValue
-        if (status.getState == TaskState.TASK_LOST 
-            && taskIdToSlaveId.contains(tid)) {
+        if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
           // We lost the executor on this slave, so remember that it's gone
-          slavesWithExecutors -= taskIdToSlaveId(tid)
+          val slaveId = taskIdToSlaveId(tid)
+          val host = slaveIdToHost(slaveId)
+          if (hostsAlive.contains(host)) {
+            slaveIdsWithExecutors -= slaveId
+            hostsAlive -= host
+            activeTaskSetsQueue.foreach(_.hostLost(host))
+            failedHost = Some(host)
+          }
         }
-        taskIdToJobId.get(tid) match {
-          case Some(jobId) =>
-            if (activeJobs.contains(jobId)) {
-              jobToUpdate = Some(activeJobs(jobId))
+        taskIdToTaskSetId.get(tid) match {
+          case Some(taskSetId) =>
+            if (activeTaskSets.contains(taskSetId)) {
+              //activeTaskSets(taskSetId).statusUpdate(status)
+              taskSetToUpdate = Some(activeTaskSets(taskSetId))
             }
             if (isFinished(status.getState)) {
-              taskIdToJobId.remove(tid)
-              if (jobTasks.contains(jobId)) {
-                jobTasks(jobId) -= tid
+              taskIdToTaskSetId.remove(tid)
+              if (taskSetTaskIds.contains(taskSetId)) {
+                taskSetTaskIds(taskSetId) -= tid
               }
               taskIdToSlaveId.remove(tid)
             }
+            if (status.getState == TaskState.TASK_FAILED) {
+              taskFailed = true
+            }
           case None =>
-            logInfo("Ignoring update from TID " + tid + " because its job is gone")
+            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
         }
       } catch {
         case e: Exception => logError("Exception in statusUpdate", e)
       }
     }
-    for (j <- jobToUpdate) {
-      j.statusUpdate(status)
+    // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock
+    if (taskSetToUpdate != None) {
+      taskSetToUpdate.get.statusUpdate(status)
+    }
+    if (failedHost != None) {
+      listener.hostLost(failedHost.get)
+      reviveOffers();
+    }
+    if (taskFailed) {
+      // Also revive offers if a task had failed for some reason other than host lost
+      reviveOffers()
     }
   }
 
   override def error(d: SchedulerDriver, message: String) {
     logError("Mesos error: " + message)
     synchronized {
-      if (activeJobs.size > 0) {
-        // Have each job throw a SparkException with the error
-        for ((jobId, activeJob) <- activeJobs) {
+      if (activeTaskSets.size > 0) {
+        // Have each task set throw a SparkException with the error
+        for ((taskSetId, manager) <- activeTaskSets) {
           try {
-            activeJob.error(message)
+            manager.error(message)
           } catch {
             case e: Exception => logError("Exception in error callback", e)
           }
         }
       } else {
-        // No jobs are active but we still got an error. Just exit since this
+        // No task sets are active but we still got an error. Just exit since this
         // must mean the error is during registration.
         // It might be good to do something smarter here in the future.
         System.exit(1)
@@ -373,41 +423,68 @@ private class MesosScheduler(
     return Utils.serialize(props.toArray)
   }
 
-  override def frameworkMessage(
-      d: SchedulerDriver, 
-      e: ExecutorID,
-      s: SlaveID,
-      b: Array[Byte]) {}
+  override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
 
   override def slaveLost(d: SchedulerDriver, s: SlaveID) {
-    slavesWithExecutors.remove(s.getValue)
+    var failedHost: Option[String] = None
+    synchronized {
+      val slaveId = s.getValue
+      val host = slaveIdToHost(slaveId)
+      if (hostsAlive.contains(host)) {
+        slaveIdsWithExecutors -= slaveId
+        hostsAlive -= host
+        activeTaskSetsQueue.foreach(_.hostLost(host))
+        failedHost = Some(host)
+      }
+    }
+    if (failedHost != None) {
+      listener.hostLost(failedHost.get)
+      reviveOffers();
+    }
   }
 
   override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
-    slavesWithExecutors.remove(s.getValue)
+    logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
+    slaveLost(d, s)
   }
 
   override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+  // Check for speculatable tasks in all our active jobs.
+  def checkSpeculatableTasks() {
+    var shouldRevive = false
+    synchronized {
+      for (ts <- activeTaskSetsQueue) {
+        shouldRevive |= ts.checkSpeculatableTasks()
+      }
+    }
+    if (shouldRevive) {
+      reviveOffers()
+    }
+  }
+
+  def reviveOffers() {
+    driver.reviveOffers()
+  }
 }
 
 object MesosScheduler {
   /**
-   * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
-   * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+   * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. 
+   * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM 
    * environment variable.
    */
   def memoryStringToMb(str: String): Int = {
     val lower = str.toLowerCase
     if (lower.endsWith("k")) {
-      (lower.substring(0, lower.length - 1).toLong / 1024).toInt
+      (lower.substring(0, lower.length-1).toLong / 1024).toInt
     } else if (lower.endsWith("m")) {
-      lower.substring(0, lower.length - 1).toInt
+      lower.substring(0, lower.length-1).toInt
     } else if (lower.endsWith("g")) {
-      lower.substring(0, lower.length - 1).toInt * 1024
+      lower.substring(0, lower.length-1).toInt * 1024
     } else if (lower.endsWith("t")) {
-      lower.substring(0, lower.length - 1).toInt * 1024 * 1024
-    } else {
-      // no suffix, so it's just a number in bytes
+      lower.substring(0, lower.length-1).toInt * 1024 * 1024
+    } else {// no suffix, so it's just a number in bytes
       (lower.toLong / 1024 / 1024).toInt
     }
   }
diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
new file mode 100644
index 0000000000000000000000000000000000000000..af2f80ea6671756f768c66be2f4ae2142c9f23d4
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala
@@ -0,0 +1,32 @@
+package spark.scheduler.mesos
+
+/**
+ * Information about a running task attempt.
+ */
+class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) {
+  var finishTime: Long = 0
+  var failed = false
+
+  def markSuccessful(time: Long = System.currentTimeMillis) {
+    finishTime = time
+  }
+
+  def markFailed(time: Long = System.currentTimeMillis) {
+    finishTime = time
+    failed = true
+  }
+
+  def finished: Boolean = finishTime != 0
+
+  def successful: Boolean = finished && !failed
+
+  def duration: Long = {
+    if (!finished) {
+      throw new UnsupportedOperationException("duration() called on unfinished tasks")
+    } else {
+      finishTime - launchTime
+    }
+  }
+
+  def timeRunning(currentTime: Long): Long = currentTime - launchTime
+}
diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
similarity index 50%
rename from core/src/main/scala/spark/SimpleJob.scala
rename to core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
index 01c7efff1e0af2bed9c6085b0958847968441c37..535c17d9d4db78f29acca2b7e458159664a28391 100644
--- a/core/src/main/scala/spark/SimpleJob.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala
@@ -1,28 +1,32 @@
-package spark
+package spark.scheduler.mesos
 
+import java.util.Arrays
 import java.util.{HashMap => JHashMap}
 
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
 
 import com.google.protobuf.ByteString
 
 import org.apache.mesos._
-import org.apache.mesos.Protos._
+import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _}
+
+import spark._
+import spark.scheduler._
 
 /**
- * A Job that runs a set of tasks with no interdependencies.
+ * Schedules the tasks within a single TaskSet in the MesosScheduler.
  */
-class SimpleJob(
+class TaskSetManager(
     sched: MesosScheduler, 
-    tasksSeq: Seq[Task[_]], 
-    runId: Int,
-    jobId: Int) 
-  extends Job(runId, jobId)
-  with Logging {
+    val taskSet: TaskSet)
+  extends Logging {
   
   // Maximum time to wait to run a task in a preferred location (in ms)
-  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong
+  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
 
   // CPUs to request per task
   val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
@@ -30,18 +34,20 @@ class SimpleJob(
   // Maximum times a task is allowed to fail before failing the job
   val MAX_TASK_FAILURES = 4
 
+  // Quantile of tasks at which to start speculation
+  val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+  val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
   // Serializer for closures and tasks.
   val ser = SparkEnv.get.closureSerializer.newInstance()
 
-  val callingThread = Thread.currentThread
-  val tasks = tasksSeq.toArray
+  val priority = taskSet.priority
+  val tasks = taskSet.tasks
   val numTasks = tasks.length
-  val launched = new Array[Boolean](numTasks)
+  val copiesRunning = new Array[Int](numTasks)
   val finished = new Array[Boolean](numTasks)
   val numFailures = new Array[Int](numTasks)
-  val tidToIndex = HashMap[String, Int]()
-
-  var tasksLaunched = 0
+  val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
   var tasksFinished = 0
 
   // Last time when we launched a preferred task (for delay scheduling)
@@ -62,6 +68,13 @@ class SimpleJob(
   // List containing all pending tasks (also used as a stack, as above)
   val allPendingTasks = new ArrayBuffer[Int]
 
+  // Tasks that can be specualted. Since these will be a small fraction of total
+  // tasks, we'll just hold them in a HaskSet.
+  val speculatableTasks = new HashSet[Int]
+
+  // Task index, start and finish time for each task attempt (indexed by task ID)
+  val taskInfos = new HashMap[String, TaskInfo]
+
   // Did the job fail?
   var failed = false
   var causeOfFailure = ""
@@ -76,6 +89,12 @@ class SimpleJob(
   // exceptions automatically.
   val recentExceptions = HashMap[String, (Int, Long)]()
 
+  // Figure out the current map output tracker generation and set it on all tasks
+  val generation = sched.mapOutputTracker.getGeneration
+  for (t <- tasks) {
+    t.generation = generation
+  }
+
   // Add all our tasks to the pending lists. We do this in reverse order
   // of task index so that tasks with low indices get launched first.
   for (i <- (0 until numTasks).reverse) {
@@ -84,7 +103,7 @@ class SimpleJob(
 
   // Add a task to all the pending-task lists that it should be on.
   def addPendingTask(index: Int) {
-    val locations = tasks(index).preferredLocations
+    val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
     if (locations.size == 0) {
       pendingTasksWithNoPrefs += index
     } else {
@@ -110,13 +129,37 @@ class SimpleJob(
     while (!list.isEmpty) {
       val index = list.last
       list.trimEnd(1)
-      if (!launched(index) && !finished(index)) {
+      if (copiesRunning(index) == 0 && !finished(index)) {
         return Some(index)
       }
     }
     return None
   }
 
+  // Return a speculative task for a given host if any are available. The task should not have an
+  // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
+  // task must have a preference for this host (or no preferred locations at all).
+  def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
+    speculatableTasks.retain(index => !finished(index))  // Remove finished tasks from set
+    val localTask = speculatableTasks.find { index =>
+      val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+      val attemptLocs = taskAttempts(index).map(_.host)
+      (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
+    }
+    if (localTask != None) {
+      speculatableTasks -= localTask.get
+      return localTask
+    }
+    if (!localOnly && speculatableTasks.size > 0) {
+      val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
+      if (nonLocalTask != None) {
+        speculatableTasks -= nonLocalTask.get
+        return nonLocalTask
+      }
+    }
+    return None
+  }
+
   // Dequeue a pending task for a given node and return its index.
   // If localOnly is set to false, allow non-local tasks as well.
   def findTask(host: String, localOnly: Boolean): Option[Int] = {
@@ -129,10 +172,13 @@ class SimpleJob(
       return noPrefTask
     }
     if (!localOnly) {
-      return findTaskFromList(allPendingTasks) // Look for non-local task
-    } else {
-      return None
+      val nonLocalTask = findTaskFromList(allPendingTasks)
+      if (nonLocalTask != None) {
+        return nonLocalTask
+      }
     }
+    // Finally, if all else has failed, find a speculative task
+    return findSpeculativeTask(host, localOnly)
   }
 
   // Does a host count as a preferred location for a task? This is true if
@@ -144,11 +190,11 @@ class SimpleJob(
   }
 
   // Respond to an offer of a single slave from the scheduler by finding a task
-  def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = {
-    if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) {
+  def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = {
+    if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
       val time = System.currentTimeMillis
-      val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
-      val host = offer.getHostname
+      var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+      
       findTask(host, localOnly) match {
         case Some(index) => {
           // Found a task; do some bookkeeping and return a Mesos task for it
@@ -156,17 +202,17 @@ class SimpleJob(
           val taskId = sched.newTaskId()
           // Figure out whether this should count as a preferred launch
           val preferred = isPreferredLocation(task, host)
-          val prefStr = if(preferred) "preferred" else "non-preferred"
-          val message =
-            "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
-              jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr)
-          logInfo(message)
+          val prefStr = if (preferred) "preferred" else "non-preferred"
+          logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+              taskSet.id, index, taskId.getValue, slaveId, host, prefStr))
           // Do various bookkeeping
-          tidToIndex(taskId.getValue) = index
-          launched(index) = true
-          tasksLaunched += 1
-          if (preferred)
+          copiesRunning(index) += 1
+          val info = new TaskInfo(taskId.getValue, index, time, host)
+          taskInfos(taskId.getValue) = info
+          taskAttempts(index) = info :: taskAttempts(index)
+          if (preferred) {
             lastPreferredLaunchTime = time
+          }
           // Create and return the Mesos task object
           val cpuRes = Resource.newBuilder()
             .setName("cpus")
@@ -178,13 +224,13 @@ class SimpleJob(
           val serializedTask = ser.serialize(task)
           val timeTaken = System.currentTimeMillis - startTime
 
-          logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s"
-            .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName))
+          logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+            taskSet.id, index, serializedTask.limit, timeTaken))
 
-          val taskName = "task %d:%d".format(jobId, index)
-          return Some(TaskInfo.newBuilder()
+          val taskName = "task %s:%d".format(taskSet.id, index)
+          return Some(MTaskInfo.newBuilder()
               .setTaskId(taskId)
-              .setSlaveId(offer.getSlaveId)
+              .setSlaveId(SlaveID.newBuilder().setValue(slaveId))
               .setExecutor(sched.executorInfo)
               .setName(taskName)
               .addResources(cpuRes)
@@ -213,18 +259,21 @@ class SimpleJob(
 
   def taskFinished(status: TaskStatus) {
     val tid = status.getTaskId.getValue
-    val index = tidToIndex(tid)
+    val info = taskInfos(tid)
+    val index = info.index
+    info.markSuccessful()
     if (!finished(index)) {
       tasksFinished += 1
-      logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks))
-      // Deserialize task result
-      val result = ser.deserialize[TaskResult[_]](
-        status.getData.toByteArray, getClass.getClassLoader)
-      sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+      logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+          tid, info.duration, tasksFinished, numTasks))
+      // Deserialize task result and pass it to the scheduler
+      val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer)
+      sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
       // Mark finished and stop if we've finished all the tasks
       finished(index) = true
-      if (tasksFinished == numTasks)
-        sched.jobFinished(this)
+      if (tasksFinished == numTasks) {
+        sched.taskSetFinished(this)
+      }
     } else {
       logInfo("Ignoring task-finished event for TID " + tid +
         " because task " + index + " is already finished")
@@ -233,30 +282,29 @@ class SimpleJob(
 
   def taskLost(status: TaskStatus) {
     val tid = status.getTaskId.getValue
-    val index = tidToIndex(tid)
+    val info = taskInfos(tid)
+    val index = info.index
+    info.markFailed()
     if (!finished(index)) {
-      logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index))
-      launched(index) = false
-      tasksLaunched -= 1
+      logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+      copiesRunning(index) -= 1
       // Check if the problem is a map output fetch failure. In that case, this
       // task will never succeed on any node, so tell the scheduler about it.
       if (status.getData != null && status.getData.size > 0) {
-        val reason = ser.deserialize[TaskEndReason](
-          status.getData.toByteArray, getClass.getClassLoader)
+        val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer)
         reason match {
           case fetchFailed: FetchFailed =>
-            logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri)
-            sched.taskEnded(tasks(index), fetchFailed, null, null)
+            logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+            sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
             finished(index) = true
             tasksFinished += 1
-            if (tasksFinished == numTasks) {
-              sched.jobFinished(this)
-            }
+            sched.taskSetFinished(this)
             return
+
           case ef: ExceptionFailure =>
             val key = ef.exception.toString
             val now = System.currentTimeMillis
-            val (printFull, dupCount) =
+            val (printFull, dupCount) = {
               if (recentExceptions.contains(key)) {
                 val (dupCount, printTime) = recentExceptions(key)
                 if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
@@ -267,32 +315,28 @@ class SimpleJob(
                   (false, dupCount + 1)
                 }
               } else {
-                recentExceptions += Tuple(key, (0, now))
+                recentExceptions(key) = (0, now)
                 (true, 0)
               }
-
+            }
             if (printFull) {
-              val stackTrace =
-                for (elem <- ef.exception.getStackTrace)
-                yield "\tat %s".format(elem.toString)
-              logInfo("Loss was due to %s\n%s".format(
-                ef.exception.toString, stackTrace.mkString("\n")))
+              val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
+              logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
             } else {
-              logInfo("Loss was due to %s [duplicate %d]".format(
-                ef.exception.toString, dupCount))
+              logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
             }
+
           case _ => {}
         }
       }
-      // On other failures, re-enqueue the task as pending for a max number of retries
+      // On non-fetch failures, re-enqueue the task as pending for a max number of retries
       addPendingTask(index)
-      // Count attempts only on FAILED and LOST state (not on KILLED)
-      if (status.getState == TaskState.TASK_FAILED ||
-          status.getState == TaskState.TASK_LOST) {
+      // Count failed attempts only on FAILED and LOST state (not on KILLED)
+      if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) {
         numFailures(index) += 1
         if (numFailures(index) > MAX_TASK_FAILURES) {
-          logError("Task %d:%d failed more than %d times; aborting job".format(
-            jobId, index, MAX_TASK_FAILURES))
+          logError("Task %s:%d failed more than %d times; aborting job".format(
+              taskSet.id, index, MAX_TASK_FAILURES))
           abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES))
         }
       }
@@ -311,6 +355,71 @@ class SimpleJob(
     failed = true
     causeOfFailure = message
     // TODO: Kill running tasks if we were not terminated due to a Mesos error
-    sched.jobFinished(this)
+    sched.taskSetFinished(this)
+  }
+
+  def hostLost(hostname: String) {
+    logInfo("Re-queueing tasks for " + hostname)
+    // If some task has preferred locations only on hostname, put it in the no-prefs list
+    // to avoid the wait from delay scheduling
+    for (index <- getPendingTasksForHost(hostname)) {
+      val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive
+      if (newLocs.isEmpty) {
+        pendingTasksWithNoPrefs += index
+      }
+    }
+    // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage
+    if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+      for ((tid, info) <- taskInfos if info.host == hostname) {
+        val index = taskInfos(tid).index
+        if (finished(index)) {
+          finished(index) = false
+          copiesRunning(index) -= 1
+          tasksFinished -= 1
+          addPendingTask(index)
+          // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+          // stage finishes when a total of tasks.size tasks finish.
+          sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
+        }
+      }
+    }
+  }
+
+  /**
+   * Check for tasks to be speculated and return true if there are any. This is called periodically
+   * by the MesosScheduler.
+   *
+   * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+   * we don't scan the whole task set. It might also help to make this sorted by launch time.
+   */
+  def checkSpeculatableTasks(): Boolean = {
+    // Can't speculate if we only have one task, or if all tasks have finished.
+    if (numTasks == 1 || tasksFinished == numTasks) {
+      return false
+    }
+    var foundTasks = false
+    val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+    logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+    if (tasksFinished >= minFinishedForSpeculation) {
+      val time = System.currentTimeMillis()
+      val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+      Arrays.sort(durations)
+      val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+      // TODO: Threshold should also look at standard deviation of task durations and have a lower
+      // bound based on that.
+      logDebug("Task length threshold for speculation: " + threshold)
+      for ((tid, info) <- taskInfos) {
+        val index = info.index
+        if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+            !speculatableTasks.contains(index)) {
+          logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+              taskSet.id, index, info.host, threshold))
+          speculatableTasks += index
+          foundTasks = true
+        }
+      }
+    }
+    return foundTasks
   }
 }
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
new file mode 100644
index 0000000000000000000000000000000000000000..367c79dd7655336188097ad07e3d792b4333374b
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -0,0 +1,507 @@
+package spark.storage
+
+import java.io._
+import java.nio._
+import java.nio.channels.FileChannel.MapMode
+import java.util.{HashMap => JHashMap}
+import java.util.LinkedHashMap
+import java.util.UUID
+import java.util.Collections
+
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.Future
+import scala.actors.Futures.future
+import scala.actors.remote._
+import scala.actors.remote.RemoteActor._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+
+import it.unimi.dsi.fastutil.io._
+
+import spark.CacheTracker
+import spark.Logging
+import spark.Serializer
+import spark.SizeEstimator
+import spark.SparkEnv
+import spark.SparkException
+import spark.Utils
+import spark.network._
+
+class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
+  def this() = this(null, 0)
+
+  override def writeExternal(out: ObjectOutput) {
+    out.writeUTF(ip)
+    out.writeInt(port)
+  }
+
+  override def readExternal(in: ObjectInput) {
+    ip = in.readUTF()
+    port = in.readInt()
+  }
+
+  override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+
+  override def hashCode = ip.hashCode * 41 + port
+
+  override def equals(that: Any) = that match {
+    case id: BlockManagerId => port == id.port && ip == id.ip
+    case _ => false
+  }
+}
+
+
+case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message)
+
+
+class BlockLocker(numLockers: Int) {
+  private val hashLocker = Array.fill(numLockers)(new Object())
+  
+  def getLock(blockId: String): Object = {
+    return hashLocker(Math.abs(blockId.hashCode % numLockers))
+  }
+}
+
+
+/**
+ * A start towards a block manager class. This will eventually be used for both RDD persistence
+ * and shuffle outputs.
+ *
+ * TODO: Should make the communication with Master or Peers code more robust and log friendly.
+ */
+class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging {
+  
+  private val NUM_LOCKS = 337
+  private val locker = new BlockLocker(NUM_LOCKS)
+
+  private val storageLevels = Collections.synchronizedMap(new JHashMap[String, StorageLevel])
+  
+  private val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+  private val diskStore: BlockStore = new DiskStore(this, 
+    System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+  
+  val connectionManager = new ConnectionManager(0)
+  
+  val connectionManagerId = connectionManager.id
+  val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
+  
+  // TODO(Haoyuan): This will be removed after cacheTracker is removed from the code base.
+  var cacheTracker: CacheTracker = null
+
+  initLogging()
+
+  initialize()
+
+  /**
+   * Construct a BlockManager with a memory limit set based on system properties.
+   */
+  def this(serializer: Serializer) =
+    this(BlockManager.getMaxMemoryFromSystemProperties(), serializer)
+
+  /**
+   * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
+   * BlockManagerWorker actor.
+   */
+  def initialize() {
+    BlockManagerMaster.mustRegisterBlockManager(
+      RegisterBlockManager(blockManagerId, maxMemory, maxMemory))
+    BlockManagerWorker.startBlockManagerWorker(this)
+  }
+ 
+  /**
+   * Get locations of the block.
+   */
+  def getLocations(blockId: String): Seq[String] = {
+    val startTimeMs = System.currentTimeMillis
+    var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
+    val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq
+    logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
+    return locations
+  }
+
+  /**
+   * Get locations of an array of blocks
+   */
+  def getLocationsMultipleBlockIds(blockIds: Array[String]): Array[Seq[String]] = {
+    val startTimeMs = System.currentTimeMillis
+    val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds(
+      GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
+    logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+    return locations
+  }
+
+  def getLocal(blockId: String): Option[Iterator[Any]] = {
+    logDebug("Getting block " + blockId)
+    locker.getLock(blockId).synchronized {
+    
+      // Check storage level of block 
+      val level = storageLevels.get(blockId)
+      if (level != null) {
+        logDebug("Level for block " + blockId + " is " + level + " on local machine")
+        
+        // Look for the block in memory
+        if (level.useMemory) {
+          logDebug("Getting block " + blockId + " from memory")
+          memoryStore.getValues(blockId) match {
+            case Some(iterator) => {
+              logDebug("Block " + blockId + " found in memory")
+              return Some(iterator)
+            }
+            case None => {
+              logDebug("Block " + blockId + " not found in memory")
+            }
+          }
+        } else {
+          logDebug("Not getting block " + blockId + " from memory")
+        }
+
+        // Look for block in disk 
+        if (level.useDisk) {
+          logDebug("Getting block " + blockId + " from disk")
+          diskStore.getValues(blockId) match {
+            case Some(iterator) => {
+              logDebug("Block " + blockId + " found in disk")
+              return Some(iterator)
+            }
+            case None => {
+              throw new Exception("Block " + blockId + " not found in disk")
+              return None
+            }
+          }
+        } else {
+          logDebug("Not getting block " + blockId + " from disk")
+        }
+
+      } else {
+        logDebug("Level for block " + blockId + " not found")
+      }
+    } 
+    return None 
+  }
+
+  def getRemote(blockId: String): Option[Iterator[Any]] = {
+    // Get locations of block
+    val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId))
+
+    // Get block from remote locations
+    for (loc <- locations) {
+      val data = BlockManagerWorker.syncGetBlock(
+          GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
+      if (data != null) {
+        logDebug("Data is not null: " + data)
+        return Some(dataDeserialize(data))
+      }
+      logDebug("Data is null")
+    }
+    logDebug("Data not found")
+    return None
+  }
+
+  /**
+   * Read a block from the block manager.
+   */
+  def get(blockId: String): Option[Iterator[Any]] = {
+    getLocal(blockId).orElse(getRemote(blockId))
+  }
+
+  /**
+   * Read many blocks from block manager using their BlockManagerIds.
+   */
+  def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = {
+    logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks")
+    var startTime = System.currentTimeMillis
+    val blocks = new HashMap[String,Option[Iterator[Any]]]() 
+    val localBlockIds = new ArrayBuffer[String]()
+    val remoteBlockIds = new ArrayBuffer[String]()
+    val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]()
+
+    // Split local and remote blocks
+    for ((address, blockIds) <- blocksByAddress) {
+      if (address == blockManagerId) {
+        localBlockIds ++= blockIds
+      } else {
+        remoteBlockIds ++= blockIds
+        remoteBlockIdsPerLocation(address) = blockIds
+      }
+    }
+    
+    // Start getting remote blocks
+    val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) =>
+      val cmId = ConnectionManagerId(bmId.ip, bmId.port)
+      val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId)))
+      val blockMessageArray = new BlockMessageArray(blockMessages)
+      val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+      (cmId, future)
+    }
+    logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+    // Get the local blocks while remote blocks are being fetched
+    startTime = System.currentTimeMillis
+    localBlockIds.foreach(id => {
+      get(id) match {
+        case Some(block) => {
+          blocks.update(id, Some(block))
+          logDebug("Got local block " + id)
+        }
+        case None => {
+          throw new BlockException(id, "Could not get block " + id + " from local machine")
+        }
+      }
+    }) 
+    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+    // wait for and gather all the remote blocks
+    for ((cmId, future) <- remoteBlockFutures) {
+      var count = 0
+      val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first
+      future() match {
+        case Some(message) => {
+          val bufferMessage = message.asInstanceOf[BufferMessage]
+          val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+          blockMessageArray.foreach(blockMessage => {
+            if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+              throw new BlockException(oneBlockId, "Unexpected message received from " + cmId)
+            }
+            val buffer = blockMessage.getData()
+            val blockId = blockMessage.getId()
+            val block = dataDeserialize(buffer)
+            blocks.update(blockId, Some(block))
+            logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime))
+            count += 1
+          })
+        }
+        case None => {
+          throw new BlockException(oneBlockId, "Could not get blocks from " + cmId)
+        }
+      }
+      logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + Utils.getUsedTimeMs(startTime) + " ms")
+    }
+
+    logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+    return blocks
+  }
+
+  /**
+   * Write a new block to the block manager.
+   */
+  def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) {
+    if (!level.useDisk && !level.useMemory) {
+      throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
+    }
+
+    val startTimeMs = System.currentTimeMillis 
+    var bytes: ByteBuffer = null
+    
+    locker.getLock(blockId).synchronized {
+      logDebug("Put for block " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+        + " to get into synchronized block")
+      
+      // Check and warn if block with same id already exists 
+      if (storageLevels.get(blockId) != null) {
+        logWarning("Block " + blockId + " already exists in local machine")
+        return
+      }
+
+      // Store the storage level
+      storageLevels.put(blockId, level)
+      
+      if (level.useMemory && level.useDisk) {
+        // If saving to both memory and disk, then serialize only once 
+        memoryStore.putValues(blockId, values, level) match {
+          case Left(newValues) => 
+            diskStore.putValues(blockId, newValues, level) match {
+              case Right(newBytes) => bytes = newBytes
+              case _ => throw new Exception("Unexpected return value")
+            }
+          case Right(newBytes) =>
+            bytes = newBytes
+            diskStore.putBytes(blockId, newBytes, level)
+        }
+      } else if (level.useMemory) {
+        // If only save to memory 
+        memoryStore.putValues(blockId, values, level) match {
+          case Right(newBytes) => bytes = newBytes
+          case _ => 
+        }
+      } else {
+        // If only save to disk
+        diskStore.putValues(blockId, values, level) match {
+          case Right(newBytes) => bytes = newBytes
+          case _ => throw new Exception("Unexpected return value")
+        }
+      }
+        
+      if (tellMaster) {
+        notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
+        logDebug("Put block " + blockId + " after notifying the master " + Utils.getUsedTimeMs(startTimeMs))
+      }
+    }
+
+    // Replicate block if required 
+    if (level.replication > 1) {
+      if (bytes == null) {
+        bytes = dataSerialize(values) // serialize the block if not already done
+      }
+      replicate(blockId, bytes, level) 
+    }
+
+    // TODO(Haoyuan): This code will be removed when CacheTracker is gone.
+    if (blockId.startsWith("rdd")) {
+      notifyTheCacheTracker(blockId)
+    }
+    logDebug("Put block " + blockId + " after notifying the CacheTracker " + Utils.getUsedTimeMs(startTimeMs))
+  }
+
+
+  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+    val startTime = System.currentTimeMillis 
+    if (!level.useDisk && !level.useMemory) {
+      throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set")
+    } else if (level.deserialized) {
+      throw new IllegalArgumentException("Storage level cannot have deserialized when putBytes is used")
+    }
+    val replicationFuture = if (level.replication > 1) {
+      future {
+        replicate(blockId, bytes, level)
+      }
+    } else {
+      null
+    }
+
+    locker.getLock(blockId).synchronized {
+      logDebug("PutBytes for block " + blockId + " used " + Utils.getUsedTimeMs(startTime)
+        + " to get into synchronized block")
+      if (storageLevels.get(blockId) != null) {
+        logWarning("Block " + blockId + " already exists")
+        return
+      }
+      storageLevels.put(blockId, level)
+
+      if (level.useMemory) {
+        memoryStore.putBytes(blockId, bytes, level)
+      }
+      if (level.useDisk) {
+        diskStore.putBytes(blockId, bytes, level)
+      }
+      if (tellMaster) {
+        notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0))
+      }
+    }
+
+    if (blockId.startsWith("rdd")) {
+      notifyTheCacheTracker(blockId)
+    }
+    
+    if (level.replication > 1) {
+      if (replicationFuture == null) {
+        throw new Exception("Unexpected")
+      }
+      replicationFuture() 
+    }
+
+    val finishTime = System.currentTimeMillis
+    if (level.replication > 1) {
+      logDebug("PutBytes with replication took " + (finishTime - startTime) + " ms")
+    } else {
+      logDebug("PutBytes without replication took " + (finishTime - startTime) + " ms")
+    }
+
+  }
+
+  private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+    val tLevel: StorageLevel =
+      new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+    var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers(
+      GetPeers(blockManagerId, level.replication - 1))
+    for (peer: BlockManagerId <- peers) {
+      val start = System.nanoTime
+      logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+        + data.array().length + " Bytes. To node: " + peer)
+      if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
+        new ConnectionManagerId(peer.ip, peer.port))) {
+        logError("Failed to call syncPutBlock to " + peer)
+      }
+      logDebug("Replicated BlockId " + blockId + " once used " +
+        (System.nanoTime - start) / 1e6 + " s; The size of the data is " +
+        data.array().length + " bytes.")
+    }
+  }
+
+  // TODO(Haoyuan): This code will be removed when CacheTracker is gone.
+  def notifyTheCacheTracker(key: String) {
+    val rddInfo = key.split(":")
+    val rddId: Int = rddInfo(1).toInt
+    val splitIndex: Int = rddInfo(2).toInt
+    val host = System.getProperty("spark.hostname", Utils.localHostName)
+    cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host))
+  }
+
+  /**
+   * Read a block consisting of a single object.
+   */
+  def getSingle(blockId: String): Option[Any] = {
+    get(blockId).map(_.next)
+  }
+
+  /**
+   * Write a block consisting of a single object.
+   */
+  def putSingle(blockId: String, value: Any, level: StorageLevel) {
+    put(blockId, Iterator(value), level)
+  }
+
+  /**
+   * Drop block from memory (called when memory store has reached it limit)
+   */
+  def dropFromMemory(blockId: String) {
+    locker.getLock(blockId).synchronized {
+      val level = storageLevels.get(blockId)
+      if (level == null) {
+        logWarning("Block " + blockId + " cannot be removed from memory as it does not exist")
+        return
+      }
+      if (!level.useMemory) {
+        logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory")
+        return
+      }
+      memoryStore.remove(blockId)  
+      if (!level.useDisk) {
+        storageLevels.remove(blockId) 
+      } else {
+        val newLevel = level.clone 
+        newLevel.useMemory = false
+        storageLevels.remove(blockId)
+        storageLevels.put(blockId, newLevel)
+      }
+    }
+  }
+
+  def dataSerialize(values: Iterator[Any]): ByteBuffer = {
+    /*serializer.newInstance().serializeMany(values)*/
+    val byteStream = new FastByteArrayOutputStream(4096)
+    serializer.newInstance().serializeStream(byteStream).writeAll(values).close()
+    byteStream.trim()
+    ByteBuffer.wrap(byteStream.array)
+  }
+
+  def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = {
+    /*serializer.newInstance().deserializeMany(bytes)*/
+    val ser = serializer.newInstance()
+    return ser.deserializeStream(new FastByteArrayInputStream(bytes.array())).toIterator
+  }
+
+  private def notifyMaster(heartBeat: HeartBeat) {
+    BlockManagerMaster.mustHeartBeat(heartBeat)
+  }
+}
+
+object BlockManager extends Logging {
+  def getMaxMemoryFromSystemProperties(): Long = {
+    val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
+    val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong
+    logInfo("Maximum memory to use: " + bytes)
+    bytes
+  }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
new file mode 100644
index 0000000000000000000000000000000000000000..bd94c185e9a6287f9b3bf2dfe493611951438335
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -0,0 +1,516 @@
+package spark.storage
+
+import java.io._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.Random
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
+import akka.util.duration._
+
+import spark.Logging
+import spark.Utils
+
+sealed trait ToBlockManagerMaster
+
+case class RegisterBlockManager(
+    blockManagerId: BlockManagerId,
+    maxMemSize: Long,
+    maxDiskSize: Long)
+  extends ToBlockManagerMaster
+  
+class HeartBeat(
+    var blockManagerId: BlockManagerId,
+    var blockId: String,
+    var storageLevel: StorageLevel,
+    var deserializedSize: Long,
+    var size: Long)
+  extends ToBlockManagerMaster
+  with Externalizable {
+
+  def this() = this(null, null, null, 0, 0)  // For deserialization only
+
+  override def writeExternal(out: ObjectOutput) {
+    blockManagerId.writeExternal(out)
+    out.writeUTF(blockId)
+    storageLevel.writeExternal(out)
+    out.writeInt(deserializedSize.toInt)
+    out.writeInt(size.toInt)
+  }
+
+  override def readExternal(in: ObjectInput) {
+    blockManagerId = new BlockManagerId()
+    blockManagerId.readExternal(in)
+    blockId = in.readUTF()
+    storageLevel = new StorageLevel()
+    storageLevel.readExternal(in)
+    deserializedSize = in.readInt()
+    size = in.readInt()
+  }
+}
+
+object HeartBeat {
+  def apply(blockManagerId: BlockManagerId,
+      blockId: String,
+      storageLevel: StorageLevel,
+      deserializedSize: Long,
+      size: Long): HeartBeat = {
+    new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+  }
+
+ 
+  // For pattern-matching
+  def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+    Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size))
+  }
+}
+  
+case class GetLocations(
+    blockId: String)
+  extends ToBlockManagerMaster
+
+case class GetLocationsMultipleBlockIds(
+    blockIds: Array[String])
+  extends ToBlockManagerMaster
+  
+case class GetPeers(
+    blockManagerId: BlockManagerId,
+    size: Int)
+  extends ToBlockManagerMaster
+  
+case class RemoveHost(
+    host: String)
+  extends ToBlockManagerMaster
+
+class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging {
+  class BlockManagerInfo(
+      timeMs: Long,
+      maxMem: Long,
+      maxDisk: Long) {
+    private var lastSeenMs = timeMs
+    private var remainedMem = maxMem
+    private var remainedDisk = maxDisk
+    private val blocks = new HashMap[String, StorageLevel]
+    
+    def updateLastSeenMs() {
+      lastSeenMs = System.currentTimeMillis() / 1000
+    }
+    
+    def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) =
+        synchronized {
+      updateLastSeenMs()
+      
+      if (blocks.contains(blockId)) {
+        val oriLevel: StorageLevel = blocks(blockId)
+        
+        if (oriLevel.deserialized) {
+          remainedMem += deserializedSize
+        }
+        if (oriLevel.useMemory) {
+          remainedMem += size
+        }
+        if (oriLevel.useDisk) {
+          remainedDisk += size
+        }
+      }
+
+      blocks += (blockId -> storageLevel)
+
+      if (storageLevel.deserialized) {
+        remainedMem -= deserializedSize
+      }
+      if (storageLevel.useMemory) {
+        remainedMem -= size
+      }
+      if (storageLevel.useDisk) {
+        remainedDisk -= size
+      }
+      
+      if (!(storageLevel.deserialized || storageLevel.useMemory || storageLevel.useDisk)) {
+        blocks.remove(blockId)
+      }
+    }
+
+    def getLastSeenMs(): Long = {
+      return lastSeenMs
+    }
+    
+    def getRemainedMem(): Long = {
+      return remainedMem
+    }
+    
+    def getRemainedDisk(): Long = {
+      return remainedDisk
+    }
+
+    override def toString(): String = {
+      return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk  
+    }
+  }
+
+  private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
+  private val blockIdMap = new HashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+
+  initLogging()
+  
+  def removeHost(host: String) {
+    logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
+    logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
+    val ip = host.split(":")(0)
+    val port = host.split(":")(1)
+    blockManagerInfo.remove(new BlockManagerId(ip, port.toInt))
+    logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+    self.reply(true)
+  }
+
+  def receive = {
+    case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) =>
+      register(blockManagerId, maxMemSize, maxDiskSize)
+
+    case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+      heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size)
+
+    case GetLocations(blockId) =>
+      getLocations(blockId)
+
+    case GetLocationsMultipleBlockIds(blockIds) =>
+      getLocationsMultipleBlockIds(blockIds)
+
+    case GetPeers(blockManagerId, size) =>
+      getPeers_Deterministic(blockManagerId, size)
+      /*getPeers(blockManagerId, size)*/
+      
+    case RemoveHost(host) =>
+      removeHost(host)
+
+    case msg => 
+      logInfo("Got unknown msg: " + msg)
+  }
+  
+  private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) {
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " " + blockManagerId + " "
+    logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    logInfo("Got Register Msg from " + blockManagerId)
+    if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+      logInfo("Got Register Msg from master node, don't register it")
+    } else {
+      blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
+        System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize))
+    }
+    logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    self.reply(true)
+  }
+  
+  private def heartBeat(
+      blockManagerId: BlockManagerId,
+      blockId: String,
+      storageLevel: StorageLevel,
+      deserializedSize: Long,
+      size: Long) {
+    
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " " + blockManagerId + " " + blockId + " "
+    logDebug("Got in heartBeat 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    
+    if (blockId == null) {
+      blockManagerInfo(blockManagerId).updateLastSeenMs()
+      logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+      self.reply(true)
+    }
+    
+    blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size)
+    logDebug("Got in heartBeat 2" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+    
+    var locations: HashSet[BlockManagerId] = null
+    if (blockIdMap.contains(blockId)) {
+      locations = blockIdMap(blockId)._2
+    } else {
+      locations = new HashSet[BlockManagerId]
+      blockIdMap += (blockId -> (storageLevel.replication, locations))
+    }
+    logDebug("Got in heartBeat 3" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+    
+    if (storageLevel.deserialized || storageLevel.useDisk || storageLevel.useMemory) {
+      locations += blockManagerId
+    } else {
+      locations.remove(blockManagerId)
+    }
+    logDebug("Got in heartBeat 4" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+    
+    if (locations.size == 0) {
+      blockIdMap.remove(blockId)
+    }
+    
+    logDebug("Got in heartBeat 5" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
+    self.reply(true)
+  }
+  
+  private def getLocations(blockId: String) {
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " " + blockId + " "
+    logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    if (blockIdMap.contains(blockId)) {
+      var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+      res.appendAll(blockIdMap(blockId)._2)
+      logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " 
+          + Utils.getUsedTimeMs(startTimeMs))
+      self.reply(res.toSeq)
+    } else {
+      logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
+      var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+      self.reply(res)
+    }
+  }
+  
+  private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
+    def getLocations(blockId: String): Seq[BlockManagerId] = {
+      val tmp = blockId
+      logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
+      if (blockIdMap.contains(blockId)) {
+        var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+        res.appendAll(blockIdMap(blockId)._2)
+        logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
+        return res.toSeq
+      } else {
+        logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
+        var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+        return res.toSeq
+      }
+    }
+    
+    logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
+    var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
+    for (blockId <- blockIds) {
+      res.append(getLocations(blockId))
+    }
+    logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
+    self.reply(res.toSeq)
+  }
+
+  private def getPeers(blockManagerId: BlockManagerId, size: Int) {
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " " + blockManagerId + " "
+    logDebug("Got in getPeers 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+    var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+    res.appendAll(peers)
+    res -= blockManagerId
+    val rand = new Random(System.currentTimeMillis())
+    logDebug("Got in getPeers 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    while (res.length > size) {
+      res.remove(rand.nextInt(res.length))
+    }
+    logDebug("Got in getPeers 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
+    self.reply(res.toSeq)
+  }
+  
+  private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) {
+    val startTimeMs = System.currentTimeMillis()
+    var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+    var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+
+    val peersWithIndices = peers.zipWithIndex
+    val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
+    if (selfIndex == -1) {
+      throw new Exception("Self index for " + blockManagerId + " not found")
+    }
+
+    var index = selfIndex
+    while (res.size < size) {
+      index += 1
+      if (index == selfIndex) {
+        throw new Exception("More peer expected than available")
+      }
+      res += peers(index % peers.size)
+    }
+    val resStr = res.map(_.toString).reduceLeft(_ + ", " + _)
+    logDebug("Got peers for " + blockManagerId + " as [" + resStr + "]")
+    self.reply(res.toSeq)
+  }
+}
+
+object BlockManagerMaster extends Logging {
+  initLogging()
+
+  val AKKA_ACTOR_NAME: String = "BlockMasterManager"
+  val REQUEST_RETRY_INTERVAL_MS = 100
+  val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
+  val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
+  val DEFAULT_MANAGER_IP: String = Utils.localHostName()
+  val DEFAULT_MANAGER_PORT: String = "10902"
+
+  implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis)
+  var masterActor: ActorRef = null
+
+  def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) {
+    if (isMaster) {
+      masterActor = actorOf(new BlockManagerMaster(isLocal))
+      remote.register(AKKA_ACTOR_NAME, masterActor)
+      logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT)
+      masterActor.start()
+    } else {
+      masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT)
+    }
+  }
+  
+  def notifyADeadHost(host: String) {
+    (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match {
+      case Some(true) =>
+        logInfo("Removed " + host + " successfully. @ notifyADeadHost")
+      case Some(oops) =>
+        logError("Failed @ notifyADeadHost: " + oops)
+      case None =>
+        logError("None @ notifyADeadHost.")
+    }
+  }
+
+  def mustRegisterBlockManager(msg: RegisterBlockManager) {
+    while (! syncRegisterBlockManager(msg)) {
+      logWarning("Failed to register " + msg)
+      Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+    }
+  }
+
+  def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
+    //val masterActor = RemoteActor.select(node, name)
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " msg " + msg + " "
+    logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+    
+    (masterActor ? msg).as[Any] match {
+      case Some(true) => 
+        logInfo("BlockManager registered successfully @ syncRegisterBlockManager.")
+        logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+        return true
+      case Some(oops) =>
+        logError("Failed @ syncRegisterBlockManager: " + oops)
+        return false
+      case None =>
+        logError("None @ syncRegisterBlockManager.")
+        return false
+    }
+  }
+  
+  def mustHeartBeat(msg: HeartBeat) {
+    while (! syncHeartBeat(msg)) {
+      logWarning("Failed to send heartbeat" + msg)
+      Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+    }
+  }
+  
+  def syncHeartBeat(msg: HeartBeat): Boolean = {
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " msg " + msg + " "
+    logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
+    
+    (masterActor ? msg).as[Any] match {
+      case Some(true) =>
+        logInfo("Heartbeat sent successfully.")
+        logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
+        return true
+      case Some(oops) =>
+        logError("Failed: " + oops)
+        return false
+      case None => 
+        logError("None.")
+        return false
+    }
+  }
+  
+  def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = {
+    var res: Array[BlockManagerId] = syncGetLocations(msg)
+    while (res == null) {
+      logInfo("Failed to get locations " + msg)
+      Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+      res = syncGetLocations(msg)
+    }
+    return res
+  }
+  
+  def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = {
+    val startTimeMs = System.currentTimeMillis()
+    val tmp = " msg " + msg + " "
+    logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+    
+    (masterActor ? msg).as[Seq[BlockManagerId]] match {
+      case Some(arr) =>
+        logDebug("GetLocations successfully.")
+        logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+        val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+        for (ele <- arr) {
+          res += ele
+        }
+        logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+        return res.toArray
+      case None => 
+        logError("GetLocations call returned None.")
+        return null
+    }
+  }
+
+  def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
+       Seq[Seq[BlockManagerId]] = {
+    var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
+    while (res == null) {
+      logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
+      Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+      res = syncGetLocationsMultipleBlockIds(msg)
+    }
+    return res
+  }
+  
+  def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
+      Seq[Seq[BlockManagerId]] = {
+    val startTimeMs = System.currentTimeMillis
+    val tmp = " msg " + msg + " "
+    logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+    
+    (masterActor ? msg).as[Any] match {
+      case Some(arr: Seq[Seq[BlockManagerId]]) =>
+        logDebug("GetLocationsMultipleBlockIds successfully: " + arr)
+        logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+        return arr
+      case Some(oops) =>
+        logError("Failed: " + oops)
+        return null
+      case None => 
+        logInfo("None.")
+        return null
+    }
+  }
+  
+  def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = {
+    var res: Array[BlockManagerId] = syncGetPeers(msg)
+    while ((res == null) || (res.length != msg.size)) {
+      logInfo("Failed to get peers " + msg)
+      Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+      res = syncGetPeers(msg)
+    }
+    
+    return res
+  }
+  
+  def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = {
+    val startTimeMs = System.currentTimeMillis
+    val tmp = " msg " + msg + " "
+    logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+    
+    (masterActor ? msg).as[Seq[BlockManagerId]] match {
+      case Some(arr) =>
+        logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+        val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+        logInfo("GetPeers successfully: " + arr.length)
+        res.appendAll(arr)
+        logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs))
+        return res.toArray
+      case None => 
+        logError("GetPeers call returned None.")
+        return null
+    }
+  }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a4cdbd8ddd3aa305263a7792022a973b265b86aa
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala
@@ -0,0 +1,142 @@
+package spark.storage
+
+import java.nio._
+
+import scala.actors._
+import scala.actors.Actor._
+import scala.actors.remote._
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.Random
+
+import spark.Logging
+import spark.Utils
+import spark.SparkEnv
+import spark.network._
+
+/**
+ * This should be changed to use event model late. 
+ */
+class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
+  initLogging()
+  
+  blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
+
+  def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
+    logDebug("Handling message " + msg)
+    msg match {
+      case bufferMessage: BufferMessage => {
+        try {
+          logDebug("Handling as a buffer message " + bufferMessage)
+          val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
+          logDebug("Parsed as a block message array")
+          val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get)
+          /*logDebug("Processed block messages")*/
+          return Some(new BlockMessageArray(responseMessages).toBufferMessage)
+        } catch {
+          case e: Exception => logError("Exception handling buffer message: " + e.getMessage)
+          return None
+        }
+      }
+      case otherMessage: Any => {
+        logError("Unknown type message received: " + otherMessage)
+        return None
+      }
+    }
+  }
+
+  def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
+    blockMessage.getType() match {
+      case BlockMessage.TYPE_PUT_BLOCK => {
+        val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
+        logInfo("Received [" + pB + "]")
+        putBlock(pB.id, pB.data, pB.level)
+        return None
+      } 
+      case BlockMessage.TYPE_GET_BLOCK => {
+        val gB = new GetBlock(blockMessage.getId())
+        logInfo("Received [" + gB + "]")
+        val buffer = getBlock(gB.id)
+        if (buffer == null) {
+          return None
+        }
+        return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
+      }
+      case _ => return None
+    }
+  }
+
+  private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+    val startTimeMs = System.currentTimeMillis()
+    logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
+    blockManager.putBytes(id, bytes, level)
+    logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+        + " with data size: " + bytes.array().length)
+  }
+
+  private def getBlock(id: String): ByteBuffer = {
+    val startTimeMs = System.currentTimeMillis()
+    logDebug("Getblock " + id + " started from " + startTimeMs)
+    val block = blockManager.get(id)
+    val buffer = block match {
+      case Some(tValues) => {
+        val values = tValues.asInstanceOf[Iterator[Any]]
+        val buffer = blockManager.dataSerialize(values)
+        buffer
+      }
+      case None => { 
+        null
+      }
+    }
+    logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+        + " and got buffer " + buffer)
+    return buffer
+  }
+}
+
+object BlockManagerWorker extends Logging {
+  private var blockManagerWorker: BlockManagerWorker = null
+  private val DATA_TRANSFER_TIME_OUT_MS: Long = 500
+  private val REQUEST_RETRY_INTERVAL_MS: Long = 1000
+  
+  initLogging()
+  
+  def startBlockManagerWorker(manager: BlockManager) {
+    blockManagerWorker = new BlockManagerWorker(manager)
+  }
+  
+  def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
+    val blockManager = blockManagerWorker.blockManager
+    val connectionManager = blockManager.connectionManager 
+    val serializer = blockManager.serializer
+    val blockMessage = BlockMessage.fromPutBlock(msg)
+    val blockMessageArray = new BlockMessageArray(blockMessage)
+    val resultMessage = connectionManager.sendMessageReliablySync(
+        toConnManagerId, blockMessageArray.toBufferMessage())
+    return (resultMessage != None)
+  }
+  
+  def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
+    val blockManager = blockManagerWorker.blockManager
+    val connectionManager = blockManager.connectionManager 
+    val serializer = blockManager.serializer
+    val blockMessage = BlockMessage.fromGetBlock(msg)
+    val blockMessageArray = new BlockMessageArray(blockMessage)
+    val responseMessage = connectionManager.sendMessageReliablySync(
+        toConnManagerId, blockMessageArray.toBufferMessage())
+    responseMessage match {
+      case Some(message) => {
+        val bufferMessage = message.asInstanceOf[BufferMessage]
+        logDebug("Response message received " + bufferMessage)
+        BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
+            logDebug("Found " + blockMessage)
+            return blockMessage.getData
+          })
+      }
+      case None => logDebug("No response message received"); return null
+    }
+    return null
+  }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
new file mode 100644
index 0000000000000000000000000000000000000000..bb128dce7a6b8ad45c476c59d87ccf17c77ab667
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -0,0 +1,219 @@
+package spark.storage
+
+import java.nio._
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.network._
+
+case class GetBlock(id: String)
+case class GotBlock(id: String, data: ByteBuffer)
+case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) 
+
+class BlockMessage() extends Logging{
+  // Un-initialized: typ = 0
+  // GetBlock: typ = 1
+  // GotBlock: typ = 2
+  // PutBlock: typ = 3
+  private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
+  private var id: String = null
+  private var data: ByteBuffer = null
+  private var level: StorageLevel = null
+ 
+  initLogging()
+
+  def set(getBlock: GetBlock) {
+    typ = BlockMessage.TYPE_GET_BLOCK
+    id = getBlock.id
+  }
+
+  def set(gotBlock: GotBlock) {
+    typ = BlockMessage.TYPE_GOT_BLOCK
+    id = gotBlock.id
+    data = gotBlock.data
+  }
+
+  def set(putBlock: PutBlock) {
+    typ = BlockMessage.TYPE_PUT_BLOCK
+    id = putBlock.id
+    data = putBlock.data
+    level = putBlock.level
+  }
+
+  def set(buffer: ByteBuffer) {
+    val startTime = System.currentTimeMillis
+    /*
+    println()
+    println("BlockMessage: ")
+    while(buffer.remaining > 0) {
+      print(buffer.get())
+    }
+    buffer.rewind()
+    println()
+    println()
+    */
+    typ = buffer.getInt()
+    val idLength = buffer.getInt()
+    val idBuilder = new StringBuilder(idLength)
+    for (i <- 1 to idLength) {
+      idBuilder += buffer.getChar()
+    }
+    id = idBuilder.toString()
+    
+    logDebug("Set from buffer Result: " + typ + " " + id)
+    logDebug("Buffer position is " + buffer.position)
+    if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+
+      val booleanInt = buffer.getInt()
+      val replication = buffer.getInt()
+      level = new StorageLevel(booleanInt, replication)
+      
+      val dataLength = buffer.getInt()
+      data = ByteBuffer.allocate(dataLength)
+      if (dataLength != buffer.remaining) {
+        throw new Exception("Error parsing buffer")
+      }
+      data.put(buffer)
+      data.flip()
+      logDebug("Set from buffer Result 2: " + level + " " + data)
+    } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+
+      val dataLength = buffer.getInt()
+      logDebug("Data length is "+ dataLength)
+      logDebug("Buffer position is " + buffer.position)
+      data = ByteBuffer.allocate(dataLength)
+      if (dataLength != buffer.remaining) {
+        throw new Exception("Error parsing buffer")
+      }
+      data.put(buffer)
+      data.flip()
+      logDebug("Set from buffer Result 3: " + data)
+    }
+
+    val finishTime = System.currentTimeMillis
+    logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0  + " s")
+  }
+
+  def set(bufferMsg: BufferMessage) {
+    val buffer = bufferMsg.buffers.apply(0)
+    buffer.clear()
+    set(buffer)
+  }
+  
+  def getType(): Int = {
+    return typ
+  }
+  
+  def getId(): String = {
+    return id
+  }
+  
+  def getData(): ByteBuffer = {
+    return data
+  }
+  
+  def getLevel(): StorageLevel = {
+    return level
+  }
+  
+  def toBufferMessage(): BufferMessage = {
+    val startTime = System.currentTimeMillis
+    val buffers = new ArrayBuffer[ByteBuffer]()
+    var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
+    buffer.putInt(typ).putInt(id.length())
+    id.foreach((x: Char) => buffer.putChar(x))
+    buffer.flip()
+    buffers += buffer
+
+    if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+      buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication)
+      buffer.flip()
+      buffers += buffer
+      
+      buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+      buffer.flip()
+      buffers += buffer
+
+      buffers += data
+    } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+      buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+      buffer.flip()
+      buffers += buffer
+
+      buffers += data
+    }
+    
+    logDebug("Start to log buffers.")
+    buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+    /*
+    println()
+    println("BlockMessage: ")
+    buffers.foreach(b => {
+      while(b.remaining > 0) {
+        print(b.get())
+      }
+      b.rewind()
+    })
+    println()
+    println()
+    */
+    val finishTime = System.currentTimeMillis
+    logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0  + " s")
+    return Message.createBufferMessage(buffers)
+  }
+
+  override def toString(): String = { 
+    "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + 
+    ", data = " + (if (data != null) data.remaining.toString  else "null") + "]"
+  }
+}
+
+object BlockMessage {
+  val TYPE_NON_INITIALIZED: Int = 0
+  val TYPE_GET_BLOCK: Int = 1
+  val TYPE_GOT_BLOCK: Int = 2
+  val TYPE_PUT_BLOCK: Int = 3
+ 
+  def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(bufferMessage)
+    newBlockMessage
+  }
+
+  def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(buffer)
+    newBlockMessage
+  }
+
+  def fromGetBlock(getBlock: GetBlock): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(getBlock)
+    newBlockMessage
+  }
+
+  def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(gotBlock)
+    newBlockMessage
+  }
+  
+  def fromPutBlock(putBlock: PutBlock): BlockMessage = {
+    val newBlockMessage = new BlockMessage()
+    newBlockMessage.set(putBlock)
+    newBlockMessage
+  }
+
+  def main(args: Array[String]) {
+    val B = new BlockMessage()
+    B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2))
+    val bMsg = B.toBufferMessage()
+    val C = new BlockMessage()
+    C.set(bMsg)
+    
+    println(B.getId() + " " + B.getLevel())
+    println(C.getId() + " " + C.getLevel())
+  }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5f411d34884e12871405b12b24dcb0765af01427
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala
@@ -0,0 +1,140 @@
+package spark.storage
+import java.nio._
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.network._
+
+class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
+  
+  def this(bm: BlockMessage) = this(Array(bm))
+
+  def this() = this(null.asInstanceOf[Seq[BlockMessage]])
+
+  def apply(i: Int) = blockMessages(i) 
+
+  def iterator = blockMessages.iterator
+
+  def length = blockMessages.length 
+
+  initLogging()
+  
+  def set(bufferMessage: BufferMessage) {
+    val startTime = System.currentTimeMillis
+    val newBlockMessages = new ArrayBuffer[BlockMessage]()
+    val buffer = bufferMessage.buffers(0)
+    buffer.clear()
+    /*
+    println()
+    println("BlockMessageArray: ")
+    while(buffer.remaining > 0) {
+      print(buffer.get())
+    }
+    buffer.rewind()
+    println()
+    println()
+    */
+    while(buffer.remaining() > 0) {
+      val size = buffer.getInt()
+      logDebug("Creating block message of size " + size + " bytes")
+      val newBuffer = buffer.slice()
+      newBuffer.clear()
+      newBuffer.limit(size)
+      logDebug("Trying to convert buffer " + newBuffer + " to block message")
+      val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
+      logDebug("Created " + newBlockMessage)
+      newBlockMessages += newBlockMessage 
+      buffer.position(buffer.position() + size)
+    }
+    val finishTime = System.currentTimeMillis
+    logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0  + " s")
+    this.blockMessages = newBlockMessages 
+  }
+  
+  def toBufferMessage(): BufferMessage = {
+    val buffers = new ArrayBuffer[ByteBuffer]()
+
+    blockMessages.foreach(blockMessage => {
+      val bufferMessage = blockMessage.toBufferMessage
+      logDebug("Adding " + blockMessage)
+      val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
+      sizeBuffer.flip
+      buffers += sizeBuffer
+      buffers ++= bufferMessage.buffers
+      logDebug("Added " + bufferMessage)
+    })
+   
+    logDebug("Buffer list:")
+    buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+    /*
+    println()
+    println("BlockMessageArray: ")
+    buffers.foreach(b => {
+      while(b.remaining > 0) {
+        print(b.get())
+      }
+      b.rewind()
+    })
+    println()
+    println()
+    */
+    return Message.createBufferMessage(buffers)
+  }
+}
+
+object BlockMessageArray {
+ 
+  def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
+    val newBlockMessageArray = new BlockMessageArray()
+    newBlockMessageArray.set(bufferMessage)
+    newBlockMessageArray
+  }
+  
+  def main(args: Array[String]) {
+    val blockMessages = 
+      (0 until 10).map(i => {
+        if (i % 2 == 0) {
+          val buffer =  ByteBuffer.allocate(100)
+          buffer.clear
+          BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY))
+        } else {
+          BlockMessage.fromGetBlock(GetBlock(i.toString))
+        }
+      })
+    val blockMessageArray = new BlockMessageArray(blockMessages)
+    println("Block message array created")
+    
+    val bufferMessage = blockMessageArray.toBufferMessage
+    println("Converted to buffer message")
+    
+    val totalSize = bufferMessage.size
+    val newBuffer = ByteBuffer.allocate(totalSize)
+    newBuffer.clear()
+    bufferMessage.buffers.foreach(buffer => {
+      newBuffer.put(buffer)
+      buffer.rewind()
+    })
+    newBuffer.flip
+    val newBufferMessage = Message.createBufferMessage(newBuffer) 
+    println("Copied to new buffer message, size = " + newBufferMessage.size)
+
+    val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
+    println("Converted back to block message array")
+    newBlockMessageArray.foreach(blockMessage => {
+      blockMessage.getType() match {
+        case BlockMessage.TYPE_PUT_BLOCK => {
+          val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel())
+          println(pB)
+        } 
+        case BlockMessage.TYPE_GET_BLOCK => {
+          val gB = new GetBlock(blockMessage.getId())
+          println(gB)
+        }
+      }
+    })
+  }
+}
+
+
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0584cc2d4f3992db7b43f445342f0c59c7eed835
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -0,0 +1,282 @@
+package spark.storage
+
+import spark.{Utils, Logging, Serializer, SizeEstimator}
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{File, RandomAccessFile}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel.MapMode
+import java.util.{UUID, LinkedHashMap}
+import java.util.concurrent.Executors
+
+import it.unimi.dsi.fastutil.io._
+
+/**
+ * Abstract class to store blocks
+ */
+abstract class BlockStore(blockManager: BlockManager) extends Logging {
+  initLogging()
+
+  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) 
+
+  def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer]
+
+  def getBytes(blockId: String): Option[ByteBuffer]
+
+  def getValues(blockId: String): Option[Iterator[Any]]
+
+  def remove(blockId: String)
+
+  def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values)
+
+  def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes)
+}
+
+/**
+ * Class to store blocks in memory 
+ */
+class MemoryStore(blockManager: BlockManager, maxMemory: Long) 
+  extends BlockStore(blockManager) {
+
+  class Entry(var value: Any, val size: Long, val deserialized: Boolean)
+  
+  private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true)
+  private var currentMemory = 0L
+ 
+  private val blockDropper = Executors.newSingleThreadExecutor() 
+
+  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+    if (level.deserialized) {
+      bytes.rewind()
+      val values = dataDeserialize(bytes)
+      val elements = new ArrayBuffer[Any]
+      elements ++= values
+      val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+      ensureFreeSpace(sizeEstimate)
+      val entry = new Entry(elements, sizeEstimate, true)
+      memoryStore.synchronized { memoryStore.put(blockId, entry) }
+      currentMemory += sizeEstimate
+      logDebug("Block " + blockId + " stored as values to memory")
+    } else {
+      val entry = new Entry(bytes, bytes.array().length, false)
+      ensureFreeSpace(bytes.array.length)
+      memoryStore.synchronized { memoryStore.put(blockId, entry) }
+      currentMemory += bytes.array().length
+      logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory")
+    }
+  }
+
+  def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
+    if (level.deserialized) {
+      val elements = new ArrayBuffer[Any]
+      elements ++= values
+      val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+      ensureFreeSpace(sizeEstimate)
+      val entry = new Entry(elements, sizeEstimate, true)
+      memoryStore.synchronized { memoryStore.put(blockId, entry) }
+      currentMemory += sizeEstimate
+      logDebug("Block " + blockId + " stored as values to memory")
+      return Left(elements.iterator) 
+    } else {
+      val bytes = dataSerialize(values)
+      ensureFreeSpace(bytes.array().length)
+      val entry = new Entry(bytes, bytes.array().length, false)
+      memoryStore.synchronized { memoryStore.put(blockId, entry) } 
+      currentMemory += bytes.array().length
+      logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory")
+      return Right(bytes)
+    }
+  }
+
+  def getBytes(blockId: String): Option[ByteBuffer] = {
+    throw new UnsupportedOperationException("Not implemented") 
+  }
+
+  def getValues(blockId: String): Option[Iterator[Any]] = {
+    val entry = memoryStore.synchronized { memoryStore.get(blockId) }
+    if (entry == null) {
+      return None 
+    }
+    if (entry.deserialized) {
+      return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator)
+    } else {
+      return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) 
+    }
+  }
+
+  def remove(blockId: String) {
+    memoryStore.synchronized {
+      val entry = memoryStore.get(blockId) 
+      if (entry != null) {
+        memoryStore.remove(blockId)
+        currentMemory -= entry.size
+        logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory")
+      } else {
+        logWarning("Block " + blockId + " could not be removed as it doesnt exist")
+      }
+    }
+  }
+
+  private def drop(blockId: String) {
+    blockDropper.submit(new Runnable() {
+      def run() {
+        blockManager.dropFromMemory(blockId)
+      }
+    })
+  }
+
+  private def ensureFreeSpace(space: Long) {
+    logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
+      space, currentMemory, maxMemory))
+    
+    val droppedBlockIds = new ArrayBuffer[String]()
+    var droppedMemory = 0L
+    
+    memoryStore.synchronized {
+      val iter = memoryStore.entrySet().iterator()
+      while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) {
+        val pair = iter.next()
+        val blockId = pair.getKey
+        droppedBlockIds += blockId
+        droppedMemory += pair.getValue.size
+        logDebug("Decided to drop " + blockId)
+      }
+    }  
+    
+    for (blockId <- droppedBlockIds) {
+      drop(blockId)
+    }
+
+    droppedBlockIds.clear
+  }
+}
+
+
+/**
+ * Class to store blocks in disk 
+ */
+class DiskStore(blockManager: BlockManager, rootDirs: String) 
+  extends BlockStore(blockManager) {
+
+  val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+  val localDirs = createLocalDirs()
+  var lastLocalDirUsed = 0
+
+  addShutdownHook()
+  
+  def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+    logDebug("Attempting to put block " + blockId)
+    val startTime = System.currentTimeMillis
+    val file = createFile(blockId)
+    if (file != null) {
+      val channel = new RandomAccessFile(file, "rw").getChannel()
+      val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length)
+      buffer.put(bytes.array)
+      channel.close()
+      val finishTime = System.currentTimeMillis
+      logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms")
+    } else {
+      logError("File not created for block " + blockId)
+    }
+  }
+
+  def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = {
+    val bytes = dataSerialize(values) 
+    logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes")
+    putBytes(blockId, bytes, level)
+    return Right(bytes)
+  }
+
+  def getBytes(blockId: String): Option[ByteBuffer] = {
+    val file = getFile(blockId) 
+    val length = file.length().toInt
+    val channel = new RandomAccessFile(file, "r").getChannel()
+    val bytes = ByteBuffer.allocate(length)
+    bytes.put(channel.map(MapMode.READ_WRITE, 0, length))
+    return Some(bytes)  
+  }
+
+  def getValues(blockId: String): Option[Iterator[Any]] = {
+    val file = getFile(blockId) 
+    val length = file.length().toInt
+    val channel = new RandomAccessFile(file, "r").getChannel()
+    val bytes = channel.map(MapMode.READ_ONLY, 0, length)
+    val buffer = dataDeserialize(bytes)
+    channel.close()
+    return Some(buffer) 
+  }
+
+  def remove(blockId: String) {
+    throw new UnsupportedOperationException("Not implemented") 
+  }
+  
+  private def createFile(blockId: String): File = {
+    val file = getFile(blockId) 
+    if (file == null) {
+      lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size
+      val newFile = new File(localDirs(lastLocalDirUsed), blockId)
+      newFile.getParentFile.mkdirs()
+      return newFile 
+    } else {
+      logError("File for block " + blockId + " already exists on disk, " + file)
+      return null
+    }
+  }
+
+  private def getFile(blockId: String): File = {
+    logDebug("Getting file for block " + blockId)
+    // Search for the file in all the local directories, only one of them should have the file
+    val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists)  
+    if (files.size > 1) {
+      throw new Exception("Multiple files for same block " + blockId + " exists: " + 
+        files.map(_.toString).reduceLeft(_ + ", " + _))
+      return null
+    } else if (files.size == 0) {
+      return null 
+    } else {
+      logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes")
+      return files(0)
+    }
+  }
+
+  private def createLocalDirs(): Seq[File] = {
+    logDebug("Creating local directories at root dirs '" + rootDirs + "'") 
+    rootDirs.split("[;,:]").map(rootDir => {
+        var foundLocalDir: Boolean = false
+        var localDir: File = null
+        var localDirUuid: UUID = null
+        var tries = 0
+        while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+          tries += 1
+          try {
+            localDirUuid = UUID.randomUUID()
+            localDir = new File(rootDir, "spark-local-" + localDirUuid)
+            if (!localDir.exists) {
+              localDir.mkdirs()
+              foundLocalDir = true
+            }
+          } catch {
+            case e: Exception =>
+            logWarning("Attempt " + tries + " to create local dir failed", e)
+          }
+        }
+        if (!foundLocalDir) {
+          logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + 
+            " attempts to create local dir in " + rootDir)
+          System.exit(1)
+        }
+        logDebug("Created local directory at " + localDir)
+        localDir
+    })
+  }
+
+  private def addShutdownHook() {
+    Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+      override def run() {
+        logDebug("Shutdown hook called")
+        localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+      }
+    })
+  }
+}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a2833a709063986d773dd356976392cf3a9c5a08
--- /dev/null
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -0,0 +1,78 @@
+package spark.storage
+
+import java.io._
+
+class StorageLevel(
+    var useDisk: Boolean, 
+    var useMemory: Boolean,
+    var deserialized: Boolean,
+    var replication: Int = 1)
+  extends Externalizable {
+
+  // TODO: Also add fields for caching priority, dataset ID, and flushing.
+  
+  def this(booleanInt: Int, replication: Int) {
+    this(((booleanInt & 4) != 0),
+        ((booleanInt & 2) != 0), 
+        ((booleanInt & 1) != 0),
+        replication)
+  }
+
+  def this() = this(false, true, false)  // For deserialization
+
+  override def clone(): StorageLevel = new StorageLevel(
+    this.useDisk, this.useMemory, this.deserialized, this.replication)
+
+  override def equals(other: Any): Boolean = other match {
+    case s: StorageLevel =>
+      s.useDisk == useDisk && 
+      s.useMemory == useMemory &&
+      s.deserialized == deserialized &&
+      s.replication == replication 
+    case _ =>
+      false
+  }
+  
+  def toInt(): Int = {
+    var ret = 0
+    if (useDisk) {
+      ret += 4
+    }
+    if (useMemory) {
+      ret += 2
+    }
+    if (deserialized) {
+      ret += 1
+    }
+    return ret
+  }
+
+  override def writeExternal(out: ObjectOutput) {
+    out.writeByte(toInt().toByte)
+    out.writeByte(replication.toByte)
+  }
+
+  override def readExternal(in: ObjectInput) {
+    val flags = in.readByte()
+    useDisk = (flags & 4) != 0
+    useMemory = (flags & 2) != 0
+    deserialized = (flags & 1) != 0
+    replication = in.readByte()
+  }
+
+  override def toString(): String =
+    "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+}
+
+object StorageLevel {
+  val NONE = new StorageLevel(false, false, false)
+  val DISK_ONLY = new StorageLevel(true, false, false)
+  val MEMORY_ONLY = new StorageLevel(false, true, false)
+  val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2)
+  val MEMORY_ONLY_DESER = new StorageLevel(false, true, true)
+  val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2)
+  val DISK_AND_MEMORY = new StorageLevel(true, true, false)
+  val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2)
+  val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true)
+  val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2)
+}
diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
new file mode 100644
index 0000000000000000000000000000000000000000..abe2d99dd8a5f6814aa57c4ee2fc15fb08b09ac2
--- /dev/null
+++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
@@ -0,0 +1,30 @@
+package spark.util
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+
+class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
+  override def read(): Int = {
+    if (buffer.remaining() == 0) {
+      -1
+    } else {
+      buffer.get()
+    }
+  }
+
+  override def read(dest: Array[Byte]): Int = {
+    read(dest, 0, dest.length)
+  }
+
+  override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
+    val amountToGet = math.min(buffer.remaining(), length)
+    buffer.get(dest, offset, amountToGet)
+    return amountToGet
+  }
+
+  override def skip(bytes: Long): Long = {
+    val amountToSkip = math.min(bytes, buffer.remaining).toInt
+    buffer.position(buffer.position + amountToSkip)
+    return amountToSkip
+  }
+}
diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala
new file mode 100644
index 0000000000000000000000000000000000000000..efb1ae75290f5482cb44d46b3222d34b283d9270
--- /dev/null
+++ b/core/src/main/scala/spark/util/StatCounter.scala
@@ -0,0 +1,89 @@
+package spark.util
+
+/**
+ * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
+ * numerically robust way. Includes support for merging two StatCounters. Based on Welford and
+ * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance.
+ */
+class StatCounter(values: TraversableOnce[Double]) {
+  private var n: Long = 0     // Running count of our values
+  private var mu: Double = 0  // Running mean of our values
+  private var m2: Double = 0  // Running variance numerator (sum of (x - mean)^2)
+
+  merge(values)
+
+  def this() = this(Nil)
+
+  def merge(value: Double): StatCounter = {
+    val delta = value - mu
+    n += 1
+    mu += delta / n
+    m2 += delta * (value - mu)
+    this
+  }
+
+  def merge(values: TraversableOnce[Double]): StatCounter = {
+    values.foreach(v => merge(v))
+    this
+  }
+
+  def merge(other: StatCounter): StatCounter = {
+    if (other == this) {
+      merge(other.copy())  // Avoid overwriting fields in a weird order
+    } else {
+      val delta = other.mu - mu
+      if (other.n * 10 < n) {
+        mu = mu + (delta * other.n) / (n + other.n)
+      } else if (n * 10 < other.n) {
+        mu = other.mu - (delta * n) / (n + other.n)
+      } else {
+        mu = (mu * n + other.mu * other.n) / (n + other.n)
+      }
+      m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+      n += other.n
+      this
+    }
+  }
+
+  def copy(): StatCounter = {
+    val other = new StatCounter
+    other.n = n
+    other.mu = mu
+    other.m2 = m2
+    other
+  }
+
+  def count: Long = n
+
+  def mean: Double = mu
+
+  def sum: Double = n * mu
+
+  def variance: Double = {
+    if (n == 0)
+      Double.NaN
+    else
+      m2 / n
+  }
+
+  def sampleVariance: Double = {
+    if (n <= 1)
+      Double.NaN
+    else
+      m2 / (n - 1)
+  }
+
+  def stdev: Double = math.sqrt(variance)
+
+  def sampleStdev: Double = math.sqrt(sampleVariance)
+
+  override def toString: String = {
+    "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
+  }
+}
+
+object StatCounter {
+  def apply(values: TraversableOnce[Double]) = new StatCounter(values)
+
+  def apply(values: Double*) = new StatCounter(values)
+}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
index 60290d14cab69427a771004f5e1270f00708eaa4..3d170a6e22ef0cec8544454e5622d4432cb0c78c 100644
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ b/core/src/test/scala/spark/CacheTrackerSuite.scala
@@ -1,95 +1,103 @@
 package spark
 
 import org.scalatest.FunSuite
-import collection.mutable.HashMap
+
+import scala.collection.mutable.HashMap
+
+import akka.actor._
+import akka.actor.Actor
+import akka.actor.Actor._
 
 class CacheTrackerSuite extends FunSuite {
 
   test("CacheTrackerActor slave initialization & cache status") {
-    System.setProperty("spark.master.port", "1345")
+    //System.setProperty("spark.master.port", "1345")
     val initialSize = 2L << 20
 
-    val tracker = new CacheTrackerActor
+    val tracker = actorOf(new CacheTrackerActor)
     tracker.start()
 
-    tracker !? SlaveCacheStarted("host001", initialSize)
+    tracker !! SlaveCacheStarted("host001", initialSize)
 
-    assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L)))
+    assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L)))
 
-    tracker !? StopCacheTracker
+    tracker !! StopCacheTracker
   }
 
   test("RegisterRDD") {
-    System.setProperty("spark.master.port", "1345")
+    //System.setProperty("spark.master.port", "1345")
     val initialSize = 2L << 20
 
-    val tracker = new CacheTrackerActor
+    val tracker = actorOf(new CacheTrackerActor)
     tracker.start()
 
-    tracker !? SlaveCacheStarted("host001", initialSize)
+    tracker !! SlaveCacheStarted("host001", initialSize)
 
-    tracker !? RegisterRDD(1, 3)
-    tracker !? RegisterRDD(2, 1)
+    tracker !! RegisterRDD(1, 3)
+    tracker !! RegisterRDD(2, 1)
 
-    assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List())))
+    assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List())))
 
-    tracker !? StopCacheTracker
+    tracker !! StopCacheTracker
   }
 
   test("AddedToCache") {
-    System.setProperty("spark.master.port", "1345")
+    //System.setProperty("spark.master.port", "1345")
     val initialSize = 2L << 20
 
-    val tracker = new CacheTrackerActor
+    val tracker = actorOf(new CacheTrackerActor)
     tracker.start()
 
-    tracker !? SlaveCacheStarted("host001", initialSize)
+    tracker !! SlaveCacheStarted("host001", initialSize)
 
-    tracker !? RegisterRDD(1, 2)
-    tracker !? RegisterRDD(2, 1)
+    tracker !! RegisterRDD(1, 2)
+    tracker !! RegisterRDD(2, 1)
 
-    tracker !? AddedToCache(1, 0, "host001", 2L << 15)
-    tracker !? AddedToCache(1, 1, "host001", 2L << 11)
-    tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+    tracker !! AddedToCache(1, 0, "host001", 2L << 15)
+    tracker !! AddedToCache(1, 1, "host001", 2L << 11)
+    tracker !! AddedToCache(2, 0, "host001", 3L << 10)
 
-    assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
+    assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
 
-    assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+    assert(getCacheLocations(tracker) === 
+      Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
 
-    tracker !? StopCacheTracker
+    tracker !! StopCacheTracker
   }
 
   test("DroppedFromCache") {
-    System.setProperty("spark.master.port", "1345")
+    //System.setProperty("spark.master.port", "1345")
     val initialSize = 2L << 20
 
-    val tracker = new CacheTrackerActor
+    val tracker = actorOf(new CacheTrackerActor)
     tracker.start()
 
-    tracker !? SlaveCacheStarted("host001", initialSize)
+    tracker !! SlaveCacheStarted("host001", initialSize)
 
-    tracker !? RegisterRDD(1, 2)
-    tracker !? RegisterRDD(2, 1)
+    tracker !! RegisterRDD(1, 2)
+    tracker !! RegisterRDD(2, 1)
 
-    tracker !? AddedToCache(1, 0, "host001", 2L << 15)
-    tracker !? AddedToCache(1, 1, "host001", 2L << 11)
-    tracker !? AddedToCache(2, 0, "host001", 3L << 10)
+    tracker !! AddedToCache(1, 0, "host001", 2L << 15)
+    tracker !! AddedToCache(1, 1, "host001", 2L << 11)
+    tracker !! AddedToCache(2, 0, "host001", 3L << 10)
 
-    assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L)))
-    assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
+    assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L)))
+    assert(getCacheLocations(tracker) ===
+      Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
 
-    tracker !? DroppedFromCache(1, 1, "host001", 2L << 11)
+    tracker !! DroppedFromCache(1, 1, "host001", 2L << 11)
 
-    assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L)))
-    assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
+    assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L)))
+    assert(getCacheLocations(tracker) ===
+      Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
 
-    tracker !? StopCacheTracker
+    tracker !! StopCacheTracker
   }
 
   /**
    * Helper function to get cacheLocations from CacheTracker
    */
-  def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match {
+  def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match {
     case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map {
       case (i, arr) => (i -> arr.toList)
     }
diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala
index 0e6820cbdcf31b0135d57283ef6b2b78681a5569..54421225d881e9b9e1f84b0cd1373498e64fa749 100644
--- a/core/src/test/scala/spark/MesosSchedulerSuite.scala
+++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala
@@ -2,6 +2,8 @@ package spark
 
 import org.scalatest.FunSuite
 
+import spark.scheduler.mesos.MesosScheduler
+
 class MesosSchedulerSuite extends FunSuite {
   test("memoryStringToMb"){
 
diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala
index f31251e509a9c14460a573f7584f42d206362e4e..1ac4737f046d35294a89e7165692fe10f809c966 100644
--- a/core/src/test/scala/spark/UtilsSuite.scala
+++ b/core/src/test/scala/spark/UtilsSuite.scala
@@ -2,7 +2,7 @@ package spark
 
 import org.scalatest.FunSuite
 import java.io.{ByteArrayOutputStream, ByteArrayInputStream}
-import util.Random
+import scala.util.Random
 
 class UtilsSuite extends FunSuite {
 
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 08c5a990b489ad3da43c60a73d61d7d3c5e48947..a2faf7399c44225a3df71f13d1fe330674a29a39 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -33,6 +33,7 @@ object SparkBuild extends Build {
       "org.scalatest" %% "scalatest" % "1.6.1" % "test",
       "org.scala-tools.testing" %% "scalacheck" % "1.9" % "test"
     ),
+    parallelExecution in Test := false,
     /* Workaround for issue #206 (fixed after SBT 0.11.0) */
     watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
       const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }
@@ -57,8 +58,12 @@ object SparkBuild extends Build {
       "asm" % "asm-all" % "3.3.1",
       "com.google.protobuf" % "protobuf-java" % "2.4.1",
       "de.javakaffee" % "kryo-serializers" % "0.9",
+      "se.scalablesolutions.akka" % "akka-actor" % "1.3.1",
+      "se.scalablesolutions.akka" % "akka-remote" % "1.3.1",
+      "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1",
       "org.jboss.netty" % "netty" % "3.2.6.Final",
-      "it.unimi.dsi" % "fastutil" % "6.4.2"
+      "it.unimi.dsi" % "fastutil" % "6.4.4",
+      "colt" % "colt" % "1.2.0"
     )
   ) ++ assemblySettings ++ Seq(test in assembly := {})
 
@@ -68,8 +73,7 @@ object SparkBuild extends Build {
   ) ++ assemblySettings ++ Seq(test in assembly := {})
 
   def examplesSettings = sharedSettings ++ Seq(
-    name := "spark-examples",
-    libraryDependencies += "colt" % "colt" % "1.2.0"
+    name := "spark-examples"
   )
 
   def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")