diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b34ab51f3b9969350334fa22c465d2482b6d8456..0cf078c378fd9e712389ad14f8e68333e70d5d5e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -245,7 +245,8 @@ streaming_kafka_0_10 = Module( name="streaming-kafka-0-10", dependencies=[streaming], source_file_regexes=[ - "external/kafka-0-10", + # The ending "/" is necessary otherwise it will include "sql-kafka" codes + "external/kafka-0-10/", "external/kafka-0-10-assembly", ], sbt_test_goals=[ diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 92ee0ed93d940b4b011567c76ebadd5dbfb4b9a6..43b8d9d6d7eef687a9f4c19a69b20d760ccc29c6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer, OffsetOutOfRangeException} +import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetOutOfRangeException} import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener import org.apache.kafka.common.TopicPartition @@ -81,14 +81,16 @@ import org.apache.spark.util.UninterruptibleThread * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers * and not use wrong broker addresses. */ -private[kafka010] case class KafkaSource( +private[kafka010] class KafkaSource( sqlContext: SQLContext, consumerStrategy: ConsumerStrategy, + driverKafkaParams: ju.Map[String, Object], executorKafkaParams: ju.Map[String, Object], sourceOptions: Map[String, String], metadataPath: String, startingOffsets: StartingOffsets, - failOnDataLoss: Boolean) + failOnDataLoss: Boolean, + driverGroupIdPrefix: String) extends Source with Logging { private val sc = sqlContext.sparkContext @@ -107,11 +109,31 @@ private[kafka010] case class KafkaSource( private val maxOffsetsPerTrigger = sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong) + private var groupId: String = null + + private var nextId = 0 + + private def nextGroupId(): String = { + groupId = driverGroupIdPrefix + "-" + nextId + nextId += 1 + groupId + } + /** * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the * offsets and never commits them. */ - private val consumer = consumerStrategy.createConsumer() + private var consumer: Consumer[Array[Byte], Array[Byte]] = createConsumer() + + /** + * Create a consumer using the new generated group id. We always use a new consumer to avoid + * just using a broken consumer to retry on Kafka errors, which likely will fail again. + */ + private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized { + val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams) + newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId()) + consumerStrategy.createConsumer(newKafkaParams) + } /** * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only @@ -171,6 +193,11 @@ private[kafka010] case class KafkaSource( Some(KafkaSourceOffset(offsets)) } + private def resetConsumer(): Unit = synchronized { + consumer.close() + consumer = createConsumer() + } + /** Proportionally distribute limit number of offsets among topicpartitions */ private def rateLimit( limit: Long, @@ -441,13 +468,12 @@ private[kafka010] case class KafkaSource( try { result = Some(body) } catch { - case x: OffsetOutOfRangeException => - reportDataLoss(x.getMessage) case NonFatal(e) => lastException = e logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) attempt += 1 Thread.sleep(offsetFetchAttemptIntervalMs) + resetConsumer() } } case _ => @@ -511,12 +537,12 @@ private[kafka010] object KafkaSource { )) sealed trait ConsumerStrategy { - def createConsumer(): Consumer[Array[Byte], Array[Byte]] + def createConsumer(kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] } - case class AssignStrategy(partitions: Array[TopicPartition], kafkaParams: ju.Map[String, Object]) - extends ConsumerStrategy { - override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + case class AssignStrategy(partitions: Array[TopicPartition]) extends ConsumerStrategy { + override def createConsumer( + kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = { val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) consumer.assign(ju.Arrays.asList(partitions: _*)) consumer @@ -525,9 +551,9 @@ private[kafka010] object KafkaSource { override def toString: String = s"Assign[${partitions.mkString(", ")}]" } - case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) - extends ConsumerStrategy { - override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + case class SubscribeStrategy(topics: Seq[String]) extends ConsumerStrategy { + override def createConsumer( + kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = { val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) consumer.subscribe(topics.asJava) consumer @@ -536,10 +562,10 @@ private[kafka010] object KafkaSource { override def toString: String = s"Subscribe[${topics.mkString(", ")}]" } - case class SubscribePatternStrategy( - topicPattern: String, kafkaParams: ju.Map[String, Object]) + case class SubscribePatternStrategy(topicPattern: String) extends ConsumerStrategy { - override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + override def createConsumer( + kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = { val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) consumer.subscribe( ju.regex.Pattern.compile(topicPattern), diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 585ced875caa72feccb0f72bae0d8e2efd2adb67..aa01238f91247e173521c1c9a47b4cddc953279a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -85,14 +85,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider case None => LatestOffsets } - val kafkaParamsForStrategy = + val kafkaParamsForDriver = ConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - // So that consumers in Kafka source do not mess with any existing group id - .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") - // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial // offsets by itself instead of counting on KafkaConsumer. .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") @@ -129,17 +126,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => - AssignStrategy( - JsonUtils.partitions(value), - kafkaParamsForStrategy) + AssignStrategy(JsonUtils.partitions(value)) case ("subscribe", value) => - SubscribeStrategy( - value.split(",").map(_.trim()).filter(_.nonEmpty), - kafkaParamsForStrategy) + SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty)) case ("subscribepattern", value) => - SubscribePatternStrategy( - value.trim(), - kafkaParamsForStrategy) + SubscribePatternStrategy(value.trim()) case _ => // Should never reach here as we are already matching on // matched strategy names @@ -152,11 +143,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider new KafkaSource( sqlContext, strategy, + kafkaParamsForDriver, kafkaParamsForExecutors, parameters, metadataPath, startingOffsets, - failOnDataLoss) + failOnDataLoss, + driverGroupIdPrefix = s"$uniqueGroupId-driver") } private def validateOptions(parameters: Map[String, String]): Unit = { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 5d2779aba26d01343129efd8e31e96b57773cc16..544fbc5ec36a26339ab4a9dea432dc0e20df04fa 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -845,7 +845,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared } } - ignore("stress test for failOnDataLoss=false") { + test("stress test for failOnDataLoss=false") { val reader = spark .readStream .format("kafka")