diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala similarity index 85% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala index f9e236c449634b9ded5f60aa7c6fbd98f19cf2d6..28412ea07a75cabf4c9913f0fe7d1bcf70216286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala @@ -36,9 +36,12 @@ import org.apache.spark.util.Utils * * At the end, the resulting files are loaded and the answer is checked. */ -class FileStressSuite extends StreamTest { +class FileStreamStressSuite extends StreamTest { import testImplicits._ + // Error message thrown in the streaming job for testing recovery. + private val injectedErrorMsg = "test suite injected failure!" + testQuietly("fault tolerance stress test - unpartitioned output") { stressTest(partitionWrites = false) } @@ -101,13 +104,14 @@ class FileStressSuite extends StreamTest { val input = spark.readStream.format("text").load(inputDir) def startStream(): StreamingQuery = { + val errorMsg = injectedErrorMsg // work around serialization issue val output = input .repartition(5) .as[String] .mapPartitions { iter => val rand = Random.nextInt(100) if (rand < 10) { - sys.error("failure") + sys.error(errorMsg) } iter.map(_.toLong) } @@ -131,22 +135,21 @@ class FileStressSuite extends StreamTest { } var failures = 0 - val streamThread = new Thread("stream runner") { - while (continue) { - if (failures % 10 == 0) { logError(s"Query restart #$failures") } - stream = startStream() - - try { - stream.awaitTermination() - } catch { - case ce: StreamingQueryException => - failures += 1 - } + while (continue) { + if (failures % 10 == 0) { logError(s"Query restart #$failures") } + stream = startStream() + + try { + stream.awaitTermination() + } catch { + case e: StreamingQueryException + if e.getCause != null && e.getCause.getCause != null && + e.getCause.getCause.getMessage.contains(injectedErrorMsg) => + // Getting the expected error message + failures += 1 } } - streamThread.join() - logError(s"Stream restarted $failures times.") assert(spark.read.parquet(outputDir).distinct().count() == numRecords) }