diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0f4caec7451a27d27672095150f521e47a9c74d6..648a65e7c0eb35849f7d71ee3f5b7edd9ab8b90b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -228,10 +228,21 @@ object ColumnPruning extends Rule[LogicalPlan] {
     case Project(projectList, Limit(exp, child)) =>
       Limit(exp, Project(projectList, child))
 
-    // Push down project if possible when the child is sort
-    case p @ Project(projectList, s @ Sort(_, _, grandChild))
-      if s.references.subsetOf(p.outputSet) =>
-      s.copy(child = Project(projectList, grandChild))
+    // Push down project if possible when the child is sort.
+    case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
+      if (s.references.subsetOf(p.outputSet)) {
+        s.copy(child = Project(projectList, grandChild))
+      } else {
+        val neededReferences = s.references ++ p.references
+        if (neededReferences == grandChild.outputSet) {
+          // No column we can prune, return the original plan.
+          p
+        } else {
+          // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
+          val newProjectList = grandChild.output.filter(neededReferences.contains)
+          p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
+        }
+      }
 
     // Eliminate no-op Projects
     case Project(projectList, child) if child.output == projectList => child
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index dbebcb86809de37f7ef73f49bcf0d6c048012ee2..4a1e7ceaf394b635e3b5345bf6315d86013d42e1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -80,5 +80,16 @@ class ColumnPruningSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
+  test("Column pruning for Project on Sort") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+
+    val query = input.orderBy('b.asc).select('a).analyze
+    val optimized = Optimize.execute(query)
+
+    val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
   // todo: add more tests for column pruning
 }