Skip to content
Snippets Groups Projects
Commit c1f85fc7 authored by Marcelo Vanzin's avatar Marcelo Vanzin
Browse files

[SPARK-11956][CORE] Fix a few bugs in network lib-based file transfer.

- NettyRpcEnv::openStream() now correctly propagates errors to
  the read side of the pipe.
- NettyStreamManager now throws if the file being transferred does
  not exist.
- The network library now correctly handles zero-sized streams.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #9941 from vanzin/SPARK-11956.
parent 0a5aef75
No related branches found
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ import javax.annotation.Nullable ...@@ -27,7 +27,7 @@ import javax.annotation.Nullable
import scala.concurrent.{Future, Promise} import scala.concurrent.{Future, Promise}
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.util.{DynamicVariable, Failure, Success} import scala.util.{DynamicVariable, Failure, Success, Try}
import scala.util.control.NonFatal import scala.util.control.NonFatal
import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.{Logging, SecurityManager, SparkConf}
...@@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv( ...@@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv(
@volatile private var error: Throwable = _ @volatile private var error: Throwable = _
def setError(e: Throwable): Unit = error = e def setError(e: Throwable): Unit = {
error = e
source.close()
}
override def read(dst: ByteBuffer): Int = { override def read(dst: ByteBuffer): Int = {
if (error != null) { val result = if (error == null) {
throw error Try(source.read(dst))
} else {
Failure(error)
}
result match {
case Success(bytesRead) => bytesRead
case Failure(error) => throw error
} }
source.read(dst)
} }
override def close(): Unit = source.close() override def close(): Unit = source.close()
......
...@@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) ...@@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv)
case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype")
} }
require(file != null, s"File not found: $streamId") require(file != null && file.isFile(), s"File not found: $streamId")
new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length())
} }
......
...@@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { ...@@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
val tempDir = Utils.createTempDir() val tempDir = Utils.createTempDir()
val file = new File(tempDir, "file") val file = new File(tempDir, "file")
Files.write(UUID.randomUUID().toString(), file, UTF_8) Files.write(UUID.randomUUID().toString(), file, UTF_8)
val empty = new File(tempDir, "empty")
Files.write("", empty, UTF_8);
val jar = new File(tempDir, "jar") val jar = new File(tempDir, "jar")
Files.write(UUID.randomUUID().toString(), jar, UTF_8) Files.write(UUID.randomUUID().toString(), jar, UTF_8)
val fileUri = env.fileServer.addFile(file) val fileUri = env.fileServer.addFile(file)
val emptyUri = env.fileServer.addFile(empty)
val jarUri = env.fileServer.addJar(jar) val jarUri = env.fileServer.addJar(jar)
val destDir = Utils.createTempDir() val destDir = Utils.createTempDir()
val destFile = new File(destDir, file.getName())
val destJar = new File(destDir, jar.getName())
val sm = new SecurityManager(conf) val sm = new SecurityManager(conf)
val hc = SparkHadoopUtil.get.conf val hc = SparkHadoopUtil.get.conf
Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false)
Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false)
assert(Files.equal(file, destFile)) val files = Seq(
assert(Files.equal(jar, destJar)) (file, fileUri),
(empty, emptyUri),
(jar, jarUri))
files.foreach { case (f, uri) =>
val destFile = new File(destDir, f.getName())
Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false)
assert(Files.equal(f, destFile))
}
// Try to download files that do not exist.
Seq("files", "jars").foreach { root =>
intercept[Exception] {
val uri = env.address.toSparkURL + s"/$root/doesNotExist"
Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false)
}
}
} }
} }
......
...@@ -185,16 +185,24 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { ...@@ -185,16 +185,24 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
StreamResponse resp = (StreamResponse) message; StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll(); StreamCallback callback = streamCallbacks.poll();
if (callback != null) { if (callback != null) {
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, if (resp.byteCount > 0) {
callback); StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
try { callback);
TransportFrameDecoder frameDecoder = (TransportFrameDecoder) try {
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
frameDecoder.setInterceptor(interceptor); channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
streamActive = true; frameDecoder.setInterceptor(interceptor);
} catch (Exception e) { streamActive = true;
logger.error("Error installing stream handler.", e); } catch (Exception e) {
deactivateStream(); logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
try {
callback.onComplete(resp.streamId);
} catch (Exception e) {
logger.warn("Error in stream handler onComplete().", e);
}
} }
} else { } else {
logger.error("Could not find callback for StreamResponse."); logger.error("Could not find callback for StreamResponse.");
......
...@@ -51,13 +51,14 @@ import org.apache.spark.network.util.SystemPropertyConfigProvider; ...@@ -51,13 +51,14 @@ import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.util.TransportConf;
public class StreamSuite { public class StreamSuite {
private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" };
private static TransportServer server; private static TransportServer server;
private static TransportClientFactory clientFactory; private static TransportClientFactory clientFactory;
private static File testFile; private static File testFile;
private static File tempDir; private static File tempDir;
private static ByteBuffer emptyBuffer;
private static ByteBuffer smallBuffer; private static ByteBuffer smallBuffer;
private static ByteBuffer largeBuffer; private static ByteBuffer largeBuffer;
...@@ -73,6 +74,7 @@ public class StreamSuite { ...@@ -73,6 +74,7 @@ public class StreamSuite {
@BeforeClass @BeforeClass
public static void setUp() throws Exception { public static void setUp() throws Exception {
tempDir = Files.createTempDir(); tempDir = Files.createTempDir();
emptyBuffer = createBuffer(0);
smallBuffer = createBuffer(100); smallBuffer = createBuffer(100);
largeBuffer = createBuffer(100000); largeBuffer = createBuffer(100000);
...@@ -103,6 +105,8 @@ public class StreamSuite { ...@@ -103,6 +105,8 @@ public class StreamSuite {
return new NioManagedBuffer(largeBuffer); return new NioManagedBuffer(largeBuffer);
case "smallBuffer": case "smallBuffer":
return new NioManagedBuffer(smallBuffer); return new NioManagedBuffer(smallBuffer);
case "emptyBuffer":
return new NioManagedBuffer(emptyBuffer);
case "file": case "file":
return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length());
default: default:
...@@ -138,6 +142,18 @@ public class StreamSuite { ...@@ -138,6 +142,18 @@ public class StreamSuite {
} }
} }
@Test
public void testZeroLengthStream() throws Throwable {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
try {
StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5));
task.run();
task.check();
} finally {
client.close();
}
}
@Test @Test
public void testSingleStream() throws Throwable { public void testSingleStream() throws Throwable {
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
...@@ -226,6 +242,11 @@ public class StreamSuite { ...@@ -226,6 +242,11 @@ public class StreamSuite {
outFile = File.createTempFile("data", ".tmp", tempDir); outFile = File.createTempFile("data", ".tmp", tempDir);
out = new FileOutputStream(outFile); out = new FileOutputStream(outFile);
break; break;
case "emptyBuffer":
baos = new ByteArrayOutputStream();
out = baos;
srcBuffer = emptyBuffer;
break;
default: default:
throw new IllegalArgumentException(streamId); throw new IllegalArgumentException(streamId);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment