diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala index 03b6931f0a0330dcf8e35cecbba6776dafb2839c..eb9bd40698351c843ab31bf50e863bf759831b4a 100644 --- a/src/scala/spark/LocalFileShuffle.scala +++ b/src/scala/spark/LocalFileShuffle.scala @@ -1,13 +1,13 @@ package spark import java.io._ -import java.net.URL -import java.util.UUID +import java.net._ +import java.util.{Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory} import scala.collection.mutable.{ArrayBuffer, HashMap} - /** * A simple implementation of shuffle using local files served through HTTP. * @@ -46,6 +46,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { case None => createCombiner(v) } } + for (i <- 0 until numOutputSplits) { val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) val writeStartTime = System.currentTimeMillis @@ -57,30 +58,12 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val writeTime = (System.currentTimeMillis - writeStartTime) logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") } - (myIndex, LocalFileShuffle.serverUri) + (myIndex, LocalFileShuffle.serverAddress, LocalFileShuffle.serverPort) }).collect() - // Load config option to decide whether or not to use HTTP pipelining - val UseHttpPipelining = - System.getProperty("spark.shuffle.UseHttpPipelining", "true").toBoolean - - // Build a traversable list of pairs of server URI and split. Needs to be - // of type TraversableOnce[(String, ArrayBuffer[Int])] - val splitsByUri = if (UseHttpPipelining) { - // Build a hashmap from server URI to list of splits (to facillitate - // fetching all the URIs on a server within a single connection) - val splitsByUriHM = new HashMap[String, ArrayBuffer[Int]] - for ((inputId, serverUri) <- outputLocs) { - splitsByUriHM.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId - } - splitsByUriHM - } else { - // Don't use HTTP pipelining - val splitsByUriAB = new ArrayBuffer[(String, ArrayBuffer[Int])] - for ((inputId, serverUri) <- outputLocs) { - splitsByUriAB += ((serverUri, new ArrayBuffer[Int] += inputId)) - } - splitsByUriAB + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) } // TODO: Could broadcast splitsByUri @@ -89,44 +72,49 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) return indexes.flatMap((myId: Int) => { val combiners = new HashMap[K, C] - for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { - for (i <- inputIds) { - val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId) - val readStartTime = System.currentTimeMillis - logInfo ("BEGIN READ: " + url) - // TODO: Insert data transfer code before this place - val inputStream = new ObjectInputStream(new URL(url).openStream()) - try { - while (true) { - val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] - combiners(k) = combiners.get(k) match { - case Some(oldC) => mergeCombiners(oldC, c) - case None => c - } + for ((serverAddress, serverPort, inputId) <- splitsByUri) { + val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId) + + val shuffleClient = new ShuffleClient(serverAddress, serverPort, requestPath) + val readStartTime = System.currentTimeMillis + logInfo ("BEGIN READ: " + requestPath) + shuffleClient.start + shuffleClient.join + + val inputStream = new ObjectInputStream ( + new ByteArrayInputStream(shuffleClient.byteArray)) + try { + while (true) { + val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c } - } catch { - case e: EOFException => {} } - inputStream.close() - logInfo ("END READ: " + url) - val readTime = (System.currentTimeMillis - readStartTime) - logInfo ("Reading " + url + " took " + readTime + " millis.") + } catch { + case e: EOFException => {} } + inputStream.close + + logInfo ("END READ: " + requestPath) + val readTime = (System.currentTimeMillis - readStartTime) + logInfo ("Reading " + requestPath + " took " + readTime + " millis.") } combiners }) } } - object LocalFileShuffle extends Logging { private var initialized = false private var nextShuffleId = new AtomicLong(0) // Variables initialized by initializeIfNeeded() private var shuffleDir: File = null - private var server: HttpServer = null - private var serverUri: String = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 private def initializeIfNeeded() = synchronized { if (!initialized) { @@ -137,11 +125,12 @@ object LocalFileShuffle extends Logging { var foundLocalDir = false var localDir: File = null var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { tries += 1 try { localDirUuid = UUID.randomUUID() - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) if (!localDir.exists()) { localDir.mkdirs() foundLocalDir = true @@ -158,25 +147,14 @@ object LocalFileShuffle extends Logging { shuffleDir = new File(localDir, "shuffle") shuffleDir.mkdirs() logInfo("Shuffle dir: " + shuffleDir) - val extServerPort = System.getProperty( - "spark.localFileShuffle.external.server.port", "-1").toInt - if (extServerPort != -1) { - // We're using an external HTTP server; set URI relative to its root - var extServerPath = System.getProperty( - "spark.localFileShuffle.external.server.path", "") - if (extServerPath != "" && !extServerPath.endsWith("/")) { - extServerPath += "/" - } - serverUri = "http://%s:%d/%s/spark-local-%s".format( - Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) - } else { - // Create our own server - server = new HttpServer(localDir) - server.start() - serverUri = server.uri - } + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon (true) + shuffleServer.start + logInfo ("ShuffleServer started...") + initialized = true - logInfo ("Local URI: " + serverUri) } } @@ -188,12 +166,233 @@ object LocalFileShuffle extends Logging { return file } - def getServerUri(): String = { - initializeIfNeeded() - serverUri - } - def newShuffleId(): Long = { nextShuffleId.getAndIncrement() } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread (r) + t.setDaemon (true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory (newDaemonThreadFactory) + + return threadPool + } + + class ShuffleServer + extends Thread with Logging { + // TODO: Set config param + var threadPool = newDaemonFixedThreadPool(2) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket (0) + serverPort = serverSocket.getLocalPort + + logInfo ("ShuffleServer started with " + serverSocket) + logInfo ("Local URI: " + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo ("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute (new ShuffleServerThread (clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo ("ShuffleServer now stopping...") + serverSocket.close + } + } + // Shutdown the thread pool + threadPool.shutdown + } + + class ShuffleServerThread (val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush + private val oos = new ObjectOutputStream (os) + oos.flush + private val ois = new ObjectInputStream (clientSocket.getInputStream) + + logInfo ("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive requestPath from the receiver + var requestPath = ois.readObject.asInstanceOf[String] + logInfo("requestPath: " + shuffleDir + "/" + requestPath) + + // Open the file + var requestedFile: File = null + var requestedFileLen = -1 + try { + requestedFile = new File(shuffleDir + "/" + requestPath) + requestedFileLen = requestedFile.length.toInt + } catch { + case e: Exception => { } + } + + // Send the lendth of the requestPath to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + oos.writeObject(requestedFileLen) + oos.flush + + logInfo ("requestedFileLen = " + requestedFileLen) + + // Read and send the requested file + if (requestedFileLen != -1) { + // Read + var byteArray = new Array[Byte](requestedFileLen) + val bis = + new BufferedInputStream (new FileInputStream (requestedFile)) + + var bytesRead = bis.read (byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = bis.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + bis.close + + // Send + os.write (byteArray, 0, byteArray.length) + os.flush + } else { + // Close the connection + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc. + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo ("ShuffleServerThread had a " + e) + } + } finally { + logInfo ("ShuffleServerThread is closing streams and sockets") + ois.close + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close + clientSocket.close + } + } + } + } +} + +class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String) +extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + var byteArray: Array[Byte] = null + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections + } + } + + var timeOutTimer = new Timer + // TODO: Set wait timer + timeOutTimer.schedule (timeOutTask, 1000) + + logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath)) + + try { + // Connect to the source + peerSocketToSource = new Socket (hostAddress, listenPort) + oosSource = + new ObjectOutputStream (peerSocketToSource.getOutputStream) + oosSource.flush + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream (isSource) + + // Send the request + oosSource.writeObject(requestPath) + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo ("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel + + // Receive the file + if (requestedFileLen != -1) { + byteArray = new Array[Byte] (requestedFileLen) + var bytesRead = isSource.read (byteArray, 0, byteArray.length) + var alreadyRead = bytesRead + + while (alreadyRead < requestedFileLen) { + bytesRead = isSource.read(byteArray, alreadyRead, + (byteArray.length - alreadyRead)) + if(bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo ("ShuffleClient had a " + e) + } + } finally { + cleanUpConnections + } + } + + private def cleanUpConnections: Unit = { + if (oisSource != null) { + oisSource.close + } + if (oosSource != null) { + oosSource.close + } + if (peerSocketToSource != null) { + peerSocketToSource.close + } + } }