diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala
similarity index 85%
rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala
index f9e236c449634b9ded5f60aa7c6fbd98f19cf2d6..28412ea07a75cabf4c9913f0fe7d1bcf70216286 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala
@@ -36,9 +36,12 @@ import org.apache.spark.util.Utils
  *
  * At the end, the resulting files are loaded and the answer is checked.
  */
-class FileStressSuite extends StreamTest {
+class FileStreamStressSuite extends StreamTest {
   import testImplicits._
 
+  // Error message thrown in the streaming job for testing recovery.
+  private val injectedErrorMsg = "test suite injected failure!"
+
   testQuietly("fault tolerance stress test - unpartitioned output") {
     stressTest(partitionWrites = false)
   }
@@ -101,13 +104,14 @@ class FileStressSuite extends StreamTest {
     val input = spark.readStream.format("text").load(inputDir)
 
     def startStream(): StreamingQuery = {
+      val errorMsg = injectedErrorMsg  // work around serialization issue
       val output = input
         .repartition(5)
         .as[String]
         .mapPartitions { iter =>
           val rand = Random.nextInt(100)
           if (rand < 10) {
-            sys.error("failure")
+            sys.error(errorMsg)
           }
           iter.map(_.toLong)
         }
@@ -131,22 +135,21 @@ class FileStressSuite extends StreamTest {
     }
 
     var failures = 0
-    val streamThread = new Thread("stream runner") {
-      while (continue) {
-        if (failures % 10 == 0) { logError(s"Query restart #$failures") }
-        stream = startStream()
-
-        try {
-          stream.awaitTermination()
-        } catch {
-          case ce: StreamingQueryException =>
-            failures += 1
-        }
+    while (continue) {
+      if (failures % 10 == 0) { logError(s"Query restart #$failures") }
+      stream = startStream()
+
+      try {
+        stream.awaitTermination()
+      } catch {
+        case e: StreamingQueryException
+          if e.getCause != null && e.getCause.getCause != null &&
+              e.getCause.getCause.getMessage.contains(injectedErrorMsg) =>
+          // Getting the expected error message
+          failures += 1
       }
     }
 
-    streamThread.join()
-
     logError(s"Stream restarted $failures times.")
     assert(spark.read.parquet(outputDir).distinct().count() == numRecords)
   }