From 5f0cdabd408cf8af3ebb15d78cd7072b396f8e47 Mon Sep 17 00:00:00 2001
From: Mosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>
Date: Tue, 21 Dec 2010 21:52:37 -0800
Subject: [PATCH] Added a separate thread to deserialize (1 thread per reducer)
 in CustomParallelLocalFileShuffle Upside: No synchronized blocking on
 "combiners" variable. 3x faster :) Downside: Inefficient implementation.
 Requiring too much temporary data. Approx. 2x increase in memory requirement
 :( Should be fixed at some point.

---
 conf/java-opts                                |   2 +-
 .../CustomParallelLocalFileShuffle.scala      | 109 +++++++++++++-----
 2 files changed, 82 insertions(+), 29 deletions(-)

diff --git a/conf/java-opts b/conf/java-opts
index 409f5dd3e7..af2b51124a 100644
--- a/conf/java-opts
+++ b/conf/java-opts
@@ -1 +1 @@
--Dspark.shuffle.class=spark.CustomParallelLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=4 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=5000 
+-Dspark.shuffle.class=spark.CustomParallelLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=1000 -Dspark.parallelLocalFileShuffle.maxKnockInterval=5000 
diff --git a/src/scala/spark/CustomParallelLocalFileShuffle.scala b/src/scala/spark/CustomParallelLocalFileShuffle.scala
index 9ec2a2dc66..2fa3247383 100644
--- a/src/scala/spark/CustomParallelLocalFileShuffle.scala
+++ b/src/scala/spark/CustomParallelLocalFileShuffle.scala
@@ -4,7 +4,7 @@ import java.io._
 import java.net._
 import java.util.{BitSet, Random, Timer, TimerTask, UUID}
 import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.{Executors, ThreadPoolExecutor, ThreadFactory}
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 
@@ -23,6 +23,7 @@ extends Shuffle[K, V, C] with Logging {
   @transient var hasSplitsBitVector: BitSet = null
   @transient var splitsInRequestBitVector: BitSet = null
 
+  @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null  
   @transient var combiners: HashMap[K,C] = null
   
   override def compute(input: RDD[(K, V)],
@@ -87,11 +88,19 @@ extends Shuffle[K, V, C] with Logging {
       hasSplits = 0
       hasSplitsBitVector = new BitSet(totalSplits)
       splitsInRequestBitVector = new BitSet(totalSplits)
+
+      receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]      
       combiners = new HashMap[K, C]
       
       var threadPool = CustomParallelLocalFileShuffle.newDaemonFixedThreadPool(
         CustomParallelLocalFileShuffle.MaxRxConnections)
         
+      // Start consumer
+      var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+      shuffleConsumer.setDaemon(true)
+      shuffleConsumer.start()
+      logInfo("ShuffleConsumer started...")
+        
       while (hasSplits < totalSplits) {
         var numThreadsToCreate = Math.min(totalSplits, 
           CustomParallelLocalFileShuffle.MaxRxConnections) - 
@@ -106,7 +115,7 @@ extends Shuffle[K, V, C] with Logging {
             val requestPath = "%d/%d/%d".format(shuffleId, inputId, myId)
 
             threadPool.execute(new ShuffleClient(splitIndex, serverAddress, 
-              serverPort, requestPath, mergeCombiners))
+              serverPort, requestPath))
               
             // splitIndex is in transit. Will be unset in the ShuffleClient
             splitsInRequestBitVector.synchronized {
@@ -145,8 +154,57 @@ extends Shuffle[K, V, C] with Logging {
     }
   }
   
+  class ShuffleConsumer(mergeCombiners: (C, C) => C) 
+  extends Thread with Logging {   
+    override def run: Unit = {
+      // Run until all splits are here
+      while (hasSplits < totalSplits) {
+        var splitIndex = -1
+        var recvByteArray: Array[Byte] = null
+      
+        try {
+          var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+          splitIndex = tempPair._1
+          recvByteArray = tempPair._2
+        } catch {
+          case e: Exception => {
+            logInfo("Exception during taking data from receivedData")
+          }
+        }      
+      
+        val inputStream = 
+          new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+          
+        try{
+          while (true) {
+            val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+            combiners(k) = combiners.get(k) match {
+              case Some(oldC) => mergeCombiners(oldC, c)
+              case None => c
+            }
+          }
+        } catch {
+          case e: EOFException => { }
+        }
+        inputStream.close()
+        
+        // Consumption completed. Update stats.
+        hasSplitsBitVector.synchronized {
+          hasSplitsBitVector.set(splitIndex)
+        }
+        hasSplits += 1
+
+        // We have received splitIndex
+        splitsInRequestBitVector.synchronized {
+          splitsInRequestBitVector.set(splitIndex, false)
+        }
+        
+      }
+    }
+  }
+  
   class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, 
-    requestPath: String, mergeCombiners: (C, C) => C)
+    requestPath: String)
   extends Thread with Logging {
     private var peerSocketToSource: Socket = null
     private var oosSource: ObjectOutputStream = null
@@ -192,34 +250,29 @@ extends Shuffle[K, V, C] with Logging {
           val readStartTime = System.currentTimeMillis
           logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestPath))
 
-          // Add this to combiners
-          val inputStream = new ObjectInputStream(isSource)
-            
-          try{
-            while (true) {
-              val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
-              combiners.synchronized {
-                combiners(k) = combiners.get(k) match {
-                  case Some(oldC) => mergeCombiners(oldC, c)
-                  case None => c
-                }
-              }
+          // Receive data in an Array[Byte]
+          var recvByteArray = new Array[Byte](requestedFileLen)
+          var alreadyRead = 0
+          var bytesRead = 0
+          
+          while (alreadyRead != requestedFileLen) {
+            bytesRead = isSource.read(recvByteArray, alreadyRead, 
+              requestedFileLen - alreadyRead)
+            if (bytesRead > 0) {
+              alreadyRead  = alreadyRead + bytesRead
             }
+          } 
+          
+          // Make it available to the consumer
+          try {
+            receivedData.put((splitIndex, recvByteArray))
           } catch {
-            case e: EOFException => { }
+            case e: Exception => {
+              logInfo("Exception during putting data into receivedData")
+            }
           }
-          inputStream.close()
           
-          // Reception completed. Update stats.
-          hasSplitsBitVector.synchronized {
-            hasSplitsBitVector.set(splitIndex)
-          }
-          hasSplits += 1
-
-          // We have received splitIndex
-          splitsInRequestBitVector.synchronized {
-            splitsInRequestBitVector.set(splitIndex, false)
-          }
+          // NOTE: Update of bitVectors are now done by the consumer. 
           
           receptionSucceeded = true
 
@@ -329,7 +382,7 @@ object CustomParallelLocalFileShuffle extends Logging {
       // Create and start the shuffleServer      
       shuffleServer = new ShuffleServer
       shuffleServer.setDaemon(true)
-      shuffleServer.start
+      shuffleServer.start()
       logInfo("ShuffleServer started...")
       
       initialized = true
-- 
GitLab