diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
index 6840a3ae831f0d0bf8ba0e640760878d15fe7008..a039d543c35e71da6ae6c6ea17834587adb7d6da 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala
@@ -47,7 +47,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
 
   private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
   private val blockHandler = newShuffleBlockHandler(transportConf)
-  private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)
+  private val transportContext: TransportContext =
+    new TransportContext(transportConf, blockHandler, true)
 
   private var server: TransportServer = _
 
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index 43900e6f2c9726fa8dd3e9294d2ddb603c3d4481..1b64b863a9fe5c85c6ff30f0b67f4599e62e1553 100644
--- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -59,15 +59,24 @@ public class TransportContext {
 
   private final TransportConf conf;
   private final RpcHandler rpcHandler;
+  private final boolean closeIdleConnections;
 
   private final MessageEncoder encoder;
   private final MessageDecoder decoder;
 
   public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
+    this(conf, rpcHandler, false);
+  }
+
+  public TransportContext(
+      TransportConf conf,
+      RpcHandler rpcHandler,
+      boolean closeIdleConnections) {
     this.conf = conf;
     this.rpcHandler = rpcHandler;
     this.encoder = new MessageEncoder();
     this.decoder = new MessageDecoder();
+    this.closeIdleConnections = closeIdleConnections;
   }
 
   /**
@@ -144,7 +153,7 @@ public class TransportContext {
     TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
       rpcHandler);
     return new TransportChannelHandler(client, responseHandler, requestHandler,
-      conf.connectionTimeoutMs());
+      conf.connectionTimeoutMs(), closeIdleConnections);
   }
 
   public TransportConf getConf() { return conf; }
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 4952ffb44bb8b33a32dc3ff951752aef3f087c28..42a4f664e697c3e9a1b7864b360a525af62f713b 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -158,6 +158,16 @@ public class TransportClientFactory implements Closeable {
     }
   }
 
+  /**
+   * Create a completely new {@link TransportClient} to the given remote host / port
+   * But this connection is not pooled.
+   */
+  public TransportClient createUnmanagedClient(String remoteHost, int remotePort)
+      throws IOException {
+    final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+    return createClient(address);
+  }
+
   /** Create a completely new {@link TransportClient} to the remote address. */
   private TransportClient createClient(InetSocketAddress address) throws IOException {
     logger.debug("Creating new connection to " + address);
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index 8e0ee709e38e3856f437eec3382bbfd6906db0e6..f8fcd1c3d7d76f1827a43338194d093e4452da22 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -55,16 +55,19 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
   private final TransportResponseHandler responseHandler;
   private final TransportRequestHandler requestHandler;
   private final long requestTimeoutNs;
+  private final boolean closeIdleConnections;
 
   public TransportChannelHandler(
       TransportClient client,
       TransportResponseHandler responseHandler,
       TransportRequestHandler requestHandler,
-      long requestTimeoutMs) {
+      long requestTimeoutMs,
+      boolean closeIdleConnections) {
     this.client = client;
     this.responseHandler = responseHandler;
     this.requestHandler = requestHandler;
     this.requestTimeoutNs = requestTimeoutMs * 1000L * 1000;
+    this.closeIdleConnections = closeIdleConnections;
   }
 
   public TransportClient getClient() {
@@ -111,16 +114,21 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
       IdleStateEvent e = (IdleStateEvent) evt;
       // See class comment for timeout semantics. In addition to ensuring we only timeout while
       // there are outstanding requests, we also do a secondary consistency check to ensure
-      // there's no race between the idle timeout and incrementing the numOutstandingRequests.
-      boolean hasInFlightRequests = responseHandler.numOutstandingRequests() > 0;
+      // there's no race between the idle timeout and incrementing the numOutstandingRequests
+      // (see SPARK-7003).
       boolean isActuallyOverdue =
         System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
-      if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) {
-        String address = NettyUtils.getRemoteAddress(ctx.channel());
-        logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
-          "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
-          "is wrong.", address, requestTimeoutNs / 1000 / 1000);
-        ctx.close();
+      if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
+        if (responseHandler.numOutstandingRequests() > 0) {
+          String address = NettyUtils.getRemoteAddress(ctx.channel());
+          logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
+            "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
+            "is wrong.", address, requestTimeoutNs / 1000 / 1000);
+          ctx.close();
+        } else if (closeIdleConnections) {
+          // While CloseIdleConnections is enable, we also close idle connection
+          ctx.close();
+        }
       }
     }
   }
diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
index 35de5e57ccb98cdba6a5fdd5bbd88b419ce6dbb9..f4471374193060d05464a47c33a4be5445a2505d 100644
--- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java
@@ -21,6 +21,7 @@ import java.io.IOException;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Map;
+import java.util.NoSuchElementException;
 import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -37,6 +38,7 @@ import org.apache.spark.network.client.TransportClientFactory;
 import org.apache.spark.network.server.NoOpRpcHandler;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.ConfigProvider;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.MapConfigProvider;
@@ -177,4 +179,36 @@ public class TransportClientFactorySuite {
     assertFalse(c1.isActive());
     assertFalse(c2.isActive());
   }
+
+  @Test
+  public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException {
+    TransportConf conf = new TransportConf(new ConfigProvider() {
+
+      @Override
+      public String get(String name) {
+        if ("spark.shuffle.io.connectionTimeout".equals(name)) {
+          // We should make sure there is enough time for us to observe the channel is active
+          return "1s";
+        }
+        String value = System.getProperty(name);
+        if (value == null) {
+          throw new NoSuchElementException(name);
+        }
+        return value;
+      }
+    });
+    TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
+    TransportClientFactory factory = context.createClientFactory();
+    try {
+      TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
+      assertTrue(c1.isActive());
+      long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds
+      while (c1.isActive() && System.currentTimeMillis() < expiredTime) {
+        Thread.sleep(10);
+      }
+      assertFalse(c1.isActive());
+    } finally {
+      factory.close();
+    }
+  }
 }
diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index ea6d248d66be36ad3bb305a93737d78940b60782..ef3a9dcc8711fcb96892be888fd5df308a882c6d 100644
--- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -78,7 +78,7 @@ public class ExternalShuffleClient extends ShuffleClient {
   @Override
   public void init(String appId) {
     this.appId = appId;
-    TransportContext context = new TransportContext(conf, new NoOpRpcHandler());
+    TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true);
     List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
     if (saslEnabled) {
       bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled));
@@ -137,9 +137,13 @@ public class ExternalShuffleClient extends ShuffleClient {
       String execId,
       ExecutorShuffleInfo executorInfo) throws IOException {
     checkInit();
-    TransportClient client = clientFactory.createClient(host, port);
-    byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
-    client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
+    TransportClient client = clientFactory.createUnmanagedClient(host, port);
+    try {
+      byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
+      client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
+    } finally {
+      client.close();
+    }
   }
 
   @Override