diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index c0f1da50f5e653e8704e06870f2a8e657d1857b9..fc7bba41185f00509221799b0fbcd25c16360318 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -44,7 +44,6 @@ import org.apache.spark.network.shuffle.protocol.*; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import org.apache.spark.network.util.TransportConf; - /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. * @@ -91,26 +90,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler { try { OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - - Iterator<ManagedBuffer> iter = new Iterator<ManagedBuffer>() { - private int index = 0; - - @Override - public boolean hasNext() { - return index < msg.blockIds.length; - } - - @Override - public ManagedBuffer next() { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, - msg.blockIds[index]); - index++; - metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); - return block; - } - }; - - long streamId = streamManager.registerStream(client.getClientId(), iter); + long streamId = streamManager.registerStream(client.getClientId(), + new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -209,4 +190,51 @@ public class ExternalShuffleBlockHandler extends RpcHandler { } } + private class ManagedBufferIterator implements Iterator<ManagedBuffer> { + + private int index = 0; + private final String appId; + private final String execId; + private final int shuffleId; + // An array containing mapId and reduceId pairs. + private final int[] mapIdAndReduceIds; + + ManagedBufferIterator(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + String[] blockId0Parts = blockIds[0].split("_"); + if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); + } + this.shuffleId = Integer.parseInt(blockId0Parts[1]); + mapIdAndReduceIds = new int[2 * blockIds.length]; + for (int i = 0; i < blockIds.length; i++) { + String[] blockIdParts = blockIds[i].split("_"); + if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); + } + if (Integer.parseInt(blockIdParts[1]) != shuffleId) { + throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + + ", got:" + blockIds[i]); + } + mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); + mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); + } + } + + @Override + public boolean hasNext() { + return index < mapIdAndReduceIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, + mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + index += 2; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + } + } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 62d58aba4c1e7d811b0b94dca5ab13d1404f52c8..d7ec0e299deadfa7a3de3391039bfe8a5d265697 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -150,27 +150,20 @@ public class ExternalShuffleBlockResolver { } /** - * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the - * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make - * assumptions about how the hash and sort based shuffles store their data. + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions + * about how the hash and sort based shuffles store their data. */ - public ManagedBuffer getBlockData(String appId, String execId, String blockId) { - String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length < 4) { - throw new IllegalArgumentException("Unexpected block id format: " + blockId); - } else if (!blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); - } - int shuffleId = Integer.parseInt(blockIdParts[1]); - int mapId = Integer.parseInt(blockIdParts[2]); - int reduceId = Integer.parseInt(blockIdParts[3]); - + public ManagedBuffer getBlockData( + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 0c054fc5db8f4041d2770344bc3772ac92f9fd9d..8110f1e004c736565d9b4f8add6f1b1844afb51d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -202,7 +202,7 @@ public class SaslIntegrationSuite { } }; - String[] blockIds = { "shuffle_2_3_4", "shuffle_6_7_8" }; + String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf, null); fetcher.start(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 4d48b1897038674f4278c894a68a6285f5b81f1d..7846b71d5a8b14ccac4871e4f2caec57cb120b75 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -83,9 +83,10 @@ public class ExternalShuffleBlockHandlerSuite { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); @@ -105,8 +106,8 @@ public class ExternalShuffleBlockHandlerSuite { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index bc97594903bef1e36b6650417eb9858024b59eba..23438a08fa09475b18ec8d32ff39622d00c06467 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -65,7 +65,7 @@ public class ExternalShuffleBlockResolverSuite { ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { - resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec1", 1, 1, 0); fail("Should have failed"); } catch (RuntimeException e) { assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); @@ -74,7 +74,7 @@ public class ExternalShuffleBlockResolverSuite { // Invalid shuffle manager try { resolver.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); - resolver.getBlockData("app0", "exec2", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec2", 1, 1, 0); fail("Should have failed"); } catch (UnsupportedOperationException e) { // pass @@ -84,7 +84,7 @@ public class ExternalShuffleBlockResolverSuite { resolver.registerExecutor("app0", "exec3", dataContext.createExecutorInfo(SORT_MANAGER)); try { - resolver.getBlockData("app0", "exec3", "shuffle_1_1_0"); + resolver.getBlockData("app0", "exec3", 1, 1, 0); fail("Should have failed"); } catch (Exception e) { // pass @@ -98,14 +98,14 @@ public class ExternalShuffleBlockResolverSuite { dataContext.createExecutorInfo(SORT_MANAGER)); InputStream block0Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 0).createInputStream(); String block0 = CharStreams.toString( new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); block0Stream.close(); assertEquals(sortBlock0, block0); InputStream block1Stream = - resolver.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + resolver.getBlockData("app0", "exec0", 0, 0, 1).createInputStream(); String block1 = CharStreams.toString( new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); block1Stream.close(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index d1d8f5b4e188a0cba3e2e3b875887c7ad221ab4c..4391e3023491bd9b4fa9d9b25719dd119fb4b153 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -214,10 +214,10 @@ public class ExternalShuffleIntegrationSuite { @Test public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); - assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + FetchResult execFetch0 = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */}); + FetchResult execFetch1 = fetchBlocks("exec-0", new String[] { "shuffle_1_0_0" /* wrong */ }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch0.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch1.failedBlocks); } @Test