Skip to content
Snippets Groups Projects
Commit 7ccbbdac authored by Tathagata Das's avatar Tathagata Das
Browse files

Made block generator thread safe to fix Kafka bug.

parent 23b53efc
No related branches found
No related tags found
No related merge requests found
...@@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log ...@@ -232,11 +232,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
logInfo("Data handler stopped") logInfo("Data handler stopped")
} }
def += (obj: T) { def += (obj: T): Unit = synchronized {
currentBuffer += obj currentBuffer += obj
} }
private def updateCurrentBuffer(time: Long) { private def updateCurrentBuffer(time: Long): Unit = synchronized {
try { try {
val newBlockBuffer = currentBuffer val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T] currentBuffer = new ArrayBuffer[T]
......
...@@ -23,15 +23,15 @@ import akka.actor.IOManager ...@@ -23,15 +23,15 @@ import akka.actor.IOManager
import akka.actor.Props import akka.actor.Props
import akka.util.ByteString import akka.util.ByteString
import dstream.SparkFlumeEvent import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent}
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter} import java.io.{File, BufferedWriter, OutputStreamWriter}
import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.ManualClock import util.ManualClock
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receivers.Receiver import org.apache.spark.streaming.receivers.Receiver
import org.apache.spark.Logging import org.apache.spark.{SparkContext, Logging}
import scala.util.Random import scala.util.Random
import org.apache.commons.io.FileUtils import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfter
...@@ -44,6 +44,7 @@ import java.nio.ByteBuffer ...@@ -44,6 +44,7 @@ import java.nio.ByteBuffer
import collection.JavaConversions._ import collection.JavaConversions._
import java.nio.charset.Charset import java.nio.charset.Charset
import com.google.common.io.Files import com.google.common.io.Files
import java.util.concurrent.atomic.AtomicInteger
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
...@@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ...@@ -61,7 +62,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
System.clearProperty("spark.hostPort") System.clearProperty("spark.hostPort")
} }
test("socket input stream") { test("socket input stream") {
// Start the server // Start the server
val testServer = new TestServer() val testServer = new TestServer()
...@@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ...@@ -275,10 +275,49 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
kafka.serializer.StringDecoder, kafka.serializer.StringDecoder,
kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK) kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
} }
test("multi-thread receiver") {
// set up the test receiver
val numThreads = 10
val numRecordsPerThread = 1000
val numTotalRecords = numThreads * numRecordsPerThread
val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread)
MultiThreadTestReceiver.haveAllThreadsFinished = false
// set up the network stream using the test receiver
val ssc = new StreamingContext(master, framework, batchDuration)
val networkStream = ssc.networkStream[Int](testReceiver)
val countStream = networkStream.count
val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
val outputStream = new TestOutputStream(countStream, outputBuffer)
def output = outputBuffer.flatMap(x => x)
ssc.registerOutputStream(outputStream)
ssc.start()
// Let the data from the receiver be received
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val startTime = System.currentTimeMillis()
while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) &&
System.currentTimeMillis() - startTime < 5000) {
Thread.sleep(100)
clock.addToTime(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping context")
ssc.stop()
// Verify whether data received was as expected
logInfo("--------------------------------")
logInfo("output.size = " + outputBuffer.size)
logInfo("output")
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")
assert(output.sum === numTotalRecords)
}
} }
/** This is server to test the network input stream */ /** This is a server to test the network input stream */
class TestServer() extends Logging { class TestServer() extends Logging {
val queue = new ArrayBlockingQueue[String](100) val queue = new ArrayBlockingQueue[String](100)
...@@ -340,6 +379,7 @@ object TestServer { ...@@ -340,6 +379,7 @@ object TestServer {
} }
} }
/** This is an actor for testing actor input stream */
class TestActor(port: Int) extends Actor with Receiver { class TestActor(port: Int) extends Actor with Receiver {
def bytesToString(byteString: ByteString) = byteString.utf8String def bytesToString(byteString: ByteString) = byteString.utf8String
...@@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver { ...@@ -351,3 +391,36 @@ class TestActor(port: Int) extends Actor with Receiver {
pushBlock(bytesToString(bytes)) pushBlock(bytesToString(bytes))
} }
} }
/** This is a receiver to test multiple threads inserting data using block generator */
class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
extends NetworkReceiver[Int] {
lazy val executorPool = Executors.newFixedThreadPool(numThreads)
lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
lazy val finishCount = new AtomicInteger(0)
protected def onStart() {
blockGenerator.start()
(1 to numThreads).map(threadId => {
val runnable = new Runnable {
def run() {
(1 to numRecordsPerThread).foreach(i =>
blockGenerator += (threadId * numRecordsPerThread + i) )
if (finishCount.incrementAndGet == numThreads) {
MultiThreadTestReceiver.haveAllThreadsFinished = true
}
logInfo("Finished thread " + threadId)
}
}
executorPool.submit(runnable)
})
}
protected def onStop() {
executorPool.shutdown()
}
}
object MultiThreadTestReceiver {
var haveAllThreadsFinished = false
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment