Skip to content
Snippets Groups Projects
Commit 52086cef authored by Mosharaf Chowdhury's avatar Mosharaf Chowdhury
Browse files

Building blocks are in place. Still not pulling parallely though.

parent 540a4116
No related branches found
No related tags found
No related merge requests found
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.net._
import java.util.{Timer, TimerTask, UUID}
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
import scala.collection.mutable.{ArrayBuffer, HashMap}
/**
* A simple implementation of shuffle using local files served through HTTP.
*
......@@ -46,6 +46,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
case None => createCombiner(v)
}
}
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
val writeStartTime = System.currentTimeMillis
......@@ -57,30 +58,12 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
val writeTime = (System.currentTimeMillis - writeStartTime)
logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.")
}
(myIndex, LocalFileShuffle.serverUri)
(myIndex, LocalFileShuffle.serverAddress, LocalFileShuffle.serverPort)
}).collect()
// Load config option to decide whether or not to use HTTP pipelining
val UseHttpPipelining =
System.getProperty("spark.shuffle.UseHttpPipelining", "true").toBoolean
// Build a traversable list of pairs of server URI and split. Needs to be
// of type TraversableOnce[(String, ArrayBuffer[Int])]
val splitsByUri = if (UseHttpPipelining) {
// Build a hashmap from server URI to list of splits (to facillitate
// fetching all the URIs on a server within a single connection)
val splitsByUriHM = new HashMap[String, ArrayBuffer[Int]]
for ((inputId, serverUri) <- outputLocs) {
splitsByUriHM.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId
}
splitsByUriHM
} else {
// Don't use HTTP pipelining
val splitsByUriAB = new ArrayBuffer[(String, ArrayBuffer[Int])]
for ((inputId, serverUri) <- outputLocs) {
splitsByUriAB += ((serverUri, new ArrayBuffer[Int] += inputId))
}
splitsByUriAB
val splitsByUri = new ArrayBuffer[(String, Int, Int)]
for ((inputId, serverAddress, serverPort) <- outputLocs) {
splitsByUri += ((serverAddress, serverPort, inputId))
}
// TODO: Could broadcast splitsByUri
......@@ -89,44 +72,49 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
return indexes.flatMap((myId: Int) => {
val combiners = new HashMap[K, C]
for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
for (i <- inputIds) {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
val readStartTime = System.currentTimeMillis
logInfo ("BEGIN READ: " + url)
// TODO: Insert data transfer code before this place
val inputStream = new ObjectInputStream(new URL(url).openStream())
try {
while (true) {
val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
combiners(k) = combiners.get(k) match {
case Some(oldC) => mergeCombiners(oldC, c)
case None => c
}
for ((serverAddress, serverPort, inputId) <- splitsByUri) {
val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
val shuffleClient = new ShuffleClient(serverAddress, serverPort, requestPath)
val readStartTime = System.currentTimeMillis
logInfo ("BEGIN READ: " + requestPath)
shuffleClient.start
shuffleClient.join
val inputStream = new ObjectInputStream (
new ByteArrayInputStream(shuffleClient.byteArray))
try {
while (true) {
val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
combiners(k) = combiners.get(k) match {
case Some(oldC) => mergeCombiners(oldC, c)
case None => c
}
} catch {
case e: EOFException => {}
}
inputStream.close()
logInfo ("END READ: " + url)
val readTime = (System.currentTimeMillis - readStartTime)
logInfo ("Reading " + url + " took " + readTime + " millis.")
} catch {
case e: EOFException => {}
}
inputStream.close
logInfo ("END READ: " + requestPath)
val readTime = (System.currentTimeMillis - readStartTime)
logInfo ("Reading " + requestPath + " took " + readTime + " millis.")
}
combiners
})
}
}
object LocalFileShuffle extends Logging {
private var initialized = false
private var nextShuffleId = new AtomicLong(0)
// Variables initialized by initializeIfNeeded()
private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
private var shuffleServer: ShuffleServer = null
private var serverAddress = InetAddress.getLocalHost.getHostAddress
private var serverPort: Int = -1
private def initializeIfNeeded() = synchronized {
if (!initialized) {
......@@ -137,11 +125,12 @@ object LocalFileShuffle extends Logging {
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID()
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists()) {
localDir.mkdirs()
foundLocalDir = true
......@@ -158,25 +147,14 @@ object LocalFileShuffle extends Logging {
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
// We're using an external HTTP server; set URI relative to its root
var extServerPath = System.getProperty(
"spark.localFileShuffle.external.server.path", "")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
}
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
}
// Create and start the shuffleServer
shuffleServer = new ShuffleServer
shuffleServer.setDaemon (true)
shuffleServer.start
logInfo ("ShuffleServer started...")
initialized = true
logInfo ("Local URI: " + serverUri)
}
}
......@@ -188,12 +166,233 @@ object LocalFileShuffle extends Logging {
return file
}
def getServerUri(): String = {
initializeIfNeeded()
serverUri
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
// Returns a standard ThreadFactory except all threads are daemons
private def newDaemonThreadFactory: ThreadFactory = {
new ThreadFactory {
def newThread(r: Runnable): Thread = {
var t = Executors.defaultThreadFactory.newThread (r)
t.setDaemon (true)
return t
}
}
}
// Wrapper over newFixedThreadPool
def newDaemonFixedThreadPool (nThreads: Int): ThreadPoolExecutor = {
var threadPool =
Executors.newFixedThreadPool (nThreads).asInstanceOf[ThreadPoolExecutor]
threadPool.setThreadFactory (newDaemonThreadFactory)
return threadPool
}
class ShuffleServer
extends Thread with Logging {
// TODO: Set config param
var threadPool = newDaemonFixedThreadPool(2)
var serverSocket: ServerSocket = null
override def run: Unit = {
serverSocket = new ServerSocket (0)
serverPort = serverSocket.getLocalPort
logInfo ("ShuffleServer started with " + serverSocket)
logInfo ("Local URI: " + serverAddress + ":" + serverPort)
try {
while (true) {
var clientSocket: Socket = null
try {
clientSocket = serverSocket.accept
} catch {
case e: Exception => { }
}
if (clientSocket != null) {
logInfo ("Serve: Accepted new client connection:" + clientSocket)
try {
threadPool.execute (new ShuffleServerThread (clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => {
clientSocket.close
}
}
}
}
} finally {
if (serverSocket != null) {
logInfo ("ShuffleServer now stopping...")
serverSocket.close
}
}
// Shutdown the thread pool
threadPool.shutdown
}
class ShuffleServerThread (val clientSocket: Socket)
extends Thread with Logging {
private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
os.flush
private val oos = new ObjectOutputStream (os)
oos.flush
private val ois = new ObjectInputStream (clientSocket.getInputStream)
logInfo ("new ShuffleServerThread is running")
override def run: Unit = {
try {
// Receive requestPath from the receiver
var requestPath = ois.readObject.asInstanceOf[String]
logInfo("requestPath: " + shuffleDir + "/" + requestPath)
// Open the file
var requestedFile: File = null
var requestedFileLen = -1
try {
requestedFile = new File(shuffleDir + "/" + requestPath)
requestedFileLen = requestedFile.length.toInt
} catch {
case e: Exception => { }
}
// Send the lendth of the requestPath to let the receiver know that
// transfer is about to start
// In the case of receiver timeout and connection close, this will
// throw a java.net.SocketException: Broken pipe
oos.writeObject(requestedFileLen)
oos.flush
logInfo ("requestedFileLen = " + requestedFileLen)
// Read and send the requested file
if (requestedFileLen != -1) {
// Read
var byteArray = new Array[Byte](requestedFileLen)
val bis =
new BufferedInputStream (new FileInputStream (requestedFile))
var bytesRead = bis.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = bis.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
bis.close
// Send
os.write (byteArray, 0, byteArray.length)
os.flush
} else {
// Close the connection
}
} catch {
// If something went wrong, e.g., the worker at the other end died etc.
// then close everything up
// Exception can happen if the receiver stops receiving
case e: Exception => {
logInfo ("ShuffleServerThread had a " + e)
}
} finally {
logInfo ("ShuffleServerThread is closing streams and sockets")
ois.close
// TODO: Following can cause "java.net.SocketException: Socket closed"
oos.close
clientSocket.close
}
}
}
}
}
class ShuffleClient (hostAddress: String, listenPort: Int, requestPath: String)
extends Thread with Logging {
private var peerSocketToSource: Socket = null
private var oosSource: ObjectOutputStream = null
private var oisSource: ObjectInputStream = null
var byteArray: Array[Byte] = null
override def run: Unit = {
// Setup the timeout mechanism
var timeOutTask = new TimerTask {
override def run: Unit = {
cleanUpConnections
}
}
var timeOutTimer = new Timer
// TODO: Set wait timer
timeOutTimer.schedule (timeOutTask, 1000)
logInfo ("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestPath))
try {
// Connect to the source
peerSocketToSource = new Socket (hostAddress, listenPort)
oosSource =
new ObjectOutputStream (peerSocketToSource.getOutputStream)
oosSource.flush
var isSource = peerSocketToSource.getInputStream
oisSource = new ObjectInputStream (isSource)
// Send the request
oosSource.writeObject(requestPath)
// Receive the length of the requested file
var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
logInfo ("Received requestedFileLen = " + requestedFileLen)
// Turn the timer OFF, if the sender responds before timeout
timeOutTimer.cancel
// Receive the file
if (requestedFileLen != -1) {
byteArray = new Array[Byte] (requestedFileLen)
var bytesRead = isSource.read (byteArray, 0, byteArray.length)
var alreadyRead = bytesRead
while (alreadyRead < requestedFileLen) {
bytesRead = isSource.read(byteArray, alreadyRead,
(byteArray.length - alreadyRead))
if(bytesRead > 0) {
alreadyRead = alreadyRead + bytesRead
}
}
} else {
throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestPath)
}
} catch {
// EOFException is expected to happen because sender can break
// connection due to timeout
case eofe: java.io.EOFException => { }
case e: Exception => {
logInfo ("ShuffleClient had a " + e)
}
} finally {
cleanUpConnections
}
}
private def cleanUpConnections: Unit = {
if (oisSource != null) {
oisSource.close
}
if (oosSource != null) {
oosSource.close
}
if (peerSocketToSource != null) {
peerSocketToSource.close
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment