From bd338f60d7f30f0cb735dffb39b3a6ec60766301 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Tue, 22 Nov 2016 14:15:57 -0800
Subject: [PATCH] [SPARK-18373][SPARK-18529][SS][KAFKA] Make
 failOnDataLoss=false work with Spark jobs

## What changes were proposed in this pull request?

This PR adds `CachedKafkaConsumer.getAndIgnoreLostData` to handle corner cases of `failOnDataLoss=false`.

It also resolves [SPARK-18529](https://issues.apache.org/jira/browse/SPARK-18529) after refactoring codes: Timeout will throw a TimeoutException.

## How was this patch tested?

Because I cannot find any way to manually control the Kafka server to clean up logs, it's impossible to write unit tests for each corner case. Therefore, I just created `test("stress test for failOnDataLoss=false")` which should cover most of corner cases.

I also modified some existing tests to test for both `failOnDataLoss=false` and `failOnDataLoss=true` to make sure it doesn't break existing logic.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #15820 from zsxwing/failOnDataLoss.

(cherry picked from commit 2fd101b2f0028e005fbb0bdd29e59af37aa637da)
Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
---
 .../sql/kafka010/CachedKafkaConsumer.scala    | 236 ++++++++++++--
 .../spark/sql/kafka010/KafkaSource.scala      |  23 +-
 .../spark/sql/kafka010/KafkaSourceRDD.scala   |  42 ++-
 .../spark/sql/kafka010/KafkaSourceSuite.scala | 297 +++++++++++++++---
 .../spark/sql/kafka010/KafkaTestUtils.scala   |  20 +-
 5 files changed, 523 insertions(+), 95 deletions(-)

diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
index 3b5a96534f..3f438e9918 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
@@ -18,12 +18,16 @@
 package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
+import java.util.concurrent.TimeoutException
 
-import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer}
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer, OffsetOutOfRangeException}
 import org.apache.kafka.common.TopicPartition
 
 import org.apache.spark.{SparkEnv, SparkException, TaskContext}
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.kafka010.KafkaSource._
 
 
 /**
@@ -34,10 +38,18 @@ import org.apache.spark.internal.Logging
 private[kafka010] case class CachedKafkaConsumer private(
     topicPartition: TopicPartition,
     kafkaParams: ju.Map[String, Object]) extends Logging {
+  import CachedKafkaConsumer._
 
   private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
 
-  private val consumer = {
+  private var consumer = createConsumer
+
+  /** Iterator to the already fetch data */
+  private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
+  private var nextOffsetInFetchedData = UNKNOWN_OFFSET
+
+  /** Create a KafkaConsumer to fetch records for `topicPartition` */
+  private def createConsumer: KafkaConsumer[Array[Byte], Array[Byte]] = {
     val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
     val tps = new ju.ArrayList[TopicPartition]()
     tps.add(topicPartition)
@@ -45,42 +57,193 @@ private[kafka010] case class CachedKafkaConsumer private(
     c
   }
 
-  /** Iterator to the already fetch data */
-  private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
-  private var nextOffsetInFetchedData = -2L
-
   /**
-   * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
-   * Sequential forward access will use buffers, but random access will be horribly inefficient.
+   * Get the record for the given offset if available. Otherwise it will either throw error
+   * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset),
+   * or null.
+   *
+   * @param offset the offset to fetch.
+   * @param untilOffset the max offset to fetch. Exclusive.
+   * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+   * @param failOnDataLoss When `failOnDataLoss` is `true`, this method will either return record at
+   *                       offset if available, or throw exception.when `failOnDataLoss` is `false`,
+   *                       this method will either return record at offset if available, or return
+   *                       the next earliest available record less than untilOffset, or null. It
+   *                       will not throw any exception.
    */
-  def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = {
+  def get(
+      offset: Long,
+      untilOffset: Long,
+      pollTimeoutMs: Long,
+      failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
+    require(offset < untilOffset,
+      s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]")
     logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset")
-    if (offset != nextOffsetInFetchedData) {
-      logInfo(s"Initial fetch for $topicPartition $offset")
-      seek(offset)
-      poll(pollTimeoutMs)
+    // The following loop is basically for `failOnDataLoss = false`. When `failOnDataLoss` is
+    // `false`, first, we will try to fetch the record at `offset`. If no such record exists, then
+    // we will move to the next available offset within `[offset, untilOffset)` and retry.
+    // If `failOnDataLoss` is `true`, the loop body will be executed only once.
+    var toFetchOffset = offset
+    while (toFetchOffset != UNKNOWN_OFFSET) {
+      try {
+        return fetchData(toFetchOffset, pollTimeoutMs)
+      } catch {
+        case e: OffsetOutOfRangeException =>
+          // When there is some error thrown, it's better to use a new consumer to drop all cached
+          // states in the old consumer. We don't need to worry about the performance because this
+          // is not a common path.
+          resetConsumer()
+          reportDataLoss(failOnDataLoss, s"Cannot fetch offset $toFetchOffset", e)
+          toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, untilOffset)
+      }
     }
+    resetFetchedData()
+    null
+  }
 
-    if (!fetchedData.hasNext()) { poll(pollTimeoutMs) }
-    assert(fetchedData.hasNext(),
-      s"Failed to get records for $groupId $topicPartition $offset " +
-        s"after polling for $pollTimeoutMs")
-    var record = fetchedData.next()
+  /**
+   * Return the next earliest available offset in [offset, untilOffset). If all offsets in
+   * [offset, untilOffset) are invalid (e.g., the topic is deleted and recreated), it will return
+   * `UNKNOWN_OFFSET`.
+   */
+  private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: Long): Long = {
+    val (earliestOffset, latestOffset) = getAvailableOffsetRange()
+    logWarning(s"Some data may be lost. Recovering from the earliest offset: $earliestOffset")
+    if (offset >= latestOffset || earliestOffset >= untilOffset) {
+      // [offset, untilOffset) and [earliestOffset, latestOffset) have no overlap,
+      // either
+      // --------------------------------------------------------
+      //         ^                 ^         ^         ^
+      //         |                 |         |         |
+      //   earliestOffset   latestOffset   offset   untilOffset
+      //
+      // or
+      // --------------------------------------------------------
+      //      ^          ^              ^                ^
+      //      |          |              |                |
+      //   offset   untilOffset   earliestOffset   latestOffset
+      val warningMessage =
+        s"""
+          |The current available offset range is [$earliestOffset, $latestOffset).
+          | Offset ${offset} is out of range, and records in [$offset, $untilOffset) will be
+          | skipped ${additionalMessage(failOnDataLoss = false)}
+        """.stripMargin
+      logWarning(warningMessage)
+      UNKNOWN_OFFSET
+    } else if (offset >= earliestOffset) {
+      // -----------------------------------------------------------------------------
+      //         ^            ^                  ^                                 ^
+      //         |            |                  |                                 |
+      //   earliestOffset   offset   min(untilOffset,latestOffset)   max(untilOffset, latestOffset)
+      //
+      // This will happen when a topic is deleted and recreated, and new data are pushed very fast,
+      // then we will see `offset` disappears first then appears again. Although the parameters
+      // are same, the state in Kafka cluster is changed, so the outer loop won't be endless.
+      logWarning(s"Found a disappeared offset $offset. " +
+        s"Some data may be lost ${additionalMessage(failOnDataLoss = false)}")
+      offset
+    } else {
+      // ------------------------------------------------------------------------------
+      //      ^           ^                       ^                                 ^
+      //      |           |                       |                                 |
+      //   offset   earliestOffset   min(untilOffset,latestOffset)   max(untilOffset, latestOffset)
+      val warningMessage =
+        s"""
+           |The current available offset range is [$earliestOffset, $latestOffset).
+           | Offset ${offset} is out of range, and records in [$offset, $earliestOffset) will be
+           | skipped ${additionalMessage(failOnDataLoss = false)}
+        """.stripMargin
+      logWarning(warningMessage)
+      earliestOffset
+    }
+  }
 
-    if (record.offset != offset) {
-      logInfo(s"Buffer miss for $groupId $topicPartition $offset")
+  /**
+   * Get the record at `offset`.
+   *
+   * @throws OffsetOutOfRangeException if `offset` is out of range
+   * @throws TimeoutException if cannot fetch the record in `pollTimeoutMs` milliseconds.
+   */
+  private def fetchData(
+      offset: Long,
+      pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = {
+    if (offset != nextOffsetInFetchedData || !fetchedData.hasNext()) {
+      // This is the first fetch, or the last pre-fetched data has been drained.
+      // Seek to the offset because we may call seekToBeginning or seekToEnd before this.
       seek(offset)
       poll(pollTimeoutMs)
-      assert(fetchedData.hasNext(),
-        s"Failed to get records for $groupId $topicPartition $offset " +
-          s"after polling for $pollTimeoutMs")
-      record = fetchedData.next()
+    }
+
+    if (!fetchedData.hasNext()) {
+      // We cannot fetch anything after `poll`. Two possible cases:
+      // - `offset` is out of range so that Kafka returns nothing. Just throw
+      // `OffsetOutOfRangeException` to let the caller handle it.
+      // - Cannot fetch any data before timeout. TimeoutException will be thrown.
+      val (earliestOffset, latestOffset) = getAvailableOffsetRange()
+      if (offset < earliestOffset || offset >= latestOffset) {
+        throw new OffsetOutOfRangeException(
+          Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava)
+      } else {
+        throw new TimeoutException(
+          s"Cannot fetch record for offset $offset in $pollTimeoutMs milliseconds")
+      }
+    } else {
+      val record = fetchedData.next()
+      nextOffsetInFetchedData = record.offset + 1
+      // `seek` is always called before "poll". So "record.offset" must be same as "offset".
       assert(record.offset == offset,
-        s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset")
+        s"The fetched data has a different offset: expected $offset but was ${record.offset}")
+      record
     }
+  }
+
+  /** Create a new consumer and reset cached states */
+  private def resetConsumer(): Unit = {
+    consumer.close()
+    consumer = createConsumer
+    resetFetchedData()
+  }
 
-    nextOffsetInFetchedData = offset + 1
-    record
+  /** Reset the internal pre-fetched data. */
+  private def resetFetchedData(): Unit = {
+    nextOffsetInFetchedData = UNKNOWN_OFFSET
+    fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
+  }
+
+  /**
+   * Return an addition message including useful message and instruction.
+   */
+  private def additionalMessage(failOnDataLoss: Boolean): String = {
+    if (failOnDataLoss) {
+      s"(GroupId: $groupId, TopicPartition: $topicPartition). " +
+        s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE"
+    } else {
+      s"(GroupId: $groupId, TopicPartition: $topicPartition). " +
+        s"$INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE"
+    }
+  }
+
+  /**
+   * Throw an exception or log a warning as per `failOnDataLoss`.
+   */
+  private def reportDataLoss(
+      failOnDataLoss: Boolean,
+      message: String,
+      cause: Throwable = null): Unit = {
+    val finalMessage = s"$message ${additionalMessage(failOnDataLoss)}"
+    if (failOnDataLoss) {
+      if (cause != null) {
+        throw new IllegalStateException(finalMessage)
+      } else {
+        throw new IllegalStateException(finalMessage, cause)
+      }
+    } else {
+      if (cause != null) {
+        logWarning(finalMessage)
+      } else {
+        logWarning(finalMessage, cause)
+      }
+    }
   }
 
   private def close(): Unit = consumer.close()
@@ -96,10 +259,24 @@ private[kafka010] case class CachedKafkaConsumer private(
     logDebug(s"Polled $groupId ${p.partitions()}  ${r.size}")
     fetchedData = r.iterator
   }
+
+  /**
+   * Return the available offset range of the current partition. It's a pair of the earliest offset
+   * and the latest offset.
+   */
+  private def getAvailableOffsetRange(): (Long, Long) = {
+    consumer.seekToBeginning(Set(topicPartition).asJava)
+    val earliestOffset = consumer.position(topicPartition)
+    consumer.seekToEnd(Set(topicPartition).asJava)
+    val latestOffset = consumer.position(topicPartition)
+    (earliestOffset, latestOffset)
+  }
 }
 
 private[kafka010] object CachedKafkaConsumer extends Logging {
 
+  private val UNKNOWN_OFFSET = -2L
+
   private case class CacheKey(groupId: String, topicPartition: TopicPartition)
 
   private lazy val cache = {
@@ -140,7 +317,10 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
     // If this is reattempt at running the task, then invalidate cache and start with
     // a new consumer
     if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
-      cache.remove(key)
+      val removedConsumer = cache.remove(key)
+      if (removedConsumer != null) {
+        removedConsumer.close()
+      }
       new CachedKafkaConsumer(topicPartition, kafkaParams)
     } else {
       if (!cache.containsKey(key)) {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 341081a338..1d0d402b82 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -281,7 +281,7 @@ private[kafka010] case class KafkaSource(
 
     // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
     val rdd = new KafkaSourceRDD(
-      sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr =>
+      sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr =>
       Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id)
     }
 
@@ -463,10 +463,9 @@ private[kafka010] case class KafkaSource(
    */
   private def reportDataLoss(message: String): Unit = {
     if (failOnDataLoss) {
-      throw new IllegalStateException(message +
-        ". Set the source option 'failOnDataLoss' to 'false' if you want to ignore these checks.")
+      throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
     } else {
-      logWarning(message)
+      logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
     }
   }
 }
@@ -475,6 +474,22 @@ private[kafka010] case class KafkaSource(
 /** Companion object for the [[KafkaSource]]. */
 private[kafka010] object KafkaSource {
 
+  val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE =
+    """
+      |Some data may have been lost because they are not available in Kafka any more; either the
+      | data was aged out by Kafka or the topic may have been deleted before all the data in the
+      | topic was processed. If you want your streaming query to fail on such cases, set the source
+      | option "failOnDataLoss" to "true".
+    """.stripMargin
+
+  val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE =
+    """
+      |Some data may have been lost because they are not available in Kafka any more; either the
+      | data was aged out by Kafka or the topic may have been deleted before all the data in the
+      | topic was processed. If you don't want your streaming query to fail on such cases, set the
+      | source option "failOnDataLoss" to "false".
+    """.stripMargin
+
   def kafkaSchema: StructType = StructType(Seq(
     StructField("key", BinaryType),
     StructField("value", BinaryType),
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
index 802dd040ae..244cd2c225 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.NextIterator
 
 
 /** Offset range that one partition of the KafkaSourceRDD has to read */
@@ -61,7 +62,8 @@ private[kafka010] class KafkaSourceRDD(
     sc: SparkContext,
     executorKafkaParams: ju.Map[String, Object],
     offsetRanges: Seq[KafkaSourceRDDOffsetRange],
-    pollTimeoutMs: Long)
+    pollTimeoutMs: Long,
+    failOnDataLoss: Boolean)
   extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) {
 
   override def persist(newLevel: StorageLevel): this.type = {
@@ -130,23 +132,31 @@ private[kafka010] class KafkaSourceRDD(
       logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " +
         s"skipping ${range.topic} ${range.partition}")
       Iterator.empty
-
     } else {
-
-      val consumer = CachedKafkaConsumer.getOrCreate(
-        range.topic, range.partition, executorKafkaParams)
-      var requestOffset = range.fromOffset
-
-      logDebug(s"Creating iterator for $range")
-
-      new Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]() {
-        override def hasNext(): Boolean = requestOffset < range.untilOffset
-        override def next(): ConsumerRecord[Array[Byte], Array[Byte]] = {
-          assert(hasNext(), "Can't call next() once untilOffset has been reached")
-          val r = consumer.get(requestOffset, pollTimeoutMs)
-          requestOffset += 1
-          r
+      new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() {
+        val consumer = CachedKafkaConsumer.getOrCreate(
+          range.topic, range.partition, executorKafkaParams)
+        var requestOffset = range.fromOffset
+
+        override def getNext(): ConsumerRecord[Array[Byte], Array[Byte]] = {
+          if (requestOffset >= range.untilOffset) {
+            // Processed all offsets in this partition.
+            finished = true
+            null
+          } else {
+            val r = consumer.get(requestOffset, range.untilOffset, pollTimeoutMs, failOnDataLoss)
+            if (r == null) {
+              // Losing some data. Skip the rest offsets in this partition.
+              finished = true
+              null
+            } else {
+              requestOffset = r.offset + 1
+              r
+            }
+          }
         }
+
+        override protected def close(): Unit = {}
       }
     }
   }
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index 89e713f92d..cd52fd93d1 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -17,8 +17,12 @@
 
 package org.apache.spark.sql.kafka010
 
+import java.util.Properties
+import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
 
+import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.util.Random
 
 import org.apache.kafka.clients.producer.RecordMetadata
@@ -27,8 +31,9 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.time.SpanSugar._
 
+import org.apache.spark.sql.ForeachWriter
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.streaming.{ ProcessingTime, StreamTest }
+import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
 import org.apache.spark.sql.test.SharedSQLContext
 
 abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
@@ -202,7 +207,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
 
   test("cannot stop Kafka stream") {
     val topic = newTopic()
-    testUtils.createTopic(newTopic(), partitions = 5)
+    testUtils.createTopic(topic, partitions = 5)
     testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray)
 
     val reader = spark
@@ -223,52 +228,85 @@ class KafkaSourceSuite extends KafkaSourceTest {
     )
   }
 
-  test("assign from latest offsets") {
-    val topic = newTopic()
-    testFromLatestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4))
-  }
+  for (failOnDataLoss <- Seq(true, false)) {
+    test(s"assign from latest offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topic = newTopic()
+      testFromLatestOffsets(
+        topic,
+        addPartitions = false,
+        failOnDataLoss = failOnDataLoss,
+        "assign" -> assignString(topic, 0 to 4))
+    }
 
-  test("assign from earliest offsets") {
-    val topic = newTopic()
-    testFromEarliestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4))
-  }
+    test(s"assign from earliest offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topic = newTopic()
+      testFromEarliestOffsets(
+        topic,
+        addPartitions = false,
+        failOnDataLoss = failOnDataLoss,
+        "assign" -> assignString(topic, 0 to 4))
+    }
 
-  test("assign from specific offsets") {
-    val topic = newTopic()
-    testFromSpecificOffsets(topic, "assign" -> assignString(topic, 0 to 4))
-  }
+    test(s"assign from specific offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topic = newTopic()
+      testFromSpecificOffsets(
+        topic,
+        failOnDataLoss = failOnDataLoss,
+        "assign" -> assignString(topic, 0 to 4),
+        "failOnDataLoss" -> failOnDataLoss.toString)
+    }
 
-  test("subscribing topic by name from latest offsets") {
-    val topic = newTopic()
-    testFromLatestOffsets(topic, true, "subscribe" -> topic)
-  }
+    test(s"subscribing topic by name from latest offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topic = newTopic()
+      testFromLatestOffsets(
+        topic,
+        addPartitions = true,
+        failOnDataLoss = failOnDataLoss,
+        "subscribe" -> topic)
+    }
 
-  test("subscribing topic by name from earliest offsets") {
-    val topic = newTopic()
-    testFromEarliestOffsets(topic, true, "subscribe" -> topic)
-  }
+    test(s"subscribing topic by name from earliest offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topic = newTopic()
+      testFromEarliestOffsets(
+        topic,
+        addPartitions = true,
+        failOnDataLoss = failOnDataLoss,
+        "subscribe" -> topic)
+    }
 
-  test("subscribing topic by name from specific offsets") {
-    val topic = newTopic()
-    testFromSpecificOffsets(topic, "subscribe" -> topic)
-  }
+    test(s"subscribing topic by name from specific offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topic = newTopic()
+      testFromSpecificOffsets(topic, failOnDataLoss = failOnDataLoss, "subscribe" -> topic)
+    }
 
-  test("subscribing topic by pattern from latest offsets") {
-    val topicPrefix = newTopic()
-    val topic = topicPrefix + "-suffix"
-    testFromLatestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*")
-  }
+    test(s"subscribing topic by pattern from latest offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topicPrefix = newTopic()
+      val topic = topicPrefix + "-suffix"
+      testFromLatestOffsets(
+        topic,
+        addPartitions = true,
+        failOnDataLoss = failOnDataLoss,
+        "subscribePattern" -> s"$topicPrefix-.*")
+    }
 
-  test("subscribing topic by pattern from earliest offsets") {
-    val topicPrefix = newTopic()
-    val topic = topicPrefix + "-suffix"
-    testFromEarliestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*")
-  }
+    test(s"subscribing topic by pattern from earliest offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topicPrefix = newTopic()
+      val topic = topicPrefix + "-suffix"
+      testFromEarliestOffsets(
+        topic,
+        addPartitions = true,
+        failOnDataLoss = failOnDataLoss,
+        "subscribePattern" -> s"$topicPrefix-.*")
+    }
 
-  test("subscribing topic by pattern from specific offsets") {
-    val topicPrefix = newTopic()
-    val topic = topicPrefix + "-suffix"
-    testFromSpecificOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*")
+    test(s"subscribing topic by pattern from specific offsets (failOnDataLoss: $failOnDataLoss)") {
+      val topicPrefix = newTopic()
+      val topic = topicPrefix + "-suffix"
+      testFromSpecificOffsets(
+        topic,
+        failOnDataLoss = failOnDataLoss,
+        "subscribePattern" -> s"$topicPrefix-.*")
+    }
   }
 
   test("subscribing topic by pattern with topic deletions") {
@@ -413,13 +451,59 @@ class KafkaSourceSuite extends KafkaSourceTest {
     )
   }
 
+  test("delete a topic when a Spark job is running") {
+    KafkaSourceSuite.collectedData.clear()
+
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 1)
+    testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribe", topic)
+      // If a topic is deleted and we try to poll data starting from offset 0,
+      // the Kafka consumer will just block until timeout and return an empty result.
+      // So set the timeout to 1 second to make this test fast.
+      .option("kafkaConsumer.pollTimeoutMs", "1000")
+      .option("startingOffsets", "earliest")
+      .option("failOnDataLoss", "false")
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    KafkaSourceSuite.globalTestUtils = testUtils
+    // The following ForeachWriter will delete the topic before fetching data from Kafka
+    // in executors.
+    val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
+      override def open(partitionId: Long, version: Long): Boolean = {
+        KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
+        true
+      }
+
+      override def process(value: Int): Unit = {
+        KafkaSourceSuite.collectedData.add(value)
+      }
+
+      override def close(errorOrNull: Throwable): Unit = {}
+    }).start()
+    query.processAllAvailable()
+    query.stop()
+    // `failOnDataLoss` is `false`, we should not fail the query
+    assert(query.exception.isEmpty)
+  }
+
   private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
 
   private def assignString(topic: String, partitions: Iterable[Int]): String = {
     JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
   }
 
-  private def testFromSpecificOffsets(topic: String, options: (String, String)*): Unit = {
+  private def testFromSpecificOffsets(
+      topic: String,
+      failOnDataLoss: Boolean,
+      options: (String, String)*): Unit = {
     val partitionOffsets = Map(
       new TopicPartition(topic, 0) -> -2L,
       new TopicPartition(topic, 1) -> -1L,
@@ -448,6 +532,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .option("startingOffsets", startingOffsets)
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
+      .option("failOnDataLoss", failOnDataLoss.toString)
     options.foreach { case (k, v) => reader.option(k, v) }
     val kafka = reader.load()
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@@ -469,6 +554,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
   private def testFromLatestOffsets(
       topic: String,
       addPartitions: Boolean,
+      failOnDataLoss: Boolean,
       options: (String, String)*): Unit = {
     testUtils.createTopic(topic, partitions = 5)
     testUtils.sendMessages(topic, Array("-1"))
@@ -480,6 +566,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .option("startingOffsets", s"latest")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
+      .option("failOnDataLoss", failOnDataLoss.toString)
     options.foreach { case (k, v) => reader.option(k, v) }
     val kafka = reader.load()
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@@ -513,6 +600,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
   private def testFromEarliestOffsets(
       topic: String,
       addPartitions: Boolean,
+      failOnDataLoss: Boolean,
       options: (String, String)*): Unit = {
     testUtils.createTopic(topic, partitions = 5)
     testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray)
@@ -524,6 +612,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .option("startingOffsets", s"earliest")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
+      .option("failOnDataLoss", failOnDataLoss.toString)
     options.foreach { case (k, v) => reader.option(k, v) }
     val kafka = reader.load()
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@@ -552,6 +641,11 @@ class KafkaSourceSuite extends KafkaSourceTest {
   }
 }
 
+object KafkaSourceSuite {
+  @volatile var globalTestUtils: KafkaTestUtils = _
+  val collectedData = new ConcurrentLinkedQueue[Any]()
+}
+
 
 class KafkaSourceStressSuite extends KafkaSourceTest {
 
@@ -615,7 +709,7 @@ class KafkaSourceStressSuite extends KafkaSourceTest {
                 }
               })
           case 2 => // Add new partitions
-            AddKafkaData(topics.toSet, d: _*)(message = "Add partitiosn",
+            AddKafkaData(topics.toSet, d: _*)(message = "Add partition",
               topicAction = (topic, partition) => {
                 testUtils.addPartitions(topic, partition.get + nextInt(1, 6))
               })
@@ -626,3 +720,122 @@ class KafkaSourceStressSuite extends KafkaSourceTest {
       iterations = 50)
   }
 }
+
+class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with SharedSQLContext {
+
+  import testImplicits._
+
+  private var testUtils: KafkaTestUtils = _
+
+  private val topicId = new AtomicInteger(0)
+
+  private def newTopic(): String = s"failOnDataLoss-${topicId.getAndIncrement()}"
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    testUtils = new KafkaTestUtils {
+      override def brokerConfiguration: Properties = {
+        val props = super.brokerConfiguration
+        // Try to make Kafka clean up messages as fast as possible. However, there is a hard-code
+        // 30 seconds delay (kafka.log.LogManager.InitialTaskDelayMs) so this test should run at
+        // least 30 seconds.
+        props.put("log.cleaner.backoff.ms", "100")
+        props.put("log.segment.bytes", "40")
+        props.put("log.retention.bytes", "40")
+        props.put("log.retention.check.interval.ms", "100")
+        props.put("delete.retention.ms", "10")
+        props.put("log.flush.scheduler.interval.ms", "10")
+        props
+      }
+    }
+    testUtils.setup()
+  }
+
+  override def afterAll(): Unit = {
+    if (testUtils != null) {
+      testUtils.teardown()
+      testUtils = null
+      super.afterAll()
+    }
+  }
+
+  test("stress test for failOnDataLoss=false") {
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribePattern", "failOnDataLoss.*")
+      .option("startingOffsets", "earliest")
+      .option("failOnDataLoss", "false")
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
+
+      override def open(partitionId: Long, version: Long): Boolean = {
+        true
+      }
+
+      override def process(value: Int): Unit = {
+        // Slow down the processing speed so that messages may be aged out.
+        Thread.sleep(Random.nextInt(500))
+      }
+
+      override def close(errorOrNull: Throwable): Unit = {
+      }
+    }).start()
+
+    val testTime = 1.minutes
+    val startTime = System.currentTimeMillis()
+    // Track the current existing topics
+    val topics = mutable.ArrayBuffer[String]()
+    // Track topics that have been deleted
+    val deletedTopics = mutable.Set[String]()
+    while (System.currentTimeMillis() - testTime.toMillis < startTime) {
+      Random.nextInt(10) match {
+        case 0 => // Create a new topic
+          val topic = newTopic()
+          topics += topic
+          // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small
+          // chance that a topic will be recreated after deletion due to the asynchronous update.
+          // Hence, always overwrite to handle this race condition.
+          testUtils.createTopic(topic, partitions = 1, overwrite = true)
+          logInfo(s"Create topic $topic")
+        case 1 if topics.nonEmpty => // Delete an existing topic
+          val topic = topics.remove(Random.nextInt(topics.size))
+          testUtils.deleteTopic(topic)
+          logInfo(s"Delete topic $topic")
+          deletedTopics += topic
+        case 2 if deletedTopics.nonEmpty => // Recreate a topic that was deleted.
+          val topic = deletedTopics.toSeq(Random.nextInt(deletedTopics.size))
+          deletedTopics -= topic
+          topics += topic
+          // As pushing messages into Kafka updates Zookeeper asynchronously, there is a small
+          // chance that a topic will be recreated after deletion due to the asynchronous update.
+          // Hence, always overwrite to handle this race condition.
+          testUtils.createTopic(topic, partitions = 1, overwrite = true)
+          logInfo(s"Create topic $topic")
+        case 3 =>
+          Thread.sleep(1000)
+        case _ => // Push random messages
+          for (topic <- topics) {
+            val size = Random.nextInt(10)
+            for (_ <- 0 until size) {
+              testUtils.sendMessages(topic, Array(Random.nextInt(10).toString))
+            }
+          }
+      }
+      // `failOnDataLoss` is `false`, we should not fail the query
+      if (query.exception.nonEmpty) {
+        throw query.exception.get
+      }
+    }
+
+    query.stop()
+    // `failOnDataLoss` is `false`, we should not fail the query
+    if (query.exception.nonEmpty) {
+      throw query.exception.get
+    }
+  }
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
index 9b24ccdd56..f43917e151 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
@@ -155,8 +155,16 @@ class KafkaTestUtils extends Logging {
   }
 
   /** Create a Kafka topic and wait until it is propagated to the whole cluster */
-  def createTopic(topic: String, partitions: Int): Unit = {
-    AdminUtils.createTopic(zkUtils, topic, partitions, 1)
+  def createTopic(topic: String, partitions: Int, overwrite: Boolean = false): Unit = {
+    var created = false
+    while (!created) {
+      try {
+        AdminUtils.createTopic(zkUtils, topic, partitions, 1)
+        created = true
+      } catch {
+        case e: kafka.common.TopicExistsException if overwrite => deleteTopic(topic)
+      }
+    }
     // wait until metadata is propagated
     (0 until partitions).foreach { p =>
       waitUntilMetadataIsPropagated(topic, p)
@@ -244,7 +252,7 @@ class KafkaTestUtils extends Logging {
     offsets
   }
 
-  private def brokerConfiguration: Properties = {
+  protected def brokerConfiguration: Properties = {
     val props = new Properties()
     props.put("broker.id", "0")
     props.put("host.name", "localhost")
@@ -302,9 +310,11 @@ class KafkaTestUtils extends Logging {
         }
         checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp))
       })
-      deletePath && topicPath && replicaManager && logManager && cleaner
+      // ensure the topic is gone
+      val deleted = !zkUtils.getAllTopics().contains(topic)
+      deletePath && topicPath && replicaManager && logManager && cleaner && deleted
     }
-    eventually(timeout(10.seconds)) {
+    eventually(timeout(60.seconds)) {
       assert(isDeleted, s"$topic not deleted after timeout")
     }
   }
-- 
GitLab