Skip to content
Snippets Groups Projects
Commit 248995c5 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #356 from shane-huang/master

Fix an issue in ConnectionManager where sendMessage may create too many unnecessary connections
parents 14972141 9930a95d
No related branches found
No related tags found
No related merge requests found
......@@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) {
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo
if (!message.started) logDebug("Starting to send [" + message + "]")
message.started = true
if (!message.started) {
logDebug("Starting to send [" + message + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
......
......@@ -43,12 +43,12 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
val selector = SelectorProvider.provider.openSelector()
val handleMessageExecutor = Executors.newFixedThreadPool(4)
val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
val connectionRequests = new SynchronizedQueue[SendingConnection]
val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
......@@ -79,10 +79,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def run() {
try {
while(!selectorThread.isInterrupted) {
while(!connectionRequests.isEmpty) {
val sendingConnection = connectionRequests.dequeue
for( (connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) {
......@@ -300,8 +300,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = new SendingConnection(inetSocketAddress, selector)
connectionRequests += newConnection
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
......@@ -473,6 +472,7 @@ private[spark] object ConnectionManager {
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}
......
......@@ -13,8 +13,14 @@ import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
//<mesos cluster> - the master URL
//<slaves file> - a list slaves to run connectionTest on
//[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
//[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
//[count] - how many times to run, default is 3
//[await time in seconds] : await time (in seconds), default is 600
if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
System.exit(1)
}
......@@ -29,16 +35,19 @@ private[spark] object ConnectionManagerTest extends Logging{
/*println("Slaves")*/
/*slaves.foreach(println)*/
val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
val tasknum = if (args.length > 2) args(2).toInt else slaves.length
val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
val count = if (args.length > 4) args(4).toInt else 3
val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println)
println
val count = 10
(0 until count).foreach(i => {
val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
......@@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{
None
})
val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
......@@ -56,13 +64,13 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
})
val results = futures.map(f => Await.result(f, 1.second))
val results = futures.map(f => Await.result(f, awaitTime))
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr)
resultStr
}).collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment