From 06aac8a88902f10182830259197d83adbafea516 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Sat, 3 Apr 2010 23:44:55 -0700
Subject: [PATCH] Imported changes from old repository (mostly Mosharaf's work,
 plus some fault tolerance code).

---
 run                                  |   2 +-
 src/examples/BroadcastTest.scala     |  24 +
 src/examples/SparkALS.scala          |  12 +-
 src/scala/spark/Broadcast.scala      | 798 +++++++++++++++++++++++++++
 src/scala/spark/Cached.scala         | 110 ----
 src/scala/spark/Executor.scala       |   4 +-
 src/scala/spark/HdfsFile.scala       |  38 +-
 src/scala/spark/NexusScheduler.scala | 317 ++++++-----
 src/scala/spark/SparkContext.scala   |   5 +-
 src/scala/spark/Task.scala           |   4 +-
 src/test/spark/repl/ReplSuite.scala  |  18 +-
 11 files changed, 1039 insertions(+), 293 deletions(-)
 create mode 100644 src/examples/BroadcastTest.scala
 create mode 100644 src/scala/spark/Broadcast.scala
 delete mode 100644 src/scala/spark/Cached.scala

diff --git a/run b/run
index 456615fba4..c1156892ad 100755
--- a/run
+++ b/run
@@ -4,7 +4,7 @@
 FWDIR=`dirname $0`
 
 # Set JAVA_OPTS to be able to load libnexus.so and set various other misc options
-JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx750m"
+export JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx2000m -Dspark.broadcast.masterHostAddress=127.0.0.1 -Dspark.broadcast.masterListenPort=11111 -Dspark.broadcast.blockSize=1024 -Dspark.broadcast.maxRetryCount=2 -Dspark.broadcast.serverSocketTimout=50000 -Dspark.broadcast.dualMode=false"
 if [ -e $FWDIR/conf/java-opts ] ; then
   JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`"
 fi
diff --git a/src/examples/BroadcastTest.scala b/src/examples/BroadcastTest.scala
new file mode 100644
index 0000000000..7764013413
--- /dev/null
+++ b/src/examples/BroadcastTest.scala
@@ -0,0 +1,24 @@
+import spark.SparkContext
+
+object BroadcastTest {
+  def main(args: Array[String]) {
+    if (args.length == 0) {
+      System.err.println("Usage: BroadcastTest <host> [<slices>]")
+      System.exit(1)
+    }  
+    val spark = new SparkContext(args(0), "Broadcast Test")
+    val slices = if (args.length > 1) args(1).toInt else 2
+    val num = if (args.length > 2) args(2).toInt else 1000000
+
+    var arr = new Array[Int](num)
+    for (i <- 0 until arr.length) 
+      arr(i) = i
+    
+    val barr = spark.broadcast(arr)
+    spark.parallelize(1 to 10, slices).foreach {
+      println("in task: barr = " + barr)
+      i => println(barr.value.size)
+    }
+  }
+}
+
diff --git a/src/examples/SparkALS.scala b/src/examples/SparkALS.scala
index 2fd58ed3a5..38dd0e665d 100644
--- a/src/examples/SparkALS.scala
+++ b/src/examples/SparkALS.scala
@@ -119,18 +119,18 @@ object SparkALS {
 
     // Iteratively update movies then users
     val Rc  = spark.broadcast(R)
-    var msb = spark.broadcast(ms)
-    var usb = spark.broadcast(us)
+    var msc = spark.broadcast(ms)
+    var usc = spark.broadcast(us)
     for (iter <- 1 to ITERATIONS) {
       println("Iteration " + iter + ":")
       ms = spark.parallelize(0 until M, slices)
-                .map(i => updateMovie(i, msb.value(i), usb.value, Rc.value))
+                .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
                 .toArray
-      msb = spark.broadcast(ms) // Re-broadcast ms because it was updated
+      msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
       us = spark.parallelize(0 until U, slices)
-                .map(i => updateUser(i, usb.value(i), msb.value, Rc.value))
+                .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
                 .toArray
-      usb = spark.broadcast(us) // Re-broadcast us because it was updated
+      usc = spark.broadcast(us) // Re-broadcast us because it was updated
       println("RMSE = " + rmse(R, ms, us))
       println()
     }
diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala
new file mode 100644
index 0000000000..2da5e28a0a
--- /dev/null
+++ b/src/scala/spark/Broadcast.scala
@@ -0,0 +1,798 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{UUID, PriorityQueue, Comparator}
+
+import com.google.common.collect.MapMaker
+
+import java.util.concurrent.{Executors, ExecutorService}
+
+import scala.actors.Actor
+import scala.actors.Actor._
+
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
+
+import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
+
+@serializable
+trait BroadcastRecipe {
+  val uuid = UUID.randomUUID
+
+  // We cannot have an abstract readObject here due to some weird issues with 
+  // readObject having to be 'private' in sub-classes. Possibly a Scala bug!
+  def sendBroadcast: Unit
+
+  override def toString = "spark.Broadcast(" + uuid + ")"
+}
+
+// TODO: Should think about storing in HDFS in the future
+// TODO: Right, now no parallelization between multiple broadcasts
+@serializable
+class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) 
+  extends BroadcastRecipe {
+  
+  def value = value_
+
+  BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) }
+   
+  if (!local) { sendBroadcast }
+  
+  def sendBroadcast () {
+    // Create a variableInfo object and store it in valueInfos
+    var variableInfo = blockifyObject (value_, BroadcastCS.blockSize)   
+    // TODO: Even though this part is not in use now, there is problem in the 
+    // following statement. Shouldn't use constant port and hostAddress anymore?
+    // val masterSource = 
+    //   new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort, 
+    //     variableInfo.totalBlocks, variableInfo.totalBytes, 0) 
+    // variableInfo.pqOfSources.add (masterSource)
+    
+    BroadcastCS.synchronized { 
+      // BroadcastCS.valueInfos.put (uuid, variableInfo)
+      
+      // TODO: Not using variableInfo in current implementation. Manually
+      // setting all the variables inside BroadcastCS object
+      
+      BroadcastCS.initializeVariable (variableInfo)      
+    }
+    
+    // Now store a persistent copy in HDFS, just in case 
+    val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
+    out.writeObject (value_)
+    out.close
+  }
+  
+  private def readObject (in: ObjectInputStream) {
+    in.defaultReadObject
+    BroadcastCS.synchronized {
+      val cachedVal = BroadcastCS.values.get (uuid)
+      if (cachedVal != null) {
+        value_ = cachedVal.asInstanceOf[T]
+      } else {
+        // Only a single worker (the first one) in the same node can ever be 
+        // here. The rest will always get the value ready 
+        val start = System.nanoTime        
+
+        val retByteArray = BroadcastCS.receiveBroadcast (uuid)
+        // If does not succeed, then get from HDFS copy
+        if (retByteArray != null) {
+          value_ = byteArrayToObject[T] (retByteArray)
+          BroadcastCS.values.put (uuid, value_)
+          // val variableInfo = blockifyObject (value_, BroadcastCS.blockSize)    
+          // BroadcastCS.valueInfos.put (uuid, variableInfo)
+        }  else {
+          val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
+          value_ = fileIn.readObject.asInstanceOf[T]
+          BroadcastCH.values.put(uuid, value_)
+          fileIn.close
+        } 
+        
+        val time = (System.nanoTime - start) / 1e9
+        println( System.currentTimeMillis + ": " +  "Reading Broadcasted variable " + uuid + " took " + time + " s")                  
+      }
+    }
+  }
+  
+  private def blockifyObject (obj: T, blockSize: Int): VariableInfo = {
+    val baos = new ByteArrayOutputStream
+    val oos = new ObjectOutputStream (baos)
+    oos.writeObject (obj)
+    oos.close
+    baos.close
+    val byteArray = baos.toByteArray
+    val bais = new ByteArrayInputStream (byteArray)
+    
+    var blockNum = (byteArray.length / blockSize) 
+    if (byteArray.length % blockSize != 0) 
+      blockNum += 1
+      
+    var retVal = new Array[BroadcastBlock] (blockNum)
+    var blockID = 0
+
+    // TODO: What happens in byteArray.length == 0 => blockNum == 0
+    for (i <- 0 until (byteArray.length, blockSize)) {    
+      val thisBlockSize = Math.min (blockSize, byteArray.length - i)
+      var tempByteArray = new Array[Byte] (thisBlockSize)
+      val hasRead = bais.read (tempByteArray, 0, thisBlockSize)
+      
+      retVal (blockID) = new BroadcastBlock (blockID, tempByteArray)
+      blockID += 1
+    } 
+    bais.close
+
+    var variableInfo = VariableInfo (retVal, blockNum, byteArray.length)
+    variableInfo.hasBlocks = blockNum
+    
+    return variableInfo
+  }  
+  
+  private def byteArrayToObject[A] (bytes: Array[Byte]): A = {
+    val in = new ObjectInputStream (new ByteArrayInputStream (bytes))
+    val retVal = in.readObject.asInstanceOf[A]
+    in.close
+    return retVal
+  }
+  
+  private def getByteArrayOutputStream (obj: T): ByteArrayOutputStream = {
+    val bOut = new ByteArrayOutputStream
+    val out = new ObjectOutputStream (bOut)
+    out.writeObject (obj)
+    out.close
+    bOut.close
+    return bOut
+  }  
+}
+
+@serializable 
+class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean) 
+  extends BroadcastRecipe {
+  
+  def value = value_
+
+  BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) }
+
+  if (!local) { sendBroadcast }
+
+  def sendBroadcast () {
+    val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid))
+    out.writeObject (value_)
+    out.close
+  }
+
+  // Called by Java when deserializing an object
+  private def readObject(in: ObjectInputStream) {
+    in.defaultReadObject
+    BroadcastCH.synchronized {
+      val cachedVal = BroadcastCH.values.get(uuid)
+      if (cachedVal != null) {
+        value_ = cachedVal.asInstanceOf[T]
+      } else {
+        val start = System.nanoTime
+        
+        val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid))
+        value_ = fileIn.readObject.asInstanceOf[T]
+        BroadcastCH.values.put(uuid, value_)
+        fileIn.close
+        
+        val time = (System.nanoTime - start) / 1e9
+        println( System.currentTimeMillis + ": " +  "Reading Broadcasted variable " + uuid + " took " + time + " s")
+      }
+    }
+  }
+}
+
+@serializable
+case class SourceInfo (val hostAddress: String, val listenPort: Int, 
+  val totalBlocks: Int, val totalBytes: Int, val replicaID: Int)  
+  extends Comparable [SourceInfo]{
+
+  var currentLeechers = 0
+  var receptionFailed = false
+  
+  def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
+}
+
+@serializable
+case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { }
+
+@serializable
+case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], 
+  val totalBlocks: Int, val totalBytes: Int) {
+  
+  @transient var hasBlocks = 0
+
+  val listenPortLock = new AnyRef
+  val totalBlocksLock = new AnyRef
+  val hasBlocksLock = new AnyRef
+  
+  @transient var pqOfSources = new PriorityQueue[SourceInfo]
+} 
+
+private object Broadcast {
+  private var initialized = false 
+
+  // Will be called by SparkContext or Executor before using Broadcast
+  // Calls all other initializers here
+  def initialize (isMaster: Boolean) {
+    synchronized {
+      if (!initialized) {
+        // Initialization for CentralizedHDFSBroadcast
+        BroadcastCH.initialize 
+        // Initialization for ChainedStreamingBroadcast
+        BroadcastCS.initialize (isMaster)
+        
+        initialized = true
+      }
+    }
+  }
+}
+
+private object BroadcastCS {
+  val values = new MapMaker ().softValues ().makeMap[UUID, Any]
+  // val valueInfos = new MapMaker ().softValues ().makeMap[UUID, Any]  
+
+  // private var valueToPort = Map[UUID, Int] ()
+
+  private var initialized = false
+  private var isMaster_ = false
+
+  private var masterHostAddress_ = "127.0.0.1"
+  private var masterListenPort_ : Int = 11111
+  private var blockSize_ : Int = 512 * 1024
+  private var maxRetryCount_ : Int = 2
+  private var serverSocketTimout_ : Int = 50000
+  private var dualMode_ : Boolean = false
+
+  private val hostAddress = InetAddress.getLocalHost.getHostAddress
+  private var listenPort = -1 
+  
+  var arrayOfBlocks: Array[BroadcastBlock] = null
+  var totalBytes = -1
+  var totalBlocks = -1
+  var hasBlocks = 0
+
+  val listenPortLock = new Object
+  val totalBlocksLock = new Object
+  val hasBlocksLock = new Object
+  
+  var pqOfSources = new PriorityQueue[SourceInfo]
+  
+  private var serveMR: ServeMultipleRequests = null 
+  private var guideMR: GuideMultipleRequests = null 
+
+  def initialize (isMaster__ : Boolean) {
+    synchronized {
+      if (!initialized) {
+        masterHostAddress_ = 
+          System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1")
+        masterListenPort_ = 
+          System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt
+        blockSize_ = 
+          System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024
+        maxRetryCount_ = 
+          System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt          
+        serverSocketTimout_ = 
+          System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt          
+        dualMode_ = 
+          System.getProperty ("spark.broadcast.dualMode", "false").toBoolean          
+
+        isMaster_ = isMaster__        
+        
+        if (isMaster) {
+          guideMR = new GuideMultipleRequests
+          // guideMR.setDaemon (true)
+          guideMR.start
+          println (System.currentTimeMillis + ": " +  "GuideMultipleRequests started")
+        }        
+        serveMR = new ServeMultipleRequests
+        // serveMR.setDaemon (true)
+        serveMR.start
+        
+        println (System.currentTimeMillis + ": " +  "ServeMultipleRequests started")
+        
+        println (System.currentTimeMillis + ": " +  "BroadcastCS object has been initialized")
+                  
+        initialized = true
+      }
+    }
+  }
+  
+  // TODO: This should change in future implementation. 
+  // Called from the Master constructor to setup states for this particular that
+  // is being broadcasted
+  def initializeVariable (variableInfo: VariableInfo) {
+    arrayOfBlocks = variableInfo.arrayOfBlocks
+    totalBytes = variableInfo.totalBytes
+    totalBlocks = variableInfo.totalBlocks
+    hasBlocks = variableInfo.totalBlocks
+    
+    // listenPort should already be valid
+    assert (listenPort != -1)
+    
+    pqOfSources = new PriorityQueue[SourceInfo]
+    val masterSource_0 = 
+      new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) 
+    BroadcastCS.pqOfSources.add (masterSource_0)
+    // Add one more time to have two replicas of any seeds in the PQ
+    if (BroadcastCS.dualMode) {
+      val masterSource_1 = 
+        new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) 
+      BroadcastCS.pqOfSources.add (masterSource_1)
+    }
+  }
+  
+  def masterHostAddress = masterHostAddress_
+  def masterListenPort = masterListenPort_
+  def blockSize = blockSize_
+  def maxRetryCount = maxRetryCount_
+  def serverSocketTimout = serverSocketTimout_
+  def dualMode = dualMode_
+
+  def isMaster = isMaster_ 
+
+  def receiveBroadcast (variableUUID: UUID): Array[Byte] = {  
+    // Wait until hostAddress and listenPort are created by the 
+    // ServeMultipleRequests thread
+    // NO need to wait; ServeMultipleRequests is created much further ahead
+    while (listenPort == -1) { 
+      listenPortLock.synchronized {
+        listenPortLock.wait 
+      }
+    } 
+
+    // Connect and receive broadcast from the specified source, retrying the
+    // specified number of times in case of failures
+    var retriesLeft = BroadcastCS.maxRetryCount
+    var retByteArray: Array[Byte] = null
+    do {
+      // Connect to Master and send this worker's Information
+      val clientSocketToMaster = 
+        new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort)  
+      println (System.currentTimeMillis + ": " +  "Connected to Master's guiding object")
+      // TODO: Guiding object connection is reusable
+      val oisMaster = 
+        new ObjectInputStream (clientSocketToMaster.getInputStream)
+      val oosMaster = 
+        new ObjectOutputStream (clientSocketToMaster.getOutputStream)
+      
+      oosMaster.writeObject(new SourceInfo (hostAddress, listenPort, -1, -1, 0))
+      oosMaster.flush
+
+      // Receive source information from Master        
+      var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo]
+      totalBlocks = sourceInfo.totalBlocks
+      arrayOfBlocks = new Array[BroadcastBlock] (totalBlocks)
+      totalBlocksLock.synchronized {
+        totalBlocksLock.notifyAll
+      }
+      totalBytes = sourceInfo.totalBytes
+      
+      println (System.currentTimeMillis + ": " +  "Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort)    
+
+      retByteArray = receiveSingleTransmission (sourceInfo)
+      
+      println (System.currentTimeMillis + ": " +  "I got this from receiveSingleTransmission: " + retByteArray)
+
+      // TODO: Update sourceInfo to add error notifactions for Master
+      if (retByteArray == null) { sourceInfo.receptionFailed = true }
+      
+      // TODO: Supposed to update values here, but we don't support advanced
+      // statistics right now. Master can handle leecherCount by itself.
+
+      // Send back statistics to the Master
+      oosMaster.writeObject (sourceInfo) 
+    
+      oisMaster.close
+      oosMaster.close
+      clientSocketToMaster.close                    
+      
+      retriesLeft -= 1
+    } while (retriesLeft > 0 && retByteArray == null)
+    
+    return retByteArray
+  }
+
+  // Tries to receive broadcast from the Master and returns Boolean status.
+  // This might be called multiple times to retry a defined number of times.
+  private def receiveSingleTransmission(sourceInfo: SourceInfo): Array[Byte] = {
+    var clientSocketToSource: Socket = null    
+    var oisSource: ObjectInputStream = null
+    var oosSource: ObjectOutputStream = null
+    
+    var retByteArray:Array[Byte] = null
+    
+    try {
+      // Connect to the source to get the object itself
+      clientSocketToSource = 
+        new Socket (sourceInfo.hostAddress, sourceInfo.listenPort)        
+      oosSource = 
+        new ObjectOutputStream (clientSocketToSource.getOutputStream)
+      oisSource = 
+        new ObjectInputStream (clientSocketToSource.getInputStream)
+        
+      println (System.currentTimeMillis + ": " +  "Inside receiveSingleTransmission")
+      println (System.currentTimeMillis + ": " +  "totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+      retByteArray = new Array[Byte] (totalBytes)
+      for (i <- 0 until totalBlocks) {
+        val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+        System.arraycopy (bcBlock.byteArray, 0, retByteArray, 
+          i * BroadcastCS.blockSize, bcBlock.byteArray.length)
+        arrayOfBlocks(hasBlocks) = bcBlock
+        hasBlocks += 1
+        hasBlocksLock.synchronized {
+          hasBlocksLock.notifyAll
+        }
+        println (System.currentTimeMillis + ": " +  "Received block: " + i + " " + bcBlock)
+      } 
+      assert (hasBlocks == totalBlocks)
+      println (System.currentTimeMillis + ": " +  "After the receive loop")
+    } catch {
+      case e: Exception => { 
+        retByteArray = null 
+        println (System.currentTimeMillis + ": " +  "receiveSingleTransmission had a " + e)
+      }
+    } finally {    
+      if (oisSource != null) { oisSource.close }
+      if (oosSource != null) { 
+        oosSource.close 
+      }
+      if (clientSocketToSource != null) { clientSocketToSource.close }
+    }
+          
+    return retByteArray
+  } 
+  
+  class TrackMultipleValues extends Thread {
+    override def run = {
+      var threadPool = Executors.newCachedThreadPool
+      var serverSocket: ServerSocket = null
+      
+      serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
+      println (System.currentTimeMillis + ": " +  "TrackMultipleVariables" + serverSocket + " " + listenPort)
+      
+      var keepAccepting = true
+      try {
+        while (true) {
+          var clientSocket: Socket = null
+          try {
+            serverSocket.setSoTimeout (serverSocketTimout)
+            clientSocket = serverSocket.accept
+          } catch {
+            case e: Exception => { 
+              println ("TrackMultipleValues Timeout. Stopping listening...") 
+              keepAccepting = false 
+            }
+          }
+          println (System.currentTimeMillis + ": " +  "TrackMultipleValues:Got new request:" + clientSocket)
+          if (clientSocket != null) {
+            try {            
+              threadPool.execute (new Runnable {
+                def run = {
+                  val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+                  val ois = new ObjectInputStream (clientSocket.getInputStream)
+                  try {
+                    val variableUUID = ois.readObject.asInstanceOf[UUID]
+                    var contactPort = 0
+                    // TODO: Add logic and data structures to find out UUID->port
+                    // mapping. 0 = missed the broadcast, read from HDFS; <0 = 
+                    // Haven't started yet, wait & retry; >0 = Read from this port
+                    oos.writeObject (contactPort)
+                  } catch {
+                    case e: Exception => { }
+                  } finally {
+                    ois.close
+                    oos.close
+                    clientSocket.close
+                  }
+                }
+              })
+            } catch {
+              // In failure, close the socket here; else, the thread will close it
+              case ioe: IOException => clientSocket.close
+            }
+          }
+        }
+      } finally {
+        serverSocket.close
+      }      
+    }
+  }
+  
+  class TrackSingleValue {
+    
+  }
+  
+  class GuideMultipleRequests extends Thread {
+    override def run = {
+      var threadPool = Executors.newCachedThreadPool
+      var serverSocket: ServerSocket = null
+
+      serverSocket = new ServerSocket (BroadcastCS.masterListenPort)
+      // listenPort = BroadcastCS.masterListenPort
+      println (System.currentTimeMillis + ": " +  "GuideMultipleRequests" + serverSocket + " " + listenPort)
+      
+      var keepAccepting = true
+      try {
+        while (keepAccepting) {
+          var clientSocket: Socket = null
+          try {
+            serverSocket.setSoTimeout (serverSocketTimout)
+            clientSocket = serverSocket.accept
+          } catch {
+            case e: Exception => { 
+              println ("GuideMultipleRequests Timeout. Stopping listening...") 
+              keepAccepting = false 
+            }
+          }
+          if (clientSocket != null) {
+            println (System.currentTimeMillis + ": " +  "Guide:Accepted new client connection:" + clientSocket)
+            try {            
+              threadPool.execute (new GuideSingleRequest (clientSocket))
+            } catch {
+              // In failure, close the socket here; else, the thread will close it
+              case ioe: IOException => clientSocket.close
+            }
+          }
+        }
+      } finally {
+        serverSocket.close
+      }
+    }
+    
+    class GuideSingleRequest (val clientSocket: Socket) extends Runnable {
+      private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+      private val ois = new ObjectInputStream (clientSocket.getInputStream)
+
+      private var selectedSourceInfo: SourceInfo = null
+      private var thisWorkerInfo:SourceInfo = null
+      
+      def run = {
+        try {
+          println (System.currentTimeMillis + ": " +  "new GuideSingleRequest is running")
+          // Connecting worker is sending in its hostAddress and listenPort it will 
+          // be listening to. ReplicaID is 0 and other fields are invalid (-1)
+          var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+          
+          // Select a suitable source and send it back to the worker
+          selectedSourceInfo = selectSuitableSource (sourceInfo)
+          println (System.currentTimeMillis + ": " +  "Sending selectedSourceInfo:" + selectedSourceInfo)
+          oos.writeObject (selectedSourceInfo)
+          oos.flush
+
+          // Add this new (if it can finish) source to the PQ of sources
+          thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress, 
+            sourceInfo.listenPort, totalBlocks, totalBytes, 0)  
+          println (System.currentTimeMillis + ": " +  "Adding possible new source to pqOfSources: " + thisWorkerInfo)    
+          pqOfSources.synchronized {
+            pqOfSources.add (thisWorkerInfo)
+          }
+
+          // Wait till the whole transfer is done. Then receive and update source 
+          // statistics in pqOfSources
+          sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+          pqOfSources.synchronized {
+            // This should work since SourceInfo is a case class
+            assert (pqOfSources.contains (selectedSourceInfo))
+            
+            // Remove first
+            pqOfSources.remove (selectedSourceInfo)        
+            // TODO: Removing a source based on just one failure notification!
+            // Update leecher count and put it back in IF reception succeeded
+            if (!sourceInfo.receptionFailed) {          
+              selectedSourceInfo.currentLeechers -= 1
+              pqOfSources.add (selectedSourceInfo)
+              
+              // No need to find and update thisWorkerInfo, but add its replica
+              if (BroadcastCS.dualMode) {
+                pqOfSources.add (new SourceInfo (thisWorkerInfo.hostAddress, 
+                  thisWorkerInfo.listenPort, totalBlocks, totalBytes, 1))
+              }              
+            }                        
+          }      
+        } catch {
+          // If something went wrong, e.g., the worker at the other end died etc. 
+          // then close everything up
+          case e: Exception => { 
+            // Assuming that exception caused due to receiver worker failure
+            // Remove failed worker from pqOfSources and update leecherCount of 
+            // corresponding source worker
+            pqOfSources.synchronized {
+              if (selectedSourceInfo != null) {
+                // Remove first
+                pqOfSources.remove (selectedSourceInfo)        
+                // Update leecher count and put it back in
+                selectedSourceInfo.currentLeechers -= 1
+                pqOfSources.add (selectedSourceInfo)
+              }
+              
+              // Remove thisWorkerInfo
+              if (pqOfSources != null) { pqOfSources.remove (thisWorkerInfo) }
+            }      
+          }
+        } finally {
+          ois.close
+          oos.close
+          clientSocket.close
+        }
+      }
+      
+      // TODO: If a worker fails to get the broadcasted variable from a source and
+      // comes back to Master, this function might choose the worker itself as a 
+      // source tp create a dependency cycle (this worker was put into pqOfSources 
+      // as a streming source when it first arrived). The length of this cycle can
+      // be arbitrarily long. 
+      private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+        // Select one with the lowest number of leechers
+        pqOfSources.synchronized {
+          // take is a blocking call removing the element from PQ
+          var selectedSource = pqOfSources.poll
+          assert (selectedSource != null) 
+          // Update leecher count
+          selectedSource.currentLeechers += 1
+          // Add it back and then return
+          pqOfSources.add (selectedSource)
+          return selectedSource
+        }
+      }
+    }    
+  }
+
+  class ServeMultipleRequests extends Thread {
+    override def run = {
+      var threadPool = Executors.newCachedThreadPool
+      var serverSocket: ServerSocket = null
+
+      serverSocket = new ServerSocket (0) 
+      listenPort = serverSocket.getLocalPort
+      println (System.currentTimeMillis + ": " +  "ServeMultipleRequests" + serverSocket + " " + listenPort)
+      
+      listenPortLock.synchronized {
+        listenPortLock.notifyAll
+      }
+            
+      var keepAccepting = true
+      try {
+        while (keepAccepting) {
+          var clientSocket: Socket = null
+          try {
+            serverSocket.setSoTimeout (serverSocketTimout)
+            clientSocket = serverSocket.accept
+          } catch {
+            case e: Exception => { 
+              println ("ServeMultipleRequests Timeout. Stopping listening...") 
+              keepAccepting = false 
+            }
+          }
+          if (clientSocket != null) {
+            println (System.currentTimeMillis + ": " +  "Serve:Accepted new client connection:" + clientSocket)
+            try {            
+              threadPool.execute (new ServeSingleRequest (clientSocket))
+            } catch {
+              // In failure, close socket here; else, the thread will close it
+              case ioe: IOException => clientSocket.close
+            }
+          }
+        }
+      } finally {
+        serverSocket.close
+      }
+    }
+    
+    class ServeSingleRequest (val clientSocket: Socket) extends Runnable {
+      private val oos = new ObjectOutputStream (clientSocket.getOutputStream)
+      private val ois = new ObjectInputStream (clientSocket.getInputStream)
+      
+      def run  = {
+        try {
+          println (System.currentTimeMillis + ": " +  "new ServeSingleRequest is running")
+          sendObject
+        } catch {
+          // TODO: Need to add better exception handling here
+          // If something went wrong, e.g., the worker at the other end died etc. 
+          // then close everything up
+          case e: Exception => { 
+            println (System.currentTimeMillis + ": " +  "ServeSingleRequest had a " + e)
+          }
+        } finally {
+          println (System.currentTimeMillis + ": " +  "ServeSingleRequest is closing streams and sockets")
+          ois.close
+          oos.close
+          clientSocket.close
+        }
+      }
+
+      private def sendObject = {
+        // Wait till receiving the SourceInfo from Master
+        while (totalBlocks == -1) { 
+          totalBlocksLock.synchronized {
+            totalBlocksLock.wait
+          }
+        }
+
+        for (i <- 0 until totalBlocks) {
+          while (i == hasBlocks) { 
+            hasBlocksLock.synchronized {
+              hasBlocksLock.wait
+            }
+          }
+          try {
+            oos.writeObject (arrayOfBlocks(i))
+            oos.flush
+          } catch {
+            case e: Exception => { }
+          }
+          println (System.currentTimeMillis + ": " +  "Send block: " + i + " " + arrayOfBlocks(i))
+        }
+      }    
+    } 
+    
+  }
+}
+
+private object BroadcastCH {
+  val values = new MapMaker ().softValues ().makeMap[UUID, Any]
+
+  private var initialized = false
+
+  private var fileSystem: FileSystem = null
+  private var workDir: String = null
+  private var compress: Boolean = false
+  private var bufferSize: Int = 65536
+
+  def initialize () {
+    synchronized {
+      if (!initialized) {
+        bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+        val dfs = System.getProperty("spark.dfs", "file:///")
+        if (!dfs.startsWith("file://")) {
+          val conf = new Configuration()
+          conf.setInt("io.file.buffer.size", bufferSize)
+          val rep = System.getProperty("spark.dfs.replication", "3").toInt
+          conf.setInt("dfs.replication", rep)
+          fileSystem = FileSystem.get(new URI(dfs), conf)
+        }
+        workDir = System.getProperty("spark.dfs.workdir", "/tmp")
+        compress = System.getProperty("spark.compress", "false").toBoolean
+
+        initialized = true
+      }
+    }
+  }
+
+  private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid)
+
+  def openFileForReading(uuid: UUID): InputStream = {
+    val fileStream = if (fileSystem != null) {
+      fileSystem.open(getPath(uuid))
+    } else {
+      // Local filesystem
+      new FileInputStream(getPath(uuid).toString)
+    }
+    if (compress)
+      new LZFInputStream(fileStream) // LZF stream does its own buffering
+    else if (fileSystem == null)
+      new BufferedInputStream(fileStream, bufferSize)
+    else
+      fileStream // Hadoop streams do their own buffering
+  }
+
+  def openFileForWriting(uuid: UUID): OutputStream = {
+    val fileStream = if (fileSystem != null) {
+      fileSystem.create(getPath(uuid))
+    } else {
+      // Local filesystem
+      new FileOutputStream(getPath(uuid).toString)
+    }
+    if (compress)
+      new LZFOutputStream(fileStream) // LZF stream does its own buffering
+    else if (fileSystem == null)
+      new BufferedOutputStream(fileStream, bufferSize)
+    else
+      fileStream // Hadoop streams do their own buffering
+  }
+}
diff --git a/src/scala/spark/Cached.scala b/src/scala/spark/Cached.scala
deleted file mode 100644
index 8113340e1f..0000000000
--- a/src/scala/spark/Cached.scala
+++ /dev/null
@@ -1,110 +0,0 @@
-package spark
-
-import java.io._
-import java.net.URI
-import java.util.UUID
-
-import com.google.common.collect.MapMaker
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
-
-import spark.compress.lzf.{LZFInputStream, LZFOutputStream}
-
-@serializable class Cached[T](@transient var value_ : T, local: Boolean) {
-  val uuid = UUID.randomUUID()
-  def value = value_
-
-  Cache.synchronized { Cache.values.put(uuid, value_) }
-
-  if (!local) writeCacheFile()
-
-  private def writeCacheFile() {
-    val out = new ObjectOutputStream(Cache.openFileForWriting(uuid))
-    out.writeObject(value_)
-    out.close()
-  }
-
-  // Called by Java when deserializing an object
-  private def readObject(in: ObjectInputStream) {
-    in.defaultReadObject
-    Cache.synchronized {
-      val cachedVal = Cache.values.get(uuid)
-      if (cachedVal != null) {
-        value_ = cachedVal.asInstanceOf[T]
-      } else {
-        val start = System.nanoTime
-        val fileIn = new ObjectInputStream(Cache.openFileForReading(uuid))
-        value_ = fileIn.readObject().asInstanceOf[T]
-        Cache.values.put(uuid, value_)
-        fileIn.close()
-        val time = (System.nanoTime - start) / 1e9
-        println("Reading cached variable " + uuid + " took " + time + " s")
-      }
-    }
-  }
-  
-  override def toString = "spark.Cached(" + uuid + ")"
-}
-
-private object Cache {
-  val values = new MapMaker().softValues().makeMap[UUID, Any]()
-
-  private var initialized = false
-  private var fileSystem: FileSystem = null
-  private var workDir: String = null
-  private var compress: Boolean = false
-  private var bufferSize: Int = 65536
-
-  // Will be called by SparkContext or Executor before using cache
-  def initialize() {
-    synchronized {
-      if (!initialized) {
-        bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
-        val dfs = System.getProperty("spark.dfs", "file:///")
-        if (!dfs.startsWith("file://")) {
-          val conf = new Configuration()
-          conf.setInt("io.file.buffer.size", bufferSize)
-          val rep = System.getProperty("spark.dfs.replication", "3").toInt
-          conf.setInt("dfs.replication", rep)
-          fileSystem = FileSystem.get(new URI(dfs), conf)
-        }
-        workDir = System.getProperty("spark.dfs.workdir", "/tmp")
-        compress = System.getProperty("spark.compress", "false").toBoolean
-        initialized = true
-      }
-    }
-  }
-
-  private def getPath(uuid: UUID) = new Path(workDir + "/cache-" + uuid)
-
-  def openFileForReading(uuid: UUID): InputStream = {
-    val fileStream = if (fileSystem != null) {
-      fileSystem.open(getPath(uuid))
-    } else {
-      // Local filesystem
-      new FileInputStream(getPath(uuid).toString)
-    }
-    if (compress)
-      new LZFInputStream(fileStream) // LZF stream does its own buffering
-    else if (fileSystem == null)
-      new BufferedInputStream(fileStream, bufferSize)
-    else
-      fileStream // Hadoop streams do their own buffering
-  }
-
-  def openFileForWriting(uuid: UUID): OutputStream = {
-    val fileStream = if (fileSystem != null) {
-      fileSystem.create(getPath(uuid))
-    } else {
-      // Local filesystem
-      new FileOutputStream(getPath(uuid).toString)
-    }
-    if (compress)
-      new LZFOutputStream(fileStream) // LZF stream does its own buffering
-    else if (fileSystem == null)
-      new BufferedOutputStream(fileStream, bufferSize)
-    else
-      fileStream // Hadoop streams do their own buffering
-  }
-}
diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala
index 4cc8f00aa9..d115c6acd9 100644
--- a/src/scala/spark/Executor.scala
+++ b/src/scala/spark/Executor.scala
@@ -18,8 +18,8 @@ object Executor {
         for ((key, value) <- props)
           System.setProperty(key, value)
         
-        // Initialize cache (uses some properties read above)
-        Cache.initialize()
+        // Initialize broadcast system (uses some properties read above)
+        Broadcast.initialize(false)
         
         // If the REPL is in use, create a ClassLoader that will be able to
         // read new classes defined by the REPL as the user types code
diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala
index 8050683f99..87d8e8cc81 100644
--- a/src/scala/spark/HdfsFile.scala
+++ b/src/scala/spark/HdfsFile.scala
@@ -27,9 +27,9 @@ import org.apache.hadoop.mapred.Reporter
 abstract class DistributedFile[T, Split](@transient sc: SparkContext) {
   def splits: Array[Split]
   def iterator(split: Split): Iterator[T]
-  def prefers(split: Split, slot: SlaveOffer): Boolean
+  def preferredLocations(split: Split): Seq[String]
 
-  def taskStarted(split: Split, slot: SlaveOffer) {}
+  def taskStarted(split: Split, offer: SlaveOffer) {}
 
   def sparkContext = sc
 
@@ -87,8 +87,8 @@ abstract class DistributedFile[T, Split](@transient sc: SparkContext) {
 abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split],
   val split: Split)
 extends Task[U] {
-  override def prefers(slot: SlaveOffer) = file.prefers(split, slot)
-  override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) }
+  override def preferredLocations: Seq[String] = file.preferredLocations(split)
+  override def markStarted(offer: SlaveOffer) { file.taskStarted(split, offer) }
 }
 
 class ForeachTask[T, Split](file: DistributedFile[T, Split],
@@ -124,31 +124,31 @@ extends FileTask[Option[T], T, Split](file, split) {
 class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U) 
 extends DistributedFile[U, Split](prev.sparkContext) {
   override def splits = prev.splits
-  override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
+  override def preferredLocations(sp: Split) = prev.preferredLocations(sp)
   override def iterator(split: Split) = prev.iterator(split).map(f)
-  override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+  override def taskStarted(split: Split, offer: SlaveOffer) = prev.taskStarted(split, offer)
 }
 
 class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean) 
 extends DistributedFile[T, Split](prev.sparkContext) {
   override def splits = prev.splits
-  override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot)
+  override def preferredLocations(sp: Split) = prev.preferredLocations(sp)
   override def iterator(split: Split) = prev.iterator(split).filter(f)
-  override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+  override def taskStarted(split: Split, offer: SlaveOffer) = prev.taskStarted(split, offer)
 }
 
 class CachedFile[T, Split](prev: DistributedFile[T, Split])
 extends DistributedFile[T, Split](prev.sparkContext) {
   val id = CachedFile.newId()
-  @transient val cacheLocs = Map[Split, List[Int]]()
+  @transient val cacheLocs = Map[Split, List[String]]()
 
   override def splits = prev.splits
 
-  override def prefers(split: Split, slot: SlaveOffer): Boolean = {
+  override def preferredLocations(split: Split): Seq[String] = {
     if (cacheLocs.contains(split))
-      cacheLocs(split).contains(slot.getSlaveId)
+      cacheLocs(split)
     else
-      prev.prefers(split, slot)
+      prev.preferredLocations(split)
   }
   
   override def iterator(split: Split): Iterator[T] = {
@@ -183,11 +183,11 @@ extends DistributedFile[T, Split](prev.sparkContext) {
     }
   }
 
-  override def taskStarted(split: Split, slot: SlaveOffer) {
+  override def taskStarted(split: Split, offer: SlaveOffer) {
     val oldList = cacheLocs.getOrElse(split, Nil)
-    val slaveId = slot.getSlaveId
-    if (!oldList.contains(slaveId))
-      cacheLocs(split) = slaveId :: oldList
+    val host = offer.getHost
+    if (!oldList.contains(host))
+      cacheLocs(split) = host :: oldList
   }
 }
 
@@ -251,8 +251,10 @@ extends DistributedFile[String, HdfsSplit](sc) {
     }
   }
 
-  override def prefers(split: HdfsSplit, slot: SlaveOffer) =
-    split.value.getLocations().contains(slot.getHost)
+  override def preferredLocations(split: HdfsSplit) = {
+    // TODO: Filtering out "localhost" in case of file:// URLs
+    split.value.getLocations().filter(_ != "localhost").toArray
+  }
 }
 
 object ConfigureLock {}
diff --git a/src/scala/spark/NexusScheduler.scala b/src/scala/spark/NexusScheduler.scala
index a96fca9350..a8a5e2947a 100644
--- a/src/scala/spark/NexusScheduler.scala
+++ b/src/scala/spark/NexusScheduler.scala
@@ -1,11 +1,11 @@
 package spark
 
 import java.io.File
-import java.util.concurrent.Semaphore
 
-import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus}
-import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver}
-import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
+import scala.collection.mutable.Map
+
+import nexus.{Scheduler => NScheduler}
+import nexus._
 
 // The main Scheduler implementation, which talks to Nexus. Clients are expected
 // to first call start(), then submit tasks through the runTasks method.
@@ -21,30 +21,26 @@ import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap}
 //    can be made cleaner.
 private class NexusScheduler(
   master: String, frameworkName: String, execArg: Array[Byte])
-extends nexus.Scheduler with spark.Scheduler
+extends NScheduler with spark.Scheduler
 {
-  // Semaphore used by runTasks to ensure only one thread can be in it
-  val semaphore = new Semaphore(1)
+  // Lock used by runTasks to ensure only one thread can be in it
+  val runTasksMutex = new Object()
 
   // Lock used to wait for  scheduler to be registered
   var isRegistered = false
   val registeredLock = new Object()
 
-  // Trait representing a set of  scheduler callbacks
-  trait Callbacks {
-    def slotOffer(s: SlaveOffer): Option[TaskDescription]
-    def taskFinished(t: TaskStatus): Unit
-    def error(code: Int, message: String): Unit
-  }
-
   // Current callback object (may be null)
-  var callbacks: Callbacks = null
+  var activeOp: ParallelOperation = null
 
   // Incrementing task ID
-  var nextTaskId = 0
+  private var nextTaskId = 0
 
-  // Maximum time to wait to run a task in a preferred location (in ms)
-  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
+  def newTaskId(): Int = {
+    val id = nextTaskId;
+    nextTaskId += 1;
+    return id
+  }
 
   // Driver for talking to Nexus
   var driver: SchedulerDriver = null
@@ -66,125 +62,27 @@ extends nexus.Scheduler with spark.Scheduler
     new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
 
   override def runTasks[T](tasks: Array[Task[T]]): Array[T] = {
-    val results = new Array[T](tasks.length)
-    if (tasks.length == 0)
-      return results
-
-    val launched = new Array[Boolean](tasks.length)
-
-    val callingThread = currentThread
-
-    var errorHappened = false
-    var errorCode = 0
-    var errorMessage = ""
-
-    // Wait for scheduler to be registered with Nexus
-    waitForRegister()
-
-    try {
-      // Acquire the runTasks semaphore
-      semaphore.acquire()
-
-      val myCallbacks = new Callbacks {
-        val firstTaskId = nextTaskId
-        var tasksLaunched = 0
-        var tasksFinished = 0
-        var lastPreferredLaunchTime = System.currentTimeMillis
-
-        def slotOffer(slot: SlaveOffer): Option[TaskDescription] = {
-          try {
-            if (tasksLaunched < tasks.length) {
-              // TODO: Add a short wait if no task with location pref is found
-              // TODO: Figure out why a function is needed around this to
-              // avoid  scala.runtime.NonLocalReturnException
-              def findTask: Option[TaskDescription] = {
-                var checkPrefVals: Array[Boolean] = Array(true)
-                val time = System.currentTimeMillis
-                if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
-                  checkPrefVals = Array(true, false) // Allow non-preferred tasks
-                // TODO: Make desiredCpus and desiredMem configurable
-                val desiredCpus = 1
-                val desiredMem = 750L * 1024L * 1024L
-                if (slot.getParams.get("cpus").toInt < desiredCpus || 
-                    slot.getParams.get("mem").toLong < desiredMem)
-                  return None
-                for (checkPref <- checkPrefVals;
-                     i <- 0 until tasks.length;
-                     if !launched(i) && (!checkPref || tasks(i).prefers(slot)))
-                {
-                  val taskId = nextTaskId
-                  nextTaskId += 1
-                  printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
-                    i, taskId, slot.getSlaveId, slot.getHost, 
-                    if(checkPref) "preferred" else "non-preferred")
-                  tasks(i).markStarted(slot)
-                  launched(i) = true
-                  tasksLaunched += 1
-                  if (checkPref)
-                    lastPreferredLaunchTime = time
-                  val params = new StringMap
-                  params.set("cpus", "" + desiredCpus)
-                  params.set("mem", "" + desiredMem)
-                  val serializedTask = Utils.serialize(tasks(i))
-                  return Some(new TaskDescription(taskId, slot.getSlaveId,
-                    "task_" + taskId, params, serializedTask))
-                }
-                return None
-              }
-              return findTask
-            } else {
-              return None
-            }
-          } catch {
-            case e: Exception => { 
-              e.printStackTrace
-              System.exit(1)
-              return None
-            }
-          }
-        }
+    runTasksMutex.synchronized {
+      waitForRegister()
+      val myOp = new SimpleParallelOperation(this, tasks)
 
-        def taskFinished(status: TaskStatus) {
-          println("Finished TID " + status.getTaskId)
-          // Deserialize task result
-          val result = Utils.deserialize[TaskResult[T]](status.getData)
-          results(status.getTaskId - firstTaskId) = result.value
-          // Update accumulators
-          Accumulators.add(callingThread, result.accumUpdates)
-          // Stop if we've finished all the tasks
-          tasksFinished += 1
-          if (tasksFinished == tasks.length) {
-            NexusScheduler.this.callbacks = null
-            NexusScheduler.this.notifyAll()
-          }
+      try {
+        this.synchronized {
+          this.activeOp = myOp
         }
-
-        def error(code: Int, message: String) {
-          // Save the error message
-          errorHappened = true
-          errorCode = code
-          errorMessage = message
-          // Indicate to caller thread that we're done
-          NexusScheduler.this.callbacks = null
-          NexusScheduler.this.notifyAll()
+        driver.reviveOffers();
+        myOp.join();
+      } finally {
+        this.synchronized {
+          this.activeOp = null
         }
       }
 
-      this.synchronized {
-        this.callbacks = myCallbacks
-      }
-      driver.reviveOffers();
-      this.synchronized {
-        while (this.callbacks != null) this.wait()
-      }
-    } finally {
-      semaphore.release()
+      if (myOp.errorHappened)
+        throw new SparkException(myOp.errorMessage, myOp.errorCode)
+      else
+        return myOp.results
     }
-
-    if (errorHappened)
-      throw new SparkException(errorMessage, errorCode)
-    else
-      return results
   }
 
   override def registered(d: SchedulerDriver, frameworkId: Int) {
@@ -197,18 +95,19 @@ extends nexus.Scheduler with spark.Scheduler
   
   override def waitForRegister() {
     registeredLock.synchronized {
-      while (!isRegistered) registeredLock.wait()
+      while (!isRegistered)
+        registeredLock.wait()
     }
   }
 
   override def resourceOffer(
-      d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) {
+      d: SchedulerDriver, oid: Long, offers: SlaveOfferVector) {
     synchronized {
       val tasks = new TaskDescriptionVector
-      if (callbacks != null) {
+      if (activeOp != null) {
         try {
-          for (i <- 0 until slots.size.toInt) {
-            callbacks.slotOffer(slots.get(i)) match {
+          for (i <- 0 until offers.size.toInt) {
+            activeOp.slaveOffer(offers.get(i)) match {
               case Some(task) => tasks.add(task)
               case None => {}
             }
@@ -225,21 +124,21 @@ extends nexus.Scheduler with spark.Scheduler
 
   override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
     synchronized {
-      if (callbacks != null && status.getState == TaskState.TASK_FINISHED) {
-        try {
-          callbacks.taskFinished(status)
-        } catch {
-          case e: Exception => e.printStackTrace
+      try {
+        if (activeOp != null) {
+          activeOp.statusUpdate(status)
         }
+      } catch {
+        case e: Exception => e.printStackTrace
       }
     }
   }
 
   override def error(d: SchedulerDriver, code: Int, message: String) {
     synchronized {
-      if (callbacks != null) {
+      if (activeOp != null) {
         try {
-          callbacks.error(code, message)
+          activeOp.error(code, message)
         } catch {
           case e: Exception => e.printStackTrace
         }
@@ -256,3 +155,135 @@ extends nexus.Scheduler with spark.Scheduler
       driver.stop()
   }
 }
+
+
+// Trait representing a set of  scheduler callbacks
+trait ParallelOperation {
+  def slaveOffer(s: SlaveOffer): Option[TaskDescription]
+  def statusUpdate(t: TaskStatus): Unit
+  def error(code: Int, message: String): Unit
+}
+
+
+class SimpleParallelOperation[T](sched: NexusScheduler, tasks: Array[Task[T]])
+extends ParallelOperation
+{
+  // Maximum time to wait to run a task in a preferred location (in ms)
+  val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong
+
+  val callingThread = currentThread
+  val numTasks = tasks.length
+  val results = new Array[T](numTasks)
+  val launched = new Array[Boolean](numTasks)
+  val finished = new Array[Boolean](numTasks)
+  val tidToIndex = Map[Int, Int]()
+
+  var allFinished = false
+  val joinLock = new Object()
+
+  var errorHappened = false
+  var errorCode = 0
+  var errorMessage = ""
+
+  var tasksLaunched = 0
+  var tasksFinished = 0
+  var lastPreferredLaunchTime = System.currentTimeMillis
+
+  def setAllFinished() {
+    joinLock.synchronized {
+      allFinished = true
+      joinLock.notifyAll()
+    }
+  }
+
+  def join() {
+    joinLock.synchronized {
+      while (!allFinished)
+        joinLock.wait()
+    }
+  }
+
+  def slaveOffer(offer: SlaveOffer): Option[TaskDescription] = {
+    if (tasksLaunched < numTasks) {
+      var checkPrefVals: Array[Boolean] = Array(true)
+      val time = System.currentTimeMillis
+      if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
+        checkPrefVals = Array(true, false) // Allow non-preferred tasks
+      // TODO: Make desiredCpus and desiredMem configurable
+      val desiredCpus = 1
+      val desiredMem = 750L * 1024L * 1024L
+      if (offer.getParams.get("cpus").toInt < desiredCpus || 
+          offer.getParams.get("mem").toLong < desiredMem)
+        return None
+      for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
+        if (!launched(i) && (!checkPref ||
+            tasks(i).preferredLocations.contains(offer.getHost) ||
+            tasks(i).preferredLocations.isEmpty))
+        {
+          val taskId = sched.newTaskId()
+          tidToIndex(taskId) = i
+          printf("Starting task %d as TID %d on slave %d: %s (%s)\n",
+            i, taskId, offer.getSlaveId, offer.getHost, 
+            if(checkPref) "preferred" else "non-preferred")
+          tasks(i).markStarted(offer)
+          launched(i) = true
+          tasksLaunched += 1
+          if (checkPref)
+            lastPreferredLaunchTime = time
+          val params = new StringMap
+          params.set("cpus", "" + desiredCpus)
+          params.set("mem", "" + desiredMem)
+          val serializedTask = Utils.serialize(tasks(i))
+          return Some(new TaskDescription(taskId, offer.getSlaveId,
+            "task_" + taskId, params, serializedTask))
+        }
+      }
+    }
+    return None
+  }
+
+  def statusUpdate(status: TaskStatus) {
+    status.getState match {
+      case TaskState.TASK_FINISHED =>
+        taskFinished(status)
+      case TaskState.TASK_LOST =>
+        taskLost(status)
+      case TaskState.TASK_FAILED =>
+        taskLost(status)
+      case TaskState.TASK_KILLED =>
+        taskLost(status)
+      case _ =>
+    }
+  }
+
+  def taskFinished(status: TaskStatus) {
+    val tid = status.getTaskId
+    println("Finished TID " + tid)
+    // Deserialize task result
+    val result = Utils.deserialize[TaskResult[T]](status.getData)
+    results(tidToIndex(tid)) = result.value
+    // Update accumulators
+    Accumulators.add(callingThread, result.accumUpdates)
+    // Mark finished and stop if we've finished all the tasks
+    finished(tidToIndex(tid)) = true
+    tasksFinished += 1
+    if (tasksFinished == numTasks)
+      setAllFinished()
+  }
+
+  def taskLost(status: TaskStatus) {
+    val tid = status.getTaskId
+    println("Lost TID " + tid)
+    launched(tidToIndex(tid)) = false
+    tasksLaunched -= 1
+  }
+
+  def error(code: Int, message: String) {
+    // Save the error message
+    errorHappened = true
+    errorCode = code
+    errorMessage = message
+    // Indicate to caller thread that we're done
+    setAllFinished()
+  }
+}
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
index 4bfbcb6f21..7972702205 100644
--- a/src/scala/spark/SparkContext.scala
+++ b/src/scala/spark/SparkContext.scala
@@ -6,7 +6,7 @@ import java.util.UUID
 import scala.collection.mutable.ArrayBuffer
 
 class SparkContext(master: String, frameworkName: String) {
-  Cache.initialize()
+  Broadcast.initialize(true)
 
   def parallelize[T](seq: Seq[T], numSlices: Int): ParallelArray[T] =
     new SimpleParallelArray[T](this, seq, numSlices)
@@ -17,7 +17,8 @@ class SparkContext(master: String, frameworkName: String) {
     new Accumulator(initialValue, param)
 
   // TODO: Keep around a weak hash map of values to Cached versions?
-  def broadcast[T](value: T) = new Cached(value, local)
+  def broadcast[T](value: T) = new ChainedStreamingBroadcast (value, local)
+  // def broadcast[T](value: T) = new CentralizedHDFSBroadcast (value, local)
 
   def textFile(path: String) = new HdfsTextFile(this, path)
 
diff --git a/src/scala/spark/Task.scala b/src/scala/spark/Task.scala
index e559996a37..efb864472d 100644
--- a/src/scala/spark/Task.scala
+++ b/src/scala/spark/Task.scala
@@ -5,8 +5,8 @@ import nexus._
 @serializable
 trait Task[T] {
   def run: T
-  def prefers(slot: SlaveOffer): Boolean = true
-  def markStarted(slot: SlaveOffer) {}
+  def preferredLocations: Seq[String] = Nil
+  def markStarted(offer: SlaveOffer) {}
 }
 
 @serializable
diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala
index d71fe20a94..43ef296efe 100644
--- a/src/test/spark/repl/ReplSuite.scala
+++ b/src/test/spark/repl/ReplSuite.scala
@@ -85,15 +85,15 @@ class ReplSuite extends FunSuite {
     assertContains("res2: Int = 100", output)
   }
   
-  test ("cached vars") {
-    // Test that the value that a cached var had when it was created is used,
-    // even if that cached var is then modified in the driver program
+  test ("broadcast vars") {
+    // Test that the value that a broadcast var had when it was created is used,
+    // even if that broadcast var is then modified in the driver program
     val output = runInterpreter("local", """
       var array = new Array[Int](5)
-      val cachedArray = sc.cache(array)
-      sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+      val broadcastedArray = sc.broadcast(array)
+      sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
       array(0) = 5
-      sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+      sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
       """)
     assertDoesNotContain("error:", output)
     assertDoesNotContain("Exception", output)
@@ -109,10 +109,10 @@ class ReplSuite extends FunSuite {
       v = 10
       sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
       var array = new Array[Int](5)
-      val cachedArray = sc.cache(array)
-      sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+      val broadcastedArray = sc.broadcast(array)
+      sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
       array(0) = 5
-      sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray
+      sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray
       """)
     assertDoesNotContain("error:", output)
     assertDoesNotContain("Exception", output)
-- 
GitLab