From df47b40b764e25cbd10ce49d7152e1d33f51a263 Mon Sep 17 00:00:00 2001
From: shane-huang <shengsheng.huang@intel.com>
Date: Wed, 20 Feb 2013 11:51:13 +0800
Subject: [PATCH] Shuffle Performance fix: Use netty embeded OIO file server
 instead of ConnectionManager Shuffle Performance Optimization: do not send
 0-byte block requests to reduce network messages change reference from
 io.Source to scala.io.Source to avoid looking into io.netty package

Signed-off-by: shane-huang <shengsheng.huang@intel.com>
---
 .../java/spark/network/netty/FileClient.java  |  89 ++++++
 .../netty/FileClientChannelInitializer.java   |  29 ++
 .../network/netty/FileClientHandler.java      |  38 +++
 .../java/spark/network/netty/FileServer.java  |  59 ++++
 .../netty/FileServerChannelInitializer.java   |  33 +++
 .../network/netty/FileServerHandler.java      |  68 +++++
 .../spark/network/netty/PathResolver.java     |  12 +
 .../spark/network/netty/FileHeader.scala      |  57 ++++
 .../spark/network/netty/ShuffleCopier.scala   |  88 ++++++
 .../spark/network/netty/ShuffleSender.scala   |  50 ++++
 .../scala/spark/storage/BlockManager.scala    | 272 ++++++++++++++----
 .../main/scala/spark/storage/DiskStore.scala  |  51 +++-
 project/SparkBuild.scala                      |   3 +-
 .../spark/streaming/util/RawTextSender.scala  |   2 +-
 14 files changed, 795 insertions(+), 56 deletions(-)
 create mode 100644 core/src/main/java/spark/network/netty/FileClient.java
 create mode 100644 core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
 create mode 100644 core/src/main/java/spark/network/netty/FileClientHandler.java
 create mode 100644 core/src/main/java/spark/network/netty/FileServer.java
 create mode 100644 core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
 create mode 100644 core/src/main/java/spark/network/netty/FileServerHandler.java
 create mode 100755 core/src/main/java/spark/network/netty/PathResolver.java
 create mode 100644 core/src/main/scala/spark/network/netty/FileHeader.scala
 create mode 100644 core/src/main/scala/spark/network/netty/ShuffleCopier.scala
 create mode 100644 core/src/main/scala/spark/network/netty/ShuffleSender.scala

diff --git a/core/src/main/java/spark/network/netty/FileClient.java b/core/src/main/java/spark/network/netty/FileClient.java
new file mode 100644
index 0000000000..d0c5081dd2
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClient.java
@@ -0,0 +1,89 @@
+package spark.network.netty;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.channel.AbstractChannel;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioSocketChannel;
+
+import java.util.Arrays;
+
+public class FileClient {
+
+  private FileClientHandler handler = null;
+  private Channel channel = null;
+  private Bootstrap bootstrap = null;
+
+  public FileClient(FileClientHandler handler){
+    this.handler = handler;
+  }
+ 
+  public void init(){
+     bootstrap = new Bootstrap();
+     bootstrap.group(new OioEventLoopGroup())
+      .channel(OioSocketChannel.class)
+      .option(ChannelOption.SO_KEEPALIVE, true)
+      .option(ChannelOption.TCP_NODELAY, true)
+      .handler(new FileClientChannelInitializer(handler));
+  } 
+
+  public static final class ChannelCloseListener implements ChannelFutureListener {
+    private FileClient fc = null;
+    public ChannelCloseListener(FileClient fc){
+      this.fc = fc;
+    }
+    @Override
+    public void operationComplete(ChannelFuture future) {
+      if (fc.bootstrap!=null){
+        fc.bootstrap.shutdown();
+        fc.bootstrap = null;
+      }
+    }
+  }
+
+  public void connect(String host, int port){
+    try {
+      
+      // Start the connection attempt.
+      channel = bootstrap.connect(host, port).sync().channel();
+      // ChannelFuture cf = channel.closeFuture();
+      //cf.addListener(new ChannelCloseListener(this));
+    } catch (InterruptedException e) {
+      close();
+    } 
+  }
+ 
+  public void waitForClose(){
+    try {
+      channel.closeFuture().sync();
+    } catch (InterruptedException e){
+      e.printStackTrace();
+    }
+  } 
+
+  public void sendRequest(String file){
+    //assert(file == null);
+    //assert(channel == null);
+    channel.write(file+"\r\n");
+  }
+
+  public void close(){
+    if(channel != null) {
+        channel.close();
+        channel = null;
+    }
+    if ( bootstrap!=null) {
+      bootstrap.shutdown();
+      bootstrap = null;
+    }
+  }
+  
+
+}
+
+
diff --git a/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
new file mode 100644
index 0000000000..50e5704619
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientChannelInitializer.java
@@ -0,0 +1,29 @@
+package spark.network.netty;
+
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import io.netty.handler.codec.string.StringEncoder;
+import io.netty.util.CharsetUtil;
+
+import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.logging.LogLevel;
+
+public class FileClientChannelInitializer extends
+    ChannelInitializer<SocketChannel> {
+
+  private FileClientHandler fhandler;
+
+  public FileClientChannelInitializer(FileClientHandler handler) {
+    fhandler = handler;
+  }
+
+  @Override
+  public void initChannel(SocketChannel channel) {
+    // file no more than 2G
+    channel.pipeline()
+        .addLast("encoder", new StringEncoder(BufType.BYTE))
+        .addLast("handler", fhandler);
+  }
+}
diff --git a/core/src/main/java/spark/network/netty/FileClientHandler.java b/core/src/main/java/spark/network/netty/FileClientHandler.java
new file mode 100644
index 0000000000..911c8b32b5
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileClientHandler.java
@@ -0,0 +1,38 @@
+package spark.network.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundByteHandlerAdapter;
+import io.netty.util.CharsetUtil;
+
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.logging.Logger;
+
+public abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
+
+  private FileHeader currentHeader = null;
+
+  public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
+
+  @Override
+  public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
+    // Use direct buffer if possible.
+    return ctx.alloc().ioBuffer();
+  }
+  
+  @Override
+  public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
+    // get header
+    if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
+      currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
+    }
+    // get file
+    if(in.readableBytes() >= currentHeader.fileLen()){
+      handle(ctx,in,currentHeader);
+      currentHeader = null;
+      ctx.close();
+    }
+  }
+
+}
+
diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java
new file mode 100644
index 0000000000..729e45f0a1
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServer.java
@@ -0,0 +1,59 @@
+package spark.network.netty;
+
+import java.io.File;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.Channel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.channel.oio.OioEventLoopGroup;
+import io.netty.channel.socket.oio.OioServerSocketChannel;
+import io.netty.handler.logging.LogLevel;
+import io.netty.handler.logging.LoggingHandler;
+
+/**
+ * Server that accept the path of a file an echo back its content.
+ */
+public class FileServer {
+
+    private ServerBootstrap bootstrap = null;
+    private Channel channel = null;
+    private PathResolver pResolver;
+
+    public FileServer(PathResolver pResolver){
+      this.pResolver = pResolver;
+    }
+
+    public void run(int port) {
+        // Configure the server.
+        bootstrap = new ServerBootstrap();
+        try {
+            bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
+             .channel(OioServerSocketChannel.class)
+             .option(ChannelOption.SO_BACKLOG, 100)
+             .option(ChannelOption.SO_RCVBUF, 1500)
+             .childHandler(new FileServerChannelInitializer(pResolver));
+            // Start the server.
+            channel = bootstrap.bind(port).sync().channel();
+            channel.closeFuture().sync();
+        } catch (InterruptedException e) {
+          // TODO Auto-generated catch block
+          e.printStackTrace();
+        } finally{
+          bootstrap.shutdown();
+        }
+    }
+    
+    public void stop(){
+      if (channel!=null){
+        channel.close();
+      }
+      if (bootstrap != null){
+        bootstrap.shutdown();
+        bootstrap = null;
+      }
+    }
+}
+
+
diff --git a/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
new file mode 100644
index 0000000000..9d0618ff1c
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerChannelInitializer.java
@@ -0,0 +1,33 @@
+package spark.network.netty;
+
+import java.io.File;
+import io.netty.buffer.BufType;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.string.StringDecoder;
+import io.netty.handler.codec.string.StringEncoder;
+import io.netty.handler.codec.DelimiterBasedFrameDecoder;
+import io.netty.handler.codec.Delimiters;
+import io.netty.util.CharsetUtil;
+import io.netty.handler.logging.LoggingHandler;
+import io.netty.handler.logging.LogLevel;
+
+public class FileServerChannelInitializer extends
+    ChannelInitializer<SocketChannel> {
+
+  PathResolver pResolver;  
+
+  public FileServerChannelInitializer(PathResolver pResolver) {
+    this.pResolver = pResolver;
+  }
+
+  @Override
+  public void initChannel(SocketChannel channel) {
+    channel.pipeline()
+        .addLast("framer", new DelimiterBasedFrameDecoder(
+                8192, Delimiters.lineDelimiter()))
+        .addLast("strDecoder", new StringDecoder())
+        .addLast("handler", new FileServerHandler(pResolver));
+        
+  }
+}
diff --git a/core/src/main/java/spark/network/netty/FileServerHandler.java b/core/src/main/java/spark/network/netty/FileServerHandler.java
new file mode 100644
index 0000000000..e1083e87a2
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/FileServerHandler.java
@@ -0,0 +1,68 @@
+package spark.network.netty;
+
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundMessageHandlerAdapter;
+import io.netty.channel.DefaultFileRegion;
+import io.netty.handler.stream.ChunkedFile;
+import java.io.File;
+import java.io.FileInputStream;
+
+public class FileServerHandler extends
+    ChannelInboundMessageHandlerAdapter<String> {
+
+  PathResolver pResolver; 
+ 
+  public FileServerHandler(PathResolver pResolver){
+    this.pResolver = pResolver;
+  }
+
+  @Override
+  public void messageReceived(ChannelHandlerContext ctx, String blockId) {
+    String path = pResolver.getAbsolutePath(blockId);
+    // if getFilePath returns null, close the channel
+    if (path == null) {
+        //ctx.close();
+        return;
+    }
+    File file = new File(path);
+    if (file.exists()) {
+      if (!file.isFile()) {
+        //logger.info("Not a file : " + file.getAbsolutePath());
+        ctx.write(new FileHeader(0, blockId).buffer());
+        ctx.flush();
+        return;
+      }
+      long length = file.length();
+      if (length > Integer.MAX_VALUE || length <= 0 ) {
+        //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
+        ctx.write(new FileHeader(0, blockId).buffer());
+        ctx.flush();
+        return;  
+      }
+      int len = new Long(length).intValue();
+      //logger.info("Sending block "+blockId+" filelen = "+len);
+      //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
+      ctx.write((new FileHeader(len, blockId)).buffer());
+      try {
+       ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
+            .getChannel(), 0, file.length()));
+      } catch (Exception e) {
+        // TODO Auto-generated catch block
+        //logger.warning("Exception when sending file : "
+            //+ file.getAbsolutePath());
+        e.printStackTrace();
+      }
+    } else {
+      //logger.warning("File not found: " + file.getAbsolutePath());
+      ctx.write(new FileHeader(0, blockId).buffer());
+    }
+    ctx.flush();
+  }
+ 
+  
+  @Override
+  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+    cause.printStackTrace();
+    ctx.close();
+  }
+}
diff --git a/core/src/main/java/spark/network/netty/PathResolver.java b/core/src/main/java/spark/network/netty/PathResolver.java
new file mode 100755
index 0000000000..5d5eda006e
--- /dev/null
+++ b/core/src/main/java/spark/network/netty/PathResolver.java
@@ -0,0 +1,12 @@
+package spark.network.netty;
+
+public interface PathResolver {
+  /**
+   * Get the absolute path of the file
+   * 
+   * @param fileId
+   * @return the absolute path of file
+   */
+  public String getAbsolutePath(String fileId);
+  
+}
diff --git a/core/src/main/scala/spark/network/netty/FileHeader.scala b/core/src/main/scala/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000000..aed4254234
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/FileHeader.scala
@@ -0,0 +1,57 @@
+package spark.network.netty
+
+import io.netty.buffer._
+
+import spark.Logging
+
+private[spark] class FileHeader (
+  val fileLen: Int,
+  val blockId: String) extends Logging {
+
+  lazy val buffer = {
+    val buf = Unpooled.buffer()
+    buf.capacity(FileHeader.HEADER_SIZE)
+    buf.writeInt(fileLen)
+    buf.writeInt(blockId.length)
+    blockId.foreach((x: Char) => buf.writeByte(x))
+    //padding the rest of header
+    if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+      buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+    } else {
+      throw new Exception("too long header " + buf.readableBytes) 
+      logInfo("too long header") 
+    }
+    buf
+  }
+
+}
+
+private[spark] object FileHeader {
+
+  val HEADER_SIZE = 40
+
+  def getFileLenOffset = 0
+  def getFileLenSize = Integer.SIZE/8
+
+  def create(buf: ByteBuf): FileHeader = {
+    val length = buf.readInt
+    val idLength = buf.readInt
+    val idBuilder = new StringBuilder(idLength)
+    for (i <- 1 to idLength) {
+      idBuilder += buf.readByte().asInstanceOf[Char]
+    }
+    val blockId = idBuilder.toString()
+    new FileHeader(length, blockId)
+  }
+
+
+  def main (args:Array[String]){
+
+    val header = new FileHeader(25,"block_0");
+    val buf = header.buffer;
+    val newheader = FileHeader.create(buf);
+    System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+  }
+}
+
diff --git a/core/src/main/scala/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000000..d8d35bfeec
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,88 @@
+package spark.network.netty
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.channel.ChannelInboundByteHandlerAdapter
+import io.netty.util.CharsetUtil
+
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.logging.Logger
+import spark.Logging
+import spark.network.ConnectionManagerId
+import java.util.concurrent.Executors
+
+private[spark] class ShuffleCopier extends Logging {
+
+  def getBlock(cmId: ConnectionManagerId,
+    blockId: String,
+    resultCollectCallback: (String, Long, ByteBuf) => Unit) = {
+
+    val handler = new ShuffleClientHandler(resultCollectCallback)
+    val fc = new FileClient(handler)
+    fc.init()
+    fc.connect(cmId.host, cmId.port)
+    fc.sendRequest(blockId)
+    fc.waitForClose()
+    fc.close()
+  }
+
+  def getBlocks(cmId: ConnectionManagerId,
+    blocks: Seq[(String, Long)],
+    resultCollectCallback: (String, Long, ByteBuf) => Unit) = {
+
+    blocks.map {
+      case(blockId,size) => {
+        getBlock(cmId,blockId,resultCollectCallback)
+      }
+    }
+  }
+}
+
+private[spark] class ShuffleClientHandler(val resultCollectCallBack: (String, Long, ByteBuf) => Unit ) extends FileClientHandler with Logging {
+
+  def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+    logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+    resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+  }
+}
+
+private[spark] object ShuffleCopier extends Logging {
+
+  def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) = {
+    logInfo("File: " + blockId + " content is : \" "
+      + content.toString(CharsetUtil.UTF_8) + "\"")
+  }
+
+  def runGetBlock(host:String, port:Int, file:String){
+    val handler = new ShuffleClientHandler(echoResultCollectCallBack)
+    val fc = new FileClient(handler)
+    fc.init();
+    fc.connect(host, port)
+    fc.sendRequest(file)
+    fc.waitForClose();
+    fc.close()
+  }
+
+  def main(args: Array[String]) {
+    if (args.length < 3) {
+      System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+      System.exit(1)
+    }
+    val host = args(0)
+    val port = args(1).toInt
+    val file = args(2)
+    val threads = if (args.length>3) args(3).toInt else 10
+
+    val copiers = Executors.newFixedThreadPool(80)
+    for (i <- Range(0,threads)){
+      val runnable = new Runnable() {
+        def run() {
+          runGetBlock(host,port,file)
+        }
+      }
+      copiers.execute(runnable)
+    }
+    copiers.shutdown
+  }
+
+}
diff --git a/core/src/main/scala/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000000..c1986812e9
--- /dev/null
+++ b/core/src/main/scala/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,50 @@
+package spark.network.netty
+
+import spark.Logging
+import java.io.File
+
+
+private[spark] class ShuffleSender(val port: Int, val pResolver:PathResolver) extends Logging {
+  val server = new FileServer(pResolver)
+ 
+  Runtime.getRuntime().addShutdownHook(
+    new Thread() {
+      override def run() {
+        server.stop()
+      }
+    }
+  )
+
+  def start() {
+    server.run(port)
+  }
+}
+
+private[spark] object ShuffleSender {
+  def main(args: Array[String]) {
+    if (args.length < 3) {
+      System.err.println("Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+      System.exit(1)
+    }
+    val port = args(0).toInt
+    val subDirsPerLocalDir = args(1).toInt
+    val localDirs = args.drop(2) map {new File(_)}
+    val pResovler = new PathResolver {
+      def getAbsolutePath(blockId:String):String = {
+        if (!blockId.startsWith("shuffle_")) {
+          throw new Exception("Block " + blockId + " is not a shuffle block")
+        }
+        // Figure out which local directory it hashes to, and which subdirectory in that
+        val hash = math.abs(blockId.hashCode)
+        val dirId = hash % localDirs.length
+        val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+        val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+        val file = new File(subDir, blockId)
+        return file.getAbsolutePath
+      }
+    }
+    val sender = new ShuffleSender(port, pResovler)
+
+    sender.start()
+  }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 210061e972..b8b68d4283 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -23,6 +23,8 @@ import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStam
 
 import sun.nio.ch.DirectBuffer
 
+import spark.network.netty.ShuffleCopier
+import io.netty.buffer.ByteBuf
 
 private[spark]
 case class BlockException(blockId: String, message: String, ex: Exception = null)
@@ -467,6 +469,21 @@ class BlockManager(
     getLocal(blockId).orElse(getRemote(blockId))
   }
 
+  /**
+   * A request to fetch one or more blocks, complete with their sizes
+   */
+  class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+    val size = blocks.map(_._2).sum
+  }
+
+  /**
+   * A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+   * the block (since we want all deserializaton to happen in the calling thread); can also
+   * represent a fetch failure if size == -1.
+   */
+  class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+    def failed: Boolean = size == -1
+  }
   /**
    * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
    * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
@@ -475,7 +492,12 @@ class BlockManager(
    */
   def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
       : BlockFetcherIterator = {
-    return new BlockFetcherIterator(this, blocksByAddress)
+  
+    if(System.getProperty("spark.shuffle.use.netty", "false").toBoolean){
+      return new NettyBlockFetcherIterator(this, blocksByAddress)
+    } else {
+      return new BlockFetcherIterator(this, blocksByAddress)
+    }
   }
 
   def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -908,7 +930,7 @@ class BlockFetcherIterator(
   if (blocksByAddress == null) {
     throw new IllegalArgumentException("BlocksByAddress is null")
   }
-  val totalBlocks = blocksByAddress.map(_._2.size).sum
+  var totalBlocks = blocksByAddress.map(_._2.size).sum
   logDebug("Getting " + totalBlocks + " blocks")
   var startTime = System.currentTimeMillis
   val localBlockIds = new ArrayBuffer[String]()
@@ -974,68 +996,83 @@ class BlockFetcherIterator(
     }
   }
 
-  // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
-  // at most maxBytesInFlight in order to limit the amount of data in flight.
-  val remoteRequests = new ArrayBuffer[FetchRequest]
-  for ((address, blockInfos) <- blocksByAddress) {
-    if (address == blockManagerId) {
-      localBlockIds ++= blockInfos.map(_._1)
-    } else {
-      remoteBlockIds ++= blockInfos.map(_._1)
-      // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
-      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
-      // nodes, rather than blocking on reading output from one node.
-      val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
-      logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
-      val iterator = blockInfos.iterator
-      var curRequestSize = 0L
-      var curBlocks = new ArrayBuffer[(String, Long)]
-      while (iterator.hasNext) {
-        val (blockId, size) = iterator.next()
-        curBlocks += ((blockId, size))
-        curRequestSize += size
-        if (curRequestSize >= minRequestSize) {
-          // Add this FetchRequest
+  def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = {
+    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+    // at most maxBytesInFlight in order to limit the amount of data in flight.
+    val remoteRequests = new ArrayBuffer[FetchRequest]
+    for ((address, blockInfos) <- blocksByAddress) {
+      if (address == blockManagerId) {
+        localBlockIds ++= blockInfos.map(_._1)
+      } else {
+        remoteBlockIds ++= blockInfos.map(_._1)
+        // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+        // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+        // nodes, rather than blocking on reading output from one node.
+        val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+        logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+        val iterator = blockInfos.iterator
+        var curRequestSize = 0L
+        var curBlocks = new ArrayBuffer[(String, Long)]
+        while (iterator.hasNext) {
+          val (blockId, size) = iterator.next()
+          curBlocks += ((blockId, size))
+          curRequestSize += size
+          if (curRequestSize >= minRequestSize) {
+            // Add this FetchRequest
+            remoteRequests += new FetchRequest(address, curBlocks)
+            curRequestSize = 0
+            curBlocks = new ArrayBuffer[(String, Long)]
+          }
+        }
+        // Add in the final request
+        if (!curBlocks.isEmpty) {
           remoteRequests += new FetchRequest(address, curBlocks)
-          curRequestSize = 0
-          curBlocks = new ArrayBuffer[(String, Long)]
         }
       }
-      // Add in the final request
-      if (!curBlocks.isEmpty) {
-        remoteRequests += new FetchRequest(address, curBlocks)
-      }
     }
+    remoteRequests
   }
-  // Add the remote requests into our queue in a random order
-  fetchRequests ++= Utils.randomize(remoteRequests)
 
-  // Send out initial requests for blocks, up to our maxBytesInFlight
-  while (!fetchRequests.isEmpty &&
-    (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
-    sendRequest(fetchRequests.dequeue())
+  def getLocalBlocks(){
+    // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+    // these all at once because they will just memory-map some files, so they won't consume
+    // any memory that might exceed our maxBytesInFlight
+    for (id <- localBlockIds) {
+      getLocal(id) match {
+        case Some(iter) => {
+          results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
+          logDebug("Got local block " + id)
+        }
+        case None => {
+          throw new BlockException(id, "Could not get block " + id + " from local machine")
+        }
+      }
+    }
   }
 
-  val numGets = remoteBlockIds.size - fetchRequests.size
-  logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
-  // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
-  // these all at once because they will just memory-map some files, so they won't consume
-  // any memory that might exceed our maxBytesInFlight
-  startTime = System.currentTimeMillis
-  for (id <- localBlockIds) {
-    getLocal(id) match {
-      case Some(iter) => {
-        results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
-        logDebug("Got local block " + id)
-      }
-      case None => {
-        throw new BlockException(id, "Could not get block " + id + " from local machine")
-      }
+  def initialize(){
+    // Split local and remote blocks. 
+    val remoteRequests = splitLocalRemoteBlocks()
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(remoteRequests)
+
+    // Send out initial requests for blocks, up to our maxBytesInFlight
+    while (!fetchRequests.isEmpty &&
+      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+      sendRequest(fetchRequests.dequeue())
     }
+
+    val numGets = remoteBlockIds.size - fetchRequests.size
+    logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+    // Get Local Blocks
+    startTime = System.currentTimeMillis
+    getLocalBlocks()
+    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
   }
-  logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
 
+  initialize()
   //an iterator that will read fetched blocks off the queue as they arrive.
   var resultsGotten = 0
 
@@ -1066,3 +1103,132 @@ class BlockFetcherIterator(
   def remoteBytesRead = _remoteBytesRead
 
 }
+
+class NettyBlockFetcherIterator(
+    blockManager: BlockManager,
+    blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+) extends BlockFetcherIterator(blockManager,blocksByAddress) {
+
+    import blockManager._
+
+    val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+
+    def putResult(blockId:String, blockSize:Long, blockData:ByteBuffer,
+                  results : LinkedBlockingQueue[FetchResult]){
+       results.put(new FetchResult(
+          blockId, blockSize, () => dataDeserialize(blockId, blockData) ))
+    }
+
+    def startCopiers (numCopiers: Int): List [ _ <: Thread]= {
+      (for ( i <- Range(0,numCopiers) ) yield {
+          val copier = new Thread {
+             override def run(){
+              try {
+               while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+                sendRequest(fetchRequestsSync.take())
+               }
+              } catch {
+                case x: InterruptedException => logInfo("Copier Interrupted")
+                case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+              }
+             }
+          }
+          copier.start
+          copier
+      }).toList
+    }
+
+    //keep this to interrupt the threads when necessary
+    def stopCopiers(copiers : List[_ <: Thread]) {
+      for (copier <- copiers) {
+        copier.interrupt()
+      }
+    }
+
+    override def sendRequest(req: FetchRequest) {
+      logDebug("Sending request for %d blocks (%s) from %s".format(
+        req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+      val cmId = new ConnectionManagerId(req.address.ip, System.getProperty("spark.shuffle.sender.port", "6653").toInt)
+      val cpier = new ShuffleCopier
+      cpier.getBlocks(cmId,req.blocks,(blockId:String,blockSize:Long,blockData:ByteBuf) => putResult(blockId,blockSize,blockData.nioBuffer,results))
+      logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.ip )
+    }
+
+    override def splitLocalRemoteBlocks() : ArrayBuffer[FetchRequest] = {
+      // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+      // at most maxBytesInFlight in order to limit the amount of data in flight.
+      val originalTotalBlocks = totalBlocks;
+      val remoteRequests = new ArrayBuffer[FetchRequest]
+      for ((address, blockInfos) <- blocksByAddress) {
+        if (address == blockManagerId) {
+          localBlockIds ++= blockInfos.map(_._1)
+        } else {
+          remoteBlockIds ++= blockInfos.map(_._1)
+          // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+          // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+          // nodes, rather than blocking on reading output from one node.
+          val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+          logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+          val iterator = blockInfos.iterator
+          var curRequestSize = 0L
+          var curBlocks = new ArrayBuffer[(String, Long)]
+          while (iterator.hasNext) {
+            val (blockId, size) = iterator.next()
+            if (size > 0) {
+              curBlocks += ((blockId, size))
+              curRequestSize += size
+            } else if (size == 0){
+              //here we changes the totalBlocks
+              totalBlocks -= 1
+            } else {
+              throw new SparkException("Negative block size "+blockId)
+            }
+            if (curRequestSize >= minRequestSize) {
+              // Add this FetchRequest
+              remoteRequests += new FetchRequest(address, curBlocks)
+              curRequestSize = 0
+              curBlocks = new ArrayBuffer[(String, Long)]
+            }
+          }
+          // Add in the final request
+          if (!curBlocks.isEmpty) {
+            remoteRequests += new FetchRequest(address, curBlocks)
+          }
+        }
+      }
+      logInfo("Getting " + totalBlocks + " non 0-byte blocks out of " + originalTotalBlocks + " blocks")
+      remoteRequests
+    }
+
+    var copiers : List[_ <: Thread] = null
+
+    override def initialize(){
+      // Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks 
+      val remoteRequests = splitLocalRemoteBlocks()
+      // Add the remote requests into our queue in a random order
+      for (request <- Utils.randomize(remoteRequests)) {
+        fetchRequestsSync.put(request)
+      }
+
+      copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+      logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+      // Get Local Blocks
+      startTime = System.currentTimeMillis
+      getLocalBlocks()
+      logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+    }
+
+    override def next(): (String, Option[Iterator[Any]]) = {
+      resultsGotten += 1
+      val result = results.take()
+      // if all the results has been retrieved
+      // shutdown the copiers
+      if (resultsGotten == totalBlocks) {
+        if( copiers != null )
+          stopCopiers(copiers)
+      }
+      (result.blockId, if (result.failed) None else Some(result.deserialize()))
+    }
+  }
+
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index ddbf8821ad..d702bb23e0 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -13,24 +13,35 @@ import scala.collection.mutable.ArrayBuffer
 import spark.executor.ExecutorExitCode
 
 import spark.Utils
+import spark.Logging
+import spark.network.netty.ShuffleSender
+import spark.network.netty.PathResolver
 
 /**
  * Stores BlockManager blocks on disk.
  */
 private class DiskStore(blockManager: BlockManager, rootDirs: String)
-  extends BlockStore(blockManager) {
+  extends BlockStore(blockManager) with Logging {
 
   val MAX_DIR_CREATION_ATTEMPTS: Int = 10
   val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
 
+  var shuffleSender : Thread = null
+  val thisInstance = this
   // Create one local directory for each path mentioned in spark.local.dir; then, inside this
   // directory, create multiple subdirectories that we will hash files into, in order to avoid
   // having really large inodes at the top level.
   val localDirs = createLocalDirs()
   val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
 
+  val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
+
   addShutdownHook()
 
+  if(useNetty){
+  startShuffleBlockSender()
+  }
+
   override def getSize(blockId: String): Long = {
     getFile(blockId).length()
   }
@@ -180,10 +191,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
         logDebug("Shutdown hook called")
         try {
           localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+          if (useNetty && shuffleSender != null)
+            shuffleSender.stop
         } catch {
           case t: Throwable => logError("Exception while deleting local spark dirs", t)
         }
       }
     })
   }
+
+  private def startShuffleBlockSender (){
+    try {
+      val port = System.getProperty("spark.shuffle.sender.port", "6653").toInt
+
+      val pResolver = new PathResolver {
+        def getAbsolutePath(blockId:String):String = {
+          if (!blockId.startsWith("shuffle_")) {
+            return null
+          }
+          thisInstance.getFile(blockId).getAbsolutePath()
+        }
+      } 
+      shuffleSender = new Thread {
+        override def run() = {
+          val sender = new ShuffleSender(port,pResolver)
+          logInfo("created ShuffleSender binding to port : "+ port)
+          sender.start
+        }
+      }
+      shuffleSender.setDaemon(true)
+      shuffleSender.start
+  
+    } catch {
+      case interrupted: InterruptedException =>
+        logInfo("Runner thread for ShuffleBlockSender interrupted")
+
+      case e: Exception => {
+        logError("Error running ShuffleBlockSender ", e)
+        if (shuffleSender != null) {
+        shuffleSender.stop
+          shuffleSender = null
+        }
+      }
+    }
+  }
 }
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 5f378b2398..e3645653ee 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -141,7 +141,8 @@ object SparkBuild extends Build {
       "cc.spray" % "spray-can" % "1.0-M2.1",
       "cc.spray" % "spray-server" % "1.0-M2.1",
       "cc.spray" %%  "spray-json" % "1.1.1",
-      "org.apache.mesos" % "mesos" % "0.9.0-incubating"
+      "org.apache.mesos" % "mesos" % "0.9.0-incubating",
+      "io.netty" % "netty-all" % "4.0.0.Beta2"
     ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq,
     unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") }
   ) ++ assemblySettings ++ extraAssemblySettings ++ Twirl.settings
diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
index d8b987ec86..bd0b0e74c1 100644
--- a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
@@ -5,7 +5,7 @@ import spark.util.{RateLimitedOutputStream, IntParam}
 import java.net.ServerSocket
 import spark.{Logging, KryoSerializer}
 import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import io.Source
+import scala.io.Source
 import java.io.IOException
 
 /**
-- 
GitLab