diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index b3a197cd96e3726800aec223a9e2902bf322f72c..7afdf75f3867a12a8dad89f183def7eae720d3ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => Parq import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -348,7 +348,8 @@ private[sql] object DataSourceScanExec { } relation match { - case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) => + case r: HadoopFsRelation + if r.fileFormat.supportBatch(r.sqlContext, StructType.fromAttributes(output)) => BatchedDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata) case _ => RowDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 38c00849529cf4b6d58377d7af8d9ffa525cc00a..bbbbc5ebe9030eb0632ad8d86293dbefef0b838f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -286,10 +286,6 @@ private[sql] class DefaultSource SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - val returningBatch = - supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields)) - // Try to push down filters when filter push-down is enabled. val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { filters @@ -308,8 +304,11 @@ private[sql] class DefaultSource // TODO: if you move this into the closure it reverts to the default values. // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). - val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled && - dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val enableVectorizedReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled && + resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = supportBatch(sqlContext, resultSchema) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -329,7 +328,7 @@ private[sql] class DefaultSource val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId) - val parquetReader = if (enableVectorizedParquetReader) { + val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader() vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") @@ -356,7 +355,7 @@ private[sql] class DefaultSource // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && - enableVectorizedParquetReader) { + enableVectorizedReader) { iter.asInstanceOf[Iterator[InternalRow]] } else { val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 7d206e7bc443d414da12fa768f5487bb3d0d6cc9..ed20c45d5f93f55378bc908f504fe2f98bd5d0bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.execution.BatchedDataSourceScanExec import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -589,6 +590,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext checkAnswer(sqlContext.read.parquet(path), df) } } + + test("returning batch for wide table") { + withSQLConf("spark.sql.codegen.maxFields" -> "100") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + + // donot return batch, because whole stage codegen is disabled for wide table (>200 columns) + val df2 = sqlContext.read.parquet(path) + assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty, + "Should not return batch") + checkAnswer(df2, df) + + // return batch + val columns = Seq.tabulate(90) {i => s"c$i"} + val df3 = df2.selectExpr(columns : _*) + assert( + df3.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isDefined, + "Should not return batch") + checkAnswer(df3, df.selectExpr(columns : _*)) + } + } + } } object TestingUDT {