diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
index 1559f7a9f7ac00a917cc742bbc6f9287270e9e20..162b19d7f0e9eeee56a808762531e85b05c064a0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
@@ -42,6 +42,7 @@ object MasterFailureTest extends Logging {
 
   @volatile var killed = false
   @volatile var killCount = 0
+  @volatile var setupCalled = false
 
   def main(args: Array[String]) {
     if (args.size < 2) {
@@ -131,8 +132,26 @@ object MasterFailureTest extends Logging {
     // Just making sure that the expected output does not have duplicates
     assert(expectedOutput.distinct.toSet == expectedOutput.toSet)
 
+    // Reset all state
+    reset()
+
+    // Create the directories for this test
+    val uuid = UUID.randomUUID().toString
+    val rootDir = new Path(directory, uuid)
+    val fs = rootDir.getFileSystem(new Configuration())
+    val checkpointDir = new Path(rootDir, "checkpoint")
+    val testDir = new Path(rootDir, "test")
+    fs.mkdirs(checkpointDir)
+    fs.mkdirs(testDir)
+
     // Setup the stream computation with the given operation
-    val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation)
+    val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => {
+      setupStreams(batchDuration, operation, checkpointDir, testDir)
+    })
+
+    // Check if setupStream was called to create StreamingContext
+    // (and not created from checkpoint file)
+    assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate")
 
     // Start generating files in the a different thread
     val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds)
@@ -144,9 +163,7 @@ object MasterFailureTest extends Logging {
     val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2
     val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun)
 
-    // Delete directories
     fileGeneratingThread.join()
-    val fs = checkpointDir.getFileSystem(new Configuration())
     fs.delete(checkpointDir, true)
     fs.delete(testDir, true)
     logInfo("Finished test after " + killCount + " failures")
@@ -159,32 +176,24 @@ object MasterFailureTest extends Logging {
    * files should be written for testing.
    */
   private def setupStreams[T: ClassTag](
-      directory: String,
       batchDuration: Duration,
-      operation: DStream[String] => DStream[T]
-    ): (StreamingContext, Path, Path) = {
-    // Reset all state
-    reset()
-
-    // Create the directories for this test
-    val uuid = UUID.randomUUID().toString
-    val rootDir = new Path(directory, uuid)
-    val fs = rootDir.getFileSystem(new Configuration())
-    val checkpointDir = new Path(rootDir, "checkpoint")
-    val testDir = new Path(rootDir, "test")
-    fs.mkdirs(checkpointDir)
-    fs.mkdirs(testDir)
+      operation: DStream[String] => DStream[T],
+      checkpointDir: Path,
+      testDir: Path
+    ): StreamingContext = {
+    // Mark that setup was called
+    setupCalled = true
 
     // Setup the streaming computation with the given operation
     System.clearProperty("spark.driver.port")
     System.clearProperty("spark.hostPort")
-    var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
+    val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
     ssc.checkpoint(checkpointDir.toString)
     val inputStream = ssc.textFileStream(testDir.toString)
     val operatedStream = operation(inputStream)
     val outputStream = new TestOutputStream(operatedStream)
     ssc.registerOutputStream(outputStream)
-    (ssc, checkpointDir, testDir)
+    ssc
   }
 
 
@@ -204,7 +213,7 @@ object MasterFailureTest extends Logging {
     var isTimedOut = false
     val mergedOutput = new ArrayBuffer[T]()
     val checkpointDir = ssc.checkpointDir
-    var batchDuration = ssc.graph.batchDuration
+    val batchDuration = ssc.graph.batchDuration
 
     while(!isLastOutputGenerated && !isTimedOut) {
       // Get the output buffer
@@ -261,7 +270,10 @@ object MasterFailureTest extends Logging {
         )
         Thread.sleep(sleepTime)
         // Recreate the streaming context from checkpoint
-        ssc = new StreamingContext(checkpointDir)
+        ssc = StreamingContext.getOrCreate(checkpointDir, () => {
+          throw new Exception("Trying to create new context when it " +
+            "should be reading from checkpoint file")
+        })
       }
     }
     mergedOutput
@@ -297,6 +309,7 @@ object MasterFailureTest extends Logging {
   private def reset() {
     killed = false
     killCount = 0
+    setupCalled = false
   }
 }