diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index 889b8a02784d6122425141d8f272a819269906c0..aecdda1c364983789bb2abe51a58cdcffed73b53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -343,7 +343,8 @@ abstract class Catalog {
 
   /**
    * Invalidate and refresh all the cached data (and the associated metadata) for any dataframe that
-   * contains the given data source path.
+   * contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate
+   * everything that is cached.
    *
    * @since 2.0.0
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 92fd366e101fdf72e711e4fefbc5b90bcba19878..fb72c679e3628c486cfab8bfa0363352b0f37152 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -185,9 +185,10 @@ class CacheManager extends Logging {
     plan match {
       case lr: LogicalRelation => lr.relation match {
         case hr: HadoopFsRelation =>
+          val prefixToInvalidate = qualifiedPath.toString
           val invalidate = hr.location.rootPaths
-            .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory))
-            .contains(qualifiedPath)
+            .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory).toString)
+            .exists(_.startsWith(prefixToInvalidate))
           if (invalidate) hr.location.refresh()
           invalidate
         case _ => false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala
index 5648ab480a98af0f1b467008670b22431968c154..fc08c3798ee06dde3f74530335213c6dc0922a50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/TableFileCatalog.scala
@@ -48,13 +48,18 @@ class TableFileCatalog(
 
   private val baseLocation = catalogTable.storage.locationUri
 
+  // Populated on-demand by calls to cachedAllPartitions
+  private var cachedAllPartitions: ListingFileCatalog = null
+
   override def rootPaths: Seq[Path] = baseLocation.map(new Path(_)).toSeq
 
   override def listFiles(filters: Seq[Expression]): Seq[PartitionDirectory] = {
     filterPartitions(filters).listFiles(Nil)
   }
 
-  override def refresh(): Unit = {}
+  override def refresh(): Unit = synchronized {
+    cachedAllPartitions = null
+  }
 
   /**
    * Returns a [[ListingFileCatalog]] for this table restricted to the subset of partitions
@@ -64,7 +69,7 @@ class TableFileCatalog(
    */
   def filterPartitions(filters: Seq[Expression]): ListingFileCatalog = {
     if (filters.isEmpty) {
-      cachedAllPartitions
+      allPartitions
     } else {
       filterPartitions0(filters)
     }
@@ -89,9 +94,14 @@ class TableFileCatalog(
   }
 
   // Not used in the hot path of queries when metastore partition pruning is enabled
-  lazy val cachedAllPartitions: ListingFileCatalog = filterPartitions0(Nil)
+  def allPartitions: ListingFileCatalog = synchronized {
+    if (cachedAllPartitions == null) {
+      cachedAllPartitions = filterPartitions0(Nil)
+    }
+    cachedAllPartitions
+  }
 
-  override def inputFiles: Array[String] = cachedAllPartitions.inputFiles
+  override def inputFiles: Array[String] = allPartitions.inputFiles
 }
 
 /**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 16e1e37b2fb02ba0f9cafe11c3fce840c0e262fb..c909eb5d20bcdaf5222fa5395bb93a80d0059818 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -235,7 +235,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
           if (lazyPruningEnabled) {
             catalog
           } else {
-            catalog.cachedAllPartitions
+            catalog.allPartitions
           }
         }
         val partitionSchemaColumnNames = partitionSchema.map(_.name.toLowerCase).toSet
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
index 7af81a3a90504fb9f52adf1708c8a5281ab08b78..2ca1cd4c07fdb574753ba825242808b5f99dddb8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
@@ -80,9 +80,13 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi
             val df = spark.sql("select * from test")
             assert(sql("select * from test").count() == 5)
 
+            def deleteRandomFile(): Unit = {
+              val p = new Path(spark.table("test").inputFiles.head)
+              assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, true))
+            }
+
             // Delete a file, then assert that we tried to read it. This means the table was cached.
-            val p = new Path(spark.table("test").inputFiles.head)
-            assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, true))
+            deleteRandomFile()
             val e = intercept[SparkException] {
               sql("select * from test").count()
             }
@@ -91,6 +95,19 @@ class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSi
             // Test refreshing the cache.
             spark.catalog.refreshTable("test")
             assert(sql("select * from test").count() == 4)
+            assert(spark.table("test").inputFiles.length == 4)
+
+            // Test refresh by path separately since it goes through different code paths than
+            // refreshTable does.
+            deleteRandomFile()
+            spark.catalog.cacheTable("test")
+            spark.catalog.refreshByPath("/some-invalid-path")  // no-op
+            val e2 = intercept[SparkException] {
+              sql("select * from test").count()
+            }
+            assert(e2.getMessage.contains("FileNotFoundException"))
+            spark.catalog.refreshByPath(dir.getAbsolutePath)
+            assert(sql("select * from test").count() == 3)
           }
         }
       }