diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index af594c25c54cbae1f325b46865e39ab954a4708f..e50971173c499c2312d13fbce9aef6b1335f65f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -275,13 +275,14 @@ package object dsl {
 
       def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
 
-      // TODO specify the output column names
       def generate(
         generator: Generator,
         join: Boolean = false,
         outer: Boolean = false,
-        alias: Option[String] = None): LogicalPlan =
-        Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
+        alias: Option[String] = None,
+        outputNames: Seq[String] = Nil): LogicalPlan =
+        Generate(generator, join = join, outer = outer, alias,
+          outputNames.map(UnresolvedAttribute(_)), logicalPlan)
 
       def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
         InsertIntoTable(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 4a1e7ceaf394b635e3b5345bf6315d86013d42e1..9bf61ae0917869e0835332ed0ec8a05c50192253 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.expressions.Explode
 import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
@@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest {
   test("Column pruning for Generate when Generate.join = false") {
     val input = LocalRelation('a.int, 'b.array(StringType))
 
-    val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
+    val query = input.generate(Explode('b), join = false).analyze
+
     val optimized = Optimize.execute(query)
 
-    val correctAnswer =
-      Generate(Explode('b), false, false, None, 's.string :: Nil,
-        Project('b.attr :: Nil, input)).analyze
+    val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest {
     val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
 
     val query =
-      Project(Seq('a, 's),
-        Generate(Explode('c), true, false, None, 's.string :: Nil,
-          input)).analyze
+      input
+        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
+        .select('a, 'explode)
+        .analyze
+
     val optimized = Optimize.execute(query)
 
     val correctAnswer =
-      Project(Seq('a, 's),
-        Generate(Explode('c), true, false, None, 's.string :: Nil,
-          Project(Seq('a, 'c),
-            input))).analyze
+      input
+        .select('a, 'c)
+        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
+        .select('a, 'explode)
+        .analyze
 
     comparePlans(optimized, correctAnswer)
   }
@@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest {
     val input = LocalRelation('b.array(StringType))
 
     val query =
-      Project(('s + 1).as("s+1") :: Nil,
-        Generate(Explode('b), true, false, None, 's.string :: Nil,
-          input)).analyze
+      input
+        .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
+        .select(('explode + 1).as("result"))
+        .analyze
+
     val optimized = Optimize.execute(query)
 
     val correctAnswer =
-      Project(('s + 1).as("s+1") :: Nil,
-        Generate(Explode('b), false, false, None, 's.string :: Nil,
-          input)).analyze
+      input
+        .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
+        .select(('explode + 1).as("result"))
+        .analyze
 
     comparePlans(optimized, correctAnswer)
   }