diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
index ee4d0863d9771354b9306e76d8ab0ffc91910a7f..11605dd280569d2dc937c7205496c2755a3f0e77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala
@@ -17,12 +17,19 @@
 
 package org.apache.spark.sql.execution.datasources
 
+import java.io.FileNotFoundException
+
 import scala.collection.mutable
 
+import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs._
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
 
+import org.apache.spark.internal.Logging
+import org.apache.spark.metrics.source.HiveCatalogMetrics
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
 
 
 /**
@@ -84,4 +91,223 @@ class InMemoryFileIndex(
   }
 
   override def hashCode(): Int = rootPaths.toSet.hashCode()
+
+  /**
+   * List leaf files of given paths. This method will submit a Spark job to do parallel
+   * listing whenever there is a path having more files than the parallel partition discovery
+   * discovery threshold.
+   *
+   * This is publicly visible for testing.
+   */
+  def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {
+    val output = mutable.LinkedHashSet[FileStatus]()
+    val pathsToFetch = mutable.ArrayBuffer[Path]()
+    for (path <- paths) {
+      fileStatusCache.getLeafFiles(path) match {
+        case Some(files) =>
+          HiveCatalogMetrics.incrementFileCacheHits(files.length)
+          output ++= files
+        case None =>
+          pathsToFetch += path
+      }
+    }
+    val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass))
+    val discovered = InMemoryFileIndex.bulkListLeafFiles(
+      pathsToFetch, hadoopConf, filter, sparkSession)
+    discovered.foreach { case (path, leafFiles) =>
+      HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size)
+      fileStatusCache.putLeafFiles(path, leafFiles.toArray)
+      output ++= leafFiles
+    }
+    output
+  }
+}
+
+object InMemoryFileIndex extends Logging {
+
+  /** A serializable variant of HDFS's BlockLocation. */
+  private case class SerializableBlockLocation(
+      names: Array[String],
+      hosts: Array[String],
+      offset: Long,
+      length: Long)
+
+  /** A serializable variant of HDFS's FileStatus. */
+  private case class SerializableFileStatus(
+      path: String,
+      length: Long,
+      isDir: Boolean,
+      blockReplication: Short,
+      blockSize: Long,
+      modificationTime: Long,
+      accessTime: Long,
+      blockLocations: Array[SerializableBlockLocation])
+
+  /**
+   * Lists a collection of paths recursively. Picks the listing strategy adaptively depending
+   * on the number of paths to list.
+   *
+   * This may only be called on the driver.
+   *
+   * @return for each input path, the set of discovered files for the path
+   */
+  private def bulkListLeafFiles(
+      paths: Seq[Path],
+      hadoopConf: Configuration,
+      filter: PathFilter,
+      sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = {
+
+    // Short-circuits parallel listing when serial listing is likely to be faster.
+    if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) {
+      return paths.map { path =>
+        (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession)))
+      }
+    }
+
+    logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")
+    HiveCatalogMetrics.incrementParallelListingJobCount(1)
+
+    val sparkContext = sparkSession.sparkContext
+    val serializableConfiguration = new SerializableConfiguration(hadoopConf)
+    val serializedPaths = paths.map(_.toString)
+    val parallelPartitionDiscoveryParallelism =
+      sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism
+
+    // Set the number of parallelism to prevent following file listing from generating many tasks
+    // in case of large #defaultParallelism.
+    val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism)
+
+    val statusMap = sparkContext
+      .parallelize(serializedPaths, numParallelism)
+      .mapPartitions { pathStrings =>
+        val hadoopConf = serializableConfiguration.value
+        pathStrings.map(new Path(_)).toSeq.map { path =>
+          (path, listLeafFiles(path, hadoopConf, filter, None))
+        }.iterator
+      }.map { case (path, statuses) =>
+      val serializableStatuses = statuses.map { status =>
+        // Turn FileStatus into SerializableFileStatus so we can send it back to the driver
+        val blockLocations = status match {
+          case f: LocatedFileStatus =>
+            f.getBlockLocations.map { loc =>
+              SerializableBlockLocation(
+                loc.getNames,
+                loc.getHosts,
+                loc.getOffset,
+                loc.getLength)
+            }
+
+          case _ =>
+            Array.empty[SerializableBlockLocation]
+        }
+
+        SerializableFileStatus(
+          status.getPath.toString,
+          status.getLen,
+          status.isDirectory,
+          status.getReplication,
+          status.getBlockSize,
+          status.getModificationTime,
+          status.getAccessTime,
+          blockLocations)
+      }
+      (path.toString, serializableStatuses)
+    }.collect()
+
+    // turn SerializableFileStatus back to Status
+    statusMap.map { case (path, serializableStatuses) =>
+      val statuses = serializableStatuses.map { f =>
+        val blockLocations = f.blockLocations.map { loc =>
+          new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length)
+        }
+        new LocatedFileStatus(
+          new FileStatus(
+            f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime,
+            new Path(f.path)),
+          blockLocations)
+      }
+      (new Path(path), statuses)
+    }
+  }
+
+  /**
+   * Lists a single filesystem path recursively. If a SparkSession object is specified, this
+   * function may launch Spark jobs to parallelize listing.
+   *
+   * If sessionOpt is None, this may be called on executors.
+   *
+   * @return all children of path that match the specified filter.
+   */
+  private def listLeafFiles(
+      path: Path,
+      hadoopConf: Configuration,
+      filter: PathFilter,
+      sessionOpt: Option[SparkSession]): Seq[FileStatus] = {
+    logTrace(s"Listing $path")
+    val fs = path.getFileSystem(hadoopConf)
+    val name = path.getName.toLowerCase
+
+    // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist
+    // Note that statuses only include FileStatus for the files and dirs directly under path,
+    // and does not include anything else recursively.
+    val statuses = try fs.listStatus(path) catch {
+      case _: FileNotFoundException =>
+        logWarning(s"The directory $path was not found. Was it deleted very recently?")
+        Array.empty[FileStatus]
+    }
+
+    val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName))
+
+    val allLeafStatuses = {
+      val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory)
+      val nestedFiles: Seq[FileStatus] = sessionOpt match {
+        case Some(session) =>
+          bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2)
+        case _ =>
+          dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt))
+      }
+      val allFiles = topLevelFiles ++ nestedFiles
+      if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles
+    }
+
+    allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map {
+      case f: LocatedFileStatus =>
+        f
+
+      // NOTE:
+      //
+      // - Although S3/S3A/S3N file system can be quite slow for remote file metadata
+      //   operations, calling `getFileBlockLocations` does no harm here since these file system
+      //   implementations don't actually issue RPC for this method.
+      //
+      // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not
+      //   be a big deal since we always use to `listLeafFilesInParallel` when the number of
+      //   paths exceeds threshold.
+      case f =>
+        // The other constructor of LocatedFileStatus will call FileStatus.getPermission(),
+        // which is very slow on some file system (RawLocalFileSystem, which is launch a
+        // subprocess and parse the stdout).
+        val locations = fs.getFileBlockLocations(f, 0, f.getLen)
+        val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize,
+          f.getModificationTime, 0, null, null, null, null, f.getPath, locations)
+        if (f.isSymlink) {
+          lfs.setSymlink(f.getSymlink)
+        }
+        lfs
+    }
+  }
+
+  /** Checks if we should filter out this path name. */
+  def shouldFilterOut(pathName: String): Boolean = {
+    // We filter follow paths:
+    // 1. everything that starts with _ and ., except _common_metadata and _metadata
+    // because Parquet needs to find those metadata files from leaf files returned by this method.
+    // We should refactor this logic to not mix metadata files with data files.
+    // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we
+    // should skip this file in case of double reading.
+    val exclude = (pathName.startsWith("_") && !pathName.contains("=")) ||
+      pathName.startsWith(".") || pathName.endsWith("._COPYING_")
+    val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata")
+    exclude && !include
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index 71500a010581e0617b23e106aa153a7c3734f806..ffd7f6c750f859107ec106b23c38f289f8d95b03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -17,22 +17,17 @@
 
 package org.apache.spark.sql.execution.datasources
 
-import java.io.FileNotFoundException
-
 import scala.collection.mutable
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs._
-import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.metrics.source.HiveCatalogMetrics
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.{expressions, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.types.{StringType, StructType}
-import org.apache.spark.util.SerializableConfiguration
 
 /**
  * An abstract class that represents [[FileIndex]]s that are aware of partitioned tables.
@@ -241,224 +236,8 @@ abstract class PartitioningAwareFileIndex(
     val name = path.getName
     !((name.startsWith("_") && !name.contains("=")) || name.startsWith("."))
   }
-
-  /**
-   * List leaf files of given paths. This method will submit a Spark job to do parallel
-   * listing whenever there is a path having more files than the parallel partition discovery
-   * discovery threshold.
-   *
-   * This is publicly visible for testing.
-   */
-  def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {
-    val output = mutable.LinkedHashSet[FileStatus]()
-    val pathsToFetch = mutable.ArrayBuffer[Path]()
-    for (path <- paths) {
-      fileStatusCache.getLeafFiles(path) match {
-        case Some(files) =>
-          HiveCatalogMetrics.incrementFileCacheHits(files.length)
-          output ++= files
-        case None =>
-          pathsToFetch += path
-      }
-    }
-    val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass))
-    val discovered = PartitioningAwareFileIndex.bulkListLeafFiles(
-      pathsToFetch, hadoopConf, filter, sparkSession)
-    discovered.foreach { case (path, leafFiles) =>
-      HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size)
-      fileStatusCache.putLeafFiles(path, leafFiles.toArray)
-      output ++= leafFiles
-    }
-    output
-  }
 }
 
-object PartitioningAwareFileIndex extends Logging {
+object PartitioningAwareFileIndex {
   val BASE_PATH_PARAM = "basePath"
-
-  /** A serializable variant of HDFS's BlockLocation. */
-  private case class SerializableBlockLocation(
-      names: Array[String],
-      hosts: Array[String],
-      offset: Long,
-      length: Long)
-
-  /** A serializable variant of HDFS's FileStatus. */
-  private case class SerializableFileStatus(
-      path: String,
-      length: Long,
-      isDir: Boolean,
-      blockReplication: Short,
-      blockSize: Long,
-      modificationTime: Long,
-      accessTime: Long,
-      blockLocations: Array[SerializableBlockLocation])
-
-  /**
-   * Lists a collection of paths recursively. Picks the listing strategy adaptively depending
-   * on the number of paths to list.
-   *
-   * This may only be called on the driver.
-   *
-   * @return for each input path, the set of discovered files for the path
-   */
-  private def bulkListLeafFiles(
-      paths: Seq[Path],
-      hadoopConf: Configuration,
-      filter: PathFilter,
-      sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = {
-
-    // Short-circuits parallel listing when serial listing is likely to be faster.
-    if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) {
-      return paths.map { path =>
-        (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession)))
-      }
-    }
-
-    logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")
-    HiveCatalogMetrics.incrementParallelListingJobCount(1)
-
-    val sparkContext = sparkSession.sparkContext
-    val serializableConfiguration = new SerializableConfiguration(hadoopConf)
-    val serializedPaths = paths.map(_.toString)
-    val parallelPartitionDiscoveryParallelism =
-      sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism
-
-    // Set the number of parallelism to prevent following file listing from generating many tasks
-    // in case of large #defaultParallelism.
-    val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism)
-
-    val statusMap = sparkContext
-      .parallelize(serializedPaths, numParallelism)
-      .mapPartitions { pathStrings =>
-        val hadoopConf = serializableConfiguration.value
-        pathStrings.map(new Path(_)).toSeq.map { path =>
-          (path, listLeafFiles(path, hadoopConf, filter, None))
-        }.iterator
-      }.map { case (path, statuses) =>
-        val serializableStatuses = statuses.map { status =>
-          // Turn FileStatus into SerializableFileStatus so we can send it back to the driver
-          val blockLocations = status match {
-            case f: LocatedFileStatus =>
-              f.getBlockLocations.map { loc =>
-                SerializableBlockLocation(
-                  loc.getNames,
-                  loc.getHosts,
-                  loc.getOffset,
-                  loc.getLength)
-              }
-
-            case _ =>
-              Array.empty[SerializableBlockLocation]
-          }
-
-          SerializableFileStatus(
-            status.getPath.toString,
-            status.getLen,
-            status.isDirectory,
-            status.getReplication,
-            status.getBlockSize,
-            status.getModificationTime,
-            status.getAccessTime,
-            blockLocations)
-        }
-        (path.toString, serializableStatuses)
-      }.collect()
-
-    // turn SerializableFileStatus back to Status
-    statusMap.map { case (path, serializableStatuses) =>
-      val statuses = serializableStatuses.map { f =>
-        val blockLocations = f.blockLocations.map { loc =>
-          new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length)
-        }
-        new LocatedFileStatus(
-          new FileStatus(
-            f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime,
-            new Path(f.path)),
-          blockLocations)
-      }
-      (new Path(path), statuses)
-    }
-  }
-
-  /**
-   * Lists a single filesystem path recursively. If a SparkSession object is specified, this
-   * function may launch Spark jobs to parallelize listing.
-   *
-   * If sessionOpt is None, this may be called on executors.
-   *
-   * @return all children of path that match the specified filter.
-   */
-  private def listLeafFiles(
-      path: Path,
-      hadoopConf: Configuration,
-      filter: PathFilter,
-      sessionOpt: Option[SparkSession]): Seq[FileStatus] = {
-    logTrace(s"Listing $path")
-    val fs = path.getFileSystem(hadoopConf)
-    val name = path.getName.toLowerCase
-
-    // [SPARK-17599] Prevent InMemoryFileIndex from failing if path doesn't exist
-    // Note that statuses only include FileStatus for the files and dirs directly under path,
-    // and does not include anything else recursively.
-    val statuses = try fs.listStatus(path) catch {
-      case _: FileNotFoundException =>
-        logWarning(s"The directory $path was not found. Was it deleted very recently?")
-        Array.empty[FileStatus]
-    }
-
-    val filteredStatuses = statuses.filterNot(status => shouldFilterOut(status.getPath.getName))
-
-    val allLeafStatuses = {
-      val (dirs, topLevelFiles) = filteredStatuses.partition(_.isDirectory)
-      val nestedFiles: Seq[FileStatus] = sessionOpt match {
-        case Some(session) =>
-          bulkListLeafFiles(dirs.map(_.getPath), hadoopConf, filter, session).flatMap(_._2)
-        case _ =>
-          dirs.flatMap(dir => listLeafFiles(dir.getPath, hadoopConf, filter, sessionOpt))
-      }
-      val allFiles = topLevelFiles ++ nestedFiles
-      if (filter != null) allFiles.filter(f => filter.accept(f.getPath)) else allFiles
-    }
-
-    allLeafStatuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map {
-      case f: LocatedFileStatus =>
-        f
-
-      // NOTE:
-      //
-      // - Although S3/S3A/S3N file system can be quite slow for remote file metadata
-      //   operations, calling `getFileBlockLocations` does no harm here since these file system
-      //   implementations don't actually issue RPC for this method.
-      //
-      // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not
-      //   be a big deal since we always use to `listLeafFilesInParallel` when the number of
-      //   paths exceeds threshold.
-      case f =>
-        // The other constructor of LocatedFileStatus will call FileStatus.getPermission(),
-        // which is very slow on some file system (RawLocalFileSystem, which is launch a
-        // subprocess and parse the stdout).
-        val locations = fs.getFileBlockLocations(f, 0, f.getLen)
-        val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize,
-          f.getModificationTime, 0, null, null, null, null, f.getPath, locations)
-        if (f.isSymlink) {
-          lfs.setSymlink(f.getSymlink)
-        }
-        lfs
-    }
-  }
-
-  /** Checks if we should filter out this path name. */
-  def shouldFilterOut(pathName: String): Boolean = {
-    // We filter follow paths:
-    // 1. everything that starts with _ and ., except _common_metadata and _metadata
-    // because Parquet needs to find those metadata files from leaf files returned by this method.
-    // We should refactor this logic to not mix metadata files with data files.
-    // 2. everything that ends with `._COPYING_`, because this is a intermediate state of file. we
-    // should skip this file in case of double reading.
-    val exclude = (pathName.startsWith("_") && !pathName.contains("=")) ||
-      pathName.startsWith(".") || pathName.endsWith("._COPYING_")
-    val include = pathName.startsWith("_common_metadata") || pathName.startsWith("_metadata")
-    exclude && !include
-  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
index 7ea4064927576fcfb3e9a78683360f97cbe2b5b5..00f5d5db8f5f4988c132fb857c0e90e17b723234 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
@@ -135,15 +135,15 @@ class FileIndexSuite extends SharedSQLContext {
     }
   }
 
-  test("PartitioningAwareFileIndex - file filtering") {
-    assert(!PartitioningAwareFileIndex.shouldFilterOut("abcd"))
-    assert(PartitioningAwareFileIndex.shouldFilterOut(".ab"))
-    assert(PartitioningAwareFileIndex.shouldFilterOut("_cd"))
-    assert(!PartitioningAwareFileIndex.shouldFilterOut("_metadata"))
-    assert(!PartitioningAwareFileIndex.shouldFilterOut("_common_metadata"))
-    assert(PartitioningAwareFileIndex.shouldFilterOut("_ab_metadata"))
-    assert(PartitioningAwareFileIndex.shouldFilterOut("_cd_common_metadata"))
-    assert(PartitioningAwareFileIndex.shouldFilterOut("a._COPYING_"))
+  test("InMemoryFileIndex - file filtering") {
+    assert(!InMemoryFileIndex.shouldFilterOut("abcd"))
+    assert(InMemoryFileIndex.shouldFilterOut(".ab"))
+    assert(InMemoryFileIndex.shouldFilterOut("_cd"))
+    assert(!InMemoryFileIndex.shouldFilterOut("_metadata"))
+    assert(!InMemoryFileIndex.shouldFilterOut("_common_metadata"))
+    assert(InMemoryFileIndex.shouldFilterOut("_ab_metadata"))
+    assert(InMemoryFileIndex.shouldFilterOut("_cd_common_metadata"))
+    assert(InMemoryFileIndex.shouldFilterOut("a._COPYING_"))
   }
 
   test("SPARK-17613 - PartitioningAwareFileIndex: base path w/o '/' at end") {