diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
index 3e697f36a43494989011749cc539bccfe2c0acdd..c445c15a5f644367c05e850c9738ddcc9ea572f2 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
@@ -64,7 +64,20 @@ private[kinesis] class KinesisCheckpointer(
   def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
     synchronized {
       checkpointers.remove(shardId)
-      checkpoint(shardId, checkpointer)
+    }
+    if (checkpointer != null) {
+      try {
+        // We must call `checkpoint()` with no parameter to finish reading shards.
+        // See an URL below for details:
+        // https://forums.aws.amazon.com/thread.jspa?threadID=244218
+        KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
+      } catch {
+        case NonFatal(e) =>
+          logError(s"Exception:  WorkerId $workerId encountered an exception while checkpointing" +
+            s"to finish reading a shard of $shardId.", e)
+          // Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor
+          throw e
+      }
     }
   }
 
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index 0fe66254e989dce3020cfb4eb791ebced474d8d7..f183ef00b33cdbf010a47cb6d42ca36c553a5294 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -40,11 +40,10 @@ import org.apache.spark.internal.Logging
  *
  * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE!
  */
-private[kinesis] class KinesisTestUtils extends Logging {
+private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging {
 
   val endpointUrl = KinesisTestUtils.endpointUrl
   val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
-  val streamShardCount = 2
 
   private val createStreamTimeoutSeconds = 300
   private val describeStreamPollTimeSeconds = 1
@@ -88,7 +87,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
     logInfo(s"Creating stream ${_streamName}")
     val createStreamRequest = new CreateStreamRequest()
     createStreamRequest.setStreamName(_streamName)
-    createStreamRequest.setShardCount(2)
+    createStreamRequest.setShardCount(streamShardCount)
     kinesisClient.createStream(createStreamRequest)
 
     // The stream is now being created. Wait for it to become active.
@@ -97,6 +96,31 @@ private[kinesis] class KinesisTestUtils extends Logging {
     logInfo(s"Created stream ${_streamName}")
   }
 
+  def getShards(): Seq[Shard] = {
+    kinesisClient.describeStream(_streamName).getStreamDescription.getShards.asScala
+  }
+
+  def splitShard(shardId: String): Unit = {
+    val splitShardRequest = new SplitShardRequest()
+    splitShardRequest.withStreamName(_streamName)
+    splitShardRequest.withShardToSplit(shardId)
+    // Set a half of the max hash value
+    splitShardRequest.withNewStartingHashKey("170141183460469231731687303715884105728")
+    kinesisClient.splitShard(splitShardRequest)
+    // Wait for the shards to become active
+    waitForStreamToBeActive(_streamName)
+  }
+
+  def mergeShard(shardToMerge: String, adjacentShardToMerge: String): Unit = {
+    val mergeShardRequest = new MergeShardsRequest
+    mergeShardRequest.withStreamName(_streamName)
+    mergeShardRequest.withShardToMerge(shardToMerge)
+    mergeShardRequest.withAdjacentShardToMerge(adjacentShardToMerge)
+    kinesisClient.mergeShards(mergeShardRequest)
+    // Wait for the shards to become active
+    waitForStreamToBeActive(_streamName)
+  }
+
   /**
    * Push data to Kinesis stream and return a map of
    * shardId -> seq of (data, seq number) pushed to corresponding shard
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
index 0b455e574e6fad389e3ec70d30652bd6749e93b1..2ee3224b3c2860cbc4b233eaccd3ac72a9dca074 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
@@ -25,7 +25,8 @@ import scala.collection.mutable.ArrayBuffer
 import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
 import com.google.common.util.concurrent.{FutureCallback, Futures}
 
-private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
+private[kinesis] class KPLBasedKinesisTestUtils(streamShardCount: Int = 2)
+    extends KinesisTestUtils(streamShardCount) {
   override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
     if (!aggregate) {
       new SimpleDataGenerator(kinesisClient)
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
index bcaed628a8ddd5fd98833ee24a4cf86d9f9a099f..fef24ed4c5dd0911c7b91dedca365a2fe7464b84 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -118,7 +118,7 @@ class KinesisCheckpointerSuite extends TestSuiteBase
     when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
 
     kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock)
-    verify(checkpointerMock, times(1)).checkpoint(anyString())
+    verify(checkpointerMock, times(1)).checkpoint()
   }
 
   test("if checkpointing is going on, wait until finished before removing and checkpointing") {
@@ -145,7 +145,8 @@ class KinesisCheckpointerSuite extends TestSuiteBase
 
     clock.advance(checkpointInterval.milliseconds / 2)
     eventually(timeout(1 second)) {
-      verify(checkpointerMock, times(2)).checkpoint(anyString())
+      verify(checkpointerMock, times(1)).checkpoint(anyString)
+      verify(checkpointerMock, times(1)).checkpoint()
     }
   }
 }
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 0e71bf9b84332097d65fa98826f9d0ad7c29a18f..404b673c011718a96c96dc5fb09a1bad2f7d2324 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -225,6 +225,76 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
     ssc.stop(stopSparkContext = false)
   }
 
+  testIfEnabled("split and merge shards in a stream") {
+    // Since this test tries to split and merge shards in a stream, we create another
+    // temporary stream and then remove it when finished.
+    val localAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
+    val localTestUtils = new KPLBasedKinesisTestUtils(1)
+    localTestUtils.createStream()
+    try {
+      val awsCredentials = KinesisTestUtils.getAWSCredentials()
+      val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName,
+        localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST,
+        Seconds(10), StorageLevel.MEMORY_ONLY,
+        awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+      val collected = new mutable.HashSet[Int]
+      stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
+        collected.synchronized {
+          collected ++= rdd.collect()
+          logInfo("Collected = " + collected.mkString(", "))
+        }
+      }
+      ssc.start()
+
+      val testData1 = 1 to 10
+      val testData2 = 11 to 20
+      val testData3 = 21 to 30
+
+      eventually(timeout(60 seconds), interval(10 second)) {
+        localTestUtils.pushData(testData1, aggregateTestData)
+        assert(collected.synchronized { collected === testData1.toSet },
+          "\nData received does not match data sent")
+      }
+
+      val shardToSplit = localTestUtils.getShards().head
+      localTestUtils.splitShard(shardToSplit.getShardId)
+      val (splitOpenShards, splitCloseShards) = localTestUtils.getShards().partition { shard =>
+        shard.getSequenceNumberRange.getEndingSequenceNumber == null
+      }
+
+      // We should have one closed shard and two open shards
+      assert(splitCloseShards.size == 1)
+      assert(splitOpenShards.size == 2)
+
+      eventually(timeout(60 seconds), interval(10 second)) {
+        localTestUtils.pushData(testData2, aggregateTestData)
+        assert(collected.synchronized { collected === (testData1 ++ testData2).toSet },
+          "\nData received does not match data sent after splitting a shard")
+      }
+
+      val Seq(shardToMerge, adjShard) = splitOpenShards
+      localTestUtils.mergeShard(shardToMerge.getShardId, adjShard.getShardId)
+      val (mergedOpenShards, mergedCloseShards) = localTestUtils.getShards().partition { shard =>
+        shard.getSequenceNumberRange.getEndingSequenceNumber == null
+      }
+
+      // We should have three closed shards and one open shard
+      assert(mergedCloseShards.size == 3)
+      assert(mergedOpenShards.size == 1)
+
+      eventually(timeout(60 seconds), interval(10 second)) {
+        localTestUtils.pushData(testData3, aggregateTestData)
+        assert(collected.synchronized { collected === (testData1 ++ testData2 ++ testData3).toSet },
+          "\nData received does not match data sent after merging shards")
+      }
+    } finally {
+      ssc.stop(stopSparkContext = false)
+      localTestUtils.deleteStream()
+      localTestUtils.deleteDynamoDBTable(localAppName)
+    }
+  }
+
   testIfEnabled("failure recovery") {
     val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
     val checkpointDir = Utils.createTempDir().getAbsolutePath
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 5ac007cd598b9e22086bbc169c8ef7c9178b3738..2e8ed698278d0bff3a462a712cb483ec30fa8634 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1420,7 +1420,7 @@ class KinesisStreamTests(PySparkStreamingTestCase):
 
         import random
         kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000)))
-        kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils()
+        kinesisTestUtils = self.ssc._jvm.org.apache.spark.streaming.kinesis.KinesisTestUtils(2)
         try:
             kinesisTestUtils.createStream()
             aWSCredentials = kinesisTestUtils.getAWSCredentials()