diff --git a/alltests b/alltests index 3c9db301c404bd632ce5ae08241786f248495eab..8beab92952aa9d48622d5f4b1b35959f2c457f9b 100755 --- a/alltests +++ b/alltests @@ -1,3 +1,3 @@ #!/bin/bash FWDIR=`dirname $0` -$FWDIR/run org.scalatest.tools.Runner -p $FWDIR/build/classes -o $@ +$FWDIR/run org.scalatest.tools.Runner -p $FWDIR/classes -o $@ diff --git a/run b/run index c1156892ad4e3760ca8d90b8cf01a6a3a5823da5..82a12b011309b01b6a93486a224387dcd4920fa5 100755 --- a/run +++ b/run @@ -4,7 +4,8 @@ FWDIR=`dirname $0` # Set JAVA_OPTS to be able to load libnexus.so and set various other misc options -export JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx2000m -Dspark.broadcast.masterHostAddress=127.0.0.1 -Dspark.broadcast.masterListenPort=11111 -Dspark.broadcast.blockSize=1024 -Dspark.broadcast.maxRetryCount=2 -Dspark.broadcast.serverSocketTimout=50000 -Dspark.broadcast.dualMode=false" +export JAVA_OPTS="-Djava.library.path=$FWDIR/third_party:$FWDIR/src/native -Xmx2000m -Dspark.broadcast.masterHostAddress=127.0.0.1 -Dspark.broadcast.masterTrackerPort=11111 -Dspark.broadcast.blockSize=1024 -Dspark.broadcast.maxRetryCount=2 -Dspark.broadcast.serverSocketTimout=50000 -Dspark.broadcast.dualMode=false" + if [ -e $FWDIR/conf/java-opts ] ; then JAVA_OPTS+=" `cat $FWDIR/conf/java-opts`" fi diff --git a/src/examples/SparkALS.scala b/src/examples/SparkALS.scala index 38dd0e665dd4f3b5281279e3c81796e16dbb21b5..cbbbba3c7933cfbbe23ecbcd35de1d82c78beed0 100644 --- a/src/examples/SparkALS.scala +++ b/src/examples/SparkALS.scala @@ -122,6 +122,8 @@ object SparkALS { var msc = spark.broadcast(ms) var usc = spark.broadcast(us) for (iter <- 1 to ITERATIONS) { + val start = System.nanoTime + println("Iteration " + iter + ":") ms = spark.parallelize(0 until M, slices) .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value)) @@ -133,6 +135,9 @@ object SparkALS { usc = spark.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) println() + + val time = (System.nanoTime - start) / 1e9 + println( "This iteration took " + time + " s") } } } diff --git a/src/scala/spark/Broadcast.scala b/src/scala/spark/Broadcast.scala index 2da5e28a0a8a5fa9b6cc2c6afd34ddb2d3ade3e2..73b9ea39e8ef1bcc3c4ed5d621a0c3fec9a63cd5 100644 --- a/src/scala/spark/Broadcast.scala +++ b/src/scala/spark/Broadcast.scala @@ -8,9 +8,6 @@ import com.google.common.collect.MapMaker import java.util.concurrent.{Executors, ExecutorService} -import scala.actors.Actor -import scala.actors.Actor._ - import scala.collection.mutable.Map import org.apache.hadoop.conf.Configuration @@ -29,8 +26,6 @@ trait BroadcastRecipe { override def toString = "spark.Broadcast(" + uuid + ")" } -// TODO: Should think about storing in HDFS in the future -// TODO: Right, now no parallelization between multiple broadcasts @serializable class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) extends BroadcastRecipe { @@ -39,31 +34,81 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) BroadcastCS.synchronized { BroadcastCS.values.put (uuid, value_) } + @transient var arrayOfBlocks: Array[BroadcastBlock] = null + @transient var totalBytes = -1 + @transient var totalBlocks = -1 + @transient var hasBlocks = 0 + + @transient var listenPortLock = new Object + @transient var guidePortLock = new Object + @transient var totalBlocksLock = new Object + @transient var hasBlocksLock = new Object + + @transient var pqOfSources = new PriorityQueue[SourceInfo] + + @transient var serveMR: ServeMultipleRequests = null + @transient var guideMR: GuideMultipleRequests = null + + @transient var hostAddress = InetAddress.getLocalHost.getHostAddress + @transient var listenPort = -1 + @transient var guidePort = -1 + if (!local) { sendBroadcast } def sendBroadcast () { // Create a variableInfo object and store it in valueInfos var variableInfo = blockifyObject (value_, BroadcastCS.blockSize) - // TODO: Even though this part is not in use now, there is problem in the - // following statement. Shouldn't use constant port and hostAddress anymore? - // val masterSource = - // new SourceInfo (BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort, - // variableInfo.totalBlocks, variableInfo.totalBytes, 0) - // variableInfo.pqOfSources.add (masterSource) + guideMR = new GuideMultipleRequests + // guideMR.setDaemon (true) + guideMR.start + // println (System.currentTimeMillis + ": " + "GuideMultipleRequests started") + + serveMR = new ServeMultipleRequests + // serveMR.setDaemon (true) + serveMR.start + // println (System.currentTimeMillis + ": " + "ServeMultipleRequests started") + + // Prepare the value being broadcasted + // TODO: Refactoring and clean-up required here + arrayOfBlocks = variableInfo.arrayOfBlocks + totalBytes = variableInfo.totalBytes + totalBlocks = variableInfo.totalBlocks + hasBlocks = variableInfo.totalBlocks + + while (listenPort == -1) { + listenPortLock.synchronized { + listenPortLock.wait + } + } + + pqOfSources = new PriorityQueue[SourceInfo] + val masterSource_0 = + new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) + pqOfSources.add (masterSource_0) + // Add one more time to have two replicas of any seeds in the PQ + if (BroadcastCS.dualMode) { + val masterSource_1 = + new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) + pqOfSources.add (masterSource_1) + } + + // Register with the Tracker + while (guidePort == -1) { + guidePortLock.synchronized { + guidePortLock.wait + } + } + BroadcastCS.synchronized { - // BroadcastCS.valueInfos.put (uuid, variableInfo) - - // TODO: Not using variableInfo in current implementation. Manually - // setting all the variables inside BroadcastCS object - - BroadcastCS.initializeVariable (variableInfo) + BroadcastCS.registerValue (uuid, guidePort) } + // TODO: Make it a separate thread? // Now store a persistent copy in HDFS, just in case val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) out.writeObject (value_) - out.close + out.close } private def readObject (in: ObjectInputStream) { @@ -75,9 +120,18 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) } else { // Only a single worker (the first one) in the same node can ever be // here. The rest will always get the value ready + + // Initializing everything because Master will only send null/0 values + initializeSlaveVariables + + serveMR = new ServeMultipleRequests + // serveMR.setDaemon (true) + serveMR.start + // println (System.currentTimeMillis + ": " + "ServeMultipleRequests started") + val start = System.nanoTime - val retByteArray = BroadcastCS.receiveBroadcast (uuid) + val retByteArray = receiveBroadcast (uuid) // If does not succeed, then get from HDFS copy if (retByteArray != null) { value_ = byteArrayToObject[T] (retByteArray) @@ -97,6 +151,19 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) } } + private def initializeSlaveVariables = { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + listenPortLock = new Object + totalBlocksLock = new Object + hasBlocksLock = new Object + serveMR = null + hostAddress = InetAddress.getLocalHost.getHostAddress + listenPort = -1 + } + private def blockifyObject (obj: T, blockSize: Int): VariableInfo = { val baos = new ByteArrayOutputStream val oos = new ObjectOutputStream (baos) @@ -108,7 +175,7 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) var blockNum = (byteArray.length / blockSize) if (byteArray.length % blockSize != 0) - blockNum += 1 + blockNum += 1 var retVal = new Array[BroadcastBlock] (blockNum) var blockID = 0 @@ -145,199 +212,50 @@ class ChainedStreamingBroadcast[T] (@transient var value_ : T, local: Boolean) bOut.close return bOut } -} - -@serializable -class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean) - extends BroadcastRecipe { - - def value = value_ - - BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) } - - if (!local) { sendBroadcast } - - def sendBroadcast () { - val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) - out.writeObject (value_) - out.close - } - - // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject - BroadcastCH.synchronized { - val cachedVal = BroadcastCH.values.get(uuid) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - val start = System.nanoTime - - val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BroadcastCH.values.put(uuid, value_) - fileIn.close - - val time = (System.nanoTime - start) / 1e9 - println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } -} - -@serializable -case class SourceInfo (val hostAddress: String, val listenPort: Int, - val totalBlocks: Int, val totalBytes: Int, val replicaID: Int) - extends Comparable [SourceInfo]{ - - var currentLeechers = 0 - var receptionFailed = false - - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -@serializable -case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } - -@serializable -case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], - val totalBlocks: Int, val totalBytes: Int) { - - @transient var hasBlocks = 0 - - val listenPortLock = new AnyRef - val totalBlocksLock = new AnyRef - val hasBlocksLock = new AnyRef - - @transient var pqOfSources = new PriorityQueue[SourceInfo] -} - -private object Broadcast { - private var initialized = false - - // Will be called by SparkContext or Executor before using Broadcast - // Calls all other initializers here - def initialize (isMaster: Boolean) { - synchronized { - if (!initialized) { - // Initialization for CentralizedHDFSBroadcast - BroadcastCH.initialize - // Initialization for ChainedStreamingBroadcast - BroadcastCS.initialize (isMaster) - - initialized = true - } - } - } -} - -private object BroadcastCS { - val values = new MapMaker ().softValues ().makeMap[UUID, Any] - // val valueInfos = new MapMaker ().softValues ().makeMap[UUID, Any] - - // private var valueToPort = Map[UUID, Int] () - - private var initialized = false - private var isMaster_ = false - - private var masterHostAddress_ = "127.0.0.1" - private var masterListenPort_ : Int = 11111 - private var blockSize_ : Int = 512 * 1024 - private var maxRetryCount_ : Int = 2 - private var serverSocketTimout_ : Int = 50000 - private var dualMode_ : Boolean = false - - private val hostAddress = InetAddress.getLocalHost.getHostAddress - private var listenPort = -1 - - var arrayOfBlocks: Array[BroadcastBlock] = null - var totalBytes = -1 - var totalBlocks = -1 - var hasBlocks = 0 - - val listenPortLock = new Object - val totalBlocksLock = new Object - val hasBlocksLock = new Object - - var pqOfSources = new PriorityQueue[SourceInfo] - private var serveMR: ServeMultipleRequests = null - private var guideMR: GuideMultipleRequests = null - - def initialize (isMaster__ : Boolean) { - synchronized { - if (!initialized) { - masterHostAddress_ = - System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1") - masterListenPort_ = - System.getProperty ("spark.broadcast.masterListenPort", "11111").toInt - blockSize_ = - System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 - maxRetryCount_ = - System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt - serverSocketTimout_ = - System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt - dualMode_ = - System.getProperty ("spark.broadcast.dualMode", "false").toBoolean - - isMaster_ = isMaster__ - - if (isMaster) { - guideMR = new GuideMultipleRequests - // guideMR.setDaemon (true) - guideMR.start - println (System.currentTimeMillis + ": " + "GuideMultipleRequests started") - } - serveMR = new ServeMultipleRequests - // serveMR.setDaemon (true) - serveMR.start - - println (System.currentTimeMillis + ": " + "ServeMultipleRequests started") - - println (System.currentTimeMillis + ": " + "BroadcastCS object has been initialized") - - initialized = true - } - } - } - - // TODO: This should change in future implementation. - // Called from the Master constructor to setup states for this particular that - // is being broadcasted - def initializeVariable (variableInfo: VariableInfo) { - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks + def receiveBroadcast (variableUUID: UUID): Array[Byte] = { + var clientSocketToTracker: Socket = null + var oisTracker: ObjectInputStream = null + var oosTracker: ObjectOutputStream = null - // listenPort should already be valid - assert (listenPort != -1) + var masterListenPort: Int = -1 - pqOfSources = new PriorityQueue[SourceInfo] - val masterSource_0 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 0) - BroadcastCS.pqOfSources.add (masterSource_0) - // Add one more time to have two replicas of any seeds in the PQ - if (BroadcastCS.dualMode) { - val masterSource_1 = - new SourceInfo (hostAddress, listenPort, totalBlocks, totalBytes, 1) - BroadcastCS.pqOfSources.add (masterSource_1) - } - } - - def masterHostAddress = masterHostAddress_ - def masterListenPort = masterListenPort_ - def blockSize = blockSize_ - def maxRetryCount = maxRetryCount_ - def serverSocketTimout = serverSocketTimout_ - def dualMode = dualMode_ + // masterListenPort aka guidePort value legend + // 0 = missed the broadcast, read from HDFS; + // <0 = hasn't started yet, wait & retry; (never happens) + // >0 = Read from this port + var retriesLeft = BroadcastCS.maxRetryCount + do { + try { + // Connect to the tracker to find out the guide + val clientSocketToTracker = + new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterTrackerPort) + val oisTracker = + new ObjectInputStream (clientSocketToTracker.getInputStream) + val oosTracker = + new ObjectOutputStream (clientSocketToTracker.getOutputStream) + + // Send UUID and receive masterListenPort + oosTracker.writeObject (uuid) + masterListenPort = oisTracker.readObject.asInstanceOf[Int] + } catch { + // In case of any failure, set masterListenPort = 0 to read from HDFS + case e: Exception => (masterListenPort = 0) + } finally { + if (oisTracker != null) { oisTracker.close } + if (oosTracker != null) { oosTracker.close } + if (clientSocketToTracker != null) { clientSocketToTracker.close } + } - def isMaster = isMaster_ + retriesLeft -= 1 + } while (retriesLeft > 0 && masterListenPort < 0) + // println (System.currentTimeMillis + ": " + "Got this guidePort from Tracker: " + masterListenPort) + + // If Tracker says that there is no guide for this object, read from HDFS + if (masterListenPort == 0) { return null } - def receiveBroadcast (variableUUID: UUID): Array[Byte] = { // Wait until hostAddress and listenPort are created by the // ServeMultipleRequests thread - // NO need to wait; ServeMultipleRequests is created much further ahead while (listenPort == -1) { listenPortLock.synchronized { listenPortLock.wait @@ -346,13 +264,13 @@ private object BroadcastCS { // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures - var retriesLeft = BroadcastCS.maxRetryCount + retriesLeft = BroadcastCS.maxRetryCount var retByteArray: Array[Byte] = null - do { + do { // Connect to Master and send this worker's Information val clientSocketToMaster = - new Socket(BroadcastCS.masterHostAddress, BroadcastCS.masterListenPort) - println (System.currentTimeMillis + ": " + "Connected to Master's guiding object") + new Socket(BroadcastCS.masterHostAddress, masterListenPort) + // println (System.currentTimeMillis + ": " + "Connected to Master's guiding object") // TODO: Guiding object connection is reusable val oisMaster = new ObjectInputStream (clientSocketToMaster.getInputStream) @@ -371,11 +289,11 @@ private object BroadcastCS { } totalBytes = sourceInfo.totalBytes - println (System.currentTimeMillis + ": " + "Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + // println (System.currentTimeMillis + ": " + "Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) retByteArray = receiveSingleTransmission (sourceInfo) - println (System.currentTimeMillis + ": " + "I got this from receiveSingleTransmission: " + retByteArray) + // println (System.currentTimeMillis + ": " + "I got this from receiveSingleTransmission: " + retByteArray) // TODO: Update sourceInfo to add error notifactions for Master if (retByteArray == null) { sourceInfo.receptionFailed = true } @@ -414,8 +332,8 @@ private object BroadcastCS { oisSource = new ObjectInputStream (clientSocketToSource.getInputStream) - println (System.currentTimeMillis + ": " + "Inside receiveSingleTransmission") - println (System.currentTimeMillis + ": " + "totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + // println (System.currentTimeMillis + ": " + "Inside receiveSingleTransmission") + // println (System.currentTimeMillis + ": " + "totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) retByteArray = new Array[Byte] (totalBytes) for (i <- 0 until totalBlocks) { val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] @@ -426,14 +344,14 @@ private object BroadcastCS { hasBlocksLock.synchronized { hasBlocksLock.notifyAll } - println (System.currentTimeMillis + ": " + "Received block: " + i + " " + bcBlock) + // println (System.currentTimeMillis + ": " + "Received block: " + i + " " + bcBlock) } assert (hasBlocks == totalBlocks) - println (System.currentTimeMillis + ": " + "After the receive loop") + // println (System.currentTimeMillis + ": " + "After the receive loop") } catch { case e: Exception => { retByteArray = null - println (System.currentTimeMillis + ": " + "receiveSingleTransmission had a " + e) + // println (System.currentTimeMillis + ": " + "receiveSingleTransmission had a " + e) } } finally { if (oisSource != null) { oisSource.close } @@ -445,91 +363,35 @@ private object BroadcastCS { return retByteArray } - - class TrackMultipleValues extends Thread { - override def run = { - var threadPool = Executors.newCachedThreadPool - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket (BroadcastCS.masterListenPort) - println (System.currentTimeMillis + ": " + "TrackMultipleVariables" + serverSocket + " " + listenPort) - - var keepAccepting = true - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout (serverSocketTimout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - println ("TrackMultipleValues Timeout. Stopping listening...") - keepAccepting = false - } - } - println (System.currentTimeMillis + ": " + "TrackMultipleValues:Got new request:" + clientSocket) - if (clientSocket != null) { - try { - threadPool.execute (new Runnable { - def run = { - val oos = new ObjectOutputStream (clientSocket.getOutputStream) - val ois = new ObjectInputStream (clientSocket.getInputStream) - try { - val variableUUID = ois.readObject.asInstanceOf[UUID] - var contactPort = 0 - // TODO: Add logic and data structures to find out UUID->port - // mapping. 0 = missed the broadcast, read from HDFS; <0 = - // Haven't started yet, wait & retry; >0 = Read from this port - oos.writeObject (contactPort) - } catch { - case e: Exception => { } - } finally { - ois.close - oos.close - clientSocket.close - } - } - }) - } catch { - // In failure, close the socket here; else, the thread will close it - case ioe: IOException => clientSocket.close - } - } - } - } finally { - serverSocket.close - } - } - } - - class TrackSingleValue { - - } - + class GuideMultipleRequests extends Thread { override def run = { var threadPool = Executors.newCachedThreadPool var serverSocket: ServerSocket = null - serverSocket = new ServerSocket (BroadcastCS.masterListenPort) - // listenPort = BroadcastCS.masterListenPort - println (System.currentTimeMillis + ": " + "GuideMultipleRequests" + serverSocket + " " + listenPort) + serverSocket = new ServerSocket (0) + guidePort = serverSocket.getLocalPort + // println (System.currentTimeMillis + ": " + "GuideMultipleRequests" + serverSocket + " " + guidePort) + guidePortLock.synchronized { + guidePortLock.notifyAll + } + var keepAccepting = true try { while (keepAccepting) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout (serverSocketTimout) + serverSocket.setSoTimeout (BroadcastCS.serverSocketTimout) clientSocket = serverSocket.accept } catch { case e: Exception => { - println ("GuideMultipleRequests Timeout. Stopping listening...") + // println ("GuideMultipleRequests Timeout. Stopping listening...") keepAccepting = false } } if (clientSocket != null) { - println (System.currentTimeMillis + ": " + "Guide:Accepted new client connection:" + clientSocket) + // println (System.currentTimeMillis + ": " + "Guide:Accepted new client connection:" + clientSocket) try { threadPool.execute (new GuideSingleRequest (clientSocket)) } catch { @@ -552,21 +414,21 @@ private object BroadcastCS { def run = { try { - println (System.currentTimeMillis + ": " + "new GuideSingleRequest is running") + // println (System.currentTimeMillis + ": " + "new GuideSingleRequest is running") // Connecting worker is sending in its hostAddress and listenPort it will // be listening to. ReplicaID is 0 and other fields are invalid (-1) var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] // Select a suitable source and send it back to the worker selectedSourceInfo = selectSuitableSource (sourceInfo) - println (System.currentTimeMillis + ": " + "Sending selectedSourceInfo:" + selectedSourceInfo) + // println (System.currentTimeMillis + ": " + "Sending selectedSourceInfo:" + selectedSourceInfo) oos.writeObject (selectedSourceInfo) oos.flush // Add this new (if it can finish) source to the PQ of sources thisWorkerInfo = new SourceInfo(sourceInfo.hostAddress, sourceInfo.listenPort, totalBlocks, totalBytes, 0) - println (System.currentTimeMillis + ": " + "Adding possible new source to pqOfSources: " + thisWorkerInfo) + // println (System.currentTimeMillis + ": " + "Adding possible new source to pqOfSources: " + thisWorkerInfo) pqOfSources.synchronized { pqOfSources.add (thisWorkerInfo) } @@ -649,7 +511,7 @@ private object BroadcastCS { serverSocket = new ServerSocket (0) listenPort = serverSocket.getLocalPort - println (System.currentTimeMillis + ": " + "ServeMultipleRequests" + serverSocket + " " + listenPort) + // println (System.currentTimeMillis + ": " + "ServeMultipleRequests" + serverSocket + " " + listenPort) listenPortLock.synchronized { listenPortLock.notifyAll @@ -660,16 +522,16 @@ private object BroadcastCS { while (keepAccepting) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout (serverSocketTimout) + serverSocket.setSoTimeout (BroadcastCS.serverSocketTimout) clientSocket = serverSocket.accept } catch { case e: Exception => { - println ("ServeMultipleRequests Timeout. Stopping listening...") + // println ("ServeMultipleRequests Timeout. Stopping listening...") keepAccepting = false } } if (clientSocket != null) { - println (System.currentTimeMillis + ": " + "Serve:Accepted new client connection:" + clientSocket) + // println (System.currentTimeMillis + ": " + "Serve:Accepted new client connection:" + clientSocket) try { threadPool.execute (new ServeSingleRequest (clientSocket)) } catch { @@ -689,17 +551,17 @@ private object BroadcastCS { def run = { try { - println (System.currentTimeMillis + ": " + "new ServeSingleRequest is running") + // println (System.currentTimeMillis + ": " + "new ServeSingleRequest is running") sendObject } catch { // TODO: Need to add better exception handling here // If something went wrong, e.g., the worker at the other end died etc. // then close everything up case e: Exception => { - println (System.currentTimeMillis + ": " + "ServeSingleRequest had a " + e) + // println (System.currentTimeMillis + ": " + "ServeSingleRequest had a " + e) } } finally { - println (System.currentTimeMillis + ": " + "ServeSingleRequest is closing streams and sockets") + // println (System.currentTimeMillis + ": " + "ServeSingleRequest is closing streams and sockets") ois.close oos.close clientSocket.close @@ -726,11 +588,229 @@ private object BroadcastCS { } catch { case e: Exception => { } } - println (System.currentTimeMillis + ": " + "Send block: " + i + " " + arrayOfBlocks(i)) + // println (System.currentTimeMillis + ": " + "Send block: " + i + " " + arrayOfBlocks(i)) } } } - + } +} + +@serializable +class CentralizedHDFSBroadcast[T](@transient var value_ : T, local: Boolean) + extends BroadcastRecipe { + + def value = value_ + + BroadcastCH.synchronized { BroadcastCH.values.put(uuid, value_) } + + if (!local) { sendBroadcast } + + def sendBroadcast () { + val out = new ObjectOutputStream (BroadcastCH.openFileForWriting(uuid)) + out.writeObject (value_) + out.close + } + + // Called by Java when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject + BroadcastCH.synchronized { + val cachedVal = BroadcastCH.values.get(uuid) + if (cachedVal != null) { + value_ = cachedVal.asInstanceOf[T] + } else { + // println( System.currentTimeMillis + ": " + "Started reading Broadcasted variable " + uuid) + val start = System.nanoTime + + val fileIn = new ObjectInputStream(BroadcastCH.openFileForReading(uuid)) + value_ = fileIn.readObject.asInstanceOf[T] + BroadcastCH.values.put(uuid, value_) + fileIn.close + + val time = (System.nanoTime - start) / 1e9 + println( System.currentTimeMillis + ": " + "Reading Broadcasted variable " + uuid + " took " + time + " s") + } + } + } +} + +@serializable +case class SourceInfo (val hostAddress: String, val listenPort: Int, + val totalBlocks: Int, val totalBytes: Int, val replicaID: Int) + extends Comparable [SourceInfo]{ + + var currentLeechers = 0 + var receptionFailed = false + + def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) +} + +@serializable +case class BroadcastBlock (val blockID: Int, val byteArray: Array[Byte]) { } + +@serializable +case class VariableInfo (@transient val arrayOfBlocks : Array[BroadcastBlock], + val totalBlocks: Int, val totalBytes: Int) { + + @transient var hasBlocks = 0 + + val listenPortLock = new AnyRef + val totalBlocksLock = new AnyRef + val hasBlocksLock = new AnyRef + + @transient var pqOfSources = new PriorityQueue[SourceInfo] +} + +private object Broadcast { + private var initialized = false + + // Will be called by SparkContext or Executor before using Broadcast + // Calls all other initializers here + def initialize (isMaster: Boolean) { + synchronized { + if (!initialized) { + // Initialization for CentralizedHDFSBroadcast + BroadcastCH.initialize + // Initialization for ChainedStreamingBroadcast + BroadcastCS.initialize (isMaster) + + initialized = true + } + } + } +} + +private object BroadcastCS { + val values = new MapMaker ().softValues ().makeMap[UUID, Any] + // val valueInfos = new MapMaker ().softValues ().makeMap[UUID, Any] + + var valueToGuidePortMap = Map[UUID, Int] () + + private var initialized = false + private var isMaster_ = false + + private var masterHostAddress_ = "127.0.0.1" + private var masterTrackerPort_ : Int = 11111 + private var blockSize_ : Int = 512 * 1024 + private var maxRetryCount_ : Int = 2 + private var serverSocketTimout_ : Int = 50000 + private var dualMode_ : Boolean = false + + private var trackMV: TrackMultipleValues = null + + def initialize (isMaster__ : Boolean) { + synchronized { + if (!initialized) { + masterHostAddress_ = + System.getProperty ("spark.broadcast.masterHostAddress", "127.0.0.1") + masterTrackerPort_ = + System.getProperty ("spark.broadcast.masterTrackerPort", "11111").toInt + blockSize_ = + System.getProperty ("spark.broadcast.blockSize", "512").toInt * 1024 + maxRetryCount_ = + System.getProperty ("spark.broadcast.maxRetryCount", "2").toInt + serverSocketTimout_ = + System.getProperty ("spark.broadcast.serverSocketTimout", "50000").toInt + dualMode_ = + System.getProperty ("spark.broadcast.dualMode", "false").toBoolean + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + // trackMV.setDaemon (true) + trackMV.start + // println (System.currentTimeMillis + ": " + "TrackMultipleValues started") + } + + initialized = true + } + } + } + + def masterHostAddress = masterHostAddress_ + def masterTrackerPort = masterTrackerPort_ + def blockSize = blockSize_ + def maxRetryCount = maxRetryCount_ + def serverSocketTimout = serverSocketTimout_ + def dualMode = dualMode_ + + def isMaster = isMaster_ + + def registerValue (uuid: UUID, guidePort: Int) = { + valueToGuidePortMap.synchronized { + valueToGuidePortMap += (uuid -> guidePort) + // println (System.currentTimeMillis + ": " + "New value registered with the Tracker " + valueToGuidePortMap) + } + } + + // TODO: Who call this and when? + def unregisterValue (uuid: UUID) { + valueToGuidePortMap.synchronized { + valueToGuidePortMap (uuid) = 0 + // println (System.currentTimeMillis + ": " + "Value unregistered from the Tracker " + valueToGuidePortMap) + } + } + + class TrackMultipleValues extends Thread { + override def run = { + var threadPool = Executors.newCachedThreadPool + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket (BroadcastCS.masterTrackerPort) + // println (System.currentTimeMillis + ": " + "TrackMultipleValues" + serverSocket) + + var keepAccepting = true + try { + while (true) { + var clientSocket: Socket = null + try { + // TODO: + // serverSocket.setSoTimeout (serverSocketTimout) + clientSocket = serverSocket.accept + } catch { + case e: Exception => { + // println ("TrackMultipleValues Timeout. Stopping listening...") + keepAccepting = false + } + } + + if (clientSocket != null) { + try { + threadPool.execute (new Runnable { + def run = { + val oos = new ObjectOutputStream (clientSocket.getOutputStream) + val ois = new ObjectInputStream (clientSocket.getInputStream) + try { + val uuid = ois.readObject.asInstanceOf[UUID] + // masterListenPort/guidePort value legend + // 0 = missed the broadcast, read from HDFS; + // <0 = hasn't started yet, wait & retry; (never happens) + // >0 = Read from this port + var guidePort = if (valueToGuidePortMap.contains (uuid)) { + valueToGuidePortMap (uuid) + } else -1 + // println (System.currentTimeMillis + ": " + "TrackMultipleValues:Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) + oos.writeObject (guidePort) + } catch { + case e: Exception => { } + } finally { + ois.close + oos.close + clientSocket.close + } + } + }) + } catch { + // In failure, close the socket here; else, the thread will close it + case ioe: IOException => clientSocket.close + } + } + } + } finally { + serverSocket.close + } + } } } diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala index 87d8e8cc81dec65e6333935a7d2e5bc25c769e5b..8050683f99255373f8610e015cd49e57f38f01c4 100644 --- a/src/scala/spark/HdfsFile.scala +++ b/src/scala/spark/HdfsFile.scala @@ -27,9 +27,9 @@ import org.apache.hadoop.mapred.Reporter abstract class DistributedFile[T, Split](@transient sc: SparkContext) { def splits: Array[Split] def iterator(split: Split): Iterator[T] - def preferredLocations(split: Split): Seq[String] + def prefers(split: Split, slot: SlaveOffer): Boolean - def taskStarted(split: Split, offer: SlaveOffer) {} + def taskStarted(split: Split, slot: SlaveOffer) {} def sparkContext = sc @@ -87,8 +87,8 @@ abstract class DistributedFile[T, Split](@transient sc: SparkContext) { abstract class FileTask[U, T, Split](val file: DistributedFile[T, Split], val split: Split) extends Task[U] { - override def preferredLocations: Seq[String] = file.preferredLocations(split) - override def markStarted(offer: SlaveOffer) { file.taskStarted(split, offer) } + override def prefers(slot: SlaveOffer) = file.prefers(split, slot) + override def markStarted(slot: SlaveOffer) { file.taskStarted(split, slot) } } class ForeachTask[T, Split](file: DistributedFile[T, Split], @@ -124,31 +124,31 @@ extends FileTask[Option[T], T, Split](file, split) { class MappedFile[U, T, Split](prev: DistributedFile[T, Split], f: T => U) extends DistributedFile[U, Split](prev.sparkContext) { override def splits = prev.splits - override def preferredLocations(sp: Split) = prev.preferredLocations(sp) + override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) override def iterator(split: Split) = prev.iterator(split).map(f) - override def taskStarted(split: Split, offer: SlaveOffer) = prev.taskStarted(split, offer) + override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } class FilteredFile[T, Split](prev: DistributedFile[T, Split], f: T => Boolean) extends DistributedFile[T, Split](prev.sparkContext) { override def splits = prev.splits - override def preferredLocations(sp: Split) = prev.preferredLocations(sp) + override def prefers(split: Split, slot: SlaveOffer) = prev.prefers(split, slot) override def iterator(split: Split) = prev.iterator(split).filter(f) - override def taskStarted(split: Split, offer: SlaveOffer) = prev.taskStarted(split, offer) + override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } class CachedFile[T, Split](prev: DistributedFile[T, Split]) extends DistributedFile[T, Split](prev.sparkContext) { val id = CachedFile.newId() - @transient val cacheLocs = Map[Split, List[String]]() + @transient val cacheLocs = Map[Split, List[Int]]() override def splits = prev.splits - override def preferredLocations(split: Split): Seq[String] = { + override def prefers(split: Split, slot: SlaveOffer): Boolean = { if (cacheLocs.contains(split)) - cacheLocs(split) + cacheLocs(split).contains(slot.getSlaveId) else - prev.preferredLocations(split) + prev.prefers(split, slot) } override def iterator(split: Split): Iterator[T] = { @@ -183,11 +183,11 @@ extends DistributedFile[T, Split](prev.sparkContext) { } } - override def taskStarted(split: Split, offer: SlaveOffer) { + override def taskStarted(split: Split, slot: SlaveOffer) { val oldList = cacheLocs.getOrElse(split, Nil) - val host = offer.getHost - if (!oldList.contains(host)) - cacheLocs(split) = host :: oldList + val slaveId = slot.getSlaveId + if (!oldList.contains(slaveId)) + cacheLocs(split) = slaveId :: oldList } } @@ -251,10 +251,8 @@ extends DistributedFile[String, HdfsSplit](sc) { } } - override def preferredLocations(split: HdfsSplit) = { - // TODO: Filtering out "localhost" in case of file:// URLs - split.value.getLocations().filter(_ != "localhost").toArray - } + override def prefers(split: HdfsSplit, slot: SlaveOffer) = + split.value.getLocations().contains(slot.getHost) } object ConfigureLock {} diff --git a/src/scala/spark/NexusScheduler.scala b/src/scala/spark/NexusScheduler.scala index a8a5e2947a454d6146a8e8eb90c15f864bca40c3..a96fca9350d9db9c04df09957ccc9cc367f3445f 100644 --- a/src/scala/spark/NexusScheduler.scala +++ b/src/scala/spark/NexusScheduler.scala @@ -1,11 +1,11 @@ package spark import java.io.File +import java.util.concurrent.Semaphore -import scala.collection.mutable.Map - -import nexus.{Scheduler => NScheduler} -import nexus._ +import nexus.{ExecutorInfo, TaskDescription, TaskState, TaskStatus} +import nexus.{SlaveOffer, SchedulerDriver, NexusSchedulerDriver} +import nexus.{SlaveOfferVector, TaskDescriptionVector, StringMap} // The main Scheduler implementation, which talks to Nexus. Clients are expected // to first call start(), then submit tasks through the runTasks method. @@ -21,26 +21,30 @@ import nexus._ // can be made cleaner. private class NexusScheduler( master: String, frameworkName: String, execArg: Array[Byte]) -extends NScheduler with spark.Scheduler +extends nexus.Scheduler with spark.Scheduler { - // Lock used by runTasks to ensure only one thread can be in it - val runTasksMutex = new Object() + // Semaphore used by runTasks to ensure only one thread can be in it + val semaphore = new Semaphore(1) // Lock used to wait for scheduler to be registered var isRegistered = false val registeredLock = new Object() + // Trait representing a set of scheduler callbacks + trait Callbacks { + def slotOffer(s: SlaveOffer): Option[TaskDescription] + def taskFinished(t: TaskStatus): Unit + def error(code: Int, message: String): Unit + } + // Current callback object (may be null) - var activeOp: ParallelOperation = null + var callbacks: Callbacks = null // Incrementing task ID - private var nextTaskId = 0 + var nextTaskId = 0 - def newTaskId(): Int = { - val id = nextTaskId; - nextTaskId += 1; - return id - } + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong // Driver for talking to Nexus var driver: SchedulerDriver = null @@ -62,27 +66,125 @@ extends NScheduler with spark.Scheduler new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) override def runTasks[T](tasks: Array[Task[T]]): Array[T] = { - runTasksMutex.synchronized { - waitForRegister() - val myOp = new SimpleParallelOperation(this, tasks) + val results = new Array[T](tasks.length) + if (tasks.length == 0) + return results + + val launched = new Array[Boolean](tasks.length) + + val callingThread = currentThread + + var errorHappened = false + var errorCode = 0 + var errorMessage = "" + + // Wait for scheduler to be registered with Nexus + waitForRegister() + + try { + // Acquire the runTasks semaphore + semaphore.acquire() + + val myCallbacks = new Callbacks { + val firstTaskId = nextTaskId + var tasksLaunched = 0 + var tasksFinished = 0 + var lastPreferredLaunchTime = System.currentTimeMillis + + def slotOffer(slot: SlaveOffer): Option[TaskDescription] = { + try { + if (tasksLaunched < tasks.length) { + // TODO: Add a short wait if no task with location pref is found + // TODO: Figure out why a function is needed around this to + // avoid scala.runtime.NonLocalReturnException + def findTask: Option[TaskDescription] = { + var checkPrefVals: Array[Boolean] = Array(true) + val time = System.currentTimeMillis + if (time - lastPreferredLaunchTime > LOCALITY_WAIT) + checkPrefVals = Array(true, false) // Allow non-preferred tasks + // TODO: Make desiredCpus and desiredMem configurable + val desiredCpus = 1 + val desiredMem = 750L * 1024L * 1024L + if (slot.getParams.get("cpus").toInt < desiredCpus || + slot.getParams.get("mem").toLong < desiredMem) + return None + for (checkPref <- checkPrefVals; + i <- 0 until tasks.length; + if !launched(i) && (!checkPref || tasks(i).prefers(slot))) + { + val taskId = nextTaskId + nextTaskId += 1 + printf("Starting task %d as TID %d on slave %d: %s (%s)\n", + i, taskId, slot.getSlaveId, slot.getHost, + if(checkPref) "preferred" else "non-preferred") + tasks(i).markStarted(slot) + launched(i) = true + tasksLaunched += 1 + if (checkPref) + lastPreferredLaunchTime = time + val params = new StringMap + params.set("cpus", "" + desiredCpus) + params.set("mem", "" + desiredMem) + val serializedTask = Utils.serialize(tasks(i)) + return Some(new TaskDescription(taskId, slot.getSlaveId, + "task_" + taskId, params, serializedTask)) + } + return None + } + return findTask + } else { + return None + } + } catch { + case e: Exception => { + e.printStackTrace + System.exit(1) + return None + } + } + } - try { - this.synchronized { - this.activeOp = myOp + def taskFinished(status: TaskStatus) { + println("Finished TID " + status.getTaskId) + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(status.getTaskId - firstTaskId) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Stop if we've finished all the tasks + tasksFinished += 1 + if (tasksFinished == tasks.length) { + NexusScheduler.this.callbacks = null + NexusScheduler.this.notifyAll() + } } - driver.reviveOffers(); - myOp.join(); - } finally { - this.synchronized { - this.activeOp = null + + def error(code: Int, message: String) { + // Save the error message + errorHappened = true + errorCode = code + errorMessage = message + // Indicate to caller thread that we're done + NexusScheduler.this.callbacks = null + NexusScheduler.this.notifyAll() } } - if (myOp.errorHappened) - throw new SparkException(myOp.errorMessage, myOp.errorCode) - else - return myOp.results + this.synchronized { + this.callbacks = myCallbacks + } + driver.reviveOffers(); + this.synchronized { + while (this.callbacks != null) this.wait() + } + } finally { + semaphore.release() } + + if (errorHappened) + throw new SparkException(errorMessage, errorCode) + else + return results } override def registered(d: SchedulerDriver, frameworkId: Int) { @@ -95,19 +197,18 @@ extends NScheduler with spark.Scheduler override def waitForRegister() { registeredLock.synchronized { - while (!isRegistered) - registeredLock.wait() + while (!isRegistered) registeredLock.wait() } } override def resourceOffer( - d: SchedulerDriver, oid: Long, offers: SlaveOfferVector) { + d: SchedulerDriver, oid: Long, slots: SlaveOfferVector) { synchronized { val tasks = new TaskDescriptionVector - if (activeOp != null) { + if (callbacks != null) { try { - for (i <- 0 until offers.size.toInt) { - activeOp.slaveOffer(offers.get(i)) match { + for (i <- 0 until slots.size.toInt) { + callbacks.slotOffer(slots.get(i)) match { case Some(task) => tasks.add(task) case None => {} } @@ -124,21 +225,21 @@ extends NScheduler with spark.Scheduler override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { synchronized { - try { - if (activeOp != null) { - activeOp.statusUpdate(status) + if (callbacks != null && status.getState == TaskState.TASK_FINISHED) { + try { + callbacks.taskFinished(status) + } catch { + case e: Exception => e.printStackTrace } - } catch { - case e: Exception => e.printStackTrace } } } override def error(d: SchedulerDriver, code: Int, message: String) { synchronized { - if (activeOp != null) { + if (callbacks != null) { try { - activeOp.error(code, message) + callbacks.error(code, message) } catch { case e: Exception => e.printStackTrace } @@ -155,135 +256,3 @@ extends NScheduler with spark.Scheduler driver.stop() } } - - -// Trait representing a set of scheduler callbacks -trait ParallelOperation { - def slaveOffer(s: SlaveOffer): Option[TaskDescription] - def statusUpdate(t: TaskStatus): Unit - def error(code: Int, message: String): Unit -} - - -class SimpleParallelOperation[T](sched: NexusScheduler, tasks: Array[Task[T]]) -extends ParallelOperation -{ - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "1000").toLong - - val callingThread = currentThread - val numTasks = tasks.length - val results = new Array[T](numTasks) - val launched = new Array[Boolean](numTasks) - val finished = new Array[Boolean](numTasks) - val tidToIndex = Map[Int, Int]() - - var allFinished = false - val joinLock = new Object() - - var errorHappened = false - var errorCode = 0 - var errorMessage = "" - - var tasksLaunched = 0 - var tasksFinished = 0 - var lastPreferredLaunchTime = System.currentTimeMillis - - def setAllFinished() { - joinLock.synchronized { - allFinished = true - joinLock.notifyAll() - } - } - - def join() { - joinLock.synchronized { - while (!allFinished) - joinLock.wait() - } - } - - def slaveOffer(offer: SlaveOffer): Option[TaskDescription] = { - if (tasksLaunched < numTasks) { - var checkPrefVals: Array[Boolean] = Array(true) - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime > LOCALITY_WAIT) - checkPrefVals = Array(true, false) // Allow non-preferred tasks - // TODO: Make desiredCpus and desiredMem configurable - val desiredCpus = 1 - val desiredMem = 750L * 1024L * 1024L - if (offer.getParams.get("cpus").toInt < desiredCpus || - offer.getParams.get("mem").toLong < desiredMem) - return None - for (checkPref <- checkPrefVals; i <- 0 until numTasks) { - if (!launched(i) && (!checkPref || - tasks(i).preferredLocations.contains(offer.getHost) || - tasks(i).preferredLocations.isEmpty)) - { - val taskId = sched.newTaskId() - tidToIndex(taskId) = i - printf("Starting task %d as TID %d on slave %d: %s (%s)\n", - i, taskId, offer.getSlaveId, offer.getHost, - if(checkPref) "preferred" else "non-preferred") - tasks(i).markStarted(offer) - launched(i) = true - tasksLaunched += 1 - if (checkPref) - lastPreferredLaunchTime = time - val params = new StringMap - params.set("cpus", "" + desiredCpus) - params.set("mem", "" + desiredMem) - val serializedTask = Utils.serialize(tasks(i)) - return Some(new TaskDescription(taskId, offer.getSlaveId, - "task_" + taskId, params, serializedTask)) - } - } - } - return None - } - - def statusUpdate(status: TaskStatus) { - status.getState match { - case TaskState.TASK_FINISHED => - taskFinished(status) - case TaskState.TASK_LOST => - taskLost(status) - case TaskState.TASK_FAILED => - taskLost(status) - case TaskState.TASK_KILLED => - taskLost(status) - case _ => - } - } - - def taskFinished(status: TaskStatus) { - val tid = status.getTaskId - println("Finished TID " + tid) - // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(tidToIndex(tid)) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) - // Mark finished and stop if we've finished all the tasks - finished(tidToIndex(tid)) = true - tasksFinished += 1 - if (tasksFinished == numTasks) - setAllFinished() - } - - def taskLost(status: TaskStatus) { - val tid = status.getTaskId - println("Lost TID " + tid) - launched(tidToIndex(tid)) = false - tasksLaunched -= 1 - } - - def error(code: Int, message: String) { - // Save the error message - errorHappened = true - errorCode = code - errorMessage = message - // Indicate to caller thread that we're done - setAllFinished() - } -} diff --git a/src/scala/spark/Task.scala b/src/scala/spark/Task.scala index efb864472dea6323da97f03d54d52945e152fda8..e559996a379e73b7a37c599207b5eb76964fd675 100644 --- a/src/scala/spark/Task.scala +++ b/src/scala/spark/Task.scala @@ -5,8 +5,8 @@ import nexus._ @serializable trait Task[T] { def run: T - def preferredLocations: Seq[String] = Nil - def markStarted(offer: SlaveOffer) {} + def prefers(slot: SlaveOffer): Boolean = true + def markStarted(slot: SlaveOffer) {} } @serializable diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala index 43ef296efeb839996c4174cd747e162180d8b106..d71fe20a94357ab02ffb8f5a65a3052d44362d66 100644 --- a/src/test/spark/repl/ReplSuite.scala +++ b/src/test/spark/repl/ReplSuite.scala @@ -85,15 +85,15 @@ class ReplSuite extends FunSuite { assertContains("res2: Int = 100", output) } - test ("broadcast vars") { - // Test that the value that a broadcast var had when it was created is used, - // even if that broadcast var is then modified in the driver program + test ("cached vars") { + // Test that the value that a cached var had when it was created is used, + // even if that cached var is then modified in the driver program val output = runInterpreter("local", """ var array = new Array[Int](5) - val broadcastedArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray + val cachedArray = sc.cache(array) + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -109,10 +109,10 @@ class ReplSuite extends FunSuite { v = 10 sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) var array = new Array[Int](5) - val broadcastedArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray + val cachedArray = sc.cache(array) + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastedArray.value(x)).toArray + sc.parallelize(0 to 4).map(x => cachedArray.value(x)).toArray """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output)