diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 976343ed961c554249c04f2a203afe0faf6ffe8d..13a13f0a7e4029563a4eaf604eafff104fe588ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -150,7 +150,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: This does not handle cases where column pruning has been performed. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 2c7c58e66b8551e2dac3fc4f2ecaca737fbca9cc..35884139b6be88f73fb5fa00c9c0e1dc6ede9d9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -17,8 +17,22 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis._ + private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + /** + * Returns the [[Resolver]] for the current configuration, which can be used to determin if two + * identifiers are equal. + */ + def resolver: Resolver = { + if (caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b32c7d0fcbaa4b2f7c1c582ff0806199bcbf25f3..c8aadb2ed53403538d8080a2dcf5da78d31d83d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} +import org.apache.spark.sql.types.StructType abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -116,6 +117,23 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) + /** + * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function + * should only be called on analyzed plans since it will throw [[AnalysisException]] for + * unresolved [[Attribute]]s. + */ + def resolve(schema: StructType, resolver: Resolver): Seq[Attribute] = { + schema.map { field => + resolveQuoted(field.name, resolver).map { + case a: AttributeReference => a + case other => sys.error(s"can not handle nested schema yet... plan $this") + }.getOrElse { + throw new AnalysisException( + s"Unable to resolve ${field.name} given [${output.map(_.name).mkString(", ")}]") + } + } + } + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d363cb000d39a52671b5fb297fa5cf49b6c5913e..e97c6be7f177a1a57abfeae3e5510920b28b5101 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -151,7 +151,7 @@ private[sql] case class DataSourceScan( override val outputPartitioning = { val bucketSpec = relation match { // TODO: this should be closer to bucket planning. - case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled() => r.bucketSpec + case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index d1569a4ec2b4068c2a36d9c2d6b731d8473a404c..292d366e727d3d7a042f7568bc484dcf9a1bad7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val sparkContext: SparkContext = sqlContext.sparkContext @@ -29,6 +29,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { def strategies: Seq[Strategy] = sqlContext.experimental.extraStrategies ++ ( + FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 887f5469b5f8fbffeba84a0b491c1788dc5a5b2c..e65a771202bce238f5b0b711240ed63be4a13e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -143,7 +143,7 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths) + val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sqlContext, @@ -208,7 +208,20 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths) + // If they gave a schema, then we try and figure out the types of the partition columns + // from that schema. + val partitionSchema = userSpecifiedSchema.map { schema => + StructType( + partitionColumns.map { c => + // TODO: Case sensitivity. + schema + .find(_.name.toLowerCase() == c.toLowerCase()) + .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) + }) + } + + val fileCatalog: FileCatalog = + new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sqlContext, @@ -220,22 +233,11 @@ case class DataSource( "It must be specified manually") } - // If they gave a schema, then we try and figure out the types of the partition columns - // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => - StructType( - partitionColumns.map { c => - // TODO: Case sensitivity. - schema - .find(_.name.toLowerCase() == c.toLowerCase()) - .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) - }) - }.getOrElse(fileCatalog.partitionSpec(None).partitionColumns) HadoopFsRelation( sqlContext, fileCatalog, - partitionSchema = partitionSchema, + partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -296,7 +298,7 @@ case class DataSource( resolveRelation() .asInstanceOf[HadoopFsRelation] .location - .partitionSpec(None) + .partitionSpec() .partitionColumns .fieldNames .toSet) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1adf3b6676555389464d606f8900a866bcaf3a27..7f6671552ebde800747ba0fdd401716bd3c89f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val partitionAndNormalColumnFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet - val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray + val selectedPartitions = t.location.listFiles(partitionFilters) logInfo { val total = t.partitionSpec.partitions.length @@ -180,7 +180,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) t.bucketSpec match { - case Some(spec) if t.sqlContext.conf.bucketingEnabled() => + case Some(spec) if t.sqlContext.conf.bucketingEnabled => val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { val bucketed = @@ -200,7 +200,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { requiredColumns.map(_.name).toArray, filters, None, - bucketFiles.toArray, + bucketFiles, confBroadcast, t.options).coalesce(1) } @@ -233,7 +233,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { a.map(_.name).toArray, f, None, - t.location.allFiles().toArray, + t.location.allFiles(), confBroadcast, t.options)) :: Nil } @@ -255,7 +255,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters: Seq[Expression], buckets: Option[BitSet], partitionColumns: StructType, - partitions: Array[Partition], + partitions: Seq[Partition], options: Map[String, String]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] @@ -272,14 +272,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { relation.bucketSpec match { - case Some(spec) if relation.sqlContext.conf.bucketingEnabled() => + case Some(spec) if relation.sqlContext.conf.bucketingEnabled => val requiredDataColumns = requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) // Builds RDD[Row]s for each selected partition. val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap { - case Partition(partitionValues, dir) => - val files = relation.location.getStatus(dir) + case Partition(partitionValues, files) => val bucketed = files.groupBy { f => BucketingUtils .getBucketId(f.getPath.getName) @@ -327,14 +326,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { - case Partition(partitionValues, dir) => + case Partition(partitionValues, files) => val dataRows = relation.fileFormat.buildInternalScan( relation.sqlContext, relation.dataSchema, requiredDataColumns.map(_.name).toArray, filters, buckets, - relation.location.getStatus(dir), + files, confBroadcast, options) @@ -525,33 +524,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) } - protected def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[Partition] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = - partitionPruningPredicates - .reduceOption(expressions.And) - .getOrElse(Literal(true)) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - partitions.filter { case Partition(values, _) => boundPredicate(values) } - } else { - partitions - } - } - // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..e2cbbc34d91a4a1cc2b505066394a73b4dbaef76 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow + +/** + * A single file that should be read, along with partition column values that + * need to be prepended to each row. The reading should start at the first + * valid record found after `offset`. + */ +case class PartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long) + +/** + * A collection of files that should be read as a single task possibly from multiple partitioned + * directories. + * + * IMPLEMENT ME: This is just a placeholder for a future implementation. + * TODO: This currently does not take locality information about the files into account. + */ +case class FilePartition(val index: Int, files: Seq[PartitionedFile]) extends Partition + +class FileScanRDD( + @transient val sqlContext: SQLContext, + readFunction: (PartitionedFile) => Iterator[InternalRow], + @transient val filePartitions: Seq[FilePartition]) + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + throw new NotImplementedError("Not Implemented Yet") + } + + override protected def getPartitions: Array[Partition] = Array.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala new file mode 100644 index 0000000000000000000000000000000000000000..ef95d5d28961fad5510b08506056bde1839b296b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +/** + * A strategy for planning scans over collections of files that might be partitioned or bucketed + * by user specified columns. + * + * At a high level planning occurs in several phases: + * - Split filters by when they need to be evaluated. + * - Prune the schema of the data requested based on any projections present. Today this pruning + * is only done on top level columns, but formats should support pruning of nested columns as + * well. + * - Construct a reader function by passing filters and the schema into the FileFormat. + * - Using an partition pruning predicates, enumerate the list of files that should be read. + * - Split the files into tasks and construct a FileScanRDD. + * - Add any projection or filters that must be evaluated after the scan. + * + * Files are assigned into tasks using the following algorithm: + * - If the table is bucketed, group files by bucket id into the correct number of partitions. + * - If the table is not bucketed or bucketing is turned off: + * - If any file is larger than the threshold, split it into pieces based on that threshold + * - Sort the files by decreasing file size. + * - Assign the ordered files to buckets using the following algorithm. If the current partition + * is under the threshold with the addition of the next file, add it. If not, open a new bucket + * and add it. Proceed to the next file. + */ +private[sql] object FileSourceStrategy extends Strategy with Logging { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projects, filters, l@LogicalRelation(files: HadoopFsRelation, _, _)) + if files.fileFormat.toString == "TestFileFormat" => + // Filters on this relation fall into four categories based on where we can use them to avoid + // reading unneeded data: + // - partition keys only - used to prune directories to read + // - bucket keys only - optionally used to prune files to read + // - keys stored in the data only - optionally used to skip groups of data in files + // - filters that need to be evaluated again after the scan + val filterSet = ExpressionSet(filters) + + val partitionColumns = + AttributeSet(l.resolve(files.partitionSchema, files.sqlContext.analyzer.resolver)) + val partitionKeyFilters = + ExpressionSet(filters.filter(_.references.subsetOf(partitionColumns))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + + val bucketColumns = + AttributeSet( + files.bucketSpec + .map(_.bucketColumnNames) + .getOrElse(Nil) + .map(l.resolveQuoted(_, files.sqlContext.conf.resolver) + .getOrElse(sys.error("")))) + + // Partition keys are not available in the statistics of the files. + val dataFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) + + // Predicates with both partition keys and attributes need to be evaluated after the scan. + val afterScanFilters = filterSet -- partitionKeyFilters + logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}") + + val selectedPartitions = files.location.listFiles(partitionKeyFilters.toSeq) + + val filterAttributes = AttributeSet(afterScanFilters) + val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects + val requiredAttributes = AttributeSet(requiredExpressions).map(_.name).toSet + + val prunedDataSchema = + StructType( + files.dataSchema.filter(f => requiredAttributes.contains(f.name))) + logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}") + + val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + + val readFile = files.fileFormat.buildReader( + sqlContext = files.sqlContext, + partitionSchema = files.partitionSchema, + dataSchema = prunedDataSchema, + filters = pushedDownFilters, + options = files.options) + + val plannedPartitions = files.bucketSpec match { + case Some(bucketing) if files.sqlContext.conf.bucketingEnabled => + logInfo(s"Planning with ${bucketing.numBuckets} buckets") + val bucketed = + selectedPartitions + .flatMap { p => + p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen)) + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + + (0 until bucketing.numBuckets).map { bucketId => + FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + } + + case _ => + val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes") + + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + assert(file.getLen != 0) + (0L to file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size) + } + } + }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + var currentSize = 0L + + /** Add the given file to the current partition. */ + def addFile(file: PartitionedFile): Unit = { + currentSize += file.length + currentFiles.append(file) + } + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + val newPartition = + FilePartition( + partitions.size, + currentFiles.toArray.toSeq) // Copy to a new Array. + partitions.append(newPartition) + } + currentFiles.clear() + currentSize = 0 + } + + // Assign files to partitions using "First Fit Decreasing" (FFD) + // TODO: consider adding a slop factor here? + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() + addFile(file) + } else { + addFile(file) + } + } + closePartition() + partitions + } + + val scan = + DataSourceScan( + l.output, + new FileScanRDD( + files.sqlContext, + readFile, + plannedPartitions), + files, + Map("format" -> files.fileFormat.toString)) + + val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) + val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan) + val withProjections = if (projects.forall(_.isInstanceOf[AttributeReference])) { + withFilter + } else { + execution.Project(projects, withFilter) + } + + withProjections :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 18a460fc85ed4aa64b2c344c7ca8ba04578b8ab4..3ac2ff494fa819119112c0a11c0b1c4559d3a5f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,17 +32,23 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ -object Partition { - def apply(values: InternalRow, path: String): Partition = +object PartitionDirectory { + def apply(values: InternalRow, path: String): PartitionDirectory = apply(values, new Path(path)) } -private[sql] case class Partition(values: InternalRow, path: Path) +/** + * Holds a directory in a partitioned collection of files as well as as the partition values + * in the form of a Row. Before scanning, the files at `path` need to be enumerated. + */ +private[sql] case class PartitionDirectory(values: InternalRow, path: Path) -private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] case class PartitionSpec( + partitionColumns: StructType, + partitions: Seq[PartitionDirectory]) private[sql] object PartitionSpec { - val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) } private[sql] object PartitioningUtils { @@ -133,7 +139,7 @@ private[sql] object PartitioningUtils { // Finally, we create `Partition`s based on paths and resolved partition values. val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) => - Partition(InternalRow.fromSeq(literals.map(_.value)), path) + PartitionDirectory(InternalRow.fromSeq(literals.map(_.value)), path) } PartitionSpec(StructType(fields), partitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 0e6b9855c70decf886e4c1a7e2bc1a991fe76d77..c96a508cf1baac51a5b25ec051dd457dc8515280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -52,7 +52,7 @@ object CSVRelation extends Logging { tokenizedRDD: RDD[Array[String]], schema: StructType, requiredColumns: Array[String], - inputs: Array[FileStatus], + inputs: Seq[FileStatus], sqlContext: SQLContext, params: CSVOptions): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 42c07c8a23f5ec89a892d611e733e61c6c2f79a3..a5f94262ff402e962c9f3623c1fe14fa0337a584 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -103,7 +103,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: Filter before calling buildInternalScan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 05b44d1a2a04fde00782f36335c5ba53fcfdbe8b..3fa5ebf1bb81e64d567dfe5393c7c1fbc715fc8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -95,7 +95,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: Filter files for all formats before calling buildInternalScan. @@ -115,7 +115,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - private def createBaseRdd(sqlContext: SQLContext, inputPaths: Array[FileStatus]): RDD[String] = { + private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = { val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index f1060074d6acb2fd9731132a2fc247af8224ae3f..342034ca0ff92c312c46ff71729198bfc6e3aabf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -274,7 +274,7 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister with requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - allFiles: Array[FileStatus], + allFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 2869a6a1ac0745f88cad8971b99645a85c36677c..6af403dec5fba316ae2d7ed062c6c02d296a8340 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -94,7 +94,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { verifySchema(dataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 59429d254ebb6cafe00edf60b29ebb1e0925f440..cbdc37a2a1622ec907dd1783a1e21e7aa7b87436 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -504,6 +504,11 @@ object SQLConf { " method", isPublic = false) + val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes", + defaultValue = Some(128 * 1024 * 1024), // parquet.block.size + doc = "The maximum number of bytes to pack into a single partition when reading files.", + isPublic = true) + val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", defaultValue = Some(true), doc = "When true, the planner will try to find out duplicated exchanges and re-use them", @@ -538,6 +543,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin /** ************************ Spark SQL Params/Hints ******************* */ + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) @@ -605,7 +612,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) - def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 601f944fb6363e79efbbdadc3264d1a9c384a761..95ffc33011e8e78d6737ca068077bda3c12f2522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -30,11 +30,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet @@ -409,7 +409,7 @@ case class HadoopFsRelation( def partitionSchemaOption: Option[StructType] = if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec(partitionSchemaOption) + def partitionSpec: PartitionSpec = location.partitionSpec() def refresh(): Unit = location.refresh() @@ -454,11 +454,41 @@ trait FileFormat { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] + + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be prepended to the rows that + * are produced by the iterator. + * @param dataSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + // TODO: Remove this default implementation when the other formats have been ported + // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } } +/** + * A collection of data files from a partitioned relation, along with the partition values in the + * form of an [[InternalRow]]. + */ +case class Partition(values: InternalRow, files: Seq[FileStatus]) + /** * An interface for objects capable of enumerating the files that comprise a relation as well * as the partitioning characteristics of those files. @@ -466,7 +496,18 @@ trait FileFormat { trait FileCatalog { def paths: Seq[Path] - def partitionSpec(schema: Option[StructType]): PartitionSpec + def partitionSpec(): PartitionSpec + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + def listFiles(filters: Seq[Expression]): Seq[Partition] def allFiles(): Seq[FileStatus] @@ -478,11 +519,17 @@ trait FileCatalog { /** * A file catalog that caches metadata gathered by scanning all the files present in `paths` * recursively. + * + * @param parameters as set of options to control discovery + * @param paths a list of paths to scan + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions */ class HDFSFileCatalog( val sqlContext: SQLContext, val parameters: Map[String, String], - val paths: Seq[Path]) + val paths: Seq[Path], + val partitionSchema: Option[StructType]) extends FileCatalog with Logging { private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) @@ -491,9 +538,9 @@ class HDFSFileCatalog( var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] var cachedPartitionSpec: PartitionSpec = _ - def partitionSpec(schema: Option[StructType]): PartitionSpec = { + def partitionSpec(): PartitionSpec = { if (cachedPartitionSpec == null) { - cachedPartitionSpec = inferPartitioning(schema) + cachedPartitionSpec = inferPartitioning(partitionSchema) } cachedPartitionSpec @@ -501,6 +548,53 @@ class HDFSFileCatalog( refresh() + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + if (partitionSpec().partitionColumns.isEmpty) { + Partition(InternalRow.empty, allFiles()) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionDirectory(values, path) => Partition(values, getStatus(path)) + } + } + } + + protected def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = + partitionPruningPredicates + .reduceOption(expressions.And) + .getOrElse(Literal(true)) + + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + val selected = partitions.filter { + case PartitionDirectory(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } + + selected + } else { + partitions + } + } + def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) @@ -560,7 +654,7 @@ class HDFSFileCatalog( PartitionSpec(userProvidedSchema, spec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) - case None => + case _ => PartitioningUtils.parsePartitions( leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..2f8129c5da40d90e9fa8afd6b42dec3a4b80f9fc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.{File, FilenameFilter} + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.execution.{DataSourceScan, PhysicalRDD} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.BitSet + +class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { + import testImplicits._ + + test("unpartitioned table, single partition") { + val table = + createTable( + files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1, + "file7" -> 1, + "file8" -> 1, + "file9" -> 1, + "file10" -> 1)) + + checkScan(table.select('c1)) { partitions => + // 10 one byte files should fit in a single partition with 10 files. + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 10, "when checking partition 1") + // 1 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 1) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + + test("unpartitioned table, multiple partitions") { + val table = + createTable( + files = Seq( + "file1" -> 5, + "file2" -> 5, + "file3" -> 5)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + checkScan(table.select('c1)) { partitions => + // 5 byte files should be laid out [(5, 5), (5)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") + assert(partitions(1).files.size == 1, "when checking partition 2") + + // 5 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("Unpartitioned table, large file that gets split") { + val table = + createTable( + files = Seq( + "file1" -> 15, + "file2" -> 4)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + checkScan(table.select('c1)) { partitions => + // Files should be laid out [(0-5), (5-10, 4)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") + assert(partitions(1).files.size == 2, "when checking partition 2") + + // Start by reading 10 bytes of the first file + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 10) + + // Second partition reads the remaining 5 + assert(partitions(1).files.head.start == 10) + assert(partitions(1).files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("partitioned table") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("p1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + + test("partitioned table - after scan filters") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) + // Need to evalaute filters that are not pushed down. + assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + // Don't reevaluate partition only filters. + assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + } + + test("bucketed table") { + val table = + createTable( + files = Seq( + "p1=1/file1_0000" -> 1, + "p1=1/file2_0000" -> 1, + "p1=1/file3_0002" -> 1, + "p1=2/file4_0002" -> 1, + "p1=2/file5_0000" -> 1, + "p1=2/file6_0000" -> 1, + "p1=2/file7_0000" -> 1), + buckets = 3) + + // No partition pruning + checkScan(table) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 5) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 2) + } + + // With partition pruning + checkScan(table.where("p1=2")) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 3) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 1) + } + } + + // Helpers for checking the arguments passed to the FileFormat. + + protected val checkPartitionSchema = + checkArgument("partition schema", _.partitionSchema, _: StructType) + protected val checkDataSchema = + checkArgument("data schema", _.dataSchema, _: StructType) + protected val checkDataFilters = + checkArgument("data filters", _.filters.toSet, _: Set[Filter]) + + /** Helper for building checks on the arguments passed to the reader. */ + protected def checkArgument[T](name: String, arg: LastArguments.type => T, expected: T): Unit = { + if (arg(LastArguments) != expected) { + fail( + s""" + |Wrong $name + |expected: $expected + |actual: ${arg(LastArguments)} + """.stripMargin) + } + } + + /** Returns a resolved expression for `str` in the context of `df`. */ + def resolve(df: DataFrame, str: String): Expression = { + df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head + } + + /** Returns a set with all the filters present in the physical plan. */ + def getPhysicalFilters(df: DataFrame): ExpressionSet = { + ExpressionSet( + df.queryExecution.executedPlan.collect { + case execution.Filter(f, _) => splitConjunctivePredicates(f) + }.flatten) + } + + /** Plans the query and calls the provided validation function with the planned partitioning. */ + def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { + val fileScan = df.queryExecution.executedPlan.collect { + case DataSourceScan(_, scan: FileScanRDD, _, _) => scan + }.headOption.getOrElse { + fail(s"No FileScan in query\n${df.queryExecution}") + } + + func(fileScan.filePartitions) + } + + /** + * Constructs a new table given a list of file names and sizes expressed in bytes. The table + * is written out in a temporary directory and any nested directories in the files names + * are automatically created. + * + * When `buckets` is > 0 the returned [[DataFrame]] will have metadata specifying that number of + * buckets. However, it is the responsibility of the caller to assign files to each bucket + * by appending the bucket id to the file names. + */ + def createTable( + files: Seq[(String, Int)], + buckets: Int = 0): DataFrame = { + val tempDir = Utils.createTempDir() + files.foreach { + case (name, size) => + val file = new File(tempDir, name) + assert(file.getParentFile.exists() || file.getParentFile.mkdirs()) + util.stringToFile(file, "*" * size) + } + + val df = sqlContext.read + .format(classOf[TestFileFormat].getName) + .load(tempDir.getCanonicalPath) + + if (buckets > 0) { + val bucketed = df.queryExecution.analyzed transform { + case l @ LogicalRelation(r: HadoopFsRelation, _, _) => + l.copy(relation = + r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) + } + Dataset.newDataFrame(sqlContext, bucketed) + } else { + df + } + } +} + +/** Holds the last arguments passed to [[TestFileFormat]]. */ +object LastArguments { + var partitionSchema: StructType = _ + var dataSchema: StructType = _ + var filters: Seq[Filter] = _ + var options: Map[String, String] = _ +} + +/** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */ +class TestFileFormat extends FileFormat { + + override def toString: String = "TestFileFormat" + + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = + Some( + StructType(Nil) + .add("c1", IntegerType) + .add("c2", IntegerType)) + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildInternalScan( + sqlContext: SQLContext, + dataSchema: StructType, + requiredColumns: Array[String], + filters: Array[Filter], + bucketSet: Option[BitSet], + inputFiles: Seq[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration], + options: Map[String, String]): RDD[InternalRow] = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + + // Record the arguments so they can be checked in the test case. + LastArguments.partitionSchema = partitionSchema + LastArguments.dataSchema = dataSchema + LastArguments.filters = filters + LastArguments.options = options + + (file: PartitionedFile) => { Iterator.empty } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 026191528ede43ccbe77a4351f886253765c3eb6..f875b54cd664921946aa52c9077f9beee7ab13af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8f6cd66f1f6813c8426f222bb9caf1bc709f7ddd..c70510b4834d6da4874811dbf713911e50998b97 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,11 +41,11 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.{datasources, FileRelation} -import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource, ParquetRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.sources.{HadoopFsRelation, HDFSFileCatalog} import org.apache.spark.sql.types._ private[hive] case class HiveSerDe( @@ -469,7 +469,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte parquetRelation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[datasources.Partition]) + PartitionSpec(StructType(Nil), Array.empty[datasources.PartitionDirectory]) } if (useCached) { @@ -499,7 +499,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) - ParquetPartition(values, location) + PartitionDirectory(values, location) } val partitionSpec = PartitionSpec(partitionSchema, partitions) @@ -753,7 +753,7 @@ class MetaStoreFileCatalog( hive: HiveContext, paths: Seq[Path], partitionSpecFromHive: PartitionSpec) - extends HDFSFileCatalog(hive, Map.empty, paths) { + extends HDFSFileCatalog(hive, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { override def getStatus(path: Path): Array[FileStatus] = { @@ -761,7 +761,7 @@ class MetaStoreFileCatalog( fs.listStatus(path) } - override def partitionSpec(schema: Option[StructType]): PartitionSpec = partitionSpecFromHive + override def partitionSpec(): PartitionSpec = partitionSpecFromHive } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 614f9e05d76f97fd98ffc053201d266bab86b1f9..cbb6333336383033b17d49eaadb00f884ccf9b7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -79,6 +79,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) override def strategies: Seq[Strategy] = { ctx.experimental.extraStrategies ++ Seq( + FileSourceStrategy, DataSourceStrategy, HiveCommandStrategy(ctx), HiveDDLStrategy, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 8a39d95fc5677bbc0bc887f2d3c8dfcd35a8a4c1..ae041c5137f0f150ac39d61b46e270b359753686 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -111,7 +111,7 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes @@ -221,7 +221,7 @@ private[orc] case class OrcTableScan( @transient sqlContext: SQLContext, attributes: Seq[Attribute], filters: Array[Filter], - @transient inputPaths: Array[FileStatus]) + @transient inputPaths: Seq[FileStatus]) extends Logging with HiveInspectors {