diff --git a/core/src/main/java/spark/network/netty/FileServer.java b/core/src/main/java/spark/network/netty/FileServer.java index dd3f12561cb655cfb14b822a0d61376550276196..dd3a557ae59522187f6e15f68d352f5e19c5c900 100644 --- a/core/src/main/java/spark/network/netty/FileServer.java +++ b/core/src/main/java/spark/network/netty/FileServer.java @@ -37,29 +37,33 @@ class FileServer { .childHandler(new FileServerChannelInitializer(pResolver)); // Start the server. channelFuture = bootstrap.bind(addr); - this.port = addr.getPort(); + try { + // Get the address we bound to. + InetSocketAddress boundAddress = + ((InetSocketAddress) channelFuture.sync().channel().localAddress()); + this.port = boundAddress.getPort(); + } catch (InterruptedException ie) { + this.port = 0; + } } /** * Start the file server asynchronously in a new thread. */ public void start() { - try { - blockingThread = new Thread() { - public void run() { - try { - Channel channel = channelFuture.sync().channel(); - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } + blockingThread = new Thread() { + public void run() { + try { + channelFuture.channel().closeFuture().sync(); + LOG.info("FileServer exiting"); + } catch (InterruptedException e) { + LOG.error("File server start got interrupted", e); } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } finally { - bootstrap.shutdown(); - } + // NOTE: bootstrap is shutdown in stop() + } + }; + blockingThread.setDaemon(true); + blockingThread.start(); } public int getPort() { @@ -67,17 +71,16 @@ class FileServer { } public void stop() { - if (blockingThread != null) { - blockingThread.stop(); - blockingThread = null; - } + // Close the bound channel. if (channelFuture != null) { - channelFuture.channel().closeFuture(); + channelFuture.channel().close(); channelFuture = null; } + // Shutdown bootstrap. if (bootstrap != null) { bootstrap.shutdown(); bootstrap = null; } + // TODO: Shutdown all accepted channels as well ? } } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 57d4dafefc56a10728e4e62a86b3a41460ddad01..c7281200e7e0086660e9cfeb65d288ce28f7875b 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -59,6 +59,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() bs.flush() val prevPos = lastValidPosition lastValidPosition = channel.position() @@ -68,6 +70,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) override def revertPartialWrites() { // Discard current writes. We do this by flushing the outstanding writes and // truncate the file to the last valid position. + objOut.flush() bs.flush() channel.truncate(lastValidPosition) } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 4e50ae2ca9821338def50771620c5dd4b48abd27..b967016cf726791b543781a9f42cf8c9607aab71 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -305,9 +305,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.partitioner.get === p) } + test("shuffle non-zero block size") { + sc = new SparkContext("local-cluster[2,1,512]", "test") + val NUM_BLOCKS = 3 + + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS), + classOf[spark.KryoSerializer].getName) + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + + assert(c.count === 10) + + // All blocks must have non-zero size + (0 until NUM_BLOCKS).foreach { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(statuses.forall(s => s._2 > 0)) + } + } + test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[1,2,512]", "test") + sc = new SparkContext("local-cluster[2,1,512]", "test") val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))