diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index 3e697f36a43494989011749cc539bccfe2c0acdd..c445c15a5f644367c05e850c9738ddcc9ea572f2 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -64,7 +64,20 @@ private[kinesis] class KinesisCheckpointer( def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { synchronized { checkpointers.remove(shardId) - checkpoint(shardId, checkpointer) + } + if (checkpointer != null) { + try { + // We must call `checkpoint()` with no parameter to finish reading shards. + // See an URL below for details: + // https://forums.aws.amazon.com/thread.jspa?threadID=244218 + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + } catch { + case NonFatal(e) => + logError(s"Exception: WorkerId $workerId encountered an exception while checkpointing" + + s"to finish reading a shard of $shardId.", e) + // Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor + throw e + } } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0fe66254e989dce3020cfb4eb791ebced474d8d7..f183ef00b33cdbf010a47cb6d42ca36c553a5294 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -40,11 +40,10 @@ import org.apache.spark.internal.Logging * * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ -private[kinesis] class KinesisTestUtils extends Logging { +private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() - val streamShardCount = 2 private val createStreamTimeoutSeconds = 300 private val describeStreamPollTimeSeconds = 1 @@ -88,7 +87,7 @@ private[kinesis] class KinesisTestUtils extends Logging { logInfo(s"Creating stream ${_streamName}") val createStreamRequest = new CreateStreamRequest() createStreamRequest.setStreamName(_streamName) - createStreamRequest.setShardCount(2) + createStreamRequest.setShardCount(streamShardCount) kinesisClient.createStream(createStreamRequest) // The stream is now being created. Wait for it to become active. @@ -97,6 +96,31 @@ private[kinesis] class KinesisTestUtils extends Logging { logInfo(s"Created stream ${_streamName}") } + def getShards(): Seq[Shard] = { + kinesisClient.describeStream(_streamName).getStreamDescription.getShards.asScala + } + + def splitShard(shardId: String): Unit = { + val splitShardRequest = new SplitShardRequest() + splitShardRequest.withStreamName(_streamName) + splitShardRequest.withShardToSplit(shardId) + // Set a half of the max hash value + splitShardRequest.withNewStartingHashKey("170141183460469231731687303715884105728") + kinesisClient.splitShard(splitShardRequest) + // Wait for the shards to become active + waitForStreamToBeActive(_streamName) + } + + def mergeShard(shardToMerge: String, adjacentShardToMerge: String): Unit = { + val mergeShardRequest = new MergeShardsRequest + mergeShardRequest.withStreamName(_streamName) + mergeShardRequest.withShardToMerge(shardToMerge) + mergeShardRequest.withAdjacentShardToMerge(adjacentShardToMerge) + kinesisClient.mergeShards(mergeShardRequest) + // Wait for the shards to become active + waitForStreamToBeActive(_streamName) + } + /** * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala index 0b455e574e6fad389e3ec70d30652bd6749e93b1..2ee3224b3c2860cbc4b233eaccd3ac72a9dca074 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala @@ -25,7 +25,8 @@ import scala.collection.mutable.ArrayBuffer import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult} import com.google.common.util.concurrent.{FutureCallback, Futures} -private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils { +private[kinesis] class KPLBasedKinesisTestUtils(streamShardCount: Int = 2) + extends KinesisTestUtils(streamShardCount) { override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { if (!aggregate) { new SimpleDataGenerator(kinesisClient) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala index bcaed628a8ddd5fd98833ee24a4cf86d9f9a099f..fef24ed4c5dd0911c7b91dedca365a2fe7464b84 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -118,7 +118,7 @@ class KinesisCheckpointerSuite extends TestSuiteBase when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock) - verify(checkpointerMock, times(1)).checkpoint(anyString()) + verify(checkpointerMock, times(1)).checkpoint() } test("if checkpointing is going on, wait until finished before removing and checkpointing") { @@ -145,7 +145,8 @@ class KinesisCheckpointerSuite extends TestSuiteBase clock.advance(checkpointInterval.milliseconds / 2) eventually(timeout(1 second)) { - verify(checkpointerMock, times(2)).checkpoint(anyString()) + verify(checkpointerMock, times(1)).checkpoint(anyString) + verify(checkpointerMock, times(1)).checkpoint() } } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 0e71bf9b84332097d65fa98826f9d0ad7c29a18f..404b673c011718a96c96dc5fb09a1bad2f7d2324 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -225,6 +225,76 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun ssc.stop(stopSparkContext = false) } + testIfEnabled("split and merge shards in a stream") { + // Since this test tries to split and merge shards in a stream, we create another + // temporary stream and then remove it when finished. + val localAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}" + val localTestUtils = new KPLBasedKinesisTestUtils(1) + localTestUtils.createStream() + try { + val awsCredentials = KinesisTestUtils.getAWSCredentials() + val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName, + localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + val collected = new mutable.HashSet[Int] + stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => + collected.synchronized { + collected ++= rdd.collect() + logInfo("Collected = " + collected.mkString(", ")) + } + } + ssc.start() + + val testData1 = 1 to 10 + val testData2 = 11 to 20 + val testData3 = 21 to 30 + + eventually(timeout(60 seconds), interval(10 second)) { + localTestUtils.pushData(testData1, aggregateTestData) + assert(collected.synchronized { collected === testData1.toSet }, + "\nData received does not match data sent") + } + + val shardToSplit = localTestUtils.getShards().head + localTestUtils.splitShard(shardToSplit.getShardId) + val (splitOpenShards, splitCloseShards) = localTestUtils.getShards().partition { shard => + shard.getSequenceNumberRange.getEndingSequenceNumber == null + } + + // We should have one closed shard and two open shards + assert(splitCloseShards.size == 1) + assert(splitOpenShards.size == 2) + + eventually(timeout(60 seconds), interval(10 second)) { + localTestUtils.pushData(testData2, aggregateTestData) + assert(collected.synchronized { collected === (testData1 ++ testData2).toSet }, + "\nData received does not match data sent after splitting a shard") + } + + val Seq(shardToMerge, adjShard) = splitOpenShards + localTestUtils.mergeShard(shardToMerge.getShardId, adjShard.getShardId) + val (mergedOpenShards, mergedCloseShards) = localTestUtils.getShards().partition { shard => + shard.getSequenceNumberRange.getEndingSequenceNumber == null + } + + // We should have three closed shards and one open shard + assert(mergedCloseShards.size == 3) + assert(mergedOpenShards.size == 1) + + eventually(timeout(60 seconds), interval(10 second)) { + localTestUtils.pushData(testData3, aggregateTestData) + assert(collected.synchronized { collected === (testData1 ++ testData2 ++ testData3).toSet }, + "\nData received does not match data sent after merging shards") + } + } finally { + ssc.stop(stopSparkContext = false) + localTestUtils.deleteStream() + localTestUtils.deleteDynamoDBTable(localAppName) + } + } + testIfEnabled("failure recovery") { val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) val checkpointDir = Utils.createTempDir().getAbsolutePath diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5ac007cd598b9e22086bbc169c8ef7c9178b3738..2e8ed698278d0bff3a462a712cb483ec30fa8634 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1420,7 +1420,7 @@ class KinesisStreamTests(PySparkStreamingTestCase): import random kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) - kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils() + kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2) try: kinesisTestUtils.createStream() aWSCredentials = kinesisTestUtils.getAWSCredentials()