From 328b229840d6e87c7faf7ee3cd5bf66a905c9a7d Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Mon, 13 Feb 2017 12:03:36 -0800
Subject: [PATCH] [SPARK-17714][CORE][TEST-MAVEN][TEST-HADOOP2.6] Avoid using
 ExecutorClassLoader to load Netty generated classes

## What changes were proposed in this pull request?

Netty's `MessageToMessageEncoder` uses [Javassist](https://github.com/netty/netty/blob/91a0bdc17a8298437d6de08a8958d753799bd4a6/common/src/main/java/io/netty/util/internal/JavassistTypeParameterMatcherGenerator.java#L62) to generate a matcher class and the implementation calls `Class.forName` to check if this class is already generated. If `MessageEncoder` or `MessageDecoder` is created in `ExecutorClassLoader.findClass`, it will cause `ClassCircularityError`. This is because loading this Netty generated class will call `ExecutorClassLoader.findClass` to search this class, and `ExecutorClassLoader` will try to use RPC to load it and cause to load the non-exist matcher class again. JVM will report `ClassCircularityError` to prevent such infinite recursion.

##### Why it only happens in Maven builds

It's because Maven and SBT have different class loader tree. The Maven build will set a URLClassLoader as the current context class loader to run the tests and expose this issue. The class loader tree is as following:

```
bootstrap class loader ------ ... ----- REPL class loader ---- ExecutorClassLoader
|
|
URLClasssLoader
```

The SBT build uses the bootstrap class loader directly and `ReplSuite.test("propagation of local properties")` is the first test in ReplSuite, which happens to load `io/netty/util/internal/__matchers__/org/apache/spark/network/protocol/MessageMatcher` into the bootstrap class loader (Note: in maven build, it's loaded into URLClasssLoader so it cannot be found in ExecutorClassLoader). This issue can be reproduced in SBT as well. Here are the produce steps:
- Enable `hadoop.caller.context.enabled`.
- Replace `Class.forName` with `Utils.classForName` in `object CallerContext`.
- Ignore `ReplSuite.test("propagation of local properties")`.
- Run `ReplSuite` using SBT.

This PR just creates a singleton MessageEncoder and MessageDecoder and makes sure they are created before switching to ExecutorClassLoader. TransportContext will be created when creating RpcEnv and that happens before creating ExecutorClassLoader.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16859 from zsxwing/SPARK-17714.

(cherry picked from commit 905fdf0c243e1776c54c01a25b17878361400225)
Signed-off-by: Shixiong Zhu <shixiong@databricks.com>
---
 .../spark/network/TransportContext.java       | 22 ++++++++++++++-----
 .../network/protocol/MessageDecoder.java      |  4 ++++
 .../network/protocol/MessageEncoder.java      |  4 ++++
 .../server/TransportChannelHandler.java       | 11 +++++-----
 .../apache/spark/network/ProtocolSuite.java   |  8 +++----
 .../scala/org/apache/spark/util/Utils.scala   | 16 ++++----------
 6 files changed, 38 insertions(+), 27 deletions(-)

diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
index 5b69e2bb03..37ba543380 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -62,8 +62,20 @@ public class TransportContext {
   private final RpcHandler rpcHandler;
   private final boolean closeIdleConnections;
 
-  private final MessageEncoder encoder;
-  private final MessageDecoder decoder;
+  /**
+   * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created
+   * before switching the current context class loader to ExecutorClassLoader.
+   *
+   * Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the
+   * implementation calls "Class.forName" to check if this calls is already generated. If the
+   * following two objects are created in "ExecutorClassLoader.findClass", it will cause
+   * "ClassCircularityError". This is because loading this Netty generated class will call
+   * "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use
+   * RPC to load it and cause to load the non-exist matcher class again. JVM will report
+   * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714)
+   */
+  private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE;
+  private static final MessageDecoder DECODER = MessageDecoder.INSTANCE;
 
   public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
     this(conf, rpcHandler, false);
@@ -75,8 +87,6 @@ public class TransportContext {
       boolean closeIdleConnections) {
     this.conf = conf;
     this.rpcHandler = rpcHandler;
-    this.encoder = new MessageEncoder();
-    this.decoder = new MessageDecoder();
     this.closeIdleConnections = closeIdleConnections;
   }
 
@@ -135,9 +145,9 @@ public class TransportContext {
     try {
       TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
       channel.pipeline()
-        .addLast("encoder", encoder)
+        .addLast("encoder", ENCODER)
         .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
-        .addLast("decoder", decoder)
+        .addLast("decoder", DECODER)
         .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
         // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
         // would require more logic to guarantee if this were not part of the same event loop.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index f0956438ad..39a7495828 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
 
   private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
 
+  public static final MessageDecoder INSTANCE = new MessageDecoder();
+
+  private MessageDecoder() {}
+
   @Override
   public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
     Message.Type msgType = Message.Type.decode(in);
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 276f16637e..997f74e1a2 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder<Message> {
 
   private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class);
 
+  public static final MessageEncoder INSTANCE = new MessageEncoder();
+
+  private MessageEncoder() {}
+
   /***
    * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
    * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index c6ccae18b5..56782a8327 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -18,7 +18,7 @@
 package org.apache.spark.network.server;
 
 import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.channel.ChannelInboundHandlerAdapter;
 import io.netty.handler.timeout.IdleState;
 import io.netty.handler.timeout.IdleStateEvent;
 import org.slf4j.Logger;
@@ -26,7 +26,6 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportResponseHandler;
-import org.apache.spark.network.protocol.Message;
 import org.apache.spark.network.protocol.RequestMessage;
 import org.apache.spark.network.protocol.ResponseMessage;
 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
@@ -48,7 +47,7 @@ import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
  * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not
  * timeout if the client is continuously sending but getting no responses, for simplicity.
  */
-public class TransportChannelHandler extends SimpleChannelInboundHandler<Message> {
+public class TransportChannelHandler extends ChannelInboundHandlerAdapter {
   private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class);
 
   private final TransportClient client;
@@ -114,11 +113,13 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
   }
 
   @Override
-  public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception {
+  public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception {
     if (request instanceof RequestMessage) {
       requestHandler.handle((RequestMessage) request);
-    } else {
+    } else if (request instanceof ResponseMessage) {
       responseHandler.handle((ResponseMessage) request);
+    } else {
+      ctx.fireChannelRead(request);
     }
   }
 
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 6c8dd742f4..bb1c40c4b0 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -49,11 +49,11 @@ import org.apache.spark.network.util.NettyUtils;
 public class ProtocolSuite {
   private void testServerToClient(Message msg) {
     EmbeddedChannel serverChannel = new EmbeddedChannel(new FileRegionEncoder(),
-      new MessageEncoder());
+      MessageEncoder.INSTANCE);
     serverChannel.writeOutbound(msg);
 
     EmbeddedChannel clientChannel = new EmbeddedChannel(
-        NettyUtils.createFrameDecoder(), new MessageDecoder());
+        NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
 
     while (!serverChannel.outboundMessages().isEmpty()) {
       clientChannel.writeInbound(serverChannel.readOutbound());
@@ -65,11 +65,11 @@ public class ProtocolSuite {
 
   private void testClientToServer(Message msg) {
     EmbeddedChannel clientChannel = new EmbeddedChannel(new FileRegionEncoder(),
-      new MessageEncoder());
+      MessageEncoder.INSTANCE);
     clientChannel.writeOutbound(msg);
 
     EmbeddedChannel serverChannel = new EmbeddedChannel(
-        NettyUtils.createFrameDecoder(), new MessageDecoder());
+        NettyUtils.createFrameDecoder(), MessageDecoder.INSTANCE);
 
     while (!clientChannel.outboundMessages().isEmpty()) {
       serverChannel.writeInbound(clientChannel.readOutbound());
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 1319a4ce26..00b1b54f61 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2566,12 +2566,8 @@ private[util] object CallerContext extends Logging {
   val callerContextSupported: Boolean = {
     SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && {
       try {
-        // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
-        // master Maven build, so do not use it before resolving SPARK-17714.
-        // scalastyle:off classforname
-        Class.forName("org.apache.hadoop.ipc.CallerContext")
-        Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
-        // scalastyle:on classforname
+        Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+        Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
         true
       } catch {
         case _: ClassNotFoundException =>
@@ -2633,12 +2629,8 @@ private[spark] class CallerContext(
   def setCurrentContext(): Unit = {
     if (CallerContext.callerContextSupported) {
       try {
-        // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in
-        // master Maven build, so do not use it before resolving SPARK-17714.
-        // scalastyle:off classforname
-        val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext")
-        val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder")
-        // scalastyle:on classforname
+        val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext")
+        val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder")
         val builderInst = builder.getConstructor(classOf[String]).newInstance(context)
         val hdfsContext = builder.getMethod("build").invoke(builderInst)
         callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext)
-- 
GitLab