diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index e916f1ee0893b4a45dd8fe6c786cb6aa2df28f28..2555332d222da06558f5b55b53c39d6391920423 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -17,13 +17,13 @@
 
 package org.apache.spark.streaming.kinesis
 
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfterEach
 
-import org.apache.spark.{SparkConf, SparkContext, SparkException}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
 import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
 
 abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
-  extends KinesisFunSuite with BeforeAndAfterAll {
+  extends KinesisFunSuite with BeforeAndAfterEach with LocalSparkContext {
 
   private val testData = 1 to 8
 
@@ -35,10 +35,10 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
   private var shardIdToRange: Map[String, SequenceNumberRange] = null
   private var allRanges: Seq[SequenceNumberRange] = null
 
-  private var sc: SparkContext = null
   private var blockManager: BlockManager = null
 
   override def beforeAll(): Unit = {
+    super.beforeAll()
     runIfTestsEnabled("Prepare KinesisTestUtils") {
       testUtils = new KPLBasedKinesisTestUtils()
       testUtils.createStream()
@@ -55,19 +55,23 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
         (shardId, seqNumRange)
       }
       allRanges = shardIdToRange.values.toSeq
-
-      val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite")
-      sc = new SparkContext(conf)
-      blockManager = sc.env.blockManager
     }
   }
 
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite")
+    sc = new SparkContext(conf)
+    blockManager = sc.env.blockManager
+  }
+
   override def afterAll(): Unit = {
-    if (testUtils != null) {
-      testUtils.deleteStream()
-    }
-    if (sc != null) {
-      sc.stop()
+    try {
+      if (testUtils != null) {
+        testUtils.deleteStream()
+      }
+    } finally {
+      super.afterAll()
     }
   }