diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 5b69e2bb0354699cfe9854e2973dd682f824a9de..37ba543380f07aa9f9bda4d7c65bc4334cf2acab 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -62,8 +62,20 @@ public class TransportContext { private final RpcHandler rpcHandler; private final boolean closeIdleConnections; - private final MessageEncoder encoder; - private final MessageDecoder decoder; + /** + * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created + * before switching the current context class loader to ExecutorClassLoader. + * + * Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the + * implementation calls "Class.forName" to check if this calls is already generated. If the + * following two objects are created in "ExecutorClassLoader.findClass", it will cause + * "ClassCircularityError". This is because loading this Netty generated class will call + * "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use + * RPC to load it and cause to load the non-exist matcher class again. JVM will report + * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714) + */ + private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; + private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this(conf, rpcHandler, false); @@ -75,8 +87,6 @@ public class TransportContext { boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; - this.encoder = new MessageEncoder(); - this.decoder = new MessageDecoder(); this.closeIdleConnections = closeIdleConnections; } @@ -135,9 +145,9 @@ public class TransportContext { try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() - .addLast("encoder", encoder) + .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) - .addLast("decoder", decoder) + .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index f0956438ade2461c3c8a4533a476d9bf463f034b..39a7495828a8a11efa9ec25f733f93bcd14afc80 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + public static final MessageDecoder INSTANCE = new MessageDecoder(); + + private MessageDecoder() {} + @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) { Message.Type msgType = Message.Type.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 276f16637efc937008c062422366fd3d02129353..997f74e1a21b4440461d25fd1848304db5896f10 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> { private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + public static final MessageEncoder INSTANCE = new MessageEncoder(); + + private MessageEncoder() {} + /*** * Encodes a Message by invoking its encode() method. For non-data messages, we will add one * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index c6ccae18b5e06cdd23c2c6de3f748fab825f66f2..56782a8327876ad77ee94879e1b469803eba3365 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -18,7 +18,7 @@ package org.apache.spark.network.server; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; @@ -26,7 +26,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -48,7 +47,7 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * timeout if the client is continuously sending but getting no responses, for simplicity. */ -public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> { +public class TransportChannelHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; @@ -114,11 +113,13 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); - } else { + } else if (request instanceof ResponseMessage) { responseHandler.handle((ResponseMessage) request); + } else { + ctx.fireChannelRead(request); } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 6c8dd742f4b6412f85901c5a8051051e4278fa54..bb1c40c4b0e0686dfaa4a2aa01249122e988fa8c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -49,11 +49,11 @@ import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { private void testServerToClient(Message msg) { EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); + MessageEncoder.INSTANCE); serverChannel.writeOutbound(msg); EmbeddedChannel clientChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); + NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!serverChannel.outboundMessages().isEmpty()) { clientChannel.writeInbound(serverChannel.readOutbound()); @@ -65,11 +65,11 @@ public class ProtocolSuite { private void testClientToServer(Message msg) { EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(), - new MessageEncoder()); + MessageEncoder.INSTANCE); clientChannel.writeOutbound(msg); EmbeddedChannel serverChannel = new EmbeddedChannel( - NettyUtils.createFrameDecoder(), new MessageDecoder()); + NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE); while (!clientChannel.outboundMessages().isEmpty()) { serverChannel.writeInbound(clientChannel.readOutbound()); diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1319a4ce26f5600b67135e7b2baec86d8572d7dc..00b1b54f61a528d9e23b8a47aed66e397fed88eb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2566,12 +2566,8 @@ private[util] object CallerContext extends Logging { val callerContextSupported: Boolean = { SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { try { - // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in - // master Maven build, so do not use it before resolving SPARK-17714. - // scalastyle:off classforname - Class.forName("org.apache.hadoop.ipc.CallerContext") - Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname + Utils.classForName("org.apache.hadoop.ipc.CallerContext") + Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") true } catch { case _: ClassNotFoundException => @@ -2633,12 +2629,8 @@ private[spark] class CallerContext( def setCurrentContext(): Unit = { if (CallerContext.callerContextSupported) { try { - // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in - // master Maven build, so do not use it before resolving SPARK-17714. - // scalastyle:off classforname - val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") - val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") val builderInst = builder.getConstructor(classOf[String]).newInstance(context) val hdfsContext = builder.getMethod("build").invoke(builderInst) callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)