Skip to content
Snippets Groups Projects
Commit 95efc895 authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Tathagata Das
Browse files

[SPARK-18588][SS][KAFKA] Create a new KafkaConsumer when error happens to fix the flaky test

## What changes were proposed in this pull request?

When KafkaSource fails on Kafka errors, we should create a new consumer to retry rather than using the existing broken one because it's possible that the broken one will fail again.

This PR also assigns a new group id to the new created consumer for a possible race condition:  the broken consumer cannot talk with the Kafka cluster in `close` but the new consumer can talk to Kafka cluster. I'm not sure if this will happen or not. Just for safety to avoid that the Kafka cluster thinks there are two consumers with the same group id in a short time window. (Note: CachedKafkaConsumer doesn't need this fix since `assign` never uses the group id.)

## How was this patch tested?

In https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/70370/console , it ran this flaky test 120 times and all passed.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #16282 from zsxwing/kafka-fix.
parent 354e9361
No related branches found
No related tags found
No related merge requests found
...@@ -245,7 +245,8 @@ streaming_kafka_0_10 = Module( ...@@ -245,7 +245,8 @@ streaming_kafka_0_10 = Module(
name="streaming-kafka-0-10", name="streaming-kafka-0-10",
dependencies=[streaming], dependencies=[streaming],
source_file_regexes=[ source_file_regexes=[
"external/kafka-0-10", # The ending "/" is necessary otherwise it will include "sql-kafka" codes
"external/kafka-0-10/",
"external/kafka-0-10-assembly", "external/kafka-0-10-assembly",
], ],
sbt_test_goals=[ sbt_test_goals=[
......
...@@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets ...@@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import scala.util.control.NonFatal import scala.util.control.NonFatal
import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetOutOfRangeException}
import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener
import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.TopicPartition
...@@ -81,14 +81,16 @@ import org.apache.spark.util.UninterruptibleThread ...@@ -81,14 +81,16 @@ import org.apache.spark.util.UninterruptibleThread
* To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers
* and not use wrong broker addresses. * and not use wrong broker addresses.
*/ */
private[kafka010] case class KafkaSource( private[kafka010] class KafkaSource(
sqlContext: SQLContext, sqlContext: SQLContext,
consumerStrategy: ConsumerStrategy, consumerStrategy: ConsumerStrategy,
driverKafkaParams: ju.Map[String, Object],
executorKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String], sourceOptions: Map[String, String],
metadataPath: String, metadataPath: String,
startingOffsets: StartingOffsets, startingOffsets: StartingOffsets,
failOnDataLoss: Boolean) failOnDataLoss: Boolean,
driverGroupIdPrefix: String)
extends Source with Logging { extends Source with Logging {
private val sc = sqlContext.sparkContext private val sc = sqlContext.sparkContext
...@@ -107,11 +109,31 @@ private[kafka010] case class KafkaSource( ...@@ -107,11 +109,31 @@ private[kafka010] case class KafkaSource(
private val maxOffsetsPerTrigger = private val maxOffsetsPerTrigger =
sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong) sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong)
private var groupId: String = null
private var nextId = 0
private def nextGroupId(): String = {
groupId = driverGroupIdPrefix + "-" + nextId
nextId += 1
groupId
}
/** /**
* A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
* offsets and never commits them. * offsets and never commits them.
*/ */
private val consumer = consumerStrategy.createConsumer() private var consumer: Consumer[Array[Byte], Array[Byte]] = createConsumer()
/**
* Create a consumer using the new generated group id. We always use a new consumer to avoid
* just using a broken consumer to retry on Kafka errors, which likely will fail again.
*/
private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized {
val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
consumerStrategy.createConsumer(newKafkaParams)
}
/** /**
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
...@@ -171,6 +193,11 @@ private[kafka010] case class KafkaSource( ...@@ -171,6 +193,11 @@ private[kafka010] case class KafkaSource(
Some(KafkaSourceOffset(offsets)) Some(KafkaSourceOffset(offsets))
} }
private def resetConsumer(): Unit = synchronized {
consumer.close()
consumer = createConsumer()
}
/** Proportionally distribute limit number of offsets among topicpartitions */ /** Proportionally distribute limit number of offsets among topicpartitions */
private def rateLimit( private def rateLimit(
limit: Long, limit: Long,
...@@ -441,13 +468,12 @@ private[kafka010] case class KafkaSource( ...@@ -441,13 +468,12 @@ private[kafka010] case class KafkaSource(
try { try {
result = Some(body) result = Some(body)
} catch { } catch {
case x: OffsetOutOfRangeException =>
reportDataLoss(x.getMessage)
case NonFatal(e) => case NonFatal(e) =>
lastException = e lastException = e
logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e)
attempt += 1 attempt += 1
Thread.sleep(offsetFetchAttemptIntervalMs) Thread.sleep(offsetFetchAttemptIntervalMs)
resetConsumer()
} }
} }
case _ => case _ =>
...@@ -511,12 +537,12 @@ private[kafka010] object KafkaSource { ...@@ -511,12 +537,12 @@ private[kafka010] object KafkaSource {
)) ))
sealed trait ConsumerStrategy { sealed trait ConsumerStrategy {
def createConsumer(): Consumer[Array[Byte], Array[Byte]] def createConsumer(kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]]
} }
case class AssignStrategy(partitions: Array[TopicPartition], kafkaParams: ju.Map[String, Object]) case class AssignStrategy(partitions: Array[TopicPartition]) extends ConsumerStrategy {
extends ConsumerStrategy { override def createConsumer(
override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
consumer.assign(ju.Arrays.asList(partitions: _*)) consumer.assign(ju.Arrays.asList(partitions: _*))
consumer consumer
...@@ -525,9 +551,9 @@ private[kafka010] object KafkaSource { ...@@ -525,9 +551,9 @@ private[kafka010] object KafkaSource {
override def toString: String = s"Assign[${partitions.mkString(", ")}]" override def toString: String = s"Assign[${partitions.mkString(", ")}]"
} }
case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) case class SubscribeStrategy(topics: Seq[String]) extends ConsumerStrategy {
extends ConsumerStrategy { override def createConsumer(
override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
consumer.subscribe(topics.asJava) consumer.subscribe(topics.asJava)
consumer consumer
...@@ -536,10 +562,10 @@ private[kafka010] object KafkaSource { ...@@ -536,10 +562,10 @@ private[kafka010] object KafkaSource {
override def toString: String = s"Subscribe[${topics.mkString(", ")}]" override def toString: String = s"Subscribe[${topics.mkString(", ")}]"
} }
case class SubscribePatternStrategy( case class SubscribePatternStrategy(topicPattern: String)
topicPattern: String, kafkaParams: ju.Map[String, Object])
extends ConsumerStrategy { extends ConsumerStrategy {
override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { override def createConsumer(
kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
consumer.subscribe( consumer.subscribe(
ju.regex.Pattern.compile(topicPattern), ju.regex.Pattern.compile(topicPattern),
......
...@@ -85,14 +85,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider ...@@ -85,14 +85,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
case None => LatestOffsets case None => LatestOffsets
} }
val kafkaParamsForStrategy = val kafkaParamsForDriver =
ConfigUpdater("source", specifiedKafkaParams) ConfigUpdater("source", specifiedKafkaParams)
.set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
.set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
// So that consumers in Kafka source do not mess with any existing group id
.set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver")
// Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
// offsets by itself instead of counting on KafkaConsumer. // offsets by itself instead of counting on KafkaConsumer.
.set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
...@@ -129,17 +126,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider ...@@ -129,17 +126,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case ("assign", value) => case ("assign", value) =>
AssignStrategy( AssignStrategy(JsonUtils.partitions(value))
JsonUtils.partitions(value),
kafkaParamsForStrategy)
case ("subscribe", value) => case ("subscribe", value) =>
SubscribeStrategy( SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
value.split(",").map(_.trim()).filter(_.nonEmpty),
kafkaParamsForStrategy)
case ("subscribepattern", value) => case ("subscribepattern", value) =>
SubscribePatternStrategy( SubscribePatternStrategy(value.trim())
value.trim(),
kafkaParamsForStrategy)
case _ => case _ =>
// Should never reach here as we are already matching on // Should never reach here as we are already matching on
// matched strategy names // matched strategy names
...@@ -152,11 +143,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider ...@@ -152,11 +143,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
new KafkaSource( new KafkaSource(
sqlContext, sqlContext,
strategy, strategy,
kafkaParamsForDriver,
kafkaParamsForExecutors, kafkaParamsForExecutors,
parameters, parameters,
metadataPath, metadataPath,
startingOffsets, startingOffsets,
failOnDataLoss) failOnDataLoss,
driverGroupIdPrefix = s"$uniqueGroupId-driver")
} }
private def validateOptions(parameters: Map[String, String]): Unit = { private def validateOptions(parameters: Map[String, String]): Unit = {
......
...@@ -845,7 +845,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared ...@@ -845,7 +845,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
} }
} }
ignore("stress test for failOnDataLoss=false") { test("stress test for failOnDataLoss=false") {
val reader = spark val reader = spark
.readStream .readStream
.format("kafka") .format("kafka")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment