From 9b4cd1648b6c2467a63109ba817d7e7a0c46ffb9 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Wed, 12 Sep 2012 14:54:14 -0700
Subject: [PATCH] Fix bugs with Connection's shutdown callback failing to get
 its address

---
 .../main/scala/spark/network/Connection.scala    |  9 ++++++---
 .../scala/spark/network/ConnectionManager.scala  | 16 +++++++++++++---
 2 files changed, 19 insertions(+), 6 deletions(-)

diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index da8aff9dd5..0209f4b29d 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -23,8 +23,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
   var onExceptionCallback: (Connection, Exception) => Unit = null
   var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
 
-  lazy val remoteAddress = getRemoteAddress() 
-  lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) 
+  val remoteAddress = getRemoteAddress()
+  val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
 
   def key() = channel.keyFor(selector)
 
@@ -39,7 +39,10 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
   }
 
   def close() {
-    key.cancel()
+    val k = key()
+    if (k != null) {
+      k.cancel()
+    }
     channel.close()
     callOnCloseCallback()
   }
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 0e764fff81..2bb5f5fc6b 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -16,6 +16,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import akka.dispatch.{Await, Promise, ExecutionContext, Future}
 import akka.util.Duration
+import akka.util.duration._
 
 case class ConnectionManagerId(host: String, port: Int) {
   def toSocketAddress() = new InetSocketAddress(host, port)
@@ -403,7 +404,10 @@ object ConnectionManager {
     (0 until count).map(i => {
       val bufferMessage = Message.createBufferMessage(buffer.duplicate)
       manager.sendMessageReliably(manager.id, bufferMessage)
-    }).foreach(f => {if (!f().isDefined) println("Failed")})
+    }).foreach(f => {
+      val g = Await.result(f, 1 second)
+      if (!g.isDefined) println("Failed")
+    })
     val finishTime = System.currentTimeMillis
     
     val mb = size * count / 1024.0 / 1024.0
@@ -430,7 +434,10 @@ object ConnectionManager {
     (0 until count).map(i => {
       val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
       manager.sendMessageReliably(manager.id, bufferMessage)
-    }).foreach(f => {if (!f().isDefined) println("Failed")})
+    }).foreach(f => {
+      val g = Await.result(f, 1 second)
+      if (!g.isDefined) println("Failed")
+    })
     val finishTime = System.currentTimeMillis
     
     val ms = finishTime - startTime
@@ -457,7 +464,10 @@ object ConnectionManager {
       (0 until count).map(i => {
           val bufferMessage = Message.createBufferMessage(buffer.duplicate)
           manager.sendMessageReliably(manager.id, bufferMessage)
-        }).foreach(f => {if (!f().isDefined) println("Failed")})
+        }).foreach(f => {
+          val g = Await.result(f, 1 second)
+          if (!g.isDefined) println("Failed")
+        })
       val finishTime = System.currentTimeMillis
       Thread.sleep(1000)
       val mb = size * count / 1024.0 / 1024.0
-- 
GitLab