From d249636e59fabd8ca57a47dc2cbad9c4a4e7a750 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Thu, 23 Jul 2015 20:06:54 -0700
Subject: [PATCH] [SPARK-9216] [STREAMING] Define KinesisBackedBlockRDDs

For more information see master JIRA: https://issues.apache.org/jira/browse/SPARK-9215
Design Doc: https://docs.google.com/document/d/1k0dl270EnK7uExrsCE7jYw7PYx0YC935uBcxn3p0f58/edit

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #7578 from tdas/kinesis-rdd and squashes the following commits:

543d208 [Tathagata Das] Fixed scala style
5082a30 [Tathagata Das] Fixed scala style
3f40c2d [Tathagata Das] Addressed comments
c4f25d2 [Tathagata Das] Addressed comment
d3d64d1 [Tathagata Das] Minor update
f6e35c8 [Tathagata Das] Added retry logic to make it more robust
8874b70 [Tathagata Das] Updated Kinesis RDD
575bdbc [Tathagata Das] Fix scala style issues
4a36096 [Tathagata Das] Add license
5da3995 [Tathagata Das] Changed KinesisSuiteHelper to KinesisFunSuite
528e206 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into kinesis-rdd
3ae0814 [Tathagata Das] Added KinesisBackedBlockRDD
---
 .../kinesis/KinesisBackedBlockRDD.scala       | 285 ++++++++++++++++++
 .../streaming/kinesis/KinesisTestUtils.scala  |   2 +-
 .../kinesis/KinesisBackedBlockRDDSuite.scala  | 246 +++++++++++++++
 .../streaming/kinesis/KinesisFunSuite.scala   |  13 +-
 .../kinesis/KinesisStreamSuite.scala          |   4 +-
 5 files changed, 545 insertions(+), 5 deletions(-)
 create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
 create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala

diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
new file mode 100644
index 0000000000..8f144a4d97
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import scala.collection.JavaConversions._
+import scala.util.control.NonFatal
+
+import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
+import com.amazonaws.services.kinesis.AmazonKinesisClient
+import com.amazonaws.services.kinesis.model._
+
+import org.apache.spark._
+import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition}
+import org.apache.spark.storage.BlockId
+import org.apache.spark.util.NextIterator
+
+
+/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */
+private[kinesis]
+case class SequenceNumberRange(
+    streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String)
+
+/** Class representing an array of Kinesis sequence number ranges */
+private[kinesis]
+case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) {
+  def isEmpty(): Boolean = ranges.isEmpty
+  def nonEmpty(): Boolean = ranges.nonEmpty
+  override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")")
+}
+
+private[kinesis]
+object SequenceNumberRanges {
+  def apply(range: SequenceNumberRange): SequenceNumberRanges = {
+    new SequenceNumberRanges(Array(range))
+  }
+}
+
+
+/** Partition storing the information of the ranges of Kinesis sequence numbers to read */
+private[kinesis]
+class KinesisBackedBlockRDDPartition(
+    idx: Int,
+    blockId: BlockId,
+    val isBlockIdValid: Boolean,
+    val seqNumberRanges: SequenceNumberRanges
+  ) extends BlockRDDPartition(blockId, idx)
+
+/**
+ * A BlockRDD where the block data is backed by Kinesis, which can accessed using the
+ * sequence numbers of the corresponding blocks.
+ */
+private[kinesis]
+class KinesisBackedBlockRDD(
+    sc: SparkContext,
+    regionId: String,
+    endpointUrl: String,
+    @transient blockIds: Array[BlockId],
+    @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges],
+    @transient isBlockIdValid: Array[Boolean] = Array.empty,
+    retryTimeoutMs: Int = 10000,
+    awsCredentialsOption: Option[SerializableAWSCredentials] = None
+  ) extends BlockRDD[Array[Byte]](sc, blockIds) {
+
+  require(blockIds.length == arrayOfseqNumberRanges.length,
+    "Number of blockIds is not equal to the number of sequence number ranges")
+
+  override def isValid(): Boolean = true
+
+  override def getPartitions: Array[Partition] = {
+    Array.tabulate(blockIds.length) { i =>
+      val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i)
+      new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i))
+    }
+  }
+
+  override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+    val blockManager = SparkEnv.get.blockManager
+    val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition]
+    val blockId = partition.blockId
+
+    def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = {
+      logDebug(s"Read partition data of $this from block manager, block $blockId")
+      blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]])
+    }
+
+    def getBlockFromKinesis(): Iterator[Array[Byte]] = {
+      val credenentials = awsCredentialsOption.getOrElse {
+        new DefaultAWSCredentialsProviderChain().getCredentials()
+      }
+      partition.seqNumberRanges.ranges.iterator.flatMap { range =>
+        new KinesisSequenceRangeIterator(
+          credenentials, endpointUrl, regionId, range, retryTimeoutMs)
+      }
+    }
+    if (partition.isBlockIdValid) {
+      getBlockFromBlockManager().getOrElse { getBlockFromKinesis() }
+    } else {
+      getBlockFromKinesis()
+    }
+  }
+}
+
+
+/**
+ * An iterator that return the Kinesis data based on the given range of sequence numbers.
+ * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber,
+ * until the endSequenceNumber is reached.
+ */
+private[kinesis]
+class KinesisSequenceRangeIterator(
+    credentials: AWSCredentials,
+    endpointUrl: String,
+    regionId: String,
+    range: SequenceNumberRange,
+    retryTimeoutMs: Int
+  ) extends NextIterator[Array[Byte]] with Logging {
+
+  private val client = new AmazonKinesisClient(credentials)
+  private val streamName = range.streamName
+  private val shardId = range.shardId
+
+  private var toSeqNumberReceived = false
+  private var lastSeqNumber: String = null
+  private var internalIterator: Iterator[Record] = null
+
+  client.setEndpoint(endpointUrl, "kinesis", regionId)
+
+  override protected def getNext(): Array[Byte] = {
+    var nextBytes: Array[Byte] = null
+    if (toSeqNumberReceived) {
+      finished = true
+    } else {
+
+      if (internalIterator == null) {
+
+        // If the internal iterator has not been initialized,
+        // then fetch records from starting sequence number
+        internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber)
+      } else if (!internalIterator.hasNext) {
+
+        // If the internal iterator does not have any more records,
+        // then fetch more records after the last consumed sequence number
+        internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber)
+      }
+
+      if (!internalIterator.hasNext) {
+
+        // If the internal iterator still does not have any data, then throw exception
+        // and terminate this iterator
+        finished = true
+        throw new SparkException(
+          s"Could not read until the end sequence number of the range: $range")
+      } else {
+
+        // Get the record, copy the data into a byte array and remember its sequence number
+        val nextRecord: Record = internalIterator.next()
+        val byteBuffer = nextRecord.getData()
+        nextBytes = new Array[Byte](byteBuffer.remaining())
+        byteBuffer.get(nextBytes)
+        lastSeqNumber = nextRecord.getSequenceNumber()
+
+        // If the this record's sequence number matches the stopping sequence number, then make sure
+        // the iterator is marked finished next time getNext() is called
+        if (nextRecord.getSequenceNumber == range.toSeqNumber) {
+          toSeqNumberReceived = true
+        }
+      }
+
+    }
+    nextBytes
+  }
+
+  override protected def close(): Unit = {
+    client.shutdown()
+  }
+
+  /**
+   * Get records starting from or after the given sequence number.
+   */
+  private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = {
+    val shardIterator = getKinesisIterator(iteratorType, seqNum)
+    val result = getRecordsAndNextKinesisIterator(shardIterator)
+    result._1
+  }
+
+  /**
+   * Get the records starting from using a Kinesis shard iterator (which is a progress handle
+   * to get records from Kinesis), and get the next shard iterator for next consumption.
+   */
+  private def getRecordsAndNextKinesisIterator(
+      shardIterator: String): (Iterator[Record], String) = {
+    val getRecordsRequest = new GetRecordsRequest
+    getRecordsRequest.setRequestCredentials(credentials)
+    getRecordsRequest.setShardIterator(shardIterator)
+    val getRecordsResult = retryOrTimeout[GetRecordsResult](
+      s"getting records using shard iterator") {
+        client.getRecords(getRecordsRequest)
+      }
+    (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator)
+  }
+
+  /**
+   * Get the Kinesis shard iterator for getting records starting from or after the given
+   * sequence number.
+   */
+  private def getKinesisIterator(
+      iteratorType: ShardIteratorType,
+      sequenceNumber: String): String = {
+    val getShardIteratorRequest = new GetShardIteratorRequest
+    getShardIteratorRequest.setRequestCredentials(credentials)
+    getShardIteratorRequest.setStreamName(streamName)
+    getShardIteratorRequest.setShardId(shardId)
+    getShardIteratorRequest.setShardIteratorType(iteratorType.toString)
+    getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber)
+    val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult](
+        s"getting shard iterator from sequence number $sequenceNumber") {
+          client.getShardIterator(getShardIteratorRequest)
+        }
+    getShardIteratorResult.getShardIterator
+  }
+
+  /** Helper method to retry Kinesis API request with exponential backoff and timeouts */
+  private def retryOrTimeout[T](message: String)(body: => T): T = {
+    import KinesisSequenceRangeIterator._
+
+    var startTimeMs = System.currentTimeMillis()
+    var retryCount = 0
+    var waitTimeMs = MIN_RETRY_WAIT_TIME_MS
+    var result: Option[T] = None
+    var lastError: Throwable = null
+
+    def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs
+    def isMaxRetryDone = retryCount >= MAX_RETRIES
+
+    while (result.isEmpty && !isTimedOut && !isMaxRetryDone) {
+      if (retryCount > 0) {  // wait only if this is a retry
+        Thread.sleep(waitTimeMs)
+        waitTimeMs *= 2  // if you have waited, then double wait time for next round
+      }
+      try {
+        result = Some(body)
+      } catch {
+        case NonFatal(t) =>
+          lastError = t
+           t match {
+             case ptee: ProvisionedThroughputExceededException =>
+               logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee)
+             case e: Throwable =>
+               throw new SparkException(s"Error while $message", e)
+           }
+      }
+      retryCount += 1
+    }
+    result.getOrElse {
+      if (isTimedOut) {
+        throw new SparkException(
+          s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError)
+      } else {
+        throw new SparkException(
+          s"Gave up after $retryCount retries while $message, last exception: ", lastError)
+      }
+    }
+  }
+}
+
+private[streaming]
+object KinesisSequenceRangeIterator {
+  val MAX_RETRIES = 3
+  val MIN_RETRY_WAIT_TIME_MS = 100
+}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index f6bf552e6b..0ff1b7ed0f 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -177,7 +177,7 @@ private class KinesisTestUtils(
 
 private[kinesis] object KinesisTestUtils {
 
-  val envVarName = "RUN_KINESIS_TESTS"
+  val envVarName = "ENABLE_KINESIS_TESTS"
 
   val shouldRunTests = sys.env.get(envVarName) == Some("1")
 
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
new file mode 100644
index 0000000000..b2e2a4246d
--- /dev/null
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
+import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+
+class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll {
+
+  private val regionId = "us-east-1"
+  private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com"
+  private val testData = 1 to 8
+
+  private var testUtils: KinesisTestUtils = null
+  private var shardIds: Seq[String] = null
+  private var shardIdToData: Map[String, Seq[Int]] = null
+  private var shardIdToSeqNumbers: Map[String, Seq[String]] = null
+  private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null
+  private var shardIdToRange: Map[String, SequenceNumberRange] = null
+  private var allRanges: Seq[SequenceNumberRange] = null
+
+  private var sc: SparkContext = null
+  private var blockManager: BlockManager = null
+
+
+  override def beforeAll(): Unit = {
+    runIfTestsEnabled("Prepare KinesisTestUtils") {
+      testUtils = new KinesisTestUtils(endpointUrl)
+      testUtils.createStream()
+
+      shardIdToDataAndSeqNumbers = testUtils.pushData(testData)
+      require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards")
+
+      shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq
+      shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }}
+      shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }}
+      shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) =>
+        val seqNumRange = SequenceNumberRange(
+          testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last)
+        (shardId, seqNumRange)
+      }
+      allRanges = shardIdToRange.values.toSeq
+
+      val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite")
+      sc = new SparkContext(conf)
+      blockManager = sc.env.blockManager
+    }
+  }
+
+  override def afterAll(): Unit = {
+    if (sc != null) {
+      sc.stop()
+    }
+  }
+
+  testIfEnabled("Basic reading from Kinesis") {
+    // Verify all data using multiple ranges in a single RDD partition
+    val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl,
+      fakeBlockIds(1),
+      Array(SequenceNumberRanges(allRanges.toArray))
+    ).map { bytes => new String(bytes).toInt }.collect()
+    assert(receivedData1.toSet === testData.toSet)
+
+    // Verify all data using one range in each of the multiple RDD partitions
+    val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl,
+      fakeBlockIds(allRanges.size),
+      allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
+    ).map { bytes => new String(bytes).toInt }.collect()
+    assert(receivedData2.toSet === testData.toSet)
+
+    // Verify ordering within each partition
+    val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl,
+      fakeBlockIds(allRanges.size),
+      allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
+    ).map { bytes => new String(bytes).toInt }.collectPartitions()
+    assert(receivedData3.length === allRanges.size)
+    for (i <- 0 until allRanges.size) {
+      assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId))
+    }
+  }
+
+  testIfEnabled("Read data available in both block manager and Kinesis") {
+    testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2)
+  }
+
+  testIfEnabled("Read data available only in block manager, not in Kinesis") {
+    testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0)
+  }
+
+  testIfEnabled("Read data available only in Kinesis, not in block manager") {
+    testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2)
+  }
+
+  testIfEnabled("Read data available partially in block manager, rest in Kinesis") {
+    testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1)
+  }
+
+  testIfEnabled("Test isBlockValid skips block fetching from block manager") {
+    testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0,
+      testIsBlockValid = true)
+  }
+
+  testIfEnabled("Test whether RDD is valid after removing blocks from block anager") {
+    testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2,
+      testBlockRemove = true)
+  }
+
+  /**
+   * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager
+   * and the rest to a write ahead log, and then reading reading it all back using the RDD.
+   * It can also test if the partitions that were read from the log were again stored in
+   * block manager.
+   *
+   *
+   *
+   * @param numPartitions Number of partitions in RDD
+   * @param numPartitionsInBM Number of partitions to write to the BlockManager.
+   *                          Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager
+   * @param numPartitionsInKinesis Number of partitions to write to the Kinesis.
+   *                           Partitions (numPartitions - 1 - numPartitionsInKinesis) to
+   *                           (numPartitions - 1) will be written to Kinesis
+   * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching
+   * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with
+   *                        reads falling back to the WAL
+   * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4
+   *
+   *   numPartitionsInBM = 3
+   *   |------------------|
+   *   |                  |
+   *    0       1       2       3       4
+   *           |                         |
+   *           |-------------------------|
+   *              numPartitionsInKinesis = 4
+   */
+  private def testRDD(
+      numPartitions: Int,
+      numPartitionsInBM: Int,
+      numPartitionsInKinesis: Int,
+      testIsBlockValid: Boolean = false,
+      testBlockRemove: Boolean = false
+    ): Unit = {
+    require(shardIds.size > 1, "Need at least 2 shards to test")
+    require(numPartitionsInBM <= shardIds.size ,
+      "Number of partitions in BlockManager cannot be more than the Kinesis test shards available")
+    require(numPartitionsInKinesis <= shardIds.size ,
+      "Number of partitions in Kinesis cannot be more than the Kinesis test shards available")
+    require(numPartitionsInBM <= numPartitions,
+      "Number of partitions in BlockManager cannot be more than that in RDD")
+    require(numPartitionsInKinesis <= numPartitions,
+      "Number of partitions in Kinesis cannot be more than that in RDD")
+
+    // Put necessary blocks in the block manager
+    val blockIds = fakeBlockIds(numPartitions)
+    blockIds.foreach(blockManager.removeBlock(_))
+    (0 until numPartitionsInBM).foreach { i =>
+      val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() }
+      blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY)
+    }
+
+    // Create the necessary ranges to use in the RDD
+    val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)(
+      SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")))
+    val realRanges = Array.tabulate(numPartitionsInKinesis) { i =>
+      val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis)))
+      SequenceNumberRanges(Array(range))
+    }
+    val ranges = (fakeRanges ++ realRanges)
+
+
+    // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not
+    require(
+      blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty),
+      "Expected blocks not in BlockManager"
+    )
+
+    require(
+      blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty),
+      "Unexpected blocks in BlockManager"
+    )
+
+    // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not
+    require(
+      ranges.takeRight(numPartitionsInKinesis).forall {
+        _.ranges.forall { _.streamName == testUtils.streamName }
+      }, "Incorrect configuration of RDD, expected ranges not set: "
+    )
+
+    require(
+      ranges.dropRight(numPartitionsInKinesis).forall {
+        _.ranges.forall { _.streamName != testUtils.streamName }
+      }, "Incorrect configuration of RDD, unexpected ranges set"
+    )
+
+    val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges)
+    val collectedData = rdd.map { bytes =>
+      new String(bytes).toInt
+    }.collect()
+    assert(collectedData.toSet === testData.toSet)
+
+    // Verify that the block fetching is skipped when isBlockValid is set to false.
+    // This is done by using a RDD whose data is only in memory but is set to skip block fetching
+    // Using that RDD will throw exception, as it skips block fetching even if the blocks are in
+    // in BlockManager.
+    if (testIsBlockValid) {
+      require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager")
+      require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis")
+      val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray,
+        ranges, isBlockIdValid = Array.fill(blockIds.length)(false))
+      intercept[SparkException] {
+        rdd2.collect()
+      }
+    }
+
+    // Verify that the RDD is not invalid after the blocks are removed and can still read data
+    // from write ahead log
+    if (testBlockRemove) {
+      require(numPartitions === numPartitionsInKinesis,
+        "All partitions must be in WAL for this test")
+      require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test")
+      rdd.removeBlocks()
+      assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet)
+    }
+  }
+
+  /** Generate fake block ids */
+  private def fakeBlockIds(num: Int): Array[BlockId] = {
+    Array.tabulate(num) { i => new StreamBlockId(0, i) }
+  }
+}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
index 6d011f295e..8373138785 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
@@ -23,15 +23,24 @@ import org.apache.spark.SparkFunSuite
  * Helper class that runs Kinesis real data transfer tests or
  * ignores them based on env variable is set or not.
  */
-trait KinesisSuiteHelper { self: SparkFunSuite =>
+trait KinesisFunSuite extends SparkFunSuite  {
   import KinesisTestUtils._
 
   /** Run the test if environment variable is set or ignore the test */
-  def testOrIgnore(testName: String)(testBody: => Unit) {
+  def testIfEnabled(testName: String)(testBody: => Unit) {
     if (shouldRunTests) {
       test(testName)(testBody)
     } else {
       ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody)
     }
   }
+
+  /** Run the give body of code only if Kinesis tests are enabled */
+  def runIfTestsEnabled(message: String)(body: => Unit): Unit = {
+    if (shouldRunTests) {
+      body
+    } else {
+      ignore(s"$message [enable by setting env var $envVarName=1]")()
+    }
+  }
 }
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 50f71413ab..f9c952b946 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming._
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
 
-class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper
+class KinesisStreamSuite extends KinesisFunSuite
   with Eventually with BeforeAndAfter with BeforeAndAfterAll {
 
   // This is the name that KCL uses to save metadata to DynamoDB
@@ -83,7 +83,7 @@ class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper
    * you must have AWS credentials available through the default AWS provider chain,
    * and you have to set the system environment variable RUN_KINESIS_TESTS=1 .
    */
-  testOrIgnore("basic operation") {
+  testIfEnabled("basic operation") {
     val kinesisTestUtils = new KinesisTestUtils()
     try {
       kinesisTestUtils.createStream()
-- 
GitLab