diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 7bbaf6eb94453cfe3470cbfa47ca100c6e79772d..b8dc5f95906b5e0a44c35355092acd6f383b5398 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -182,10 +182,15 @@ case class CatalogTable(
 
   import CatalogTable._
 
-  /** schema of this table's partition columns */
-  def partitionSchema: StructType = StructType(schema.filter {
-    c => partitionColumnNames.contains(c.name)
-  })
+  /**
+   * schema of this table's partition columns
+   */
+  def partitionSchema: StructType = {
+    val partitionFields = schema.takeRight(partitionColumnNames.length)
+    assert(partitionFields.map(_.name) == partitionColumnNames)
+
+    StructType(partitionFields)
+  }
 
   /** Return the database this table was specified to belong to, assuming it exists. */
   def database: String = identifier.database.getOrElse {