diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
index 3f396a7e6b6985314ff05f67282de94aaccf627d..15b28256e825efcaf3d1196cc46adaf2f4c887f8 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
@@ -44,6 +44,9 @@ private[kafka010] case class CachedKafkaConsumer private(
 
   private var consumer = createConsumer
 
+  /** indicates whether this consumer is in use or not */
+  private var inuse = true
+
   /** Iterator to the already fetch data */
   private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]]
   private var nextOffsetInFetchedData = UNKNOWN_OFFSET
@@ -57,6 +60,20 @@ private[kafka010] case class CachedKafkaConsumer private(
     c
   }
 
+  case class AvailableOffsetRange(earliest: Long, latest: Long)
+
+  /**
+   * Return the available offset range of the current partition. It's a pair of the earliest offset
+   * and the latest offset.
+   */
+  def getAvailableOffsetRange(): AvailableOffsetRange = {
+    consumer.seekToBeginning(Set(topicPartition).asJava)
+    val earliestOffset = consumer.position(topicPartition)
+    consumer.seekToEnd(Set(topicPartition).asJava)
+    val latestOffset = consumer.position(topicPartition)
+    AvailableOffsetRange(earliestOffset, latestOffset)
+  }
+
   /**
    * Get the record for the given offset if available. Otherwise it will either throw error
    * (if failOnDataLoss = true), or return the next available offset within [offset, untilOffset),
@@ -107,9 +124,9 @@ private[kafka010] case class CachedKafkaConsumer private(
    * `UNKNOWN_OFFSET`.
    */
   private def getEarliestAvailableOffsetBetween(offset: Long, untilOffset: Long): Long = {
-    val (earliestOffset, latestOffset) = getAvailableOffsetRange()
-    logWarning(s"Some data may be lost. Recovering from the earliest offset: $earliestOffset")
-    if (offset >= latestOffset || earliestOffset >= untilOffset) {
+    val range = getAvailableOffsetRange()
+    logWarning(s"Some data may be lost. Recovering from the earliest offset: ${range.earliest}")
+    if (offset >= range.latest || range.earliest >= untilOffset) {
       // [offset, untilOffset) and [earliestOffset, latestOffset) have no overlap,
       // either
       // --------------------------------------------------------
@@ -124,13 +141,13 @@ private[kafka010] case class CachedKafkaConsumer private(
       //   offset   untilOffset   earliestOffset   latestOffset
       val warningMessage =
         s"""
-          |The current available offset range is [$earliestOffset, $latestOffset).
+          |The current available offset range is $range.
           | Offset ${offset} is out of range, and records in [$offset, $untilOffset) will be
           | skipped ${additionalMessage(failOnDataLoss = false)}
         """.stripMargin
       logWarning(warningMessage)
       UNKNOWN_OFFSET
-    } else if (offset >= earliestOffset) {
+    } else if (offset >= range.earliest) {
       // -----------------------------------------------------------------------------
       //         ^            ^                  ^                                 ^
       //         |            |                  |                                 |
@@ -149,12 +166,12 @@ private[kafka010] case class CachedKafkaConsumer private(
       //   offset   earliestOffset   min(untilOffset,latestOffset)   max(untilOffset, latestOffset)
       val warningMessage =
         s"""
-           |The current available offset range is [$earliestOffset, $latestOffset).
-           | Offset ${offset} is out of range, and records in [$offset, $earliestOffset) will be
+           |The current available offset range is $range.
+           | Offset ${offset} is out of range, and records in [$offset, ${range.earliest}) will be
            | skipped ${additionalMessage(failOnDataLoss = false)}
         """.stripMargin
       logWarning(warningMessage)
-      earliestOffset
+      range.earliest
     }
   }
 
@@ -183,8 +200,8 @@ private[kafka010] case class CachedKafkaConsumer private(
       // - `offset` is out of range so that Kafka returns nothing. Just throw
       // `OffsetOutOfRangeException` to let the caller handle it.
       // - Cannot fetch any data before timeout. TimeoutException will be thrown.
-      val (earliestOffset, latestOffset) = getAvailableOffsetRange()
-      if (offset < earliestOffset || offset >= latestOffset) {
+      val range = getAvailableOffsetRange()
+      if (offset < range.earliest || offset >= range.latest) {
         throw new OffsetOutOfRangeException(
           Map(topicPartition -> java.lang.Long.valueOf(offset)).asJava)
       } else {
@@ -284,18 +301,6 @@ private[kafka010] case class CachedKafkaConsumer private(
     logDebug(s"Polled $groupId ${p.partitions()}  ${r.size}")
     fetchedData = r.iterator
   }
-
-  /**
-   * Return the available offset range of the current partition. It's a pair of the earliest offset
-   * and the latest offset.
-   */
-  private def getAvailableOffsetRange(): (Long, Long) = {
-    consumer.seekToBeginning(Set(topicPartition).asJava)
-    val earliestOffset = consumer.position(topicPartition)
-    consumer.seekToEnd(Set(topicPartition).asJava)
-    val latestOffset = consumer.position(topicPartition)
-    (earliestOffset, latestOffset)
-  }
 }
 
 private[kafka010] object CachedKafkaConsumer extends Logging {
@@ -310,7 +315,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
     new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) {
       override def removeEldestEntry(
         entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = {
-        if (this.size > capacity) {
+        if (entry.getValue.inuse == false && this.size > capacity) {
           logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " +
             s"removing consumer for ${entry.getKey}")
           try {
@@ -327,6 +332,43 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
     }
   }
 
+  def releaseKafkaConsumer(
+      topic: String,
+      partition: Int,
+      kafkaParams: ju.Map[String, Object]): Unit = {
+    val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+    val topicPartition = new TopicPartition(topic, partition)
+    val key = CacheKey(groupId, topicPartition)
+
+    synchronized {
+      val consumer = cache.get(key)
+      if (consumer != null) {
+        consumer.inuse = false
+      } else {
+        logWarning(s"Attempting to release consumer that does not exist")
+      }
+    }
+  }
+
+  /**
+   * Removes (and closes) the Kafka Consumer for the given topic, partition and group id.
+   */
+  def removeKafkaConsumer(
+      topic: String,
+      partition: Int,
+      kafkaParams: ju.Map[String, Object]): Unit = {
+    val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+    val topicPartition = new TopicPartition(topic, partition)
+    val key = CacheKey(groupId, topicPartition)
+
+    synchronized {
+      val removedConsumer = cache.remove(key)
+      if (removedConsumer != null) {
+        removedConsumer.close()
+      }
+    }
+  }
+
   /**
    * Get a cached consumer for groupId, assigned to topic and partition.
    * If matching consumer doesn't already exist, will be created using kafkaParams.
@@ -342,16 +384,18 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
     // If this is reattempt at running the task, then invalidate cache and start with
     // a new consumer
     if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
-      val removedConsumer = cache.remove(key)
-      if (removedConsumer != null) {
-        removedConsumer.close()
-      }
-      new CachedKafkaConsumer(topicPartition, kafkaParams)
+      removeKafkaConsumer(topic, partition, kafkaParams)
+      val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams)
+      consumer.inuse = true
+      cache.put(key, consumer)
+      consumer
     } else {
       if (!cache.containsKey(key)) {
         cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams))
       }
-      cache.get(key)
+      val consumer = cache.get(key)
+      consumer.inuse = true
+      consumer
     }
   }
 }
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/ConsumerStrategy.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/ConsumerStrategy.scala
new file mode 100644
index 0000000000000000000000000000000000000000..66511b3065415ab7d9ed2d0a48e62beb94df0852
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/ConsumerStrategy.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer}
+import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener
+import org.apache.kafka.common.TopicPartition
+
+/**
+ * Subscribe allows you to subscribe to a fixed collection of topics.
+ * SubscribePattern allows you to use a regex to specify topics of interest.
+ * Note that unlike the 0.8 integration, using Subscribe or SubscribePattern
+ * should respond to adding partitions during a running stream.
+ * Finally, Assign allows you to specify a fixed collection of partitions.
+ * All three strategies have overloaded constructors that allow you to specify
+ * the starting offset for a particular partition.
+ */
+sealed trait ConsumerStrategy {
+  /** Create a [[KafkaConsumer]] and subscribe to topics according to a desired strategy */
+  def createConsumer(kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]]
+}
+
+/**
+ * Specify a fixed collection of partitions.
+ */
+case class AssignStrategy(partitions: Array[TopicPartition]) extends ConsumerStrategy {
+  override def createConsumer(
+      kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
+    val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
+    consumer.assign(ju.Arrays.asList(partitions: _*))
+    consumer
+  }
+
+  override def toString: String = s"Assign[${partitions.mkString(", ")}]"
+}
+
+/**
+ * Subscribe to a fixed collection of topics.
+ */
+case class SubscribeStrategy(topics: Seq[String]) extends ConsumerStrategy {
+  override def createConsumer(
+      kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
+    val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
+    consumer.subscribe(topics.asJava)
+    consumer
+  }
+
+  override def toString: String = s"Subscribe[${topics.mkString(", ")}]"
+}
+
+/**
+ * Use a regex to specify topics of interest.
+ */
+case class SubscribePatternStrategy(topicPattern: String) extends ConsumerStrategy {
+  override def createConsumer(
+      kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
+    val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
+    consumer.subscribe(
+      ju.regex.Pattern.compile(topicPattern),
+      new NoOpConsumerRebalanceListener())
+    consumer
+  }
+
+  override def toString: String = s"SubscribePattern[$topicPattern]"
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala
new file mode 100644
index 0000000000000000000000000000000000000000..80a026f4f5d7308020762f54490a5331a2336a3e
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeLimit.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.kafka.common.TopicPartition
+
+/**
+ * Objects that represent desired offset range limits for starting,
+ * ending, and specific offsets.
+ */
+private[kafka010] sealed trait KafkaOffsetRangeLimit
+
+/**
+ * Represents the desire to bind to the earliest offsets in Kafka
+ */
+private[kafka010] case object EarliestOffsetRangeLimit extends KafkaOffsetRangeLimit
+
+/**
+ * Represents the desire to bind to the latest offsets in Kafka
+ */
+private[kafka010] case object LatestOffsetRangeLimit extends KafkaOffsetRangeLimit
+
+/**
+ * Represents the desire to bind to specific offsets. A offset == -1 binds to the
+ * latest offset, and offset == -2 binds to the earliest offset.
+ */
+private[kafka010] case class SpecificOffsetRangeLimit(
+    partitionOffsets: Map[TopicPartition, Long]) extends KafkaOffsetRangeLimit
+
+private[kafka010] object KafkaOffsetRangeLimit {
+  /**
+   * Used to denote offset range limits that are resolved via Kafka
+   */
+  val LATEST = -1L // indicates resolution to the latest offset
+  val EARLIEST = -2L // indicates resolution to the earliest offset
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6b2fb3c112557de6448c8759bdf2f54b1d766e78
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+import java.util.concurrent.{Executors, ThreadFactory}
+
+import scala.collection.JavaConverters._
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.util.control.NonFatal
+
+import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer}
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.types._
+import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
+
+/**
+ * This class uses Kafka's own [[KafkaConsumer]] API to read data offsets from Kafka.
+ * The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read
+ * by this source. These strategies directly correspond to the different consumption options
+ * in. This class is designed to return a configured [[KafkaConsumer]] that is used by the
+ * [[KafkaSource]] to query for the offsets. See the docs on
+ * [[org.apache.spark.sql.kafka010.ConsumerStrategy]]
+ * for more details.
+ *
+ * Note: This class is not ThreadSafe
+ */
+private[kafka010] class KafkaOffsetReader(
+    consumerStrategy: ConsumerStrategy,
+    driverKafkaParams: ju.Map[String, Object],
+    readerOptions: Map[String, String],
+    driverGroupIdPrefix: String) extends Logging {
+  /**
+   * Used to ensure execute fetch operations execute in an UninterruptibleThread
+   */
+  val kafkaReaderThread = Executors.newSingleThreadExecutor(new ThreadFactory {
+    override def newThread(r: Runnable): Thread = {
+      val t = new UninterruptibleThread("Kafka Offset Reader") {
+        override def run(): Unit = {
+          r.run()
+        }
+      }
+      t.setDaemon(true)
+      t
+    }
+  })
+  val execContext = ExecutionContext.fromExecutorService(kafkaReaderThread)
+
+  /**
+   * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
+   * offsets and never commits them.
+   */
+  protected var consumer = createConsumer()
+
+  private val maxOffsetFetchAttempts =
+    readerOptions.getOrElse("fetchOffset.numRetries", "3").toInt
+
+  private val offsetFetchAttemptIntervalMs =
+    readerOptions.getOrElse("fetchOffset.retryIntervalMs", "1000").toLong
+
+  private var groupId: String = null
+
+  private var nextId = 0
+
+  private def nextGroupId(): String = {
+    groupId = driverGroupIdPrefix + "-" + nextId
+    nextId += 1
+    groupId
+  }
+
+  override def toString(): String = consumerStrategy.toString
+
+  /**
+   * Closes the connection to Kafka, and cleans up state.
+   */
+  def close(): Unit = {
+    consumer.close()
+    kafkaReaderThread.shutdownNow()
+  }
+
+  /**
+   * @return The Set of TopicPartitions for a given topic
+   */
+  def fetchTopicPartitions(): Set[TopicPartition] = runUninterruptibly {
+    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
+    // Poll to get the latest assigned partitions
+    consumer.poll(0)
+    val partitions = consumer.assignment()
+    consumer.pause(partitions)
+    partitions.asScala.toSet
+  }
+
+  /**
+   * Resolves the specific offsets based on Kafka seek positions.
+   * This method resolves offset value -1 to the latest and -2 to the
+   * earliest Kafka seek position.
+   */
+  def fetchSpecificOffsets(
+      partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] =
+    runUninterruptibly {
+      withRetriesWithoutInterrupt {
+        // Poll to get the latest assigned partitions
+        consumer.poll(0)
+        val partitions = consumer.assignment()
+        consumer.pause(partitions)
+        assert(partitions.asScala == partitionOffsets.keySet,
+          "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
+            "Use -1 for latest, -2 for earliest, if you don't care.\n" +
+            s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
+        logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
+
+        partitionOffsets.foreach {
+          case (tp, KafkaOffsetRangeLimit.LATEST) =>
+            consumer.seekToEnd(ju.Arrays.asList(tp))
+          case (tp, KafkaOffsetRangeLimit.EARLIEST) =>
+            consumer.seekToBeginning(ju.Arrays.asList(tp))
+          case (tp, off) => consumer.seek(tp, off)
+        }
+        partitionOffsets.map {
+          case (tp, _) => tp -> consumer.position(tp)
+        }
+      }
+    }
+
+  /**
+   * Fetch the earliest offsets for the topic partitions that are indicated
+   * in the [[ConsumerStrategy]].
+   */
+  def fetchEarliestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
+    withRetriesWithoutInterrupt {
+      // Poll to get the latest assigned partitions
+      consumer.poll(0)
+      val partitions = consumer.assignment()
+      consumer.pause(partitions)
+      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning")
+
+      consumer.seekToBeginning(partitions)
+      val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
+      logDebug(s"Got earliest offsets for partition : $partitionOffsets")
+      partitionOffsets
+    }
+  }
+
+  /**
+   * Fetch the latest offsets for the topic partitions that are indicated
+   * in the [[ConsumerStrategy]].
+   */
+  def fetchLatestOffsets(): Map[TopicPartition, Long] = runUninterruptibly {
+    withRetriesWithoutInterrupt {
+      // Poll to get the latest assigned partitions
+      consumer.poll(0)
+      val partitions = consumer.assignment()
+      consumer.pause(partitions)
+      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
+
+      consumer.seekToEnd(partitions)
+      val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
+      logDebug(s"Got latest offsets for partition : $partitionOffsets")
+      partitionOffsets
+    }
+  }
+
+  /**
+   * Fetch the earliest offsets for specific topic partitions.
+   * The return result may not contain some partitions if they are deleted.
+   */
+  def fetchEarliestOffsets(
+      newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = {
+    if (newPartitions.isEmpty) {
+      Map.empty[TopicPartition, Long]
+    } else {
+      runUninterruptibly {
+        withRetriesWithoutInterrupt {
+          // Poll to get the latest assigned partitions
+          consumer.poll(0)
+          val partitions = consumer.assignment()
+          consumer.pause(partitions)
+          logDebug(s"\tPartitions assigned to consumer: $partitions")
+
+          // Get the earliest offset of each partition
+          consumer.seekToBeginning(partitions)
+          val partitionOffsets = newPartitions.filter { p =>
+            // When deleting topics happen at the same time, some partitions may not be in
+            // `partitions`. So we need to ignore them
+            partitions.contains(p)
+          }.map(p => p -> consumer.position(p)).toMap
+          logDebug(s"Got earliest offsets for new partitions: $partitionOffsets")
+          partitionOffsets
+        }
+      }
+    }
+  }
+
+  /**
+   * This method ensures that the closure is called in an [[UninterruptibleThread]].
+   * This is required when communicating with the [[KafkaConsumer]]. In the case
+   * of streaming queries, we are already running in an [[UninterruptibleThread]],
+   * however for batch mode this is not the case.
+   */
+  private def runUninterruptibly[T](body: => T): T = {
+    if (!Thread.currentThread.isInstanceOf[UninterruptibleThread]) {
+      val future = Future {
+        body
+      }(execContext)
+      ThreadUtils.awaitResult(future, Duration.Inf)
+    } else {
+      body
+    }
+  }
+
+  /**
+   * Helper function that does multiple retries on a body of code that returns offsets.
+   * Retries are needed to handle transient failures. For e.g. race conditions between getting
+   * assignment and getting position while topics/partitions are deleted can cause NPEs.
+   *
+   * This method also makes sure `body` won't be interrupted to workaround a potential issue in
+   * `KafkaConsumer.poll`. (KAFKA-1894)
+   */
+  private def withRetriesWithoutInterrupt(
+      body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
+    // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894)
+    assert(Thread.currentThread().isInstanceOf[UninterruptibleThread])
+
+    synchronized {
+      var result: Option[Map[TopicPartition, Long]] = None
+      var attempt = 1
+      var lastException: Throwable = null
+      while (result.isEmpty && attempt <= maxOffsetFetchAttempts
+        && !Thread.currentThread().isInterrupted) {
+        Thread.currentThread match {
+          case ut: UninterruptibleThread =>
+            // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query
+            // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it.
+            //
+            // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may
+            // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the
+            // issue.
+            ut.runUninterruptibly {
+              try {
+                result = Some(body)
+              } catch {
+                case NonFatal(e) =>
+                  lastException = e
+                  logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e)
+                  attempt += 1
+                  Thread.sleep(offsetFetchAttemptIntervalMs)
+                  resetConsumer()
+              }
+            }
+          case _ =>
+            throw new IllegalStateException(
+              "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread")
+        }
+      }
+      if (Thread.interrupted()) {
+        throw new InterruptedException()
+      }
+      if (result.isEmpty) {
+        assert(attempt > maxOffsetFetchAttempts)
+        assert(lastException != null)
+        throw lastException
+      }
+      result.get
+    }
+  }
+
+  /**
+   * Create a consumer using the new generated group id. We always use a new consumer to avoid
+   * just using a broken consumer to retry on Kafka errors, which likely will fail again.
+   */
+  private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized {
+    val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
+    newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
+    consumerStrategy.createConsumer(newKafkaParams)
+  }
+
+  private def resetConsumer(): Unit = synchronized {
+    consumer.close()
+    consumer = createConsumer()
+  }
+}
+
+private[kafka010] object KafkaOffsetReader {
+
+  def kafkaSchema: StructType = StructType(Seq(
+    StructField("key", BinaryType),
+    StructField("value", BinaryType),
+    StructField("topic", StringType),
+    StructField("partition", IntegerType),
+    StructField("offset", LongType),
+    StructField("timestamp", TimestampType),
+    StructField("timestampType", IntegerType)
+  ))
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f180bbad6e36305bbd23a7a195aafebff4585e70
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.sources.{BaseRelation, TableScan}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+
+private[kafka010] class KafkaRelation(
+    override val sqlContext: SQLContext,
+    kafkaReader: KafkaOffsetReader,
+    executorKafkaParams: ju.Map[String, Object],
+    sourceOptions: Map[String, String],
+    failOnDataLoss: Boolean,
+    startingOffsets: KafkaOffsetRangeLimit,
+    endingOffsets: KafkaOffsetRangeLimit)
+    extends BaseRelation with TableScan with Logging {
+  assert(startingOffsets != LatestOffsetRangeLimit,
+    "Starting offset not allowed to be set to latest offsets.")
+  assert(endingOffsets != EarliestOffsetRangeLimit,
+    "Ending offset not allowed to be set to earliest offsets.")
+
+  private val pollTimeoutMs = sourceOptions.getOrElse(
+    "kafkaConsumer.pollTimeoutMs",
+    sqlContext.sparkContext.conf.getTimeAsMs("spark.network.timeout", "120s").toString
+  ).toLong
+
+  override def schema: StructType = KafkaOffsetReader.kafkaSchema
+
+  override def buildScan(): RDD[Row] = {
+    // Leverage the KafkaReader to obtain the relevant partition offsets
+    val fromPartitionOffsets = getPartitionOffsets(startingOffsets)
+    val untilPartitionOffsets = getPartitionOffsets(endingOffsets)
+    // Obtain topicPartitions in both from and until partition offset, ignoring
+    // topic partitions that were added and/or deleted between the two above calls.
+    if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) {
+      implicit val topicOrdering: Ordering[TopicPartition] = Ordering.by(t => t.topic())
+      val fromTopics = fromPartitionOffsets.keySet.toList.sorted.mkString(",")
+      val untilTopics = untilPartitionOffsets.keySet.toList.sorted.mkString(",")
+      throw new IllegalStateException("different topic partitions " +
+        s"for starting offsets topics[${fromTopics}] and " +
+        s"ending offsets topics[${untilTopics}]")
+    }
+
+    // Calculate offset ranges
+    val offsetRanges = untilPartitionOffsets.keySet.map { tp =>
+      val fromOffset = fromPartitionOffsets.get(tp).getOrElse {
+          // This should not happen since topicPartitions contains all partitions not in
+          // fromPartitionOffsets
+          throw new IllegalStateException(s"$tp doesn't have a from offset")
+      }
+      val untilOffset = untilPartitionOffsets(tp)
+      KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, None)
+    }.toArray
+
+    logInfo("GetBatch generating RDD of offset range: " +
+      offsetRanges.sortBy(_.topicPartition.toString).mkString(", "))
+
+    // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
+    val rdd = new KafkaSourceRDD(
+      sqlContext.sparkContext, executorKafkaParams, offsetRanges,
+      pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr =>
+      InternalRow(
+        cr.key,
+        cr.value,
+        UTF8String.fromString(cr.topic),
+        cr.partition,
+        cr.offset,
+        DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
+        cr.timestampType.id)
+    }
+    sqlContext.internalCreateDataFrame(rdd, schema).rdd
+  }
+
+  private def getPartitionOffsets(
+      kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = {
+    def validateTopicPartitions(partitions: Set[TopicPartition],
+      partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
+      assert(partitions == partitionOffsets.keySet,
+        "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
+          "Use -1 for latest, -2 for earliest, if you don't care.\n" +
+          s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions}")
+      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
+      partitionOffsets
+    }
+    val partitions = kafkaReader.fetchTopicPartitions()
+    // Obtain TopicPartition offsets with late binding support
+    kafkaOffsets match {
+      case EarliestOffsetRangeLimit => partitions.map {
+        case tp => tp -> KafkaOffsetRangeLimit.EARLIEST
+      }.toMap
+      case LatestOffsetRangeLimit => partitions.map {
+        case tp => tp -> KafkaOffsetRangeLimit.LATEST
+      }.toMap
+      case SpecificOffsetRangeLimit(partitionOffsets) =>
+        validateTopicPartitions(partitions, partitionOffsets)
+    }
+  }
+}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 43b8d9d6d7eef687a9f4c19a69b20d760ccc29c6..02b23111af788d966daa655b88e3c7accc2c63bf 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -21,11 +21,6 @@ import java.{util => ju}
 import java.io._
 import java.nio.charset.StandardCharsets
 
-import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
-
-import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsumer, OffsetOutOfRangeException}
-import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener
 import org.apache.kafka.common.TopicPartition
 
 import org.apache.spark.SparkContext
@@ -38,11 +33,9 @@ import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.kafka010.KafkaSource._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.UninterruptibleThread
 
 /**
- * A [[Source]] that uses Kafka's own [[KafkaConsumer]] API to reads data from Kafka. The design
- * for this source is as follows.
+ * A [[Source]] that reads data from Kafka using the following design.
  *
  * - The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains
  *   a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For
@@ -50,20 +43,14 @@ import org.apache.spark.util.UninterruptibleThread
  *   KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent
  *   with the semantics of `KafkaConsumer.position()`.
  *
- * - The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read
- *   by this source. These strategies directly correspond to the different consumption options
- *   in . This class is designed to return a configured [[KafkaConsumer]] that is used by the
- *   [[KafkaSource]] to query for the offsets. See the docs on
- *   [[org.apache.spark.sql.kafka010.KafkaSource.ConsumerStrategy]] for more details.
- *
  * - The [[KafkaSource]] written to do the following.
  *
- *  - As soon as the source is created, the pre-configured KafkaConsumer returned by the
- *    [[ConsumerStrategy]] is used to query the initial offsets that this source should
- *    start reading from. This used to create the first batch.
+ *  - As soon as the source is created, the pre-configured [[KafkaOffsetReader]]
+ *    is used to query the initial offsets that this source should
+ *    start reading from. This is used to create the first batch.
  *
- *   - `getOffset()` uses the KafkaConsumer to query the latest available offsets, which are
- *     returned as a [[KafkaSourceOffset]].
+ *   - `getOffset()` uses the [[KafkaOffsetReader]] to query the latest
+ *      available offsets, which are returned as a [[KafkaSourceOffset]].
  *
  *   - `getBatch()` returns a DF that reads from the 'start offset' until the 'end offset' in
  *     for each partition. The end offset is excluded to be consistent with the semantics of
@@ -82,15 +69,13 @@ import org.apache.spark.util.UninterruptibleThread
  * and not use wrong broker addresses.
  */
 private[kafka010] class KafkaSource(
-    sqlContext: SQLContext,
-    consumerStrategy: ConsumerStrategy,
-    driverKafkaParams: ju.Map[String, Object],
-    executorKafkaParams: ju.Map[String, Object],
-    sourceOptions: Map[String, String],
-    metadataPath: String,
-    startingOffsets: StartingOffsets,
-    failOnDataLoss: Boolean,
-    driverGroupIdPrefix: String)
+                                     sqlContext: SQLContext,
+                                     kafkaReader: KafkaOffsetReader,
+                                     executorKafkaParams: ju.Map[String, Object],
+                                     sourceOptions: Map[String, String],
+                                     metadataPath: String,
+                                     startingOffsets: KafkaOffsetRangeLimit,
+                                     failOnDataLoss: Boolean)
   extends Source with Logging {
 
   private val sc = sqlContext.sparkContext
@@ -100,41 +85,9 @@ private[kafka010] class KafkaSource(
     sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString
   ).toLong
 
-  private val maxOffsetFetchAttempts =
-    sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt
-
-  private val offsetFetchAttemptIntervalMs =
-    sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "1000").toLong
-
   private val maxOffsetsPerTrigger =
     sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong)
 
-  private var groupId: String = null
-
-  private var nextId = 0
-
-  private def nextGroupId(): String = {
-    groupId = driverGroupIdPrefix + "-" + nextId
-    nextId += 1
-    groupId
-  }
-
-  /**
-   * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the
-   * offsets and never commits them.
-   */
-  private var consumer: Consumer[Array[Byte], Array[Byte]] = createConsumer()
-
-  /**
-   * Create a consumer using the new generated group id. We always use a new consumer to avoid
-   * just using a broken consumer to retry on Kafka errors, which likely will fail again.
-   */
-  private def createConsumer(): Consumer[Array[Byte], Array[Byte]] = synchronized {
-    val newKafkaParams = new ju.HashMap[String, Object](driverKafkaParams)
-    newKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, nextGroupId())
-    consumerStrategy.createConsumer(newKafkaParams)
-  }
-
   /**
    * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
    * called in StreamExecutionThread. Otherwise, interrupting a thread while running
@@ -159,9 +112,9 @@ private[kafka010] class KafkaSource(
 
     metadataLog.get(0).getOrElse {
       val offsets = startingOffsets match {
-        case EarliestOffsets => KafkaSourceOffset(fetchEarliestOffsets())
-        case LatestOffsets => KafkaSourceOffset(fetchLatestOffsets())
-        case SpecificOffsets(p) => KafkaSourceOffset(fetchSpecificStartingOffsets(p))
+        case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets())
+        case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets())
+        case SpecificOffsetRangeLimit(p) => fetchAndVerify(p)
       }
       metadataLog.add(0, offsets)
       logInfo(s"Initial offsets: $offsets")
@@ -169,16 +122,31 @@ private[kafka010] class KafkaSource(
     }.partitionToOffsets
   }
 
+  private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = {
+    val result = kafkaReader.fetchSpecificOffsets(specificOffsets)
+    specificOffsets.foreach {
+      case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
+          off != KafkaOffsetRangeLimit.EARLIEST =>
+        if (result(tp) != off) {
+          reportDataLoss(
+            s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}")
+        }
+      case _ =>
+      // no real way to check that beginning or end is reasonable
+    }
+    KafkaSourceOffset(result)
+  }
+
   private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
 
-  override def schema: StructType = KafkaSource.kafkaSchema
+  override def schema: StructType = KafkaOffsetReader.kafkaSchema
 
   /** Returns the maximum available offset for this source. */
   override def getOffset: Option[Offset] = {
     // Make sure initialPartitionOffsets is initialized
     initialPartitionOffsets
 
-    val latest = fetchLatestOffsets()
+    val latest = kafkaReader.fetchLatestOffsets()
     val offsets = maxOffsetsPerTrigger match {
       case None =>
         latest
@@ -193,17 +161,12 @@ private[kafka010] class KafkaSource(
     Some(KafkaSourceOffset(offsets))
   }
 
-  private def resetConsumer(): Unit = synchronized {
-    consumer.close()
-    consumer = createConsumer()
-  }
-
   /** Proportionally distribute limit number of offsets among topicpartitions */
   private def rateLimit(
       limit: Long,
       from: Map[TopicPartition, Long],
       until: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
-    val fromNew = fetchNewPartitionEarliestOffsets(until.keySet.diff(from.keySet).toSeq)
+    val fromNew = kafkaReader.fetchEarliestOffsets(until.keySet.diff(from.keySet).toSeq)
     val sizes = until.flatMap {
       case (tp, end) =>
         // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it
@@ -253,7 +216,7 @@ private[kafka010] class KafkaSource(
 
     // Find the new partitions, and get their earliest offsets
     val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet)
-    val newPartitionOffsets = fetchNewPartitionEarliestOffsets(newPartitions.toSeq)
+    val newPartitionOffsets = kafkaReader.fetchEarliestOffsets(newPartitions.toSeq)
     if (newPartitionOffsets.keySet != newPartitions) {
       // We cannot get from offsets for some partitions. It means they got deleted.
       val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet)
@@ -311,7 +274,8 @@ private[kafka010] class KafkaSource(
 
     // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
     val rdd = new KafkaSourceRDD(
-      sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr =>
+      sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss,
+      reuseKafkaConsumer = true).map { cr =>
       InternalRow(
         cr.key,
         cr.value,
@@ -335,163 +299,10 @@ private[kafka010] class KafkaSource(
 
   /** Stop this source and free any resources it has allocated. */
   override def stop(): Unit = synchronized {
-    consumer.close()
+    kafkaReader.close()
   }
 
-  override def toString(): String = s"KafkaSource[$consumerStrategy]"
-
-  /**
-   * Set consumer position to specified offsets, making sure all assignments are set.
-   */
-  private def fetchSpecificStartingOffsets(
-      partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
-    val result = withRetriesWithoutInterrupt {
-      // Poll to get the latest assigned partitions
-      consumer.poll(0)
-      val partitions = consumer.assignment()
-      consumer.pause(partitions)
-      assert(partitions.asScala == partitionOffsets.keySet,
-        "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" +
-          "Use -1 for latest, -2 for earliest, if you don't care.\n" +
-          s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}")
-      logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets")
-
-      partitionOffsets.foreach {
-        case (tp, -1) => consumer.seekToEnd(ju.Arrays.asList(tp))
-        case (tp, -2) => consumer.seekToBeginning(ju.Arrays.asList(tp))
-        case (tp, off) => consumer.seek(tp, off)
-      }
-      partitionOffsets.map {
-        case (tp, _) => tp -> consumer.position(tp)
-      }
-    }
-    partitionOffsets.foreach {
-      case (tp, off) if off != -1 && off != -2 =>
-        if (result(tp) != off) {
-          reportDataLoss(
-            s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}")
-        }
-      case _ =>
-        // no real way to check that beginning or end is reasonable
-    }
-    result
-  }
-
-  /**
-   * Fetch the earliest offsets of partitions.
-   */
-  private def fetchEarliestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt {
-    // Poll to get the latest assigned partitions
-    consumer.poll(0)
-    val partitions = consumer.assignment()
-    consumer.pause(partitions)
-    logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning")
-
-    consumer.seekToBeginning(partitions)
-    val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
-    logDebug(s"Got earliest offsets for partition : $partitionOffsets")
-    partitionOffsets
-  }
-
-  /**
-   * Fetch the latest offset of partitions.
-   */
-  private def fetchLatestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt {
-    // Poll to get the latest assigned partitions
-    consumer.poll(0)
-    val partitions = consumer.assignment()
-    consumer.pause(partitions)
-    logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.")
-
-    consumer.seekToEnd(partitions)
-    val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap
-    logDebug(s"Got latest offsets for partition : $partitionOffsets")
-    partitionOffsets
-  }
-
-  /**
-   * Fetch the earliest offsets for newly discovered partitions. The return result may not contain
-   * some partitions if they are deleted.
-   */
-  private def fetchNewPartitionEarliestOffsets(
-      newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] =
-    if (newPartitions.isEmpty) {
-      Map.empty[TopicPartition, Long]
-    } else {
-      withRetriesWithoutInterrupt {
-        // Poll to get the latest assigned partitions
-        consumer.poll(0)
-        val partitions = consumer.assignment()
-        consumer.pause(partitions)
-        logDebug(s"\tPartitions assigned to consumer: $partitions")
-
-        // Get the earliest offset of each partition
-        consumer.seekToBeginning(partitions)
-        val partitionOffsets = newPartitions.filter { p =>
-          // When deleting topics happen at the same time, some partitions may not be in
-          // `partitions`. So we need to ignore them
-          partitions.contains(p)
-        }.map(p => p -> consumer.position(p)).toMap
-        logDebug(s"Got earliest offsets for new partitions: $partitionOffsets")
-        partitionOffsets
-      }
-    }
-
-  /**
-   * Helper function that does multiple retries on the a body of code that returns offsets.
-   * Retries are needed to handle transient failures. For e.g. race conditions between getting
-   * assignment and getting position while topics/partitions are deleted can cause NPEs.
-   *
-   * This method also makes sure `body` won't be interrupted to workaround a potential issue in
-   * `KafkaConsumer.poll`. (KAFKA-1894)
-   */
-  private def withRetriesWithoutInterrupt(
-      body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
-    // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894)
-    assert(Thread.currentThread().isInstanceOf[StreamExecutionThread])
-
-    synchronized {
-      var result: Option[Map[TopicPartition, Long]] = None
-      var attempt = 1
-      var lastException: Throwable = null
-      while (result.isEmpty && attempt <= maxOffsetFetchAttempts
-        && !Thread.currentThread().isInterrupted) {
-        Thread.currentThread match {
-          case ut: UninterruptibleThread =>
-            // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query
-            // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it.
-            //
-            // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may
-            // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the
-            // issue.
-            ut.runUninterruptibly {
-              try {
-                result = Some(body)
-              } catch {
-                case NonFatal(e) =>
-                  lastException = e
-                  logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e)
-                  attempt += 1
-                  Thread.sleep(offsetFetchAttemptIntervalMs)
-                  resetConsumer()
-              }
-            }
-          case _ =>
-            throw new IllegalStateException(
-              "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread")
-        }
-      }
-      if (Thread.interrupted()) {
-        throw new InterruptedException()
-      }
-      if (result.isEmpty) {
-        assert(attempt > maxOffsetFetchAttempts)
-        assert(lastException != null)
-        throw lastException
-      }
-      result.get
-    }
-  }
+  override def toString(): String = s"KafkaSource[$kafkaReader]"
 
   /**
    * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
@@ -506,10 +317,8 @@ private[kafka010] class KafkaSource(
   }
 }
 
-
 /** Companion object for the [[KafkaSource]]. */
 private[kafka010] object KafkaSource {
-
   val INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE =
     """
       |Some data may have been lost because they are not available in Kafka any more; either the
@@ -526,57 +335,7 @@ private[kafka010] object KafkaSource {
       | source option "failOnDataLoss" to "false".
     """.stripMargin
 
-  def kafkaSchema: StructType = StructType(Seq(
-    StructField("key", BinaryType),
-    StructField("value", BinaryType),
-    StructField("topic", StringType),
-    StructField("partition", IntegerType),
-    StructField("offset", LongType),
-    StructField("timestamp", TimestampType),
-    StructField("timestampType", IntegerType)
-  ))
-
-  sealed trait ConsumerStrategy {
-    def createConsumer(kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]]
-  }
-
-  case class AssignStrategy(partitions: Array[TopicPartition]) extends ConsumerStrategy {
-    override def createConsumer(
-        kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
-      val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
-      consumer.assign(ju.Arrays.asList(partitions: _*))
-      consumer
-    }
-
-    override def toString: String = s"Assign[${partitions.mkString(", ")}]"
-  }
-
-  case class SubscribeStrategy(topics: Seq[String]) extends ConsumerStrategy {
-    override def createConsumer(
-        kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
-      val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
-      consumer.subscribe(topics.asJava)
-      consumer
-    }
-
-    override def toString: String = s"Subscribe[${topics.mkString(", ")}]"
-  }
-
-  case class SubscribePatternStrategy(topicPattern: String)
-    extends ConsumerStrategy {
-    override def createConsumer(
-        kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
-      val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
-      consumer.subscribe(
-        ju.regex.Pattern.compile(topicPattern),
-        new NoOpConsumerRebalanceListener())
-      consumer
-    }
-
-    override def toString: String = s"SubscribePattern[$topicPattern]"
-  }
-
-  private def getSortedExecutorList(sc: SparkContext): Array[String] = {
+  def getSortedExecutorList(sc: SparkContext): Array[String] = {
     val bm = sc.env.blockManager
     bm.master.getPeers(bm.blockManagerId).toArray
       .map(x => ExecutorCacheTaskLocation(x.host, x.executorId))
@@ -588,5 +347,5 @@ private[kafka010] object KafkaSource {
     if (a.host == b.host) { a.executorId > b.executorId } else { a.host > b.host }
   }
 
-  private def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b
+  def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b
 }
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index aa01238f91247e173521c1c9a47b4cddc953279a..597c99e093a4280528c1da00187656c301e9393e 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -28,8 +28,7 @@ import org.apache.kafka.common.serialization.ByteArrayDeserializer
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.execution.streaming.Source
-import org.apache.spark.sql.kafka010.KafkaSource._
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
+import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -37,11 +36,12 @@ import org.apache.spark.sql.types.StructType
  * IllegalArgumentException when the Kafka Dataset is created, so that it can catch
  * missing options even before the query is started.
  */
-private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
-  with DataSourceRegister with Logging {
-
+private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider
+  with RelationProvider with Logging {
   import KafkaSourceProvider._
 
+  override def shortName(): String = "kafka"
+
   /**
    * Returns the name and schema of the source. In addition, it also verifies whether the options
    * are correct and sufficient to create the [[KafkaSource]] when the query is started.
@@ -51,9 +51,9 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
       schema: Option[StructType],
       providerName: String,
       parameters: Map[String, String]): (String, StructType) = {
+    validateStreamOptions(parameters)
     require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
-    validateOptions(parameters)
-    ("kafka", KafkaSource.kafkaSchema)
+    (shortName(), KafkaOffsetReader.kafkaSchema)
   }
 
   override def createSource(
@@ -62,7 +62,12 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
       schema: Option[StructType],
       providerName: String,
       parameters: Map[String, String]): Source = {
-      validateOptions(parameters)
+    validateStreamOptions(parameters)
+    // Each running query should use its own group id. Otherwise, the query may be only assigned
+    // partial data since Kafka will assign partitions to multiple consumers having the same group
+    // id. Hence, we should generate a unique id for each query.
+    val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+
     val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) }
     val specifiedKafkaParams =
       parameters
@@ -71,94 +76,145 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
         .map { k => k.drop(6).toString -> parameters(k) }
         .toMap
 
-    val deserClassName = classOf[ByteArrayDeserializer].getName
-    // Each running query should use its own group id. Otherwise, the query may be only assigned
-    // partial data since Kafka will assign partitions to multiple consumers having the same group
-    // id. Hence, we should generate a unique id for each query.
-    val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
-
-    val startingOffsets =
+    val startingStreamOffsets =
       caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match {
-        case Some("latest") => LatestOffsets
-        case Some("earliest") => EarliestOffsets
-        case Some(json) => SpecificOffsets(JsonUtils.partitionOffsets(json))
-        case None => LatestOffsets
+        case Some("latest") => LatestOffsetRangeLimit
+        case Some("earliest") => EarliestOffsetRangeLimit
+        case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))
+        case None => LatestOffsetRangeLimit
       }
 
-    val kafkaParamsForDriver =
-      ConfigUpdater("source", specifiedKafkaParams)
-        .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
-        .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
-
-        // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
-        // offsets by itself instead of counting on KafkaConsumer.
-        .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
-
-        // So that consumers in the driver does not commit offsets unnecessarily
-        .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
-
-        // So that the driver does not pull too much data
-        .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
-
-        // If buffer config is not set, set it to reasonable value to work around
-        // buffer issues (see KAFKA-3135)
-        .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
-        .build()
-
-    val kafkaParamsForExecutors =
-      ConfigUpdater("executor", specifiedKafkaParams)
-        .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
-        .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
-
-        // Make sure executors do only what the driver tells them.
-        .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
+    val kafkaOffsetReader = new KafkaOffsetReader(
+      strategy(caseInsensitiveParams),
+      kafkaParamsForDriver(specifiedKafkaParams),
+      parameters,
+      driverGroupIdPrefix = s"$uniqueGroupId-driver")
 
-        // So that consumers in executors do not mess with any existing group id
-        .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
+    new KafkaSource(
+      sqlContext,
+      kafkaOffsetReader,
+      kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
+      parameters,
+      metadataPath,
+      startingStreamOffsets,
+      failOnDataLoss(caseInsensitiveParams))
+  }
 
-        // So that consumers in executors does not commit offsets unnecessarily
-        .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+  /**
+   * Returns a new base relation with the given parameters.
+   *
+   * @note The parameters' keywords are case insensitive and this insensitivity is enforced
+   *       by the Map that is passed to the function.
+   */
+  override def createRelation(
+      sqlContext: SQLContext,
+      parameters: Map[String, String]): BaseRelation = {
+    validateBatchOptions(parameters)
+    // Each running query should use its own group id. Otherwise, the query may be only assigned
+    // partial data since Kafka will assign partitions to multiple consumers having the same group
+    // id. Hence, we should generate a unique id for each query.
+    val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
+    val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) }
+    val specifiedKafkaParams =
+      parameters
+        .keySet
+        .filter(_.toLowerCase.startsWith("kafka."))
+        .map { k => k.drop(6).toString -> parameters(k) }
+        .toMap
 
-        // If buffer config is not set, set it to reasonable value to work around
-        // buffer issues (see KAFKA-3135)
-        .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
-        .build()
+    val startingRelationOffsets =
+      caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match {
+        case Some("earliest") => EarliestOffsetRangeLimit
+        case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))
+        case None => EarliestOffsetRangeLimit
+      }
 
-    val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
-      case ("assign", value) =>
-        AssignStrategy(JsonUtils.partitions(value))
-      case ("subscribe", value) =>
-        SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
-      case ("subscribepattern", value) =>
-        SubscribePatternStrategy(value.trim())
-      case _ =>
-        // Should never reach here as we are already matching on
-        // matched strategy names
-        throw new IllegalArgumentException("Unknown option")
-    }
+    val endingRelationOffsets =
+      caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match {
+        case Some("latest") => LatestOffsetRangeLimit
+        case Some(json) => SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json))
+        case None => LatestOffsetRangeLimit
+      }
 
-    val failOnDataLoss =
-      caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean
+    val kafkaOffsetReader = new KafkaOffsetReader(
+      strategy(caseInsensitiveParams),
+      kafkaParamsForDriver(specifiedKafkaParams),
+      parameters,
+      driverGroupIdPrefix = s"$uniqueGroupId-driver")
 
-    new KafkaSource(
+    new KafkaRelation(
       sqlContext,
-      strategy,
-      kafkaParamsForDriver,
-      kafkaParamsForExecutors,
+      kafkaOffsetReader,
+      kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
       parameters,
-      metadataPath,
-      startingOffsets,
-      failOnDataLoss,
-      driverGroupIdPrefix = s"$uniqueGroupId-driver")
+      failOnDataLoss(caseInsensitiveParams),
+      startingRelationOffsets,
+      endingRelationOffsets)
   }
 
-  private def validateOptions(parameters: Map[String, String]): Unit = {
+  private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) =
+    ConfigUpdater("source", specifiedKafkaParams)
+      .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
+      .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
+
+      // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
+      // offsets by itself instead of counting on KafkaConsumer.
+      .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
+
+      // So that consumers in the driver does not commit offsets unnecessarily
+      .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+
+      // So that the driver does not pull too much data
+      .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
+
+      // If buffer config is not set, set it to reasonable value to work around
+      // buffer issues (see KAFKA-3135)
+      .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+      .build()
+
+  private def kafkaParamsForExecutors(
+      specifiedKafkaParams: Map[String, String], uniqueGroupId: String) =
+    ConfigUpdater("executor", specifiedKafkaParams)
+      .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
+      .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
+
+      // Make sure executors do only what the driver tells them.
+      .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
+
+      // So that consumers in executors do not mess with any existing group id
+      .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
+
+      // So that consumers in executors does not commit offsets unnecessarily
+      .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+
+      // If buffer config is not set, set it to reasonable value to work around
+      // buffer issues (see KAFKA-3135)
+      .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+      .build()
+
+  private def strategy(caseInsensitiveParams: Map[String, String]) =
+      caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
+    case ("assign", value) =>
+      AssignStrategy(JsonUtils.partitions(value))
+    case ("subscribe", value) =>
+      SubscribeStrategy(value.split(",").map(_.trim()).filter(_.nonEmpty))
+    case ("subscribepattern", value) =>
+      SubscribePatternStrategy(value.trim())
+    case _ =>
+      // Should never reach here as we are already matching on
+      // matched strategy names
+      throw new IllegalArgumentException("Unknown option")
+  }
 
-    // Validate source options
+  private def failOnDataLoss(caseInsensitiveParams: Map[String, String]) =
+    caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean
 
+  private def validateGeneralOptions(parameters: Map[String, String]): Unit = {
+    // Validate source options
     val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) }
     val specifiedStrategies =
       caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq
+
     if (specifiedStrategies.isEmpty) {
       throw new IllegalArgumentException(
         "One of the following options must be specified for Kafka source: "
@@ -251,7 +307,52 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
     }
   }
 
-  override def shortName(): String = "kafka"
+  private def validateStreamOptions(caseInsensitiveParams: Map[String, String]) = {
+    // Stream specific options
+    caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_ =>
+      throw new IllegalArgumentException("ending offset not valid in streaming queries"))
+    validateGeneralOptions(caseInsensitiveParams)
+  }
+
+  private def validateBatchOptions(caseInsensitiveParams: Map[String, String]) = {
+    // Batch specific options
+    caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match {
+      case Some("earliest") => // good to go
+      case Some("latest") =>
+        throw new IllegalArgumentException("starting offset can't be latest " +
+          "for batch queries on Kafka")
+      case Some(json) => (SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)))
+        .partitionOffsets.foreach {
+          case (tp, off) if off == KafkaOffsetRangeLimit.LATEST =>
+            throw new IllegalArgumentException(s"startingOffsets for $tp can't " +
+              "be latest for batch queries on Kafka")
+          case _ => // ignore
+        }
+      case _ => // default to earliest
+    }
+
+    caseInsensitiveParams.get(ENDING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match {
+      case Some("earliest") =>
+        throw new IllegalArgumentException("ending offset can't be earliest " +
+          "for batch queries on Kafka")
+      case Some("latest") => // good to go
+      case Some(json) => (SpecificOffsetRangeLimit(JsonUtils.partitionOffsets(json)))
+        .partitionOffsets.foreach {
+          case (tp, off) if off == KafkaOffsetRangeLimit.EARLIEST =>
+            throw new IllegalArgumentException(s"ending offset for $tp can't be " +
+              "earliest for batch queries on Kafka")
+          case _ => // ignore
+        }
+      case _ => // default to latest
+    }
+
+    validateGeneralOptions(caseInsensitiveParams)
+
+    // Don't want to throw an error, but at least log a warning.
+    if (caseInsensitiveParams.get("maxoffsetspertrigger").isDefined) {
+      logWarning("maxOffsetsPerTrigger option ignored in batch queries")
+    }
+  }
 
   /** Class to conveniently update Kafka config params, while logging the changes */
   private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
@@ -278,5 +379,8 @@ private[kafka010] class KafkaSourceProvider extends StreamSourceProvider
 private[kafka010] object KafkaSourceProvider {
   private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign")
   private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
+  private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
   private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss"
+
+  private val deserClassName = classOf[ByteArrayDeserializer].getName
 }
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
index 244cd2c225bdd7213bca27585e66ac6002cef1f4..6fb3473eb75f5092f081f645e8cb2be3f7333b9f 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
@@ -21,7 +21,7 @@ import java.{util => ju}
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.kafka.clients.consumer.ConsumerRecord
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord}
 import org.apache.kafka.common.TopicPartition
 
 import org.apache.spark.{Partition, SparkContext, TaskContext}
@@ -63,7 +63,8 @@ private[kafka010] class KafkaSourceRDD(
     executorKafkaParams: ju.Map[String, Object],
     offsetRanges: Seq[KafkaSourceRDDOffsetRange],
     pollTimeoutMs: Long,
-    failOnDataLoss: Boolean)
+    failOnDataLoss: Boolean,
+    reuseKafkaConsumer: Boolean)
   extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) {
 
   override def persist(newLevel: StorageLevel): this.type = {
@@ -122,7 +123,19 @@ private[kafka010] class KafkaSourceRDD(
   override def compute(
       thePart: Partition,
       context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = {
-    val range = thePart.asInstanceOf[KafkaSourceRDDPartition].offsetRange
+    val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition]
+    val topic = sourcePartition.offsetRange.topic
+    if (!reuseKafkaConsumer) {
+      // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique
+      // to each task (i.e., append the task's unique partition id), because we will have
+      // multiple tasks (e.g., in the case of union) reading from the same topic partitions
+      val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+      val id = TaskContext.getPartitionId()
+      executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id)
+    }
+    val kafkaPartition = sourcePartition.offsetRange.partition
+    val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
+    val range = resolveRange(consumer, sourcePartition.offsetRange)
     assert(
       range.fromOffset <= range.untilOffset,
       s"Beginning offset ${range.fromOffset} is after the ending offset ${range.untilOffset} " +
@@ -133,9 +146,7 @@ private[kafka010] class KafkaSourceRDD(
         s"skipping ${range.topic} ${range.partition}")
       Iterator.empty
     } else {
-      new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() {
-        val consumer = CachedKafkaConsumer.getOrCreate(
-          range.topic, range.partition, executorKafkaParams)
+      val underlying = new NextIterator[ConsumerRecord[Array[Byte], Array[Byte]]]() {
         var requestOffset = range.fromOffset
 
         override def getNext(): ConsumerRecord[Array[Byte], Array[Byte]] = {
@@ -156,8 +167,46 @@ private[kafka010] class KafkaSourceRDD(
           }
         }
 
-        override protected def close(): Unit = {}
+        override protected def close(): Unit = {
+          if (!reuseKafkaConsumer) {
+            // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
+            CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
+          } else {
+            // Indicate that we're no longer using this consumer
+            CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
+          }
+        }
       }
+      // Release consumer, either by removing it or indicating we're no longer using it
+      context.addTaskCompletionListener { _ =>
+        underlying.closeIfNeeded()
+      }
+      underlying
+    }
+  }
+
+  private def resolveRange(consumer: CachedKafkaConsumer, range: KafkaSourceRDDOffsetRange) = {
+    if (range.fromOffset < 0 || range.untilOffset < 0) {
+      // Late bind the offset range
+      val availableOffsetRange = consumer.getAvailableOffsetRange()
+      val fromOffset = if (range.fromOffset < 0) {
+        assert(range.fromOffset == KafkaOffsetRangeLimit.EARLIEST,
+          s"earliest offset ${range.fromOffset} does not equal ${KafkaOffsetRangeLimit.EARLIEST}")
+        availableOffsetRange.earliest
+      } else {
+        range.fromOffset
+      }
+      val untilOffset = if (range.untilOffset < 0) {
+        assert(range.untilOffset == KafkaOffsetRangeLimit.LATEST,
+          s"latest offset ${range.untilOffset} does not equal ${KafkaOffsetRangeLimit.LATEST}")
+        availableOffsetRange.latest
+      } else {
+        range.untilOffset
+      }
+      KafkaSourceRDDOffsetRange(range.topicPartition,
+        fromOffset, untilOffset, range.preferredLoc)
+    } else {
+      range
     }
   }
 }
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala
deleted file mode 100644
index 83959e597171a6cd764fd7523141e74363d027a9..0000000000000000000000000000000000000000
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.kafka010
-
-import org.apache.kafka.common.TopicPartition
-
-/*
- * Values that can be specified for config startingOffsets
- */
-private[kafka010] sealed trait StartingOffsets
-
-private[kafka010] case object EarliestOffsets extends StartingOffsets
-
-private[kafka010] case object LatestOffsets extends StartingOffsets
-
-private[kafka010] case class SpecificOffsets(
-  partitionOffsets: Map[TopicPartition, Long]) extends StartingOffsets
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..673d60ff6f87a2838242440ac823a99df5aca997
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.kafka.common.TopicPartition
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.test.SharedSQLContext
+
+class KafkaRelationSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
+
+  import testImplicits._
+
+  private val topicId = new AtomicInteger(0)
+
+  private var testUtils: KafkaTestUtils = _
+
+  private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
+
+  private def assignString(topic: String, partitions: Iterable[Int]): String = {
+    JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    testUtils = new KafkaTestUtils
+    testUtils.setup()
+  }
+
+  override def afterAll(): Unit = {
+    if (testUtils != null) {
+      testUtils.teardown()
+      testUtils = null
+      super.afterAll()
+    }
+  }
+
+  private def createDF(
+      topic: String,
+      withOptions: Map[String, String] = Map.empty[String, String],
+      brokerAddress: Option[String] = None) = {
+    val df = spark
+      .read
+      .format("kafka")
+      .option("kafka.bootstrap.servers",
+        brokerAddress.getOrElse(testUtils.brokerAddress))
+      .option("subscribe", topic)
+    withOptions.foreach {
+      case (key, value) => df.option(key, value)
+    }
+    df.load().selectExpr("CAST(value AS STRING)")
+  }
+
+
+  test("explicit earliest to latest offsets") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 3)
+    testUtils.sendMessages(topic, (0 to 9).map(_.toString).toArray, Some(0))
+    testUtils.sendMessages(topic, (10 to 19).map(_.toString).toArray, Some(1))
+    testUtils.sendMessages(topic, Array("20"), Some(2))
+
+    // Specify explicit earliest and latest offset values
+    val df = createDF(topic,
+      withOptions = Map("startingOffsets" -> "earliest", "endingOffsets" -> "latest"))
+    checkAnswer(df, (0 to 20).map(_.toString).toDF)
+
+    // "latest" should late bind to the current (latest) offset in the df
+    testUtils.sendMessages(topic, (21 to 29).map(_.toString).toArray, Some(2))
+    checkAnswer(df, (0 to 29).map(_.toString).toDF)
+  }
+
+  test("default starting and ending offsets") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 3)
+    testUtils.sendMessages(topic, (0 to 9).map(_.toString).toArray, Some(0))
+    testUtils.sendMessages(topic, (10 to 19).map(_.toString).toArray, Some(1))
+    testUtils.sendMessages(topic, Array("20"), Some(2))
+
+    // Implicit offset values, should default to earliest and latest
+    val df = createDF(topic)
+    // Test that we default to "earliest" and "latest"
+    checkAnswer(df, (0 to 20).map(_.toString).toDF)
+  }
+
+  test("explicit offsets") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 3)
+    testUtils.sendMessages(topic, (0 to 9).map(_.toString).toArray, Some(0))
+    testUtils.sendMessages(topic, (10 to 19).map(_.toString).toArray, Some(1))
+    testUtils.sendMessages(topic, Array("20"), Some(2))
+
+    // Test explicitly specified offsets
+    val startPartitionOffsets = Map(
+      new TopicPartition(topic, 0) -> -2L, // -2 => earliest
+      new TopicPartition(topic, 1) -> -2L,
+      new TopicPartition(topic, 2) -> 0L   // explicit earliest
+    )
+    val startingOffsets = JsonUtils.partitionOffsets(startPartitionOffsets)
+
+    val endPartitionOffsets = Map(
+      new TopicPartition(topic, 0) -> -1L, // -1 => latest
+      new TopicPartition(topic, 1) -> -1L,
+      new TopicPartition(topic, 2) -> 1L  // explicit offset happens to = the latest
+    )
+    val endingOffsets = JsonUtils.partitionOffsets(endPartitionOffsets)
+    val df = createDF(topic,
+        withOptions = Map("startingOffsets" -> startingOffsets, "endingOffsets" -> endingOffsets))
+    checkAnswer(df, (0 to 20).map(_.toString).toDF)
+
+    // static offset partition 2, nothing should change
+    testUtils.sendMessages(topic, (31 to 39).map(_.toString).toArray, Some(2))
+    checkAnswer(df, (0 to 20).map(_.toString).toDF)
+
+    // latest offset partition 1, should change
+    testUtils.sendMessages(topic, (21 to 30).map(_.toString).toArray, Some(1))
+    checkAnswer(df, (0 to 30).map(_.toString).toDF)
+  }
+
+  test("reuse same dataframe in query") {
+    // This test ensures that we do not cache the Kafka Consumer in KafkaRelation
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 1)
+    testUtils.sendMessages(topic, (0 to 10).map(_.toString).toArray, Some(0))
+
+    // Specify explicit earliest and latest offset values
+    val df = createDF(topic,
+      withOptions = Map("startingOffsets" -> "earliest", "endingOffsets" -> "latest"))
+    checkAnswer(df.union(df), ((0 to 10) ++ (0 to 10)).map(_.toString).toDF)
+  }
+
+  test("test late binding start offsets") {
+    var kafkaUtils: KafkaTestUtils = null
+    try {
+      /**
+       * The following settings will ensure that all log entries
+       * are removed following a call to cleanupLogs
+       */
+      val brokerProps = Map[String, Object](
+        "log.retention.bytes" -> 1.asInstanceOf[AnyRef], // retain nothing
+        "log.retention.ms" -> 1.asInstanceOf[AnyRef]     // no wait time
+      )
+      kafkaUtils = new KafkaTestUtils(withBrokerProps = brokerProps)
+      kafkaUtils.setup()
+
+      val topic = newTopic()
+      kafkaUtils.createTopic(topic, partitions = 1)
+      kafkaUtils.sendMessages(topic, (0 to 9).map(_.toString).toArray, Some(0))
+      // Specify explicit earliest and latest offset values
+      val df = createDF(topic,
+        withOptions = Map("startingOffsets" -> "earliest", "endingOffsets" -> "latest"),
+        Some(kafkaUtils.brokerAddress))
+      checkAnswer(df, (0 to 9).map(_.toString).toDF)
+      // Blow away current set of messages.
+      kafkaUtils.cleanupLogs()
+      // Add some more data, but do not call cleanup
+      kafkaUtils.sendMessages(topic, (10 to 19).map(_.toString).toArray, Some(0))
+      // Ensure that we late bind to the new starting position
+      checkAnswer(df, (10 to 19).map(_.toString).toDF)
+    } finally {
+      if (kafkaUtils != null) {
+        kafkaUtils.teardown()
+      }
+    }
+  }
+
+  test("bad batch query options") {
+    def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = {
+      val ex = intercept[IllegalArgumentException] {
+        val reader = spark
+          .read
+          .format("kafka")
+        options.foreach { case (k, v) => reader.option(k, v) }
+        reader.load()
+      }
+      expectedMsgs.foreach { m =>
+        assert(ex.getMessage.toLowerCase.contains(m.toLowerCase))
+      }
+    }
+
+    // Specifying an ending offset as the starting point
+    testBadOptions("startingOffsets" -> "latest")("starting offset can't be latest " +
+      "for batch queries on Kafka")
+
+    // Now do it with an explicit json start offset indicating latest
+    val startPartitionOffsets = Map( new TopicPartition("t", 0) -> -1L)
+    val startingOffsets = JsonUtils.partitionOffsets(startPartitionOffsets)
+    testBadOptions("subscribe" -> "t", "startingOffsets" -> startingOffsets)(
+      "startingOffsets for t-0 can't be latest for batch queries on Kafka")
+
+
+    // Make sure we catch ending offsets that indicate earliest
+    testBadOptions("endingOffsets" -> "earliest")("ending offset can't be earliest " +
+      "for batch queries on Kafka")
+
+    // Make sure we catch ending offsets that indicating earliest
+    val endPartitionOffsets = Map(new TopicPartition("t", 0) -> -2L)
+    val endingOffsets = JsonUtils.partitionOffsets(endPartitionOffsets)
+    testBadOptions("subscribe" -> "t", "endingOffsets" -> endingOffsets)(
+      "ending offset for t-0 can't be earliest for batch queries on Kafka")
+
+    // No strategy specified
+    testBadOptions()("options must be specified", "subscribe", "subscribePattern")
+
+    // Multiple strategies specified
+    testBadOptions("subscribe" -> "t", "subscribePattern" -> "t.*")(
+      "only one", "options can be specified")
+
+    testBadOptions("subscribe" -> "t", "assign" -> """{"a":[0]}""")(
+      "only one", "options can be specified")
+
+    testBadOptions("assign" -> "")("no topicpartitions to assign")
+    testBadOptions("subscribe" -> "")("no topics to subscribe")
+    testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty")
+  }
+}
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index 544fbc5ec36a26339ab4a9dea432dc0e20df04fa..211c8a5e73e4560b63fd5aa6bd44f0f4e0861a00 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -384,6 +384,9 @@ class KafkaSourceSuite extends KafkaSourceTest {
       }
     }
 
+    // Specifying an ending offset
+    testBadOptions("endingOffsets" -> "latest")("Ending offset not valid in streaming queries")
+
     // No strategy specified
     testBadOptions()("options must be specified", "subscribe", "subscribePattern")
 
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
index fd1689acf6727de8bb5a5e730e6e593accf60a6e..c2cbd86260bc5c61bd236235f8e0ad42bc95f265 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala
@@ -50,7 +50,7 @@ import org.apache.spark.SparkConf
  *
  * The reason to put Kafka test utility class in src is to test Python related Kafka APIs.
  */
-class KafkaTestUtils extends Logging {
+class KafkaTestUtils(withBrokerProps: Map[String, Object] = Map.empty) extends Logging {
 
   // Zookeeper related configurations
   private val zkHost = "localhost"
@@ -238,6 +238,24 @@ class KafkaTestUtils extends Logging {
     offsets
   }
 
+  def cleanupLogs(): Unit = {
+    server.logManager.cleanupLogs()
+  }
+
+  def getEarliestOffsets(topics: Set[String]): Map[TopicPartition, Long] = {
+    val kc = new KafkaConsumer[String, String](consumerConfiguration)
+    logInfo("Created consumer to get earliest offsets")
+    kc.subscribe(topics.asJavaCollection)
+    kc.poll(0)
+    val partitions = kc.assignment()
+    kc.pause(partitions)
+    kc.seekToBeginning(partitions)
+    val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap
+    kc.close()
+    logInfo("Closed consumer to get earliest offsets")
+    offsets
+  }
+
   def getLatestOffsets(topics: Set[String]): Map[TopicPartition, Long] = {
     val kc = new KafkaConsumer[String, String](consumerConfiguration)
     logInfo("Created consumer to get latest offsets")
@@ -263,6 +281,7 @@ class KafkaTestUtils extends Logging {
     props.put("log.flush.interval.messages", "1")
     props.put("replica.socket.timeout.ms", "1500")
     props.put("delete.topic.enable", "true")
+    props.putAll(withBrokerProps.asJava)
     props
   }