diff --git a/conf/java-opts b/conf/java-opts
index b7afd7588e325a4b5ece7d25a7717b37d4ef33da..5b4ae25c067e56825a3ece43732593ac835ac33b 100644
--- a/conf/java-opts
+++ b/conf/java-opts
@@ -1 +1 @@
--Dspark.shuffle.class=spark.HttpBlockedLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000
+-Dspark.shuffle.class=spark.CustomParallelInMemoryShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 -Dspark.parallelInMemoryShuffle.maxRxConnections=2 -Dspark.parallelInMemoryShuffle.maxTxConnections=2 -Dspark.parallelInMemoryShuffle.minKnockInterval=50 -Dspark.parallelInMemoryShuffle.maxKnockInterval=2000
diff --git a/src/scala/spark/CustomParallelInMemoryShuffle.scala b/src/scala/spark/CustomParallelInMemoryShuffle.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ee3403e49ac7d910a2adf74314e5d6f8c2ec7599
--- /dev/null
+++ b/src/scala/spark/CustomParallelInMemoryShuffle.scala
@@ -0,0 +1,555 @@
+package spark
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Random, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+/**
+ * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). 
+ * 
+ * An implementation of shuffle using local memory served through custom server 
+ * where receivers create simultaneous connections to multiple servers by 
+ * setting the 'spark.parallelLocalFileShuffle.maxRxConnections' config option.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class CustomParallelInMemoryShuffle[K, V, C] 
+extends Shuffle[K, V, C] with Logging {
+  @transient var totalSplits = 0
+  @transient var hasSplits = 0 
+  @transient var hasSplitsBitVector: BitSet = null
+  @transient var splitsInRequestBitVector: BitSet = null
+
+  @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null  
+  @transient var combiners: HashMap[K,C] = null
+  
+  override def compute(input: RDD[(K, V)],
+                       numOutputSplits: Int,
+                       createCombiner: V => C,
+                       mergeValue: (C, V) => C,
+                       mergeCombiners: (C, C) => C)
+  : RDD[(K, C)] =
+  {
+    val sc = input.sparkContext
+    val shuffleId = CustomParallelInMemoryShuffle.newShuffleId()
+    logInfo("Shuffle ID: " + shuffleId)
+
+    val splitRdd = new NumberedSplitRDD(input)
+    val numInputSplits = splitRdd.splits.size
+
+    // Run a parallel map and collect to write the intermediate data files,
+    // returning a list of inputSplitId -> serverUri pairs
+    val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+      val myIndex = pair._1
+      val myIterator = pair._2
+      val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+      for ((k, v) <- myIterator) {
+        var bucketId = k.hashCode % numOutputSplits
+        if (bucketId < 0) { // Fix bucket ID if hash code was negative
+          bucketId += numOutputSplits
+        }
+        val bucket = buckets(bucketId)
+        bucket(k) = bucket.get(k) match {
+          case Some(c) => mergeValue(c, v)
+          case None => createCombiner(v)
+        }
+      }
+      
+      for (i <- 0 until numOutputSplits) {        
+        val splitName = 
+          CustomParallelInMemoryShuffle.getSplitName(shuffleId, myIndex, i)
+
+        val writeStartTime = System.currentTimeMillis
+        logInfo("BEGIN WRITE: " + splitName)
+
+        // Write buckets(i) to a byte array & put in splitsCache instead of file
+        val baos = new ByteArrayOutputStream
+        val oos = new ObjectOutputStream(baos)
+        oos.writeObject(buckets(i))        
+        oos.close
+        baos.close
+        
+        CustomParallelInMemoryShuffle.splitsCache(splitName) = baos.toByteArray
+        val splitLen = 
+          CustomParallelInMemoryShuffle.splitsCache(splitName).length
+        
+        logInfo("END WRITE: " + splitName)
+        val writeTime = System.currentTimeMillis - writeStartTime
+        logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.")
+      }
+      
+      (myIndex, CustomParallelInMemoryShuffle.serverAddress, 
+        CustomParallelInMemoryShuffle.serverPort)
+    }).collect()
+
+    val splitsByUri = new ArrayBuffer[(String, Int, Int)]
+    for ((inputId, serverAddress, serverPort) <- outputLocs) {
+      splitsByUri += ((serverAddress, serverPort, inputId))
+    }
+
+    // TODO: Could broadcast splitsByUri
+
+    // Return an RDD that does each of the merges for a given partition
+    val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+    return indexes.flatMap((myId: Int) => {
+      totalSplits = splitsByUri.size
+      hasSplits = 0
+      hasSplitsBitVector = new BitSet(totalSplits)
+      splitsInRequestBitVector = new BitSet(totalSplits)
+
+      receivedData = new LinkedBlockingQueue[(Int, Array[Byte])]
+      combiners = new HashMap[K, C]
+      
+      var threadPool = CustomParallelInMemoryShuffle.newDaemonFixedThreadPool(
+        CustomParallelInMemoryShuffle.MaxRxConnections)
+        
+      // Start consumer
+      var shuffleConsumer = new ShuffleConsumer(mergeCombiners)
+      shuffleConsumer.setDaemon(true)
+      shuffleConsumer.start()
+      logInfo("ShuffleConsumer started...")
+        
+      while (hasSplits < totalSplits) {
+        var numThreadsToCreate = Math.min(totalSplits, 
+          CustomParallelInMemoryShuffle.MaxRxConnections) - 
+          threadPool.getActiveCount
+      
+        while (hasSplits < totalSplits && numThreadsToCreate > 0) {        
+          // Select a random split to pull
+          val splitIndex = selectRandomSplit
+          
+          if (splitIndex != -1) {
+            val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex)
+            val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId)
+
+            threadPool.execute(new ShuffleClient(splitIndex, serverAddress, 
+              serverPort, requestSplit))
+              
+            // splitIndex is in transit. Will be unset in the ShuffleClient
+            splitsInRequestBitVector.synchronized {
+              splitsInRequestBitVector.set(splitIndex)
+            }
+          }
+          
+          numThreadsToCreate = numThreadsToCreate - 1
+        }
+        
+        // Sleep for a while before creating new threads
+        Thread.sleep(CustomParallelInMemoryShuffle.MinKnockInterval)
+      }
+      
+      threadPool.shutdown()
+      combiners
+    })
+  }
+  
+  def selectRandomSplit: Int = {
+    var requiredSplits = new ArrayBuffer[Int]
+    
+    synchronized {
+      for (i <- 0 until totalSplits) {
+        if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) {
+          requiredSplits += i
+        }
+      }
+    }
+    
+    if (requiredSplits.size > 0) {
+      requiredSplits(CustomParallelInMemoryShuffle.ranGen.nextInt(
+        requiredSplits.size))
+    } else {
+      -1
+    }
+  }
+  
+  class ShuffleConsumer(mergeCombiners: (C, C) => C)
+  extends Thread with Logging {   
+    override def run: Unit = {
+      // Run until all splits are here
+      while (hasSplits < totalSplits) {
+        var splitIndex = -1
+        var recvByteArray: Array[Byte] = null
+      
+        try {
+          var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])]
+          splitIndex = tempPair._1
+          recvByteArray = tempPair._2
+        } catch {
+          case e: Exception => {
+            logInfo("Exception during taking data from receivedData")
+          }
+        }      
+      
+        val inputStream = 
+          new ObjectInputStream(new ByteArrayInputStream(recvByteArray))
+          
+        try{
+          while (true) {
+//            logInfo("" + inputStream.readObject.isInstanceOf[(K, C)])
+            val (k, c) = inputStream.readObject.asInstanceOf[(K, C)]
+            combiners(k) = combiners.get(k) match {
+              case Some(oldC) => mergeCombiners(oldC, c)
+              case None => c
+            }
+          }
+        } catch {
+          case e: EOFException => { }
+        }
+        inputStream.close()
+        
+        // Consumption completed. Update stats.
+        hasSplitsBitVector.synchronized {
+          hasSplitsBitVector.set(splitIndex)
+        }
+        hasSplits += 1
+
+        // We have received splitIndex
+        splitsInRequestBitVector.synchronized {
+          splitsInRequestBitVector.set(splitIndex, false)
+        }
+        
+      }
+    }
+  }
+  
+  class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, 
+    requestSplit: String)
+  extends Thread with Logging {
+    private var peerSocketToSource: Socket = null
+    private var oosSource: ObjectOutputStream = null
+    private var oisSource: ObjectInputStream = null
+    
+    private var receptionSucceeded = false
+
+    override def run: Unit = {
+      // Setup the timeout mechanism
+      var timeOutTask = new TimerTask {
+        override def run: Unit = {
+          cleanUpConnections()
+        }
+      }
+      
+      var timeOutTimer = new Timer
+      timeOutTimer.schedule(timeOutTask, 
+        CustomParallelInMemoryShuffle.MaxKnockInterval)
+      
+      logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit))
+      
+      try {
+        // Connect to the source
+        peerSocketToSource = new Socket(hostAddress, listenPort)
+        oosSource =
+          new ObjectOutputStream(peerSocketToSource.getOutputStream)
+        oosSource.flush()
+        var isSource = peerSocketToSource.getInputStream
+        oisSource = new ObjectInputStream(isSource)
+        
+        // Send the request
+        oosSource.writeObject(requestSplit)
+        
+        // Receive the length of the requested file
+        var requestedFileLen = oisSource.readObject.asInstanceOf[Int]
+        logInfo("Received requestedFileLen = " + requestedFileLen)
+
+        // Turn the timer OFF, if the sender responds before timeout
+        timeOutTimer.cancel()
+        
+        // Receive the file
+        if (requestedFileLen != -1) {
+          val readStartTime = System.currentTimeMillis
+          logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+
+          // Receive data in an Array[Byte]
+          var recvByteArray = new Array[Byte](requestedFileLen)
+          var alreadyRead = 0
+          var bytesRead = 0
+          
+          while (alreadyRead != requestedFileLen) {
+            bytesRead = isSource.read(recvByteArray, alreadyRead, 
+              requestedFileLen - alreadyRead)
+            if (bytesRead > 0) {
+              alreadyRead  = alreadyRead + bytesRead
+            }
+          } 
+          
+          // Make it available to the consumer
+          try {
+            receivedData.put((splitIndex, recvByteArray))
+          } catch {
+            case e: Exception => {
+              logInfo("Exception during putting data into receivedData")
+            }
+          }
+          
+          // NOTE: Update of bitVectors are now done by the consumer
+          
+          receptionSucceeded = true
+
+          logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit))
+          val readTime = System.currentTimeMillis - readStartTime
+          logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.")
+        } else {
+          throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit)
+        }
+      } catch {
+        // EOFException is expected to happen because sender can break
+        // connection due to timeout
+        case eofe: java.io.EOFException => { }
+        case e: Exception => {
+          logInfo("ShuffleClient had a " + e)
+        }
+      } finally {
+        // If reception failed, unset for future retry
+        if (!receptionSucceeded) {
+          splitsInRequestBitVector.synchronized {
+            splitsInRequestBitVector.set(splitIndex, false)
+          }
+        }
+        cleanUpConnections()
+      }
+    }
+    
+    private def cleanUpConnections(): Unit = {
+      if (oisSource != null) {
+        oisSource.close()
+      }
+      if (oosSource != null) {
+        oosSource.close()
+      }
+      if (peerSocketToSource != null) {
+        peerSocketToSource.close()
+      }
+    }
+  }  
+}
+
+object CustomParallelInMemoryShuffle extends Logging {
+  // Cache for keeping the splits around
+  val splitsCache = new HashMap[String, Array[Byte]]
+
+  // Used thoughout the code for small and large waits/timeouts
+  private var MinKnockInterval_ = 1000
+  private var MaxKnockInterval_ = 5000
+  
+  // Maximum number of connections
+  private var MaxRxConnections_ = 4
+  private var MaxTxConnections_ = 8
+  
+  private var initialized = false
+  private var nextShuffleId = new AtomicLong(0)
+
+  // Variables initialized by initializeIfNeeded()
+  private var shuffleDir: File = null
+  
+  private var shuffleServer: ShuffleServer = null
+  private var serverAddress = InetAddress.getLocalHost.getHostAddress
+  private var serverPort: Int = -1
+  
+  // Random number generator
+  var ranGen = new Random
+  
+  private def initializeIfNeeded() = synchronized {
+    if (!initialized) {
+      // Load config parameters
+      MinKnockInterval_ = System.getProperty(
+        "spark.parallelInMemoryShuffle.minKnockInterval", "1000").toInt
+      MaxKnockInterval_ =  System.getProperty(
+        "spark.parallelInMemoryShuffle.maxKnockInterval", "5000").toInt
+
+      MaxRxConnections_ = System.getProperty(
+        "spark.parallelInMemoryShuffle.maxRxConnections", "4").toInt
+      MaxTxConnections_ = System.getProperty(
+        "spark.parallelInMemoryShuffle.maxTxConnections", "8").toInt
+        
+      // TODO: localDir should be created by some mechanism common to Spark
+      // so that it can be shared among shuffle, broadcast, etc
+      val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+      var tries = 0
+      var foundLocalDir = false
+      var localDir: File = null
+      var localDirUuid: UUID = null
+      
+      while (!foundLocalDir && tries < 10) {
+        tries += 1
+        try {
+          localDirUuid = UUID.randomUUID
+          localDir = new File(localDirRoot, "spark-local-" + localDirUuid)          
+          if (!localDir.exists) {
+            localDir.mkdirs()
+            foundLocalDir = true
+          }
+        } catch {
+          case e: Exception =>
+            logWarning("Attempt " + tries + " to create local dir failed", e)
+        }
+      }
+      if (!foundLocalDir) {
+        logError("Failed 10 attempts to create local dir in " + localDirRoot)
+        System.exit(1)
+      }
+      shuffleDir = new File(localDir, "shuffle")
+      shuffleDir.mkdirs()
+      logInfo("Shuffle dir: " + shuffleDir)
+
+      // Create and start the shuffleServer      
+      shuffleServer = new ShuffleServer
+      shuffleServer.setDaemon(true)
+      shuffleServer.start()
+      logInfo("ShuffleServer started...")
+      
+      initialized = true
+    }
+  }
+  
+  def MinKnockInterval = MinKnockInterval_
+  def MaxKnockInterval = MaxKnockInterval_
+  
+  def MaxRxConnections = MaxRxConnections_
+  def MaxTxConnections = MaxTxConnections_
+  
+  def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = {
+    initializeIfNeeded()
+    // Adding shuffleDir is unnecessary. Added to keep the parsers working
+    return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId)
+  }
+
+  def newShuffleId(): Long = {
+    nextShuffleId.getAndIncrement()
+  }
+  
+  // Returns a standard ThreadFactory except all threads are daemons
+  private def newDaemonThreadFactory: ThreadFactory = {
+    new ThreadFactory {
+      def newThread(r: Runnable): Thread = {
+        var t = Executors.defaultThreadFactory.newThread(r)
+        t.setDaemon(true)
+        return t
+      }
+    }
+  }
+
+  // Wrapper over newFixedThreadPool
+  def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
+    var threadPool =
+      Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
+
+    threadPool.setThreadFactory(newDaemonThreadFactory)
+    
+    return threadPool
+  }
+  
+  class ShuffleServer
+  extends Thread with Logging {
+    var threadPool = 
+      newDaemonFixedThreadPool(CustomParallelInMemoryShuffle.MaxTxConnections)
+
+    var serverSocket: ServerSocket = null
+
+    override def run: Unit = {
+      serverSocket = new ServerSocket(0)
+      serverPort = serverSocket.getLocalPort
+
+      logInfo("ShuffleServer started with " + serverSocket)
+      logInfo("Local URI: http://" + serverAddress + ":" + serverPort)
+
+      try {
+        while (true) {
+          var clientSocket: Socket = null
+          try {
+            clientSocket = serverSocket.accept()
+          } catch {
+            case e: Exception => { }
+          }
+          if (clientSocket != null) {
+            logInfo("Serve: Accepted new client connection:" + clientSocket)
+            try {
+              threadPool.execute(new ShuffleServerThread(clientSocket))
+            } catch {
+              // In failure, close socket here; else, the thread will close it
+              case ioe: IOException => {
+                clientSocket.close()
+              }
+            }
+          }
+        }
+      } finally {
+        if (serverSocket != null) {
+          logInfo("ShuffleServer now stopping...")
+          serverSocket.close()
+        }
+      }
+      // Shutdown the thread pool
+      threadPool.shutdown()
+    }
+    
+    class ShuffleServerThread(val clientSocket: Socket)
+    extends Thread with Logging {
+      private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream]
+      os.flush()
+      private val bos = new BufferedOutputStream(os)
+      bos.flush()
+      private val oos = new ObjectOutputStream(os)
+      oos.flush()
+      private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+      logInfo("new ShuffleServerThread is running")
+      
+      override def run: Unit = {
+        try {
+          // Receive requestedSplit from the receiver
+          // Adding shuffleDir is unnecessary. Added to keep the parsers working
+          var requestedSplit = 
+            shuffleDir + "/" + ois.readObject.asInstanceOf[String]
+          logInfo("requestedSplit: " + requestedSplit)
+          
+          // Send the lendth of the requestedSplit to let the receiver know that 
+          // transfer is about to start
+          // In the case of receiver timeout and connection close, this will
+          // throw a java.net.SocketException: Broken pipe
+          var requestedSplitLen = -1
+          
+          try {
+            requestedSplitLen =
+              CustomParallelInMemoryShuffle.splitsCache(requestedSplit).length
+          } catch {
+            case e: Exception => { }
+          }
+
+          oos.writeObject(requestedSplitLen)
+          oos.flush()
+          
+          logInfo("requestedSplitLen = " + requestedSplitLen)
+
+          // Read and send the requested split
+          if (requestedSplitLen != -1) {
+            // Send
+            bos.write(CustomParallelInMemoryShuffle.splitsCache(requestedSplit),
+              0, requestedSplitLen)
+            bos.flush()
+          } else {
+            // Close the connection
+          }
+        } catch {
+          // If something went wrong, e.g., the worker at the other end died etc
+          // then close everything up
+          // Exception can happen if the receiver stops receiving
+          case e: Exception => {
+            logInfo("ShuffleServerThread had a " + e)
+          }
+        } finally {
+          logInfo("ShuffleServerThread is closing streams and sockets")
+          ois.close()
+          // TODO: Following can cause "java.net.SocketException: Socket closed"
+          oos.close()
+          bos.close()
+          clientSocket.close()
+        }
+      }
+    }
+  }
+}