diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index d474853355e5baa17459876a981706ed584b4eb3..c0845e1a0102fdbee225015f71b3d65b73549c25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import scala.collection.Map
-
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 47b06cae154363d2b084c6aded3f0d08ef53ebc7..42457d5318b4870acd9b62f41462b26e884522a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -165,6 +165,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
  *
  *  - Inserting Projections beneath the following operators:
  *   - Aggregate
+ *   - Generate
  *   - Project <- Join
  *   - LeftSemiJoin
  */
@@ -178,6 +179,21 @@ object ColumnPruning extends Rule[LogicalPlan] {
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
       a.copy(child = Project(a.references.toSeq, child))
 
+    // Eliminate attributes that are not needed to calculate the Generate.
+    case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
+      g.copy(child = Project(g.references.toSeq, g.child))
+
+    case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
+      p.copy(child = g.copy(join = false))
+
+    case p @ Project(projectList, g: Generate) if g.join =>
+      val neededChildOutput = p.references -- g.generatorOutput ++ g.references
+      if (neededChildOutput == g.child.outputSet) {
+        p
+      } else {
+        Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+      }
+
     case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
         if (a.outputSet -- p.references).nonEmpty =>
       Project(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..dbebcb86809de37f7ef73f49bcf0d6c048012ee2
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.Explode
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.types.StringType
+
+class ColumnPruningSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("Column pruning", FixedPoint(100),
+      ColumnPruning) :: Nil
+  }
+
+  test("Column pruning for Generate when Generate.join = false") {
+    val input = LocalRelation('a.int, 'b.array(StringType))
+
+    val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
+    val optimized = Optimize.execute(query)
+
+    val correctAnswer =
+      Generate(Explode('b), false, false, None, 's.string :: Nil,
+        Project('b.attr :: Nil, input)).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Column pruning for Generate when Generate.join = true") {
+    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
+
+    val query =
+      Project(Seq('a, 's),
+        Generate(Explode('c), true, false, None, 's.string :: Nil,
+          input)).analyze
+    val optimized = Optimize.execute(query)
+
+    val correctAnswer =
+      Project(Seq('a, 's),
+        Generate(Explode('c), true, false, None, 's.string :: Nil,
+          Project(Seq('a, 'c),
+            input))).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("Turn Generate.join to false if possible") {
+    val input = LocalRelation('b.array(StringType))
+
+    val query =
+      Project(('s + 1).as("s+1") :: Nil,
+        Generate(Explode('b), true, false, None, 's.string :: Nil,
+          input)).analyze
+    val optimized = Optimize.execute(query)
+
+    val correctAnswer =
+      Project(('s + 1).as("s+1") :: Nil,
+        Generate(Explode('b), false, false, None, 's.string :: Nil,
+          input)).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  // todo: add more tests for column pruning
+}