From 1a3f5f8c55d873aaf8145a8bc4867fc9902cf93d Mon Sep 17 00:00:00 2001
From: Roberto Agostino Vitillo <>
Date: Fri, 17 Feb 2017 11:43:57 -0800
Subject: [PATCH] [SPARK-19517][SS] KafkaSource fails to initialize partition

## What changes were proposed in this pull request?

This patch fixes a bug in `KafkaSource` with the (de)serialization of the length of the JSON string that contains the initial partition offsets.

## How was this patch tested?

I ran the test suite for spark-sql-kafka-0-10.

Author: Roberto Agostino Vitillo <>

Closes #16857 from vitillo/kafka_source_fix.
 dev/.rat-excludes                             |   1 +
 .../spark/sql/kafka010/KafkaSource.scala      |  32 ++++--
 ...ka-source-initial-offset-version-2.1.0.bin |   1 +
 .../spark/sql/kafka010/KafkaSourceSuite.scala | 104 ++++++++++++++++++
 4 files changed, 131 insertions(+), 7 deletions(-)
 create mode 100644 external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin

diff --git a/dev/.rat-excludes b/dev/.rat-excludes
index 6d24434ccc..2355d40d1e 100644
--- a/dev/.rat-excludes
+++ b/dev/.rat-excludes
@@ -105,3 +105,4 @@ org.apache.spark.scheduler.ExternalClusterManager
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 9c5dceca2d..92b5d91ba4 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -21,6 +21,7 @@ import java.{util => ju}
 import java.nio.charset.StandardCharsets
 import org.apache.kafka.common.TopicPartition
 import org.apache.spark.SparkContext
@@ -97,16 +98,31 @@ private[kafka010] class KafkaSource(
     val metadataLog =
       new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) {
         override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = {
-          val bytes = metadata.json.getBytes(StandardCharsets.UTF_8)
-          out.write(bytes.length)
-          out.write(bytes)
+          out.write(0) // A zero byte is written to support Spark 2.1.0 (SPARK-19517)
+          val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8))
+          writer.write(VERSION)
+          writer.write(metadata.json)
+          writer.flush
         override def deserialize(in: InputStream): KafkaSourceOffset = {
-          val length =
-          val bytes = new Array[Byte](length)
-          KafkaSourceOffset(SerializedOffset(new String(bytes, StandardCharsets.UTF_8)))
+ // A zero byte is read to support Spark 2.1.0 (SPARK-19517)
+          val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8))
+          // HDFSMetadataLog guarantees that it never creates a partial file.
+          assert(content.length != 0)
+          if (content(0) == 'v') {
+            if (content.startsWith(VERSION)) {
+              KafkaSourceOffset(SerializedOffset(content.substring(VERSION.length)))
+            } else {
+              val versionInFile = content.substring(0, content.indexOf("\n"))
+              throw new IllegalStateException(
+                s"Unsupported format. Expected version is ${VERSION.stripLineEnd} " +
+                  s"but was $versionInFile. Please upgrade your Spark.")
+            }
+          } else {
+            // The log was generated by Spark 2.1.0
+            KafkaSourceOffset(SerializedOffset(content))
+          }
@@ -335,6 +351,8 @@ private[kafka010] object KafkaSource {
       | source option "failOnDataLoss" to "false".
+  private val VERSION = "v1\n"
   def getSortedExecutorList(sc: SparkContext): Array[String] = {
     val bm = sc.env.blockManager
diff --git a/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin
new file mode 100644
index 0000000000..ae928e7249
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/resources/kafka-source-initial-offset-version-2.1.0.bin
@@ -0,0 +1 @@
\ No newline at end of file
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index 211c8a5e73..4f82b133cb 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -17,7 +17,9 @@
 package org.apache.spark.sql.kafka010
 import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.{Files, Paths}
 import java.util.Properties
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
@@ -141,6 +143,108 @@ class KafkaSourceSuite extends KafkaSourceTest {
   private val topicId = new AtomicInteger(0)
+  testWithUninterruptibleThread(
+    "deserialization of initial offset with Spark 2.1.0") {
+    withTempDir { metadataPath =>
+      val topic = newTopic
+      testUtils.createTopic(topic, partitions = 3)
+      val provider = new KafkaSourceProvider
+      val parameters = Map(
+        "kafka.bootstrap.servers" -> testUtils.brokerAddress,
+        "subscribe" -> topic
+      )
+      val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None,
+        "", parameters)
+      source.getOffset.get // Write initial offset
+      // Make sure Spark 2.1.0 will throw an exception when reading the new log
+      intercept[java.lang.IllegalArgumentException] {
+        // Simulate how Spark 2.1.0 reads the log
+        val in = new FileInputStream(metadataPath.getAbsolutePath + "/0")
+        val length =
+        val bytes = new Array[Byte](length)
+        KafkaSourceOffset(SerializedOffset(new String(bytes, UTF_8)))
+      }
+    }
+  }
+  testWithUninterruptibleThread("deserialization of initial offset written by Spark 2.1.0") {
+    withTempDir { metadataPath =>
+      val topic = "kafka-initial-offset-2-1-0"
+      testUtils.createTopic(topic, partitions = 3)
+      val provider = new KafkaSourceProvider
+      val parameters = Map(
+        "kafka.bootstrap.servers" -> testUtils.brokerAddress,
+        "subscribe" -> topic
+      )
+      val from = Paths.get(
+        getClass.getResource("/kafka-source-initial-offset-version-2.1.0.bin").getPath)
+      val to = Paths.get(s"${metadataPath.getAbsolutePath}/0")
+      Files.copy(from, to)
+      val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None,
+        "", parameters)
+      val deserializedOffset = source.getOffset.get
+      val referenceOffset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L))
+      assert(referenceOffset == deserializedOffset)
+    }
+  }
+  testWithUninterruptibleThread("deserialization of initial offset written by future version") {
+    withTempDir { metadataPath =>
+      val futureMetadataLog =
+        new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession,
+          metadataPath.getAbsolutePath) {
+          override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = {
+            out.write(0)
+            val writer = new BufferedWriter(new OutputStreamWriter(out, UTF_8))
+            writer.write(s"v0\n${metadata.json}")
+            writer.flush
+          }
+        }
+      val topic = newTopic
+      testUtils.createTopic(topic, partitions = 3)
+      val offset = KafkaSourceOffset((topic, 0, 0L), (topic, 1, 0L), (topic, 2, 0L))
+      futureMetadataLog.add(0, offset)
+      val provider = new KafkaSourceProvider
+      val parameters = Map(
+        "kafka.bootstrap.servers" -> testUtils.brokerAddress,
+        "subscribe" -> topic
+      )
+      val source = provider.createSource(spark.sqlContext, metadataPath.getAbsolutePath, None,
+        "", parameters)
+      val e = intercept[java.lang.IllegalStateException] {
+        source.getOffset.get // Read initial offset
+      }
+      assert(e.getMessage.contains("Please upgrade your Spark"))
+    }
+  }
+  test("(de)serialization of initial offsets") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 64)
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+    testStream(reader.load)(
+      makeSureGetOffsetCalled,
+      StopStream,
+      StartStream(),
+      StopStream)
+  }
   test("maxOffsetsPerTrigger") {
     val topic = newTopic()
     testUtils.createTopic(topic, partitions = 3)