diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index b3a197cd96e3726800aec223a9e2902bf322f72c..7afdf75f3867a12a8dad89f183def7eae720d3ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => Parq
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.BaseRelation
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, StructType}
 
 object RDDConversions {
   def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
@@ -348,7 +348,8 @@ private[sql] object DataSourceScanExec {
     }
 
     relation match {
-      case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) =>
+      case r: HadoopFsRelation
+        if r.fileFormat.supportBatch(r.sqlContext, StructType.fromAttributes(output)) =>
         BatchedDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata)
       case _ =>
         RowDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index 38c00849529cf4b6d58377d7af8d9ffa525cc00a..bbbbc5ebe9030eb0632ad8d86293dbefef0b838f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -286,10 +286,6 @@ private[sql] class DefaultSource
       SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
       sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP))
 
-    // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
-    val returningBatch =
-      supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields))
-
     // Try to push down filters when filter push-down is enabled.
     val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) {
       filters
@@ -308,8 +304,11 @@ private[sql] class DefaultSource
     // TODO: if you move this into the closure it reverts to the default values.
     // If true, enable using the custom RecordReader for parquet. This only works for
     // a subset of the types (no complex types).
-    val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled &&
-          dataSchema.forall(_.dataType.isInstanceOf[AtomicType])
+    val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields)
+    val enableVectorizedReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled &&
+      resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
+    // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
+    val returningBatch = supportBatch(sqlContext, resultSchema)
 
     (file: PartitionedFile) => {
       assert(file.partitionValues.numFields == partitionSchema.size)
@@ -329,7 +328,7 @@ private[sql] class DefaultSource
       val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
       val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId)
 
-      val parquetReader = if (enableVectorizedParquetReader) {
+      val parquetReader = if (enableVectorizedReader) {
         val vectorizedReader = new VectorizedParquetRecordReader()
         vectorizedReader.initialize(split, hadoopAttemptContext)
         logDebug(s"Appending $partitionSchema ${file.partitionValues}")
@@ -356,7 +355,7 @@ private[sql] class DefaultSource
 
       // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy.
       if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] &&
-          enableVectorizedParquetReader) {
+          enableVectorizedReader) {
         iter.asInstanceOf[Iterator[InternalRow]]
       } else {
         val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 7d206e7bc443d414da12fa768f5487bb3d0d6cc9..ed20c45d5f93f55378bc908f504fe2f98bd5d0bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.execution.BatchedDataSourceScanExec
 import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -589,6 +590,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
       checkAnswer(sqlContext.read.parquet(path), df)
     }
   }
+
+  test("returning batch for wide table") {
+    withSQLConf("spark.sql.codegen.maxFields" -> "100") {
+      withTempPath { dir =>
+        val path = dir.getCanonicalPath
+        val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*)
+        df.write.mode(SaveMode.Overwrite).parquet(path)
+
+        // donot return batch, because whole stage codegen is disabled for wide table (>200 columns)
+        val df2 = sqlContext.read.parquet(path)
+        assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty,
+          "Should not return batch")
+        checkAnswer(df2, df)
+
+        // return batch
+        val columns = Seq.tabulate(90) {i => s"c$i"}
+        val df3 = df2.selectExpr(columns : _*)
+        assert(
+          df3.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isDefined,
+          "Should not return batch")
+        checkAnswer(df3, df.selectExpr(columns : _*))
+      }
+    }
+  }
 }
 
 object TestingUDT {