diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index e867fc40f7f1a7c41379cd93a71233eb27cf1139..f01211e20cbfc1b5f30e6d85a4275f0f933c9af7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.StreamSourceProvider
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.util.Utils
 
 class StreamSuite extends StreamTest {
 
@@ -438,52 +439,48 @@ class StreamSuite extends StreamTest {
 
     // 1 - Test if recovery from the checkpoint is successful.
     prepareMemoryStream()
-    withTempDir { dir =>
-      // Copy the checkpoint to a temp dir to prevent changes to the original.
-      // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
-      FileUtils.copyDirectory(checkpointDir, dir)
-
-      // Checkpoint data was generated by a query with 10 shuffle partitions.
-      // In order to test reading from the checkpoint, the checkpoint must have two or more batches,
-      // since the last batch may be rerun.
-      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
-        var streamingQuery: StreamingQuery = null
-        try {
-          streamingQuery =
-            query.queryName("counts").option("checkpointLocation", dir.getCanonicalPath).start()
-          streamingQuery.processAllAvailable()
-          inputData.addData(9)
-          streamingQuery.processAllAvailable()
-
-          QueryTest.checkAnswer(spark.table("counts").toDF(),
-            Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) ::
-            Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil)
-        } finally {
-          if (streamingQuery ne null) {
-            streamingQuery.stop()
-          }
+    val dir1 = Utils.createTempDir().getCanonicalFile // not using withTempDir {}, makes test flaky
+    // Copy the checkpoint to a temp dir to prevent changes to the original.
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+    FileUtils.copyDirectory(checkpointDir, dir1)
+    // Checkpoint data was generated by a query with 10 shuffle partitions.
+    // In order to test reading from the checkpoint, the checkpoint must have two or more batches,
+    // since the last batch may be rerun.
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+      var streamingQuery: StreamingQuery = null
+      try {
+        streamingQuery =
+          query.queryName("counts").option("checkpointLocation", dir1.getCanonicalPath).start()
+        streamingQuery.processAllAvailable()
+        inputData.addData(9)
+        streamingQuery.processAllAvailable()
+
+        QueryTest.checkAnswer(spark.table("counts").toDF(),
+          Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) ::
+          Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil)
+      } finally {
+        if (streamingQuery ne null) {
+          streamingQuery.stop()
         }
       }
     }
 
     // 2 - Check recovery with wrong num shuffle partitions
     prepareMemoryStream()
-    withTempDir { dir =>
-      FileUtils.copyDirectory(checkpointDir, dir)
-
-      // Since the number of partitions is greater than 10, should throw exception.
-      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") {
-        var streamingQuery: StreamingQuery = null
-        try {
-          intercept[StreamingQueryException] {
-            streamingQuery =
-              query.queryName("badQuery").option("checkpointLocation", dir.getCanonicalPath).start()
-            streamingQuery.processAllAvailable()
-          }
-        } finally {
-          if (streamingQuery ne null) {
-            streamingQuery.stop()
-          }
+    val dir2 = Utils.createTempDir().getCanonicalFile
+    FileUtils.copyDirectory(checkpointDir, dir2)
+    // Since the number of partitions is greater than 10, should throw exception.
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") {
+      var streamingQuery: StreamingQuery = null
+      try {
+        intercept[StreamingQueryException] {
+          streamingQuery =
+            query.queryName("badQuery").option("checkpointLocation", dir2.getCanonicalPath).start()
+          streamingQuery.processAllAvailable()
+        }
+      } finally {
+        if (streamingQuery ne null) {
+          streamingQuery.stop()
         }
       }
     }