diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
index bf6c0900c97e16cc4de48e07fabdb6d35dc4e496..7c4f38e02fb2a6bb58ebaeae4d0d4fd55014ecf7 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
@@ -287,7 +287,7 @@ private[kafka010] case class CachedKafkaConsumer private(
     reportDataLoss0(failOnDataLoss, finalMessage, cause)
   }
 
-  private def close(): Unit = consumer.close()
+  def close(): Unit = consumer.close()
 
   private def seek(offset: Long): Unit = {
     logDebug(s"Seeking to $groupId $topicPartition $offset")
@@ -382,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
 
     // If this is reattempt at running the task, then invalidate cache and start with
     // a new consumer
-    if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
+    if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
       removeKafkaConsumer(topic, partition, kafkaParams)
       val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams)
       consumer.inuse = true
@@ -398,6 +398,14 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
     }
   }
 
+  /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */
+  def createUncached(
+      topic: String,
+      partition: Int,
+      kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = {
+    new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams)
+  }
+
   private def reportDataLoss0(
       failOnDataLoss: Boolean,
       finalMessage: String,
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
index 2696d6f089d2f6102394d81f9f1acea6dd1fbbba..3e65949a6fd1b428b8783b3e7e9964e6f134ffae 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader(
    * Closes the connection to Kafka, and cleans up state.
    */
   def close(): Unit = {
-    consumer.close()
-    kafkaReaderThread.shutdownNow()
+    runUninterruptibly {
+      consumer.close()
+    }
+    kafkaReaderThread.shutdown()
   }
 
   /**
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
index f180bbad6e36305bbd23a7a195aafebff4585e70..97bd2831693237995aaa822c7b3698af9fe11db7 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
+import java.util.UUID
 
 import org.apache.kafka.common.TopicPartition
 
@@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String
 
 private[kafka010] class KafkaRelation(
     override val sqlContext: SQLContext,
-    kafkaReader: KafkaOffsetReader,
-    executorKafkaParams: ju.Map[String, Object],
+    strategy: ConsumerStrategy,
     sourceOptions: Map[String, String],
+    specifiedKafkaParams: Map[String, String],
     failOnDataLoss: Boolean,
     startingOffsets: KafkaOffsetRangeLimit,
     endingOffsets: KafkaOffsetRangeLimit)
@@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation(
   override def schema: StructType = KafkaOffsetReader.kafkaSchema
 
   override def buildScan(): RDD[Row] = {
+    // Each running query should use its own group id. Otherwise, the query may be only assigned
+    // partial data since Kafka will assign partitions to multiple consumers having the same group
+    // id. Hence, we should generate a unique id for each query.
+    val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
+
+    val kafkaOffsetReader = new KafkaOffsetReader(
+      strategy,
+      KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams),
+      sourceOptions,
+      driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
     // Leverage the KafkaReader to obtain the relevant partition offsets
-    val fromPartitionOffsets = getPartitionOffsets(startingOffsets)
-    val untilPartitionOffsets = getPartitionOffsets(endingOffsets)
+    val (fromPartitionOffsets, untilPartitionOffsets) = {
+      try {
+        (getPartitionOffsets(kafkaOffsetReader, startingOffsets),
+          getPartitionOffsets(kafkaOffsetReader, endingOffsets))
+      } finally {
+        kafkaOffsetReader.close()
+      }
+    }
+
     // Obtain topicPartitions in both from and until partition offset, ignoring
     // topic partitions that were added and/or deleted between the two above calls.
     if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) {
@@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation(
       offsetRanges.sortBy(_.topicPartition.toString).mkString(", "))
 
     // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
+    val executorKafkaParams =
+      KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
     val rdd = new KafkaSourceRDD(
       sqlContext.sparkContext, executorKafkaParams, offsetRanges,
       pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr =>
@@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation(
   }
 
   private def getPartitionOffsets(
+      kafkaReader: KafkaOffsetReader,
       kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = {
     def validateTopicPartitions(partitions: Set[TopicPartition],
       partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index ab1ce347cbe346944e74126d14bf94e6afc098e5..3cb4d8cad12cc3fb90622898edf7b02ac4b04eea 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -111,10 +111,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
       sqlContext: SQLContext,
       parameters: Map[String, String]): BaseRelation = {
     validateBatchOptions(parameters)
-    // Each running query should use its own group id. Otherwise, the query may be only assigned
-    // partial data since Kafka will assign partitions to multiple consumers having the same group
-    // id. Hence, we should generate a unique id for each query.
-    val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
     val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
     val specifiedKafkaParams =
       parameters
@@ -131,20 +127,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
       ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
     assert(endingRelationOffsets != EarliestOffsetRangeLimit)
 
-    val kafkaOffsetReader = new KafkaOffsetReader(
-      strategy(caseInsensitiveParams),
-      kafkaParamsForDriver(specifiedKafkaParams),
-      parameters,
-      driverGroupIdPrefix = s"$uniqueGroupId-driver")
-
     new KafkaRelation(
       sqlContext,
-      kafkaOffsetReader,
-      kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
-      parameters,
-      failOnDataLoss(caseInsensitiveParams),
-      startingRelationOffsets,
-      endingRelationOffsets)
+      strategy(caseInsensitiveParams),
+      sourceOptions = parameters,
+      specifiedKafkaParams = specifiedKafkaParams,
+      failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
+      startingOffsets = startingRelationOffsets,
+      endingOffsets = endingRelationOffsets)
   }
 
   override def createSink(
@@ -213,46 +203,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
         ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
   }
 
-  private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) =
-    ConfigUpdater("source", specifiedKafkaParams)
-      .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
-      .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
-
-      // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
-      // offsets by itself instead of counting on KafkaConsumer.
-      .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
-
-      // So that consumers in the driver does not commit offsets unnecessarily
-      .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
-
-      // So that the driver does not pull too much data
-      .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
-
-      // If buffer config is not set, set it to reasonable value to work around
-      // buffer issues (see KAFKA-3135)
-      .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
-      .build()
-
-  private def kafkaParamsForExecutors(
-      specifiedKafkaParams: Map[String, String], uniqueGroupId: String) =
-    ConfigUpdater("executor", specifiedKafkaParams)
-      .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
-      .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
-
-      // Make sure executors do only what the driver tells them.
-      .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
-
-      // So that consumers in executors do not mess with any existing group id
-      .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
-
-      // So that consumers in executors does not commit offsets unnecessarily
-      .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
-
-      // If buffer config is not set, set it to reasonable value to work around
-      // buffer issues (see KAFKA-3135)
-      .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
-      .build()
-
   private def strategy(caseInsensitiveParams: Map[String, String]) =
       caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
     case ("assign", value) =>
@@ -414,30 +364,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
       logWarning("maxOffsetsPerTrigger option ignored in batch queries")
     }
   }
-
-  /** Class to conveniently update Kafka config params, while logging the changes */
-  private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
-    private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
-
-    def set(key: String, value: Object): this.type = {
-      map.put(key, value)
-      logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}")
-      this
-    }
-
-    def setIfUnset(key: String, value: Object): ConfigUpdater = {
-      if (!map.containsKey(key)) {
-        map.put(key, value)
-        logInfo(s"$module: Set $key to $value")
-      }
-      this
-    }
-
-    def build(): ju.Map[String, Object] = map
-  }
 }
 
-private[kafka010] object KafkaSourceProvider {
+private[kafka010] object KafkaSourceProvider extends Logging {
   private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign")
   private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
   private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
@@ -459,4 +388,66 @@ private[kafka010] object KafkaSourceProvider {
       case None => defaultOffsets
     }
   }
+
+  def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] =
+    ConfigUpdater("source", specifiedKafkaParams)
+      .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
+      .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
+
+      // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
+      // offsets by itself instead of counting on KafkaConsumer.
+      .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
+
+      // So that consumers in the driver does not commit offsets unnecessarily
+      .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+
+      // So that the driver does not pull too much data
+      .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
+
+      // If buffer config is not set, set it to reasonable value to work around
+      // buffer issues (see KAFKA-3135)
+      .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+      .build()
+
+  def kafkaParamsForExecutors(
+      specifiedKafkaParams: Map[String, String],
+      uniqueGroupId: String): ju.Map[String, Object] =
+    ConfigUpdater("executor", specifiedKafkaParams)
+      .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
+      .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
+
+      // Make sure executors do only what the driver tells them.
+      .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
+
+      // So that consumers in executors do not mess with any existing group id
+      .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
+
+      // So that consumers in executors does not commit offsets unnecessarily
+      .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+
+      // If buffer config is not set, set it to reasonable value to work around
+      // buffer issues (see KAFKA-3135)
+      .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+      .build()
+
+  /** Class to conveniently update Kafka config params, while logging the changes */
+  private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
+    private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
+
+    def set(key: String, value: Object): this.type = {
+      map.put(key, value)
+      logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}")
+      this
+    }
+
+    def setIfUnset(key: String, value: Object): ConfigUpdater = {
+      if (!map.containsKey(key)) {
+        map.put(key, value)
+        logDebug(s"$module: Set $key to $value")
+      }
+      this
+    }
+
+    def build(): ju.Map[String, Object] = map
+  }
 }
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
index 6fb3473eb75f5092f081f645e8cb2be3f7333b9f..9d9e2aaba80799dbfaf71cbb6c48e4ed0ec97d36 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
@@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD(
       context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = {
     val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition]
     val topic = sourcePartition.offsetRange.topic
-    if (!reuseKafkaConsumer) {
-      // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique
-      // to each task (i.e., append the task's unique partition id), because we will have
-      // multiple tasks (e.g., in the case of union) reading from the same topic partitions
-      val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-      val id = TaskContext.getPartitionId()
-      executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id)
-    }
     val kafkaPartition = sourcePartition.offsetRange.partition
-    val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
+    val consumer =
+      if (!reuseKafkaConsumer) {
+        // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we
+        // uses `assign`, we don't need to worry about the "group.id" conflicts.
+        CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams)
+      } else {
+        CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
+      }
     val range = resolveRange(consumer, sourcePartition.offsetRange)
     assert(
       range.fromOffset <= range.untilOffset,
@@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD(
         override protected def close(): Unit = {
           if (!reuseKafkaConsumer) {
             // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
-            CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
+            consumer.close()
           } else {
             // Indicate that we're no longer using this consumer
             CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
index 4c6e2ce87e2956a8dd03ed81200e2f23e2731a28..62cdf5b1134e4b31be356f6f9824780ca9ed08b8 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V](
 
     val consumer = if (useConsumerCache) {
       CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
-      if (context.attemptNumber > 1) {
+      if (context.attemptNumber >= 1) {
         // just in case the prior attempt failures were cache related
         CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
       }