diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index b34ab51f3b9969350334fa22c465d2482b6d8456..0cf078c378fd9e712389ad14f8e68333e70d5d5e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -245,7 +245,8 @@ streaming_kafka_0_10 = Module(
     name="streaming-kafka-0-10",
     dependencies=[streaming],
     source_file_regexes=[
-        "external/kafka-0-10",
+        # The ending "/" is necessary otherwise it will include "sql-kafka" codes
+        "external/kafka-0-10/",
         "external/kafka-0-10-assembly",
     ],
     sbt_test_goals=[
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 92ee0ed93d940b4b011567c76ebadd5dbfb4b9a6..43b8d9d6d7eef687a9f4c19a69b20d760ccc29c6 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
-import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer, OffsetOutOfRangeException}
+import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetOutOfRangeException}
 import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener
 import org.apache.kafka.common.TopicPartition
 
@@ -81,14 +81,16 @@ import org.apache.spark.util.UninterruptibleThread
  * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers
  * and not use wrong broker addresses.
  */
-private[kafka010] case class KafkaSource(
+private[kafka010] class KafkaSource(
     sqlContext: SQLContext,
     consumerStrategy: ConsumerStrategy,
+    driverKafkaParams: ju.Map[String, Object],
     executorKafkaParams: ju.Map[String, Object],
     sourceOptions: Map[String, String],
     metadataPath: String,
     startingOffsets: StartingOffsets,
-    failOnDataLoss: Boolean)
+    failOnDataLoss: Boolean,
+    driverGroupIdPrefix: String)
   extends Source with Logging {
 
   private val sc = sqlContext.sparkContext
@@ -107,11 +109,31 @@ private[kafka010] case class KafkaSource(
   private val maxOffsetsPerTrigger =
     sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong)
 
+  private var groupId: String = null
+
+  private var nextId = 0
+
+  private def nextGroupId(): String = {
+    groupId = driverGroupIdPrefix + "-" + nextId
+    nextId += 1
+    groupId
+  }
+
   /**
    * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
    * offsets and never commits them.
    */
-  private val consumer = consumerStrategy.createConsumer()
+  private var consumer: Consumer[Array[Byte], Array[Byte]] = createConsumer()
+
+  /**
+   * Create a consumer using the new generated group id. We always use a new consumer to avoid
+   * just using a broken consumer to retry on Kafka errors, which likely will fail again.
+   */
+  private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized {
+    val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
+    newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
+    consumerStrategy.createConsumer(newKafkaParams)
+  }
 
   /**
    * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
@@ -171,6 +193,11 @@ private[kafka010] case class KafkaSource(
     Some(KafkaSourceOffset(offsets))
   }
 
+  private def resetConsumer(): Unit = synchronized {
+    consumer.close()
+    consumer = createConsumer()
+  }
+
   /** Proportionally distribute limit number of offsets among topicpartitions */
   private def rateLimit(
       limit: Long,
@@ -441,13 +468,12 @@ private[kafka010] case class KafkaSource(
               try {
                 result = Some(body)
               } catch {
-                case x: OffsetOutOfRangeException =>
-                  reportDataLoss(x.getMessage)
                 case NonFatal(e) =>
                   lastException = e
                   logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e)
                   attempt += 1
                   Thread.sleep(offsetFetchAttemptIntervalMs)
+                  resetConsumer()
               }
             }
           case _ =>
@@ -511,12 +537,12 @@ private[kafka010] object KafkaSource {
   ))
 
   sealed trait ConsumerStrategy {
-    def createConsumer(): Consumer[Array[Byte], Array[Byte]]
+    def createConsumer(kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]]
   }
 
-  case class AssignStrategy(partitions: Array[TopicPartition], kafkaParams: ju.Map[String, Object])
-    extends ConsumerStrategy {
-    override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = {
+  case class AssignStrategy(partitions: Array[TopicPartition]) extends ConsumerStrategy {
+    override def createConsumer(
+        kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
       val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
       consumer.assign(ju.Arrays.asList(partitions: _*))
       consumer
@@ -525,9 +551,9 @@ private[kafka010] object KafkaSource {
     override def toString: String = s"Assign[${partitions.mkString(", ")}]"
   }
 
-  case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object])
-    extends ConsumerStrategy {
-    override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = {
+  case class SubscribeStrategy(topics: Seq[String]) extends ConsumerStrategy {
+    override def createConsumer(
+        kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
       val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
       consumer.subscribe(topics.asJava)
       consumer
@@ -536,10 +562,10 @@ private[kafka010] object KafkaSource {
     override def toString: String = s"Subscribe[${topics.mkString(", ")}]"
   }
 
-  case class SubscribePatternStrategy(
-    topicPattern: String, kafkaParams: ju.Map[String, Object])
+  case class SubscribePatternStrategy(topicPattern: String)
     extends ConsumerStrategy {
-    override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = {
+    override def createConsumer(
+        kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
       val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
       consumer.subscribe(
         ju.regex.Pattern.compile(topicPattern),
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 585ced875caa72feccb0f72bae0d8e2efd2adb67..aa01238f91247e173521c1c9a47b4cddc953279a 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -85,14 +85,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
         case None => LatestOffsets
       }
 
-    val kafkaParamsForStrategy =
+    val kafkaParamsForDriver =
       ConfigUpdater("source", specifiedKafkaParams)
         .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
         .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
 
-        // So that consumers in Kafka source do not mess with any existing group id
-        .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver")
-
         // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
         // offsets by itself instead of counting on KafkaConsumer.
         .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
@@ -129,17 +126,11 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
 
     val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
       case ("assign", value) =>
-        AssignStrategy(
-          JsonUtils.partitions(value),
-          kafkaParamsForStrategy)
+        AssignStrategy(JsonUtils.partitions(value))
       case ("subscribe", value) =>
-        SubscribeStrategy(
-          value.split(",").map(_.trim()).filter(_.nonEmpty),
-          kafkaParamsForStrategy)
+        SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
       case ("subscribepattern", value) =>
-        SubscribePatternStrategy(
-          value.trim(),
-          kafkaParamsForStrategy)
+        SubscribePatternStrategy(value.trim())
       case _ =>
         // Should never reach here as we are already matching on
         // matched strategy names
@@ -152,11 +143,13 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
     new KafkaSource(
       sqlContext,
       strategy,
+      kafkaParamsForDriver,
       kafkaParamsForExecutors,
       parameters,
       metadataPath,
       startingOffsets,
-      failOnDataLoss)
+      failOnDataLoss,
+      driverGroupIdPrefix = s"$uniqueGroupId-driver")
   }
 
   private def validateOptions(parameters: Map[String, String]): Unit = {
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index 5d2779aba26d01343129efd8e31e96b57773cc16..544fbc5ec36a26339ab4a9dea432dc0e20df04fa 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -845,7 +845,7 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
     }
   }
 
-  ignore("stress test for failOnDataLoss=false") {
+  test("stress test for failOnDataLoss=false") {
     val reader = spark
       .readStream
       .format("kafka")