diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 68701f609f77a5bb8609ddea9dd0bbc5a7d4db87..c8fa870f50e6865142f18988d1fa170166055479 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -27,7 +27,7 @@ import javax.annotation.Nullable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag -import scala.util.{DynamicVariable, Failure, Success} +import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -368,13 +368,22 @@ private[netty] class NettyRpcEnv( @volatile private var error: Throwable = _ - def setError(e: Throwable): Unit = error = e + def setError(e: Throwable): Unit = { + error = e + source.close() + } override def read(dst: ByteBuffer): Int = { - if (error != null) { - throw error + val result = if (error == null) { + Try(source.read(dst)) + } else { + Failure(error) + } + + result match { + case Success(bytesRead) => bytesRead + case Failure(error) => throw error } - source.read(dst) } override def close(): Unit = source.close() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index eb1d2604fb2357558a83b3a2de9c0e057857262a..a2768b4252dcb788fb9ddcc9b88d55701f891cd5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -44,7 +44,7 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") } - require(file != null, s"File not found: $streamId") + require(file != null && file.isFile(), s"File not found: $streamId") new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 2b664c6313efa9c8707b3e4721f9e99e5b291c48..6cc958a5f6bc81b27195f9388a3f6c373a21c44f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -729,23 +729,36 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val empty = new File(tempDir, "empty") + Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") Files.write(UUID.randomUUID().toString(), jar, UTF_8) val fileUri = env.fileServer.addFile(file) + val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val destDir = Utils.createTempDir() - val destFile = new File(destDir, file.getName()) - val destJar = new File(destDir, jar.getName()) - val sm = new SecurityManager(conf) val hc = SparkHadoopUtil.get.conf - Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) - Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) - assert(Files.equal(file, destFile)) - assert(Files.equal(jar, destJar)) + val files = Seq( + (file, fileUri), + (empty, emptyUri), + (jar, jarUri)) + files.foreach { case (f, uri) => + val destFile = new File(destDir, f.getName()) + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + assert(Files.equal(f, destFile)) + } + + // Try to download files that do not exist. + Seq("files", "jars").foreach { root => + intercept[Exception] { + val uri = env.address.toSparkURL + s"/$root/doesNotExist" + Utils.fetchFile(uri, destDir, conf, sm, hc, 0L, false) + } + } } } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index be181e0660826a335eea7939d009f57a965b333b..4c15045363b847b95ebcdcfc18a4676e61209282 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -185,16 +185,24 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { StreamResponse resp = (StreamResponse) message; StreamCallback callback = streamCallbacks.poll(); if (callback != null) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } } } else { logger.error("Could not find callback for StreamResponse."); diff --git a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java index 00158fd0816262806936c8fa034b9821aed995e4..538f3efe8d6f22129d169fb538726d9cfee8742d 100644 --- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -51,13 +51,14 @@ import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; public class StreamSuite { - private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "file" }; + private static final String[] STREAMS = { "largeBuffer", "smallBuffer", "emptyBuffer", "file" }; private static TransportServer server; private static TransportClientFactory clientFactory; private static File testFile; private static File tempDir; + private static ByteBuffer emptyBuffer; private static ByteBuffer smallBuffer; private static ByteBuffer largeBuffer; @@ -73,6 +74,7 @@ public class StreamSuite { @BeforeClass public static void setUp() throws Exception { tempDir = Files.createTempDir(); + emptyBuffer = createBuffer(0); smallBuffer = createBuffer(100); largeBuffer = createBuffer(100000); @@ -103,6 +105,8 @@ public class StreamSuite { return new NioManagedBuffer(largeBuffer); case "smallBuffer": return new NioManagedBuffer(smallBuffer); + case "emptyBuffer": + return new NioManagedBuffer(emptyBuffer); case "file": return new FileSegmentManagedBuffer(conf, testFile, 0, testFile.length()); default: @@ -138,6 +142,18 @@ public class StreamSuite { } } + @Test + public void testZeroLengthStream() throws Throwable { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + StreamTask task = new StreamTask(client, "emptyBuffer", TimeUnit.SECONDS.toMillis(5)); + task.run(); + task.check(); + } finally { + client.close(); + } + } + @Test public void testSingleStream() throws Throwable { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); @@ -226,6 +242,11 @@ public class StreamSuite { outFile = File.createTempFile("data", ".tmp", tempDir); out = new FileOutputStream(outFile); break; + case "emptyBuffer": + baos = new ByteArrayOutputStream(); + out = baos; + srcBuffer = emptyBuffer; + break; default: throw new IllegalArgumentException(streamId); }