diff --git a/conf/java-opts b/conf/java-opts index b7afd7588e325a4b5ece7d25a7717b37d4ef33da..5b4ae25c067e56825a3ece43732593ac835ac33b 100644 --- a/conf/java-opts +++ b/conf/java-opts @@ -1 +1 @@ --Dspark.shuffle.class=spark.HttpBlockedLocalFileShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 +-Dspark.shuffle.class=spark.CustomParallelInMemoryShuffle -Dspark.blockedLocalFileShuffle.maxRxConnections=2 -Dspark.blockedLocalFileShuffle.blockSize=256 -Dspark.blockedLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxRxConnections=2 -Dspark.parallelLocalFileShuffle.maxTxConnections=2 -Dspark.parallelLocalFileShuffle.minKnockInterval=50 -Dspark.parallelLocalFileShuffle.maxKnockInterval=2000 -Dspark.parallelInMemoryShuffle.maxRxConnections=2 -Dspark.parallelInMemoryShuffle.maxTxConnections=2 -Dspark.parallelInMemoryShuffle.minKnockInterval=50 -Dspark.parallelInMemoryShuffle.maxKnockInterval=2000 diff --git a/src/scala/spark/CustomParallelInMemoryShuffle.scala b/src/scala/spark/CustomParallelInMemoryShuffle.scala new file mode 100644 index 0000000000000000000000000000000000000000..ee3403e49ac7d910a2adf74314e5d6f8c2ec7599 --- /dev/null +++ b/src/scala/spark/CustomParallelInMemoryShuffle.scala @@ -0,0 +1,555 @@ +package spark + +import java.io._ +import java.net._ +import java.util.{BitSet, Random, Timer, TimerTask, UUID} +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{LinkedBlockingQueue, Executors, ThreadPoolExecutor, ThreadFactory} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +/** + * TODO: THIS IS AN ABSOLUTELY EXPERIMENTAL IMPLEMENTATON (FOR NOW). + * + * An implementation of shuffle using local memory served through custom server + * where receivers create simultaneous connections to multiple servers by + * setting the 'spark.parallelLocalFileShuffle.maxRxConnections' config option. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class CustomParallelInMemoryShuffle[K, V, C] +extends Shuffle[K, V, C] with Logging { + @transient var totalSplits = 0 + @transient var hasSplits = 0 + @transient var hasSplitsBitVector: BitSet = null + @transient var splitsInRequestBitVector: BitSet = null + + @transient var receivedData: LinkedBlockingQueue[(Int, Array[Byte])] = null + @transient var combiners: HashMap[K,C] = null + + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = CustomParallelInMemoryShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + + for (i <- 0 until numOutputSplits) { + val splitName = + CustomParallelInMemoryShuffle.getSplitName(shuffleId, myIndex, i) + + val writeStartTime = System.currentTimeMillis + logInfo("BEGIN WRITE: " + splitName) + + // Write buckets(i) to a byte array & put in splitsCache instead of file + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(buckets(i)) + oos.close + baos.close + + CustomParallelInMemoryShuffle.splitsCache(splitName) = baos.toByteArray + val splitLen = + CustomParallelInMemoryShuffle.splitsCache(splitName).length + + logInfo("END WRITE: " + splitName) + val writeTime = System.currentTimeMillis - writeStartTime + logInfo("Writing " + splitName + " of size " + splitLen + " bytes took " + writeTime + " millis.") + } + + (myIndex, CustomParallelInMemoryShuffle.serverAddress, + CustomParallelInMemoryShuffle.serverPort) + }).collect() + + val splitsByUri = new ArrayBuffer[(String, Int, Int)] + for ((inputId, serverAddress, serverPort) <- outputLocs) { + splitsByUri += ((serverAddress, serverPort, inputId)) + } + + // TODO: Could broadcast splitsByUri + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + totalSplits = splitsByUri.size + hasSplits = 0 + hasSplitsBitVector = new BitSet(totalSplits) + splitsInRequestBitVector = new BitSet(totalSplits) + + receivedData = new LinkedBlockingQueue[(Int, Array[Byte])] + combiners = new HashMap[K, C] + + var threadPool = CustomParallelInMemoryShuffle.newDaemonFixedThreadPool( + CustomParallelInMemoryShuffle.MaxRxConnections) + + // Start consumer + var shuffleConsumer = new ShuffleConsumer(mergeCombiners) + shuffleConsumer.setDaemon(true) + shuffleConsumer.start() + logInfo("ShuffleConsumer started...") + + while (hasSplits < totalSplits) { + var numThreadsToCreate = Math.min(totalSplits, + CustomParallelInMemoryShuffle.MaxRxConnections) - + threadPool.getActiveCount + + while (hasSplits < totalSplits && numThreadsToCreate > 0) { + // Select a random split to pull + val splitIndex = selectRandomSplit + + if (splitIndex != -1) { + val (serverAddress, serverPort, inputId) = splitsByUri(splitIndex) + val requestSplit = "%d/%d/%d".format(shuffleId, inputId, myId) + + threadPool.execute(new ShuffleClient(splitIndex, serverAddress, + serverPort, requestSplit)) + + // splitIndex is in transit. Will be unset in the ShuffleClient + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex) + } + } + + numThreadsToCreate = numThreadsToCreate - 1 + } + + // Sleep for a while before creating new threads + Thread.sleep(CustomParallelInMemoryShuffle.MinKnockInterval) + } + + threadPool.shutdown() + combiners + }) + } + + def selectRandomSplit: Int = { + var requiredSplits = new ArrayBuffer[Int] + + synchronized { + for (i <- 0 until totalSplits) { + if (!hasSplitsBitVector.get(i) && !splitsInRequestBitVector.get(i)) { + requiredSplits += i + } + } + } + + if (requiredSplits.size > 0) { + requiredSplits(CustomParallelInMemoryShuffle.ranGen.nextInt( + requiredSplits.size)) + } else { + -1 + } + } + + class ShuffleConsumer(mergeCombiners: (C, C) => C) + extends Thread with Logging { + override def run: Unit = { + // Run until all splits are here + while (hasSplits < totalSplits) { + var splitIndex = -1 + var recvByteArray: Array[Byte] = null + + try { + var tempPair = receivedData.take().asInstanceOf[(Int, Array[Byte])] + splitIndex = tempPair._1 + recvByteArray = tempPair._2 + } catch { + case e: Exception => { + logInfo("Exception during taking data from receivedData") + } + } + + val inputStream = + new ObjectInputStream(new ByteArrayInputStream(recvByteArray)) + + try{ + while (true) { +// logInfo("" + inputStream.readObject.isInstanceOf[(K, C)]) + val (k, c) = inputStream.readObject.asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => { } + } + inputStream.close() + + // Consumption completed. Update stats. + hasSplitsBitVector.synchronized { + hasSplitsBitVector.set(splitIndex) + } + hasSplits += 1 + + // We have received splitIndex + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + + } + } + } + + class ShuffleClient(splitIndex: Int, hostAddress: String, listenPort: Int, + requestSplit: String) + extends Thread with Logging { + private var peerSocketToSource: Socket = null + private var oosSource: ObjectOutputStream = null + private var oisSource: ObjectInputStream = null + + private var receptionSucceeded = false + + override def run: Unit = { + // Setup the timeout mechanism + var timeOutTask = new TimerTask { + override def run: Unit = { + cleanUpConnections() + } + } + + var timeOutTimer = new Timer + timeOutTimer.schedule(timeOutTask, + CustomParallelInMemoryShuffle.MaxKnockInterval) + + logInfo("ShuffleClient started... => %s:%d#%s".format(hostAddress, listenPort, requestSplit)) + + try { + // Connect to the source + peerSocketToSource = new Socket(hostAddress, listenPort) + oosSource = + new ObjectOutputStream(peerSocketToSource.getOutputStream) + oosSource.flush() + var isSource = peerSocketToSource.getInputStream + oisSource = new ObjectInputStream(isSource) + + // Send the request + oosSource.writeObject(requestSplit) + + // Receive the length of the requested file + var requestedFileLen = oisSource.readObject.asInstanceOf[Int] + logInfo("Received requestedFileLen = " + requestedFileLen) + + // Turn the timer OFF, if the sender responds before timeout + timeOutTimer.cancel() + + // Receive the file + if (requestedFileLen != -1) { + val readStartTime = System.currentTimeMillis + logInfo("BEGIN READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + + // Receive data in an Array[Byte] + var recvByteArray = new Array[Byte](requestedFileLen) + var alreadyRead = 0 + var bytesRead = 0 + + while (alreadyRead != requestedFileLen) { + bytesRead = isSource.read(recvByteArray, alreadyRead, + requestedFileLen - alreadyRead) + if (bytesRead > 0) { + alreadyRead = alreadyRead + bytesRead + } + } + + // Make it available to the consumer + try { + receivedData.put((splitIndex, recvByteArray)) + } catch { + case e: Exception => { + logInfo("Exception during putting data into receivedData") + } + } + + // NOTE: Update of bitVectors are now done by the consumer + + receptionSucceeded = true + + logInfo("END READ: http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit)) + val readTime = System.currentTimeMillis - readStartTime + logInfo("Reading http://%s:%d/shuffle/%s".format(hostAddress, listenPort, requestSplit) + " took " + readTime + " millis.") + } else { + throw new SparkException("ShuffleServer " + hostAddress + " does not have " + requestSplit) + } + } catch { + // EOFException is expected to happen because sender can break + // connection due to timeout + case eofe: java.io.EOFException => { } + case e: Exception => { + logInfo("ShuffleClient had a " + e) + } + } finally { + // If reception failed, unset for future retry + if (!receptionSucceeded) { + splitsInRequestBitVector.synchronized { + splitsInRequestBitVector.set(splitIndex, false) + } + } + cleanUpConnections() + } + } + + private def cleanUpConnections(): Unit = { + if (oisSource != null) { + oisSource.close() + } + if (oosSource != null) { + oosSource.close() + } + if (peerSocketToSource != null) { + peerSocketToSource.close() + } + } + } +} + +object CustomParallelInMemoryShuffle extends Logging { + // Cache for keeping the splits around + val splitsCache = new HashMap[String, Array[Byte]] + + // Used thoughout the code for small and large waits/timeouts + private var MinKnockInterval_ = 1000 + private var MaxKnockInterval_ = 5000 + + // Maximum number of connections + private var MaxRxConnections_ = 4 + private var MaxTxConnections_ = 8 + + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + + private var shuffleServer: ShuffleServer = null + private var serverAddress = InetAddress.getLocalHost.getHostAddress + private var serverPort: Int = -1 + + // Random number generator + var ranGen = new Random + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // Load config parameters + MinKnockInterval_ = System.getProperty( + "spark.parallelInMemoryShuffle.minKnockInterval", "1000").toInt + MaxKnockInterval_ = System.getProperty( + "spark.parallelInMemoryShuffle.maxKnockInterval", "5000").toInt + + MaxRxConnections_ = System.getProperty( + "spark.parallelInMemoryShuffle.maxRxConnections", "4").toInt + MaxTxConnections_ = System.getProperty( + "spark.parallelInMemoryShuffle.maxTxConnections", "8").toInt + + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + // Create and start the shuffleServer + shuffleServer = new ShuffleServer + shuffleServer.setDaemon(true) + shuffleServer.start() + logInfo("ShuffleServer started...") + + initialized = true + } + } + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + def MaxRxConnections = MaxRxConnections_ + def MaxTxConnections = MaxTxConnections_ + + def getSplitName(shuffleId: Long, inputId: Int, outputId: Int): String = { + initializeIfNeeded() + // Adding shuffleDir is unnecessary. Added to keep the parsers working + return "%s/%d/%d/%d".format(shuffleDir, shuffleId, inputId, outputId) + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } + + // Returns a standard ThreadFactory except all threads are daemons + private def newDaemonThreadFactory: ThreadFactory = { + new ThreadFactory { + def newThread(r: Runnable): Thread = { + var t = Executors.defaultThreadFactory.newThread(r) + t.setDaemon(true) + return t + } + } + } + + // Wrapper over newFixedThreadPool + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { + var threadPool = + Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] + + threadPool.setThreadFactory(newDaemonThreadFactory) + + return threadPool + } + + class ShuffleServer + extends Thread with Logging { + var threadPool = + newDaemonFixedThreadPool(CustomParallelInMemoryShuffle.MaxTxConnections) + + var serverSocket: ServerSocket = null + + override def run: Unit = { + serverSocket = new ServerSocket(0) + serverPort = serverSocket.getLocalPort + + logInfo("ShuffleServer started with " + serverSocket) + logInfo("Local URI: http://" + serverAddress + ":" + serverPort) + + try { + while (true) { + var clientSocket: Socket = null + try { + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { } + } + if (clientSocket != null) { + logInfo("Serve: Accepted new client connection:" + clientSocket) + try { + threadPool.execute(new ShuffleServerThread(clientSocket)) + } catch { + // In failure, close socket here; else, the thread will close it + case ioe: IOException => { + clientSocket.close() + } + } + } + } + } finally { + if (serverSocket != null) { + logInfo("ShuffleServer now stopping...") + serverSocket.close() + } + } + // Shutdown the thread pool + threadPool.shutdown() + } + + class ShuffleServerThread(val clientSocket: Socket) + extends Thread with Logging { + private val os = clientSocket.getOutputStream.asInstanceOf[OutputStream] + os.flush() + private val bos = new BufferedOutputStream(os) + bos.flush() + private val oos = new ObjectOutputStream(os) + oos.flush() + private val ois = new ObjectInputStream(clientSocket.getInputStream) + + logInfo("new ShuffleServerThread is running") + + override def run: Unit = { + try { + // Receive requestedSplit from the receiver + // Adding shuffleDir is unnecessary. Added to keep the parsers working + var requestedSplit = + shuffleDir + "/" + ois.readObject.asInstanceOf[String] + logInfo("requestedSplit: " + requestedSplit) + + // Send the lendth of the requestedSplit to let the receiver know that + // transfer is about to start + // In the case of receiver timeout and connection close, this will + // throw a java.net.SocketException: Broken pipe + var requestedSplitLen = -1 + + try { + requestedSplitLen = + CustomParallelInMemoryShuffle.splitsCache(requestedSplit).length + } catch { + case e: Exception => { } + } + + oos.writeObject(requestedSplitLen) + oos.flush() + + logInfo("requestedSplitLen = " + requestedSplitLen) + + // Read and send the requested split + if (requestedSplitLen != -1) { + // Send + bos.write(CustomParallelInMemoryShuffle.splitsCache(requestedSplit), + 0, requestedSplitLen) + bos.flush() + } else { + // Close the connection + } + } catch { + // If something went wrong, e.g., the worker at the other end died etc + // then close everything up + // Exception can happen if the receiver stops receiving + case e: Exception => { + logInfo("ShuffleServerThread had a " + e) + } + } finally { + logInfo("ShuffleServerThread is closing streams and sockets") + ois.close() + // TODO: Following can cause "java.net.SocketException: Socket closed" + oos.close() + bos.close() + clientSocket.close() + } + } + } + } +}