Skip to content
Snippets Groups Projects
Commit 1c31ebd1 authored by Marcelo Vanzin's avatar Marcelo Vanzin Committed by Reynold Xin
Browse files

[SPARK-6578] [core] Fix thread-safety issue in outbound path of network library.


While the inbound path of a netty pipeline is thread-safe, the outbound
path is not. That means that multiple threads can compete to write messages
to the next stage of the pipeline.

The network library sometimes breaks a single RPC message into multiple
buffers internally to avoid copying data (see MessageEncoder). This can
result in the following scenario (where "FxBy" means "frame x, buffer y"):

               T1         F1B1            F1B2
                            \               \
                             \               \
               socket        F1B1   F2B1    F1B2  F2B2
                                     /             /
                                    /             /
               T2                  F2B1         F2B2

And the frames now cannot be rebuilt on the receiving side because the
different messages have been mixed up on the wire.

The fix wraps these multi-buffer messages into a `FileRegion` object
so that these messages are written "atomically" to the next pipeline handler.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #5234 from vanzin/SPARK-6578 and squashes the following commits:

16b2d70 [Marcelo Vanzin] Forgot to update a type.
c9c2e4e [Marcelo Vanzin] Review comments: simplify some code.
9c888ac [Marcelo Vanzin] Small style nits.
8474bab [Marcelo Vanzin] Fix multiple calls to MessageWithHeader.transferTo().
e26509f [Marcelo Vanzin] Merge branch 'master' into SPARK-6578
c503f6c [Marcelo Vanzin] Implement a custom FileRegion instead of using locks.
84aa7ce [Marcelo Vanzin] Rename handler to the correct name.
432f3bd [Marcelo Vanzin] Remove unneeded method.
8d70e60 [Marcelo Vanzin] Fix thread-safety issue in outbound path of network library.

(cherry picked from commit f084c5de)
Signed-off-by: default avatarReynold Xin <rxin@databricks.com>
parent e347a7af
No related branches found
No related tags found
No related merge requests found
...@@ -80,6 +80,11 @@ ...@@ -80,6 +80,11 @@
<artifactId>mockito-all</artifactId> <artifactId>mockito-all</artifactId>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<build> <build>
......
...@@ -72,9 +72,11 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> { ...@@ -72,9 +72,11 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
in.encode(header); in.encode(header);
assert header.writableBytes() == 0; assert header.writableBytes() == 0;
out.add(header);
if (body != null && bodyLength > 0) { if (body != null && bodyLength > 0) {
out.add(body); out.add(new MessageWithHeader(header, body, bodyLength));
} else {
out.add(header);
} }
} }
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.protocol;
import java.io.IOException;
import java.nio.channels.WritableByteChannel;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import io.netty.buffer.ByteBuf;
import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
/**
* A wrapper message that holds two separate pieces (a header and a body) to avoid
* copying the body's content.
*/
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
private final ByteBuf header;
private final int headerLength;
private final Object body;
private final long bodyLength;
private long totalBytesTransferred;
MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
"Body must be a ByteBuf or a FileRegion.");
this.header = header;
this.headerLength = header.readableBytes();
this.body = body;
this.bodyLength = bodyLength;
}
@Override
public long count() {
return headerLength + bodyLength;
}
@Override
public long position() {
return 0;
}
@Override
public long transfered() {
return totalBytesTransferred;
}
@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
long written = 0;
if (position < headerLength) {
written += copyByteBuf(header, target);
if (header.readableBytes() > 0) {
totalBytesTransferred += written;
return written;
}
}
if (body instanceof FileRegion) {
// Adjust the position. If the write is happening as part of the same call where the header
// (or some part of it) is written, `position` will be less than the header size, so we want
// to start from position 0 in the FileRegion object. Otherwise, we start from the position
// requested by the caller.
long bodyPos = position > headerLength ? position - headerLength : 0;
written += ((FileRegion)body).transferTo(target, bodyPos);
} else if (body instanceof ByteBuf) {
written += copyByteBuf((ByteBuf) body, target);
}
totalBytesTransferred += written;
return written;
}
@Override
protected void deallocate() {
header.release();
ReferenceCountUtil.release(body);
}
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
int written = target.write(buf.nioBuffer());
buf.skipBytes(written);
return written;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
public class ByteArrayWritableChannel implements WritableByteChannel {
private final byte[] data;
private int offset;
public ByteArrayWritableChannel(int size) {
this.data = new byte[size];
this.offset = 0;
}
public byte[] getData() {
return data;
}
@Override
public int write(ByteBuffer src) {
int available = src.remaining();
src.get(data, offset, available);
offset += available;
return available;
}
@Override
public void close() {
}
@Override
public boolean isOpen() {
return true;
}
}
...@@ -17,26 +17,34 @@ ...@@ -17,26 +17,34 @@
package org.apache.spark.network; package org.apache.spark.network;
import java.util.List;
import com.google.common.primitives.Ints;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.FileRegion;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.MessageToMessageEncoder;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.ChunkFetchSuccess; import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.Message;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.MessageDecoder; import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.MessageEncoder; import org.apache.spark.network.protocol.MessageEncoder;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.NettyUtils;
public class ProtocolSuite { public class ProtocolSuite {
private void testServerToClient(Message msg) { private void testServerToClient(Message msg) {
EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder());
serverChannel.writeOutbound(msg); serverChannel.writeOutbound(msg);
EmbeddedChannel clientChannel = new EmbeddedChannel( EmbeddedChannel clientChannel = new EmbeddedChannel(
...@@ -51,7 +59,8 @@ public class ProtocolSuite { ...@@ -51,7 +59,8 @@ public class ProtocolSuite {
} }
private void testClientToServer(Message msg) { private void testClientToServer(Message msg) {
EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
new MessageEncoder());
clientChannel.writeOutbound(msg); clientChannel.writeOutbound(msg);
EmbeddedChannel serverChannel = new EmbeddedChannel( EmbeddedChannel serverChannel = new EmbeddedChannel(
...@@ -83,4 +92,25 @@ public class ProtocolSuite { ...@@ -83,4 +92,25 @@ public class ProtocolSuite {
testServerToClient(new RpcFailure(0, "this is an error")); testServerToClient(new RpcFailure(0, "this is an error"));
testServerToClient(new RpcFailure(0, "")); testServerToClient(new RpcFailure(0, ""));
} }
/**
* Handler to transform a FileRegion into a byte buffer. EmbeddedChannel doesn't actually transfer
* bytes, but messages, so this is needed so that the frame decoder on the receiving side can
* understand what MessageWithHeader actually contains.
*/
private static class FileRegionEncoder extends MessageToMessageEncoder<FileRegion> {
@Override
public void encode(ChannelHandlerContext ctx, FileRegion in, List<Object> out)
throws Exception {
ByteArrayWritableChannel channel = new ByteArrayWritableChannel(Ints.checkedCast(in.count()));
while (in.transfered() < in.count()) {
in.transferTo(channel, in.transfered());
}
out.add(Unpooled.wrappedBuffer(channel.getData()));
}
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.protocol;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import org.junit.Test;
import static org.junit.Assert.*;
import org.apache.spark.network.ByteArrayWritableChannel;
public class MessageWithHeaderSuite {
@Test
public void testSingleWrite() throws Exception {
testFileRegionBody(8, 8);
}
@Test
public void testShortWrite() throws Exception {
testFileRegionBody(8, 1);
}
@Test
public void testByteBufBody() throws Exception {
ByteBuf header = Unpooled.copyLong(42);
ByteBuf body = Unpooled.copyLong(84);
MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
ByteBuf result = doWrite(msg, 1);
assertEquals(msg.count(), result.readableBytes());
assertEquals(42, result.readLong());
assertEquals(84, result.readLong());
}
private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
ByteBuf header = Unpooled.copyLong(42);
int headerLength = header.readableBytes();
TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
assertEquals(headerLength + region.count(), result.readableBytes());
assertEquals(42, result.readLong());
for (long i = 0; i < 8; i++) {
assertEquals(i, result.readLong());
}
}
private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
int writes = 0;
ByteArrayWritableChannel channel = new ByteArrayWritableChannel((int) msg.count());
while (msg.transfered() < msg.count()) {
msg.transferTo(channel, msg.transfered());
writes++;
}
assertTrue("Not enough writes!", minExpectedWrites <= writes);
return Unpooled.wrappedBuffer(channel.getData());
}
private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
private final int writeCount;
private final int writesPerCall;
private int written;
TestFileRegion(int totalWrites, int writesPerCall) {
this.writeCount = totalWrites;
this.writesPerCall = writesPerCall;
}
@Override
public long count() {
return 8 * writeCount;
}
@Override
public long position() {
return 0;
}
@Override
public long transfered() {
return 8 * written;
}
@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
for (int i = 0; i < writesPerCall; i++) {
ByteBuf buf = Unpooled.copyLong((position / 8) + i);
ByteBuffer nio = buf.nioBuffer();
while (nio.remaining() > 0) {
target.write(nio);
}
buf.release();
written++;
}
return 8 * writesPerCall;
}
@Override
protected void deallocate() {
}
}
}
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Set everything to be logged to the file target/unit-tests.log
log4j.rootCategory=DEBUG, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=true
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Silence verbose logs from 3rd-party libraries.
log4j.logger.io.netty=INFO
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment