diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 32387707612f432baf964023d226d9d78a1e6e44..4bbbd66132b751f9968bda49ac5ea1d25e8688ca 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -24,6 +24,7 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JS
 import java.util.concurrent.TimeUnit
 
 import scala.collection.JavaConverters._
+import scala.util.Try
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.{FileSystem, Path}
@@ -585,7 +586,19 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
         getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
       } else {
         logDebug(s"Hive metastore filter is '$filter'.")
-        getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
+        try {
+          getPartitionsByFilterMethod.invoke(hive, table, filter)
+            .asInstanceOf[JArrayList[Partition]]
+        } catch {
+          case e: InvocationTargetException =>
+            // SPARK-18167 retry to investigate the flaky test. This should be reverted before
+            // the release is cut.
+            val retry = Try(getPartitionsByFilterMethod.invoke(hive, table, filter))
+            val full = Try(getAllPartitionsMethod.invoke(hive, table))
+            logError("getPartitionsByFilter failed, retry success = " + retry.isSuccess)
+            logError("getPartitionsByFilter failed, full fetch success = " + full.isSuccess)
+            throw e
+        }
       }
 
     partitions.asScala.toSeq