diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala index d8d35bfeecc0e2e851aea12f1441f6ad64c34135..a91f5a886d732234b2640b6305e0ecd54baec1c8 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala @@ -1,23 +1,21 @@ package spark.network.netty +import java.util.concurrent.Executors + import io.netty.buffer.ByteBuf import io.netty.channel.ChannelHandlerContext -import io.netty.channel.ChannelInboundByteHandlerAdapter import io.netty.util.CharsetUtil -import java.util.concurrent.atomic.AtomicInteger -import java.util.logging.Logger import spark.Logging import spark.network.ConnectionManagerId -import java.util.concurrent.Executors + private[spark] class ShuffleCopier extends Logging { - def getBlock(cmId: ConnectionManagerId, - blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + def getBlock(cmId: ConnectionManagerId, blockId: String, + resultCollectCallback: (String, Long, ByteBuf) => Unit) { - val handler = new ShuffleClientHandler(resultCollectCallback) + val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val fc = new FileClient(handler) fc.init() fc.connect(cmId.host, cmId.port) @@ -28,29 +26,28 @@ private[spark] class ShuffleCopier extends Logging { def getBlocks(cmId: ConnectionManagerId, blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) = { + resultCollectCallback: (String, Long, ByteBuf) => Unit) { - blocks.map { - case(blockId,size) => { - getBlock(cmId,blockId,resultCollectCallback) - } + for ((blockId, size) <- blocks) { + getBlock(cmId, blockId, resultCollectCallback) } } } -private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging { - - def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } -} private[spark] object ShuffleCopier extends Logging { - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = { - logInfo("File: " + blockId + " content is : \" " - + content.toString(CharsetUtil.UTF_8) + "\"") + private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + extends FileClientHandler with Logging { + + override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) + } + } + + def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") } def runGetBlock(host:String, port:Int, file:String){ @@ -71,18 +68,17 @@ private[spark] object ShuffleCopier extends Logging { val host = args(0) val port = args(1).toInt val file = args(2) - val threads = if (args.length>3) args(3).toInt else 10 + val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - for (i <- Range(0,threads)){ + for (i <- Range(0, threads)) { val runnable = new Runnable() { def run() { - runGetBlock(host,port,file) + runGetBlock(host, port, file) } } copiers.execute(runnable) } copiers.shutdown } - } diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala index c1986812e93e8fd8dc17253593038a1abf10e128..dc87fefc567949b4bc05e4de2c6002ce61752070 100644 --- a/core/src/main/scala/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala @@ -1,12 +1,13 @@ package spark.network.netty -import spark.Logging import java.io.File +import spark.Logging + -private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging { +private[spark] class ShuffleSender(val port: Int, val pResolver: PathResolver) extends Logging { val server = new FileServer(pResolver) - + Runtime.getRuntime().addShutdownHook( new Thread() { override def run() { @@ -20,17 +21,22 @@ private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) ex } } + private[spark] object ShuffleSender { + def main(args: Array[String]) { if (args.length < 3) { - System.err.println("Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>") + System.err.println( + "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>") System.exit(1) } + val port = args(0).toInt val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2) map {new File(_)} + val localDirs = args.drop(2).map(new File(_)) + val pResovler = new PathResolver { - def getAbsolutePath(blockId:String):String = { + override def getAbsolutePath(blockId: String): String = { if (!blockId.startsWith("shuffle_")) { throw new Exception("Block " + blockId + " is not a shuffle block") }