Skip to content
Snippets Groups Projects
Commit e908322c authored by Iulian Dragos's avatar Iulian Dragos Committed by Tathagata Das
Browse files

[SPARK-4631][streaming][FIX] Wait for a receiver to start before publishing test data.

This fixes two sources of non-deterministic failures in this test:

- wait for a receiver to be up before pushing data through MQTT
- gracefully handle the case where the MQTT client is overloaded. There’s
a hard-coded limit of 10 in-flight messages, and this test may hit it.
Instead of crashing, we retry sending the message.

Both of these are needed to make the test pass reliably on my machine.

Author: Iulian Dragos <jaguarul@gmail.com>

Closes #4270 from dragos/issue/fix-flaky-test-SPARK-4631 and squashes the following commits:

f66c482 [Iulian Dragos] [SPARK-4631][streaming] Wait for a receiver to start before publishing test data.
d408a8e [Iulian Dragos] Install callback before connecting to MQTT broker.
parent 683e9382
No related branches found
No related tags found
No related merge requests found
...@@ -55,14 +55,14 @@ class MQTTInputDStream( ...@@ -55,14 +55,14 @@ class MQTTInputDStream(
brokerUrl: String, brokerUrl: String,
topic: String, topic: String,
storageLevel: StorageLevel storageLevel: StorageLevel
) extends ReceiverInputDStream[String](ssc_) with Logging { ) extends ReceiverInputDStream[String](ssc_) {
def getReceiver(): Receiver[String] = { def getReceiver(): Receiver[String] = {
new MQTTReceiver(brokerUrl, topic, storageLevel) new MQTTReceiver(brokerUrl, topic, storageLevel)
} }
} }
private[streaming] private[streaming]
class MQTTReceiver( class MQTTReceiver(
brokerUrl: String, brokerUrl: String,
topic: String, topic: String,
...@@ -72,21 +72,15 @@ class MQTTReceiver( ...@@ -72,21 +72,15 @@ class MQTTReceiver(
def onStop() { def onStop() {
} }
def onStart() { def onStart() {
// Set up persistence for messages // Set up persistence for messages
val persistence = new MemoryPersistence() val persistence = new MemoryPersistence()
// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence)
// Connect to MqttBroker
client.connect()
// Subscribe to Mqtt topic
client.subscribe(topic)
// Callback automatically triggers as and when new message arrives on specified topic // Callback automatically triggers as and when new message arrives on specified topic
val callback: MqttCallback = new MqttCallback() { val callback: MqttCallback = new MqttCallback() {
...@@ -103,7 +97,15 @@ class MQTTReceiver( ...@@ -103,7 +97,15 @@ class MQTTReceiver(
} }
} }
// Set up callback for MqttClient // Set up callback for MqttClient. This needs to happen before
// connecting or subscribing, otherwise messages may be lost
client.setCallback(callback) client.setCallback(callback)
// Connect to MqttBroker
client.connect()
// Subscribe to Mqtt topic
client.subscribe(topic)
} }
} }
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
package org.apache.spark.streaming.mqtt package org.apache.spark.streaming.mqtt
import java.net.{URI, ServerSocket} import java.net.{URI, ServerSocket}
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.language.postfixOps import scala.language.postfixOps
...@@ -32,6 +34,8 @@ import org.scalatest.concurrent.Eventually ...@@ -32,6 +34,8 @@ import org.scalatest.concurrent.Eventually
import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.scheduler.StreamingListener
import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.SparkConf import org.apache.spark.SparkConf
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -67,7 +71,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { ...@@ -67,7 +71,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
val sendMessage = "MQTT demo for spark streaming" val sendMessage = "MQTT demo for spark streaming"
val receiveStream: ReceiverInputDStream[String] = val receiveStream: ReceiverInputDStream[String] =
MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
var receiveMessage: List[String] = List() @volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd => receiveStream.foreachRDD { rdd =>
if (rdd.collect.length > 0) { if (rdd.collect.length > 0) {
receiveMessage = receiveMessage ::: List(rdd.first) receiveMessage = receiveMessage ::: List(rdd.first)
...@@ -75,6 +79,11 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { ...@@ -75,6 +79,11 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
} }
} }
ssc.start() ssc.start()
// wait for the receiver to start before publishing data, or we risk failing
// the test nondeterministically. See SPARK-4631
waitForReceiverToStart()
publishData(sendMessage) publishData(sendMessage)
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
assert(sendMessage.equals(receiveMessage(0))) assert(sendMessage.equals(receiveMessage(0)))
...@@ -121,8 +130,14 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { ...@@ -121,8 +130,14 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) val message: MqttMessage = new MqttMessage(data.getBytes("utf-8"))
message.setQos(1) message.setQos(1)
message.setRetained(true) message.setRetained(true)
for (i <- 0 to 100) {
msgTopic.publish(message) for (i <- 0 to 10) {
try {
msgTopic.publish(message)
} catch {
case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
Thread.sleep(50) // wait for Spark streaming to consume something from the message queue
}
} }
} }
} finally { } finally {
...@@ -131,4 +146,18 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { ...@@ -131,4 +146,18 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
client = null client = null
} }
} }
/**
* Block until at least one receiver has started or timeout occurs.
*/
private def waitForReceiverToStart() = {
val latch = new CountDownLatch(1)
ssc.addStreamingListener(new StreamingListener {
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
latch.countDown()
}
})
assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment