Skip to content
Snippets Groups Projects
Commit 6a06c4b0 authored by jinxing's avatar jinxing Committed by Wenchen Fan
Browse files

[SPARK-21342] Fix DownloadCallback to work well with RetryingBlockFetcher.

## What changes were proposed in this pull request?

When `RetryingBlockFetcher` retries fetching blocks. There could be two `DownloadCallback`s download the same content to the same target file. It could cause `ShuffleBlockFetcherIterator` reading a partial result.

This pr proposes to create and delete the tmp files in `OneForOneBlockFetcher`

Author: jinxing <jinxing6042@126.com>
Author: Shixiong Zhu <zsxwing@gmail.com>

Closes #18565 from jinxing64/SPARK-21342.
parent 647963a2
No related branches found
No related tags found
No related merge requests found
Showing
with 108 additions and 45 deletions
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.apache.spark.network.shuffle; package org.apache.spark.network.shuffle;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
...@@ -91,15 +90,15 @@ public class ExternalShuffleClient extends ShuffleClient { ...@@ -91,15 +90,15 @@ public class ExternalShuffleClient extends ShuffleClient {
String execId, String execId,
String[] blockIds, String[] blockIds,
BlockFetchingListener listener, BlockFetchingListener listener,
File[] shuffleFiles) { TempShuffleFileManager tempShuffleFileManager) {
checkInit(); checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try { try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> { (blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port); TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, new OneForOneBlockFetcher(client, appId, execId,
shuffleFiles).start(); blockIds1, listener1, conf, tempShuffleFileManager).start();
}; };
int maxRetries = conf.maxIORetries(); int maxRetries = conf.maxIORetries();
......
...@@ -57,11 +57,21 @@ public class OneForOneBlockFetcher { ...@@ -57,11 +57,21 @@ public class OneForOneBlockFetcher {
private final String[] blockIds; private final String[] blockIds;
private final BlockFetchingListener listener; private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback; private final ChunkReceivedCallback chunkCallback;
private TransportConf transportConf = null; private final TransportConf transportConf;
private File[] shuffleFiles = null; private final TempShuffleFileManager tempShuffleFileManager;
private StreamHandle streamHandle = null; private StreamHandle streamHandle = null;
public OneForOneBlockFetcher(
TransportClient client,
String appId,
String execId,
String[] blockIds,
BlockFetchingListener listener,
TransportConf transportConf) {
this(client, appId, execId, blockIds, listener, transportConf, null);
}
public OneForOneBlockFetcher( public OneForOneBlockFetcher(
TransportClient client, TransportClient client,
String appId, String appId,
...@@ -69,18 +79,14 @@ public class OneForOneBlockFetcher { ...@@ -69,18 +79,14 @@ public class OneForOneBlockFetcher {
String[] blockIds, String[] blockIds,
BlockFetchingListener listener, BlockFetchingListener listener,
TransportConf transportConf, TransportConf transportConf,
File[] shuffleFiles) { TempShuffleFileManager tempShuffleFileManager) {
this.client = client; this.client = client;
this.openMessage = new OpenBlocks(appId, execId, blockIds); this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds; this.blockIds = blockIds;
this.listener = listener; this.listener = listener;
this.chunkCallback = new ChunkCallback(); this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf; this.transportConf = transportConf;
if (shuffleFiles != null) { this.tempShuffleFileManager = tempShuffleFileManager;
this.shuffleFiles = shuffleFiles;
assert this.shuffleFiles.length == blockIds.length:
"Number of shuffle files should equal to blocks";
}
} }
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
...@@ -119,9 +125,9 @@ public class OneForOneBlockFetcher { ...@@ -119,9 +125,9 @@ public class OneForOneBlockFetcher {
// Immediately request all chunks -- we expect that the total size of the request is // Immediately request all chunks -- we expect that the total size of the request is
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
for (int i = 0; i < streamHandle.numChunks; i++) { for (int i = 0; i < streamHandle.numChunks; i++) {
if (shuffleFiles != null) { if (tempShuffleFileManager != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(shuffleFiles[i], i)); new DownloadCallback(i));
} else { } else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback); client.fetchChunk(streamHandle.streamId, i, chunkCallback);
} }
...@@ -157,8 +163,8 @@ public class OneForOneBlockFetcher { ...@@ -157,8 +163,8 @@ public class OneForOneBlockFetcher {
private File targetFile = null; private File targetFile = null;
private int chunkIndex; private int chunkIndex;
DownloadCallback(File targetFile, int chunkIndex) throws IOException { DownloadCallback(int chunkIndex) throws IOException {
this.targetFile = targetFile; this.targetFile = tempShuffleFileManager.createTempShuffleFile();
this.channel = Channels.newChannel(new FileOutputStream(targetFile)); this.channel = Channels.newChannel(new FileOutputStream(targetFile));
this.chunkIndex = chunkIndex; this.chunkIndex = chunkIndex;
} }
...@@ -174,6 +180,9 @@ public class OneForOneBlockFetcher { ...@@ -174,6 +180,9 @@ public class OneForOneBlockFetcher {
ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
targetFile.length()); targetFile.length());
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) {
targetFile.delete();
}
} }
@Override @Override
...@@ -182,6 +191,7 @@ public class OneForOneBlockFetcher { ...@@ -182,6 +191,7 @@ public class OneForOneBlockFetcher {
// On receipt of a failure, fail every block from chunkIndex onwards. // On receipt of a failure, fail every block from chunkIndex onwards.
String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length);
failRemainingBlocks(remainingBlockIds, cause); failRemainingBlocks(remainingBlockIds, cause);
targetFile.delete();
} }
} }
} }
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
package org.apache.spark.network.shuffle; package org.apache.spark.network.shuffle;
import java.io.Closeable; import java.io.Closeable;
import java.io.File;
/** Provides an interface for reading shuffle files, either from an Executor or external service. */ /** Provides an interface for reading shuffle files, either from an Executor or external service. */
public abstract class ShuffleClient implements Closeable { public abstract class ShuffleClient implements Closeable {
...@@ -35,6 +34,16 @@ public abstract class ShuffleClient implements Closeable { ...@@ -35,6 +34,16 @@ public abstract class ShuffleClient implements Closeable {
* Note that this API takes a sequence so the implementation can batch requests, and does not * Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched. * the data of a block is fetched, rather than waiting for all blocks to be fetched.
*
* @param host the host of the remote node.
* @param port the port of the remote node.
* @param execId the executor id.
* @param blockIds block ids to fetch.
* @param listener the listener to receive block fetching status.
* @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files.
* If it's not <code>null</code>, the remote blocks will be streamed
* into temp shuffle files to reduce the memory usage, otherwise,
* they will be kept in memory.
*/ */
public abstract void fetchBlocks( public abstract void fetchBlocks(
String host, String host,
...@@ -42,5 +51,5 @@ public abstract class ShuffleClient implements Closeable { ...@@ -42,5 +51,5 @@ public abstract class ShuffleClient implements Closeable {
String execId, String execId,
String[] blockIds, String[] blockIds,
BlockFetchingListener listener, BlockFetchingListener listener,
File[] shuffleFiles); TempShuffleFileManager tempShuffleFileManager);
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.shuffle;
import java.io.File;
/**
* A manager to create temp shuffle block files to reduce the memory usage and also clean temp
* files when they won't be used any more.
*/
public interface TempShuffleFileManager {
/** Create a temp shuffle block file. */
File createTempShuffleFile();
/**
* Register a temp shuffle file to clean up when it won't be used any more. Return whether the
* file is registered successfully. If `false`, the caller should clean up the file by itself.
*/
boolean registerTempShuffleFileToClean(File file);
}
...@@ -204,7 +204,7 @@ public class SaslIntegrationSuite { ...@@ -204,7 +204,7 @@ public class SaslIntegrationSuite {
String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" };
OneForOneBlockFetcher fetcher = OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf);
fetcher.start(); fetcher.start();
blockFetchLatch.await(); blockFetchLatch.await();
checkSecurityException(exception.get()); checkSecurityException(exception.get());
......
...@@ -131,7 +131,7 @@ public class OneForOneBlockFetcherSuite { ...@@ -131,7 +131,7 @@ public class OneForOneBlockFetcherSuite {
BlockFetchingListener listener = mock(BlockFetchingListener.class); BlockFetchingListener listener = mock(BlockFetchingListener.class);
String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
OneForOneBlockFetcher fetcher = OneForOneBlockFetcher fetcher =
new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf, null); new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf);
// Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123 // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
package org.apache.spark.network package org.apache.spark.network
import java.io.{Closeable, File} import java.io.Closeable
import java.nio.ByteBuffer import java.nio.ByteBuffer
import scala.concurrent.{Future, Promise} import scala.concurrent.{Future, Promise}
...@@ -26,7 +26,7 @@ import scala.reflect.ClassTag ...@@ -26,7 +26,7 @@ import scala.reflect.ClassTag
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.ThreadUtils import org.apache.spark.util.ThreadUtils
...@@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ...@@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
execId: String, execId: String,
blockIds: Array[String], blockIds: Array[String],
listener: BlockFetchingListener, listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit tempShuffleFileManager: TempShuffleFileManager): Unit
/** /**
* Upload a single block to a remote node, available only after [[init]] is invoked. * Upload a single block to a remote node, available only after [[init]] is invoked.
...@@ -101,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo ...@@ -101,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
ret.flip() ret.flip()
result.success(new NioManagedBuffer(ret)) result.success(new NioManagedBuffer(ret))
} }
}, shuffleFiles = null) }, tempShuffleFileManager = null)
ThreadUtils.awaitResult(result.future, Duration.Inf) ThreadUtils.awaitResult(result.future, Duration.Inf)
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.apache.spark.network.netty package org.apache.spark.network.netty
import java.io.File
import java.nio.ByteBuffer import java.nio.ByteBuffer
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
...@@ -30,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer ...@@ -30,7 +29,7 @@ import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
import org.apache.spark.network.server._ import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager}
import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.util.JavaUtils import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.JavaSerializer
...@@ -90,14 +89,14 @@ private[spark] class NettyBlockTransferService( ...@@ -90,14 +89,14 @@ private[spark] class NettyBlockTransferService(
execId: String, execId: String,
blockIds: Array[String], blockIds: Array[String],
listener: BlockFetchingListener, listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = { tempShuffleFileManager: TempShuffleFileManager): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)") logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try { try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port) val client = clientFactory.createClient(host, port)
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener, new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, shuffleFiles).start() transportConf, tempShuffleFileManager).start()
} }
} }
......
...@@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} ...@@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBufferOutputStream import org.apache.spark.util.io.ChunkedByteBufferOutputStream
...@@ -66,7 +66,7 @@ final class ShuffleBlockFetcherIterator( ...@@ -66,7 +66,7 @@ final class ShuffleBlockFetcherIterator(
maxReqsInFlight: Int, maxReqsInFlight: Int,
maxReqSizeShuffleToMem: Long, maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean) detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging { extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging {
import ShuffleBlockFetcherIterator._ import ShuffleBlockFetcherIterator._
...@@ -135,7 +135,8 @@ final class ShuffleBlockFetcherIterator( ...@@ -135,7 +135,8 @@ final class ShuffleBlockFetcherIterator(
* A set to store the files used for shuffling remote huge blocks. Files in this set will be * A set to store the files used for shuffling remote huge blocks. Files in this set will be
* deleted when cleanup. This is a layer of defensiveness against disk file leaks. * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
*/ */
val shuffleFilesSet = mutable.HashSet[File]() @GuardedBy("this")
private[this] val shuffleFilesSet = mutable.HashSet[File]()
initialize() initialize()
...@@ -149,6 +150,19 @@ final class ShuffleBlockFetcherIterator( ...@@ -149,6 +150,19 @@ final class ShuffleBlockFetcherIterator(
currentResult = null currentResult = null
} }
override def createTempShuffleFile(): File = {
blockManager.diskBlockManager.createTempLocalBlock()._2
}
override def registerTempShuffleFileToClean(file: File): Boolean = synchronized {
if (isZombie) {
false
} else {
shuffleFilesSet += file
true
}
}
/** /**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/ */
...@@ -176,7 +190,7 @@ final class ShuffleBlockFetcherIterator( ...@@ -176,7 +190,7 @@ final class ShuffleBlockFetcherIterator(
} }
shuffleFilesSet.foreach { file => shuffleFilesSet.foreach { file =>
if (!file.delete()) { if (!file.delete()) {
logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath()); logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath())
} }
} }
} }
...@@ -221,12 +235,8 @@ final class ShuffleBlockFetcherIterator( ...@@ -221,12 +235,8 @@ final class ShuffleBlockFetcherIterator(
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly. // the data and write it to file directly.
if (req.size > maxReqSizeShuffleToMem) { if (req.size > maxReqSizeShuffleToMem) {
val shuffleFiles = blockIds.map { _ =>
blockManager.diskBlockManager.createTempLocalBlock()._2
}.toArray
shuffleFilesSet ++= shuffleFiles
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, shuffleFiles) blockFetchingListener, this)
} else { } else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null) blockFetchingListener, null)
......
...@@ -45,7 +45,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} ...@@ -45,7 +45,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf}
import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap}
import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
import org.apache.spark.rpc.RpcEnv import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.scheduler.LiveListenerBus
...@@ -1382,7 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE ...@@ -1382,7 +1382,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
execId: String, execId: String,
blockIds: Array[String], blockIds: Array[String],
listener: BlockFetchingListener, listener: BlockFetchingListener,
shuffleFiles: Array[File]): Unit = { tempShuffleFileManager: TempShuffleFileManager): Unit = {
listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
} }
......
...@@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester ...@@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester
import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.{SparkFunSuite, TaskContext}
import org.apache.spark.network._ import org.apache.spark.network._
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager}
import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ...@@ -432,12 +432,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val remoteBlocks = Map[BlockId, ManagedBuffer]( val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
val transfer = mock(classOf[BlockTransferService]) val transfer = mock(classOf[BlockTransferService])
var shuffleFiles: Array[File] = null var tempShuffleFileManager: TempShuffleFileManager = null
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] { .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = { override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager]
Future { Future {
listener.onBlockFetchSuccess( listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
...@@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ...@@ -466,13 +466,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
fetchShuffleBlock(blocksByAddress1) fetchShuffleBlock(blocksByAddress1)
// `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
// shuffle block to disk. // shuffle block to disk.
assert(shuffleFiles === null) assert(tempShuffleFileManager == null)
val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
fetchShuffleBlock(blocksByAddress2) fetchShuffleBlock(blocksByAddress2)
// `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
// shuffle block to disk. // shuffle block to disk.
assert(shuffleFiles != null) assert(tempShuffleFileManager != null)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment