Skip to content
Snippets Groups Projects
Commit ef77bb73 authored by Reynold Xin's avatar Reynold Xin
Browse files

Merge pull request #627 from shivaram/master

Netty and shuffle  bug fixes
parents 8cb81782 3b0cd173
No related branches found
No related tags found
No related merge requests found
...@@ -37,29 +37,33 @@ class FileServer { ...@@ -37,29 +37,33 @@ class FileServer {
.childHandler(new FileServerChannelInitializer(pResolver)); .childHandler(new FileServerChannelInitializer(pResolver));
// Start the server. // Start the server.
channelFuture = bootstrap.bind(addr); channelFuture = bootstrap.bind(addr);
this.port = addr.getPort(); try {
// Get the address we bound to.
InetSocketAddress boundAddress =
((InetSocketAddress) channelFuture.sync().channel().localAddress());
this.port = boundAddress.getPort();
} catch (InterruptedException ie) {
this.port = 0;
}
} }
/** /**
* Start the file server asynchronously in a new thread. * Start the file server asynchronously in a new thread.
*/ */
public void start() { public void start() {
try { blockingThread = new Thread() {
blockingThread = new Thread() { public void run() {
public void run() { try {
try { channelFuture.channel().closeFuture().sync();
Channel channel = channelFuture.sync().channel(); LOG.info("FileServer exiting");
channel.closeFuture().sync(); } catch (InterruptedException e) {
} catch (InterruptedException e) { LOG.error("File server start got interrupted", e);
LOG.error("File server start got interrupted", e);
}
} }
}; // NOTE: bootstrap is shutdown in stop()
blockingThread.setDaemon(true); }
blockingThread.start(); };
} finally { blockingThread.setDaemon(true);
bootstrap.shutdown(); blockingThread.start();
}
} }
public int getPort() { public int getPort() {
...@@ -67,17 +71,16 @@ class FileServer { ...@@ -67,17 +71,16 @@ class FileServer {
} }
public void stop() { public void stop() {
if (blockingThread != null) { // Close the bound channel.
blockingThread.stop();
blockingThread = null;
}
if (channelFuture != null) { if (channelFuture != null) {
channelFuture.channel().closeFuture(); channelFuture.channel().close();
channelFuture = null; channelFuture = null;
} }
// Shutdown bootstrap.
if (bootstrap != null) { if (bootstrap != null) {
bootstrap.shutdown(); bootstrap.shutdown();
bootstrap = null; bootstrap = null;
} }
// TODO: Shutdown all accepted channels as well ?
} }
} }
...@@ -59,6 +59,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) ...@@ -59,6 +59,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file. // Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit. // Return the number of bytes written for this commit.
override def commit(): Long = { override def commit(): Long = {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush() bs.flush()
val prevPos = lastValidPosition val prevPos = lastValidPosition
lastValidPosition = channel.position() lastValidPosition = channel.position()
...@@ -68,6 +70,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) ...@@ -68,6 +70,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def revertPartialWrites() { override def revertPartialWrites() {
// Discard current writes. We do this by flushing the outstanding writes and // Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position. // truncate the file to the last valid position.
objOut.flush()
bs.flush() bs.flush()
channel.truncate(lastValidPosition) channel.truncate(lastValidPosition)
} }
......
...@@ -305,9 +305,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { ...@@ -305,9 +305,32 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(c.partitioner.get === p) assert(c.partitioner.get === p)
} }
test("shuffle non-zero block size") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
val NUM_BLOCKS = 3
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
}
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS),
classOf[spark.KryoSerializer].getName)
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 10)
// All blocks must have non-zero size
(0 until NUM_BLOCKS).foreach { id =>
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
assert(statuses.forall(s => s._2 > 0))
}
}
test("shuffle serializer") { test("shuffle serializer") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks // Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[1,2,512]", "test") sc = new SparkContext("local-cluster[2,1,512]", "test")
val a = sc.parallelize(1 to 10, 2) val a = sc.parallelize(1 to 10, 2)
val b = a.map { x => val b = a.map { x =>
(x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) (x, new ShuffleSuite.NonJavaSerializableClass(x * 2))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment