Skip to content
Snippets Groups Projects
Commit 52086cef authored by Mosharaf Chowdhury's avatar Mosharaf Chowdhury
Browse files

Building blocks are in place. Still not pulling parallely though.

parent 540a4116
No related branches found
No related tags found
No related merge requests found
package spark package spark
import java.io._ import java.io._
import java.net.URL import java.net._
import java.util.UUID import java.util.{Timer, TimerTask, UUID}
import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.mutable.{ArrayBuffer, HashMap}
/** /**
* A simple implementation of shuffle using local files served through HTTP. * A simple implementation of shuffle using local files served through HTTP.
* *
...@@ -46,6 +46,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { ...@@ -46,6 +46,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
case None => createCombiner(v) case None => createCombiner(v)
} }
} }
for (i <- 0 until numOutputSplits) { for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
val writeStartTime = System.currentTimeMillis val writeStartTime = System.currentTimeMillis
...@@ -57,30 +58,12 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { ...@@ -57,30 +58,12 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
val writeTime = (System.currentTimeMillis - writeStartTime) val writeTime = (System.currentTimeMillis - writeStartTime)
logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
} }
(myIndex, LocalFileShuffle.serverUri) (myIndex, LocalFileShuffle.serverAddress, LocalFileShuffle.serverPort)
}).collect() }).collect()
// Load config option to decide whether or not to use HTTP pipelining val splitsByUri = new ArrayBuffer[(String, Int, Int)]
val UseHttpPipelining = for ((inputId, serverAddress, serverPort) <- outputLocs) {
System.getProperty("spark.shuffle.UseHttpPipelining", "true").toBoolean splitsByUri += ((serverAddress, serverPort, inputId))
// Build a traversable list of pairs of server URI and split. Needs to be
// of type TraversableOnce[(String, ArrayBuffer[Int])]
val splitsByUri = if (UseHttpPipelining) {
// Build a hashmap from server URI to list of splits (to facillitate
// fetching all the URIs on a server within a single connection)
val splitsByUriHM = new HashMap[String, ArrayBuffer[Int]]
for ((inputId, serverUri) <- outputLocs) {
splitsByUriHM.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId
}
splitsByUriHM
} else {
// Don't use HTTP pipelining
val splitsByUriAB = new ArrayBuffer[(String, ArrayBuffer[Int])]
for ((inputId, serverUri) <- outputLocs) {
splitsByUriAB += ((serverUri, new ArrayBuffer[Int] += inputId))
}
splitsByUriAB
} }
// TODO: Could broadcast splitsByUri // TODO: Could broadcast splitsByUri
...@@ -89,44 +72,49 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { ...@@ -89,44 +72,49 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
return indexes.flatMap((myId: Int) => { return indexes.flatMap((myId: Int) => {
val combiners = new HashMap[K, C] val combiners = new HashMap[K, C]
for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { for ((serverAddress, serverPort, inputId) <- splitsByUri) {
for (i <- inputIds) { val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
val readStartTime = System.currentTimeMillis val shuffleClient = new ShuffleClient(serverAddress, serverPort, requestPath)
logInfo ("BEGIN READ: " + url) val readStartTime = System.currentTimeMillis
// TODO: Insert data transfer code before this place logInfo ("BEGIN READ: " + requestPath)
val inputStream = new ObjectInputStream(new URL(url).openStream()) shuffleClient.start
try { shuffleClient.join
while (true) {
val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] val inputStream = new ObjectInputStream (
combiners(k) = combiners.get(k) match { new ByteArrayInputStream(shuffleClient.byteArray))
case Some(oldC) => mergeCombiners(oldC, c) try {
case None => c while (true) {
} val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
combiners(k) = combiners.get(k) match {
case Some(oldC) => mergeCombiners(oldC, c)
case None => c
} }
} catch {
case e: EOFException => {}
} }
inputStream.close() } catch {
logInfo ("END READ: " + url) case e: EOFException => {}
val readTime = (System.currentTimeMillis - readStartTime)
logInfo ("Reading " + url + " took " + readTime + " millis.")
} }
inputStream.close
logInfo ("END READ: " + requestPath)
val readTime = (System.currentTimeMillis - readStartTime)
logInfo ("Reading " + requestPath + " took " + readTime + " millis.")
} }
combiners combiners
}) })
} }
} }
object LocalFileShuffle extends Logging { object LocalFileShuffle extends Logging {
private var initialized = false private var initialized = false
private var nextShuffleId = new AtomicLong(0) private var nextShuffleId = new AtomicLong(0)
// Variables initialized by initializeIfNeeded() // Variables initialized by initializeIfNeeded()
private var shuffleDir: File = null private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null private var shuffleServer: ShuffleServer = null
private var serverAddress = InetAddress.getLocalHost.getHostAddress
private var serverPort: Int = -1
private def initializeIfNeeded() = synchronized { private def initializeIfNeeded() = synchronized {
if (!initialized) { if (!initialized) {
...@@ -137,11 +125,12 @@ object LocalFileShuffle extends Logging { ...@@ -137,11 +125,12 @@ object LocalFileShuffle extends Logging {
var foundLocalDir = false var foundLocalDir = false
var localDir: File = null var localDir: File = null
var localDirUuid: UUID = null var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) { while (!foundLocalDir && tries < 10) {
tries += 1 tries += 1
try { try {
localDirUuid = UUID.randomUUID() localDirUuid = UUID.randomUUID()
localDir = new File(localDirRoot, "spark-local-" + localDirUuid) localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists()) { if (!localDir.exists()) {
localDir.mkdirs() localDir.mkdirs()
foundLocalDir = true foundLocalDir = true
...@@ -158,25 +147,14 @@ object LocalFileShuffle extends Logging { ...@@ -158,25 +147,14 @@ object LocalFileShuffle extends Logging {
shuffleDir = new File(localDir, "shuffle") shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs() shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir) logInfo("Shuffle dir: " + shuffleDir)
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt // Create and start the shuffleServer
if (extServerPort != -1) { shuffleServer = new ShuffleServer
// We're using an external HTTP server; set URI relative to its root shuffleServer.setDaemon (true)
var extServerPath = System.getProperty( shuffleServer.start
"spark.localFileShuffle.external.server.path", "") logInfo ("ShuffleServer started...")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
}
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
}
initialized = true initialized = true
logInfo ("Local URI: " + serverUri)
} }
} }
...@@ -188,12 +166,233 @@ object LocalFileShuffle extends Logging { ...@@ -188,12 +166,233 @@ object LocalFileShuffle extends Logging {
return file return file
} }
def getServerUri(): String = {
initializeIfNeeded()
serverUri
}
def newShuffleId(): Long = { def newShuffleId(): Long = {
nextShuffleId.getAndIncrement() nextShuffleId.getAndIncrement()
} }
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
class ShuffleServer
extends Thread with Logging {
// TODO: Set config param
var threadPool = newDaemonFixedThreadPool(2)
var serverSocket: ServerSocket = null
override def run: Unit = {
serverSocket = new ServerSocket (0)
serverPort = serverSocket.getLocalPort
logInfo ("ShuffleServer started with " + serverSocket)
logInfo ("Local URI: " + serverAddress + ":" + serverPort)
try {
while (true) {
var clientSocket: Socket = null
try {
clientSocket = serverSocket.accept
} catch {
case e: Exception => { }
}
if (clientSocket != null) {
logInfo ("Serve: Accepted new client connection:" + clientSocket)
try {
threadPool.execute (new ShuffleServerThread (clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => {
clientSocket.close
}
}
}
}
} finally {
if (serverSocket != null) {
logInfo ("ShuffleServer now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ShuffleServerThread (val clientSocket: Socket)
extends Thread with Logging {
private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
os.flush
private val oos = new ObjectOutputStream (os)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
logInfo ("new ShuffleServerThread is running")
override def run: Unit = {
try {
// Receive requestPath from the receiver
var requestPath = ois.readObject.asInstanceOf[String]
logInfo("requestPath: " + shuffleDir + "/" + requestPath)
// Open the file
var requestedFile: File = null
var requestedFileLen = -1
try {
requestedFile = new File(shuffleDir + "/" + requestPath)
requestedFileLen = requestedFile.length.toInt
} catch {
case e: Exception => { }
}
// Send the lendth of the requestPath to let the receiver know that
// transfer is about to start
// In the case of receiver timeout and connection close, this will
// throw a java.net.SocketException: Broken pipe
oos.writeObject(requestedFileLen)
oos.flush
logInfo ("requestedFileLen = " + requestedFileLen)
// Read and send the requested file
if (requestedFileLen != -1) {
// Read
var byteArray = new Array[Byte](requestedFileLen)
val bis =
new BufferedInputStream (new FileInputStream (requestedFile))
var bytesRead = bis.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = bis.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
bis.close
// Send
os.write (byteArray, 0, byteArray.length)
os.flush
} else {
// Close the connection
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
// Exception can happen if the receiver stops receiving
case e: Exception => {
logInfo ("ShuffleServerThread had a " + e)
}
} finally {
logInfo ("ShuffleServerThread is closing streams and sockets")
ois.close
// TODO: Following can cause "java.net.SocketException: Socket closed"
oos.close
clientSocket.close
}
}
}
}
}
class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String)
extends Thread with Logging {
private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null
private var oisSource: ObjectInputStream = null
var byteArray: Array[Byte] = null
override def run: Unit = {
// Setup the timeout mechanism
var timeOutTask = new TimerTask {
override def run: Unit = {
cleanUpConnections
}
}
var timeOutTimer = new Timer
// TODO: Set wait timer
timeOutTimer.schedule (timeOutTask, 1000)
logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath))
try {
// Connect to the source
peerSocketToSource = new Socket (hostAddress, listenPort)
oosSource =
new ObjectOutputStream (peerSocketToSource.getOutputStream)
oosSource.flush
var isSource = peerSocketToSource.getInputStream
oisSource = new ObjectInputStream (isSource)
// Send the request
oosSource.writeObject(requestPath)
// Receive the length of the requested file
var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
logInfo ("Received requestedFileLen = " + requestedFileLen)
// Turn the timer OFF, if the sender responds before timeout
timeOutTimer.cancel
// Receive the file
if (requestedFileLen != -1) {
byteArray = new Array[Byte] (requestedFileLen)
var bytesRead = isSource.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = isSource.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
} else {
throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
}
} catch {
// EOFException is expected to happen because sender can break
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
logInfo ("ShuffleClient had a " + e)
}
} finally {
cleanUpConnections
}
}
private def cleanUpConnections: Unit = {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (peerSocketToSource != null) {
peerSocketToSource.close
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment