diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d20b42de22706407bccfc5bd5673b589a80c7b0e..b42a52ebd2f169c3e7dcd990d735e455b9b708da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -446,7 +446,7 @@ class SQLContext(@transient val sparkContext: SparkContext) baseRelationToDataFrame(parquet.ParquetRelation2(path +: paths, Map.empty)(this)) } else { DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + (path +: paths).mkString(","), Some(sparkContext.hadoopConfiguration), this)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cd17fde46ab71087af3e34c9e19f40346e92b8..7dd8bea49b8a523cd408f8c789d9bf639b081092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -647,6 +647,6 @@ private[parquet] object FileSystemHelper { sys.error("ERROR: attempting to append to set of Parquet files and found file" + s"that does not match name pattern: $other") case _ => 0 - }.reduceLeft((a, b) => if (a < b) b else a) + }.reduceOption(_ max _).getOrElse(0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 538d774eb97eb0e645f958c8ef726593da44e483..d0856df8d4f43b4a3f3216787e4ed84384228440 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -37,7 +37,8 @@ import org.apache.spark.util.Utils trait ParquetTest { val sqlContext: SQLContext - import sqlContext._ + import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} + import sqlContext.{conf, sparkContext} protected def configuration = sparkContext.hadoopConfiguration @@ -49,11 +50,11 @@ trait ParquetTest { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(getConf(key)).toOption) - (keys, values).zipped.foreach(setConf) + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => setConf(key, value) + case (key, Some(value)) => conf.setConf(key, value) case (key, None) => conf.unsetConf(key) } } @@ -88,7 +89,6 @@ trait ParquetTest { protected def withParquetFile[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: String => Unit): Unit = { - import sqlContext.implicits._ withTempPath { file => sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) f(file.getCanonicalPath) @@ -102,14 +102,14 @@ trait ParquetTest { protected def withParquetRDD[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(parquetFile(path))) + withParquetFile(data)(path => f(sqlContext.parquetFile(path))) } /** * Drops temporary table `tableName` after calling `f`. */ protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally dropTempTable(tableName) + try f finally sqlContext.dropTempTable(tableName) } /** @@ -125,4 +125,26 @@ trait ParquetTest { withTempTable(tableName)(f) } } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 3a9f0600617bef9be460b46397f317dbe70eb1a8..9279f5a903f55660b4652f471474609ba4ebe486 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -20,7 +20,7 @@ import java.io.IOException import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.text.SimpleDateFormat -import java.util.{List => JList, Date} +import java.util.{Date, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer @@ -34,8 +34,9 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetInputFormat, _} +import parquet.hadoop.metadata.CompressionCodecName import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetInputFormat, _} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil @@ -44,21 +45,36 @@ import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLConf, SQLContext} import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} -import org.apache.spark.{Partition => SparkPartition, TaskContext, SerializableWritable, Logging, SparkException} - +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} /** - * Allows creation of parquet based tables using the syntax - * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option - * required is `path`, which should be the location of a collection of, optionally partitioned, - * parquet files. + * Allows creation of Parquet based tables using the syntax: + * {{{ + * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...) + * }}} + * + * Supported options include: + * + * - `path`: Required. When reading Parquet files, `path` should point to the location of the + * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files. + * In the latter case, this data source tries to discover partitioning information if the the + * directory is structured in the same style of Hive partitioned tables. When writing Parquet + * file, `path` should point to the destination folder. + * + * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but + * compatible) schemas stored in all Parquet part-files. + * + * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is + * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration + * in Hive. */ class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { + private def checkPath(parameters: Map[String, String]): String = { parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) } @@ -70,6 +86,7 @@ class DefaultSource ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) } + /** Returns a new base relation with the given parameters and schema. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String], @@ -77,6 +94,7 @@ class DefaultSource ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) } + /** Returns a new base relation with the given parameters and save given data into it. */ override def createRelation( sqlContext: SQLContext, mode: SaveMode, @@ -85,33 +103,19 @@ class DefaultSource val path = checkPath(parameters) val filesystemPath = new Path(path) val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val doSave = if (fs.exists(filesystemPath)) { - mode match { - case SaveMode.Append => - sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") - case SaveMode.Overwrite => - fs.delete(filesystemPath, true) - true - case SaveMode.ErrorIfExists => - sys.error(s"path $path already exists.") - case SaveMode.Ignore => false - } - } else { - true + val doInsertion = (mode, fs.exists(filesystemPath)) match { + case (SaveMode.ErrorIfExists, true) => + sys.error(s"path $path already exists.") + case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists } - val relation = if (doSave) { - // Only save data when the save mode is not ignore. - ParquetRelation.createEmpty( - path, - data.schema.toAttributes, - false, - sqlContext.sparkContext.hadoopConfiguration, - sqlContext) - - val createdRelation = createRelation(sqlContext, parameters, data.schema) - createdRelation.asInstanceOf[ParquetRelation2].insert(data, true) - + val relation = if (doInsertion) { + val createdRelation = + createRelation(sqlContext, parameters, data.schema).asInstanceOf[ParquetRelation2] + createdRelation.insert(data, overwrite = mode == SaveMode.Overwrite) createdRelation } else { // If the save mode is Ignore, we will just create the relation based on existing data. @@ -122,37 +126,31 @@ class DefaultSource } } -private[parquet] case class Partition(values: Row, path: String) +private[sql] case class Partition(values: Row, path: String) -private[parquet] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * currently not intended as a full replacement of the parquet support in Spark SQL though it is - * likely that it will eventually subsume the existing physical plan implementation. - * - * Compared with the current implementation, this class has the following notable differences: - * - * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/` - * located at `path`. Currently only a single partitioning column is supported and it must - * be an integer. This class supports both fully self-describing data, which contains the partition - * key, and data where the partition key is only present in the folder structure. The presence - * of the partitioning key in the data is also auto-detected. The `null` partition is not yet - * supported. + * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will + * be deprecated and eventually removed once this version is proved to be stable enough. * - * Metadata: The metadata is automatically discovered by reading the first parquet file present. - * There is currently no support for working with files that have different schema. Additionally, - * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached - * to improve the speed of interactive querying. When data is added to a table it must be dropped - * and recreated to pick up any changes. + * Compared with the old implementation, this class has the following notable differences: * - * Statistics: Statistics for the size of the table are automatically populated during metadata - * discovery. + * - Partitioning discovery: Hive style multi-level partitions are auto discovered. + * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source + * can detect and merge schemas from all Parquet part-files as long as they are compatible. + * Also, metadata and [[FileStatus]]es are cached for better performance. + * - Statistics: Statistics for the size of the table are automatically populated during schema + * discovery. */ @DeveloperApi -case class ParquetRelation2 - (paths: Seq[String], parameters: Map[String, String], maybeSchema: Option[StructType] = None) - (@transient val sqlContext: SQLContext) +case class ParquetRelation2( + paths: Seq[String], + parameters: Map[String, String], + maybeSchema: Option[StructType] = None, + maybePartitionSpec: Option[PartitionSpec] = None)( + @transient val sqlContext: SQLContext) extends CatalystScan with InsertableRelation with SparkHadoopMapReduceUtil @@ -175,43 +173,90 @@ case class ParquetRelation2 override def equals(other: Any) = other match { case relation: ParquetRelation2 => + // If schema merging is required, we don't compare the actual schemas since they may evolve. + val schemaEquality = if (shouldMergeSchemas) { + shouldMergeSchemas == relation.shouldMergeSchemas + } else { + schema == relation.schema + } + paths.toSet == relation.paths.toSet && + schemaEquality && maybeMetastoreSchema == relation.maybeMetastoreSchema && - (shouldMergeSchemas == relation.shouldMergeSchemas || schema == relation.schema) + maybePartitionSpec == relation.maybePartitionSpec + case _ => false } private[sql] def sparkContext = sqlContext.sparkContext - @transient private val fs = FileSystem.get(sparkContext.hadoopConfiguration) - private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. private var metadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all "_common_metadata" files. private var commonMetadataStatuses: Array[FileStatus] = _ + + // Parquet footer cache. private var footers: Map[FileStatus, Footer] = _ - private var parquetSchema: StructType = _ + // `FileStatus` objects of all data files (Parquet part-files). var dataStatuses: Array[FileStatus] = _ + + // Partition spec of this table, including names, data types, and values of each partition + // column, and paths of each partition. var partitionSpec: PartitionSpec = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var parquetSchema: StructType = _ + + // Schema of the whole table, including partition columns. var schema: StructType = _ - var dataSchemaIncludesPartitionKeys: Boolean = _ + // Indicates whether partition columns are also included in Parquet data file schema. If not, + // we need to fill in partition column values into read rows when scanning the table. + var partitionKeysIncludedInParquetSchema: Boolean = _ + + def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = { + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + ParquetRelation.enableLogForwarding() + ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf) + } + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ def refresh(): Unit = { - val baseStatuses = { - val statuses = paths.distinct.map(p => fs.getFileStatus(fs.makeQualified(new Path(p)))) - // Support either reading a collection of raw Parquet part-files, or a collection of folders - // containing Parquet files (e.g. partitioned Parquet table). - assert(statuses.forall(!_.isDir) || statuses.forall(_.isDir)) - statuses.toArray - } + val fs = FileSystem.get(sparkContext.hadoopConfiguration) + + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + val baseStatuses = paths.distinct.map { p => + val qualified = fs.makeQualified(new Path(p)) + + if (!fs.exists(qualified) && maybeSchema.isDefined) { + fs.mkdirs(qualified) + prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration) + } + + fs.getFileStatus(qualified) + }.toArray + assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. val leaves = baseStatuses.flatMap { f => - val statuses = SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => isSummaryFile(f.getPath) || !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) } - assert(statuses.nonEmpty, s"${f.getPath} is an empty folder.") - statuses } dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) @@ -225,13 +270,14 @@ case class ParquetRelation2 f -> new Footer(f.getPath, parquetMetadata) }.seq.toMap - partitionSpec = { - val partitionDirs = dataStatuses + partitionSpec = maybePartitionSpec.getOrElse { + val partitionDirs = leaves .filterNot(baseStatuses.contains) .map(_.getPath.getParent) .distinct if (partitionDirs.nonEmpty) { + // Parses names and values of partition columns, and infer their data types. ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName) } else { // No partition directories found, makes an empty specification @@ -241,20 +287,22 @@ case class ParquetRelation2 parquetSchema = maybeSchema.getOrElse(readSchema()) - dataSchemaIncludesPartitionKeys = + partitionKeysIncludedInParquetSchema = isPartitioned && - partitionColumns.forall(f => metadataCache.parquetSchema.fieldNames.contains(f.name)) + partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name)) schema = { - val fullParquetSchema = if (dataSchemaIncludesPartitionKeys) { - metadataCache.parquetSchema + val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) { + parquetSchema } else { - StructType(metadataCache.parquetSchema.fields ++ partitionColumns.fields) + StructType(parquetSchema.fields ++ partitionColumns.fields) } + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // insensitivity issue and possible schema mismatch. maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullParquetSchema)) - .getOrElse(fullParquetSchema) + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema)) + .getOrElse(fullRelationSchema) } } @@ -303,13 +351,17 @@ case class ParquetRelation2 @transient private val metadataCache = new MetadataCache metadataCache.refresh() - private def partitionColumns = metadataCache.partitionSpec.partitionColumns + def partitionSpec = metadataCache.partitionSpec - private def partitions = metadataCache.partitionSpec.partitions + def partitionColumns = metadataCache.partitionSpec.partitionColumns - private def isPartitioned = partitionColumns.nonEmpty + def partitions = metadataCache.partitionSpec.partitions - private def dataSchemaIncludesPartitionKeys = metadataCache.dataSchemaIncludesPartitionKeys + def isPartitioned = partitionColumns.nonEmpty + + private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema + + private def parquetSchema = metadataCache.parquetSchema override def schema = metadataCache.schema @@ -412,18 +464,21 @@ case class ParquetRelation2 // When the data does not include the key and the key is requested then we must fill it in // based on information from the input split. - if (!dataSchemaIncludesPartitionKeys && partitionKeyLocations.nonEmpty) { + if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) { baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { case p if split.getPath.getParent.toString == p.path => p.values }.get + val requiredPartOrdinal = partitionKeyLocations.keys.toSeq + iterator.map { pair => val row = pair._2.asInstanceOf[SpecificMutableRow] var i = 0 - while (i < partValues.size) { + while (i < requiredPartOrdinal.size) { // TODO Avoids boxing cost here! - row.update(partitionKeyLocations(i), partValues(i)) + val partOrdinal = requiredPartOrdinal(i) + row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) i += 1 } row @@ -457,6 +512,8 @@ case class ParquetRelation2 } override def insert(data: DataFrame, overwrite: Boolean): Unit = { + assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}") + // TODO: currently we do not check whether the "schema"s are compatible // That means if one first creates a table and then INSERTs data with // and incompatible schema the execution will fail. It would be nice @@ -464,7 +521,7 @@ case class ParquetRelation2 // before calling execute(). val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val writeSupport = if (schema.map(_.dataType).forall(_.isPrimitive)) { + val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) { log.debug("Initializing MutableRowWriteSupport") classOf[MutableRowWriteSupport] } else { @@ -474,7 +531,7 @@ case class ParquetRelation2 ParquetOutputFormat.setWriteSupportClass(job, writeSupport) val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(schema.toAttributes, conf) + RowWriteSupport.setSchema(data.schema.toAttributes, conf) val destinationPath = new Path(paths.head) @@ -544,14 +601,12 @@ object ParquetRelation2 { // Whether we should merge schemas collected from all Parquet part-files. val MERGE_SCHEMA = "mergeSchema" - // Hive Metastore schema, passed in when the Parquet relation is converted from Metastore - val METASTORE_SCHEMA = "metastoreSchema" - - // Default partition name to use when the partition column value is null or empty string + // Default partition name to use when the partition column value is null or empty string. val DEFAULT_PARTITION_NAME = "partition.defaultName" - // When true, the Parquet data source caches Parquet metadata for performance - val CACHE_METADATA = "cacheMetadata" + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" private[parquet] def readSchema(footers: Seq[Footer], sqlContext: SQLContext): StructType = { footers.map { footer => @@ -579,6 +634,15 @@ object ParquetRelation2 { } } + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ private[parquet] def mergeMetastoreParquetSchema( metastoreSchema: StructType, parquetSchema: StructType): StructType = { @@ -719,16 +783,15 @@ object ParquetRelation2 { * }}} */ private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { - val distinctColNamesOfPartitions = values.map(_.columnNames).distinct - val columnCount = values.head.columnNames.size - // Column names of all partitions must match - assert(distinctColNamesOfPartitions.size == 1, { - val list = distinctColNamesOfPartitions.mkString("\t", "\n", "") + val distinctPartitionsColNames = values.map(_.columnNames).distinct + assert(distinctPartitionsColNames.size == 1, { + val list = distinctPartitionsColNames.mkString("\t", "\n", "") s"Conflicting partition column names detected:\n$list" }) // Resolves possible type conflicts for each column + val columnCount = values.head.columnNames.size val resolvedValues = (0 until columnCount).map { i => resolveTypeConflicts(values.map(_.literals(i))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index f8117c21773ae1e361c379f5fdac4757d5b6f426..eb2d5f25290b1b14857207992079c9e7a3fe2bb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import org.scalatest.BeforeAndAfterAll import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} @@ -40,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuite extends QueryTest with ParquetTest { +class ParquetFilterSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( @@ -112,210 +113,224 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } - def run(prefix: String): Unit = { - test(s"$prefix: filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + test("filter pushdown - boolean") { + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - checkFilterPredicate('_1 === true, classOf[Eq[_]], true) - checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) - } + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } + } - test(s"$prefix: filter pushdown - short") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) - checkFilterPredicate( - Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate( - Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) - checkFilterPredicate( - Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], - Seq(Row(1), Row(4))) - } + test("filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) + checkFilterPredicate( + Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate( + Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) + checkFilterPredicate( + Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], + Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - integer") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - long") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - float") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test("filter pushdown - double") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) - } + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } + } - test(s"$prefix: filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate( - '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) - } + test("filter pushdown - string") { + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } + } - test(s"$prefix: filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes("UTF-8") - } + test("filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") + } - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate( - '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) - } + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } +} + +class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - run("Parquet data source enabled") + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { - run("Parquet data source disabled") + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index c306330818c0a77f8fe3972982b04aa9b20b0cec..208f35761b8070f54ceedfec71ab551e37208fbb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -21,6 +21,9 @@ import scala.collection.JavaConversions._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.scalatest.BeforeAndAfterAll import parquet.example.data.simple.SimpleGroup import parquet.example.data.{Group, GroupWriter} import parquet.hadoop.api.WriteSupport @@ -30,16 +33,13 @@ import parquet.hadoop.{ParquetFileWriter, ParquetWriter} import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -64,10 +64,11 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuite extends QueryTest with ParquetTest { - +class ParquetIOSuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext + import sqlContext.implicits.localSeqToDataFrameHolder + /** * Writes `data` to a Parquet file, reads it back and check file contents. */ @@ -75,229 +76,281 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } - def run(prefix: String): Unit = { - test(s"$prefix: basic data types (without binary)") { - val data = (1 to 4).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - } - checkParquetFile(data) + test("basic data types (without binary)") { + val data = (1 to 4).map { i => + (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) } + checkParquetFile(data) + } - test(s"$prefix: raw binary") { - val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => - assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted - } + test("raw binary") { + val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) + withParquetRDD(data) { rdd => + assertResult(data.map(_._1.mkString(",")).sorted) { + rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted } } + } - test(s"$prefix: string") { - val data = (1 to 4).map(i => Tuple1(i.toString)) - // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL - // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) - } + test("string") { + val data = (1 to 4).map(i => Tuple1(i.toString)) + // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL + // as we store Spark SQL schema in the extra metadata. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + } - test(s"$prefix: fixed-length decimals") { - - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .toDF - // Parquet doesn't allow column names with spaces, have to add an alias here - .select($"_1" cast decimal as "dec") - - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - withTempPath { dir => - val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) - } + test("fixed-length decimals") { + + def makeDecimalRDD(decimal: DecimalType): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(i / 100.0)) + .toDF + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") + + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) } + } - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() - } + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() } + } - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() - } + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() } } + } + + test("map") { + val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) + checkParquetFile(data) + } + + test("array") { + val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) + checkParquetFile(data) + } - test(s"$prefix: map") { - val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) - checkParquetFile(data) + test("struct") { + val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) } + } - test(s"$prefix: array") { - val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) - checkParquetFile(data) + test("nested struct with array of array as field") { + val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) } + } - test(s"$prefix: struct") { - val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) - } + test("nested map with struct as value type") { + val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) + withParquetRDD(data) { rdd => + checkAnswer(rdd, data.map { case Tuple1(m) => + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + }) } + } - test(s"$prefix: nested struct with array of array as field") { - val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) - } + test("nulls") { + val allNulls = ( + null.asInstanceOf[java.lang.Boolean], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[java.lang.Float], + null.asInstanceOf[java.lang.Double]) + + withParquetRDD(allNulls :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(5)(null): _*)) } + } - test(s"$prefix: nested map with struct as value type") { - val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => - Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) - }) - } + test("nones") { + val allNones = ( + None.asInstanceOf[Option[Int]], + None.asInstanceOf[Option[Long]], + None.asInstanceOf[Option[String]]) + + withParquetRDD(allNones :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(3)(null): _*)) } + } - test(s"$prefix: nulls") { - val allNulls = ( - null.asInstanceOf[java.lang.Boolean], - null.asInstanceOf[Integer], - null.asInstanceOf[java.lang.Long], - null.asInstanceOf[java.lang.Float], - null.asInstanceOf[java.lang.Double]) - - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(5)(null): _*)) - } + test("compression codec") { + def compressionCodecFor(path: String) = { + val codecs = ParquetTypesConverter + .readMetaData(new Path(path), Some(configuration)) + .getBlocks + .flatMap(_.getColumns) + .map(_.getCodec.name()) + .distinct + + assert(codecs.size === 1) + codecs.head } - test(s"$prefix: nones") { - val allNones = ( - None.asInstanceOf[Option[Int]], - None.asInstanceOf[Option[Long]], - None.asInstanceOf[Option[String]]) + val data = (0 until 10).map(i => (i, i.toString)) - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(3)(null): _*)) + def checkCompressionCodec(codec: CompressionCodecName): Unit = { + withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withParquetFile(data) { path => + assertResult(conf.parquetCompressionCodec.toUpperCase) { + compressionCodecFor(path) + } + } } } - test(s"$prefix: compression codec") { - def compressionCodecFor(path: String) = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) - codecs.head - } + // Checks default compression codec + checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) - val data = (0 until 10).map(i => (i, i.toString)) + checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) + checkCompressionCodec(CompressionCodecName.GZIP) + checkCompressionCodec(CompressionCodecName.SNAPPY) + } - def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { - withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) - } - } - } + test("read raw Parquet file") { + def makeRawParquetFile(path: Path): Unit = { + val schema = MessageTypeParser.parseMessageType( + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + |} + """.stripMargin) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + (0 until 10).foreach { i => + val record = new SimpleGroup(schema) + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + writer.write(record) } - // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + writer.close() + } - checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) - checkCompressionCodec(CompressionCodecName.GZIP) - checkCompressionCodec(CompressionCodecName.SNAPPY) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) } + } - test(s"$prefix: read raw Parquet file") { - def makeRawParquetFile(path: Path): Unit = { - val schema = MessageTypeParser.parseMessageType( - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - |} - """.stripMargin) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - (0 until 10).foreach { i => - val record = new SimpleGroup(schema) - record.add(0, i % 2 == 0) - record.add(1, i) - record.add(2, i.toLong) - record.add(3, i.toFloat) - record.add(4, i.toDouble) - writer.write(record) - } + test("write metadata") { + withTempPath { file => + val path = new Path(file.toURI.toString) + val fs = FileSystem.getLocal(configuration) + val attributes = ScalaReflection.attributesFor[(Int, String)] + ParquetTypesConverter.writeMetaData(attributes, path, configuration) - writer.close() - } + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - withTempDir { dir => - val path = new Path(dir.toURI.toString, "part-r-0.parquet") - makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) - } + val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val actualSchema = metaData.getFileMetaData.getSchema + val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + + actualSchema.checkContains(expectedSchema) + expectedSchema.checkContains(actualSchema) } + } - test(s"$prefix: write metadata") { - withTempPath { file => - val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + test("save - overwrite") { + withParquetFile((1 to 10).map(i => (i, i.toString))) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) + checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + } + } - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + test("save - ignore") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) + checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + } + } - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + test("save - throw") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + val errorMessage = intercept[Throwable] { + newData.toDF().save( + "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + }.getMessage + assert(errorMessage.contains("already exists")) + } + } - actualSchema.checkContains(expectedSchema) - expectedSchema.checkContains(actualSchema) - } + test("save - append") { + val data = (1 to 10).map(i => (i, i.toString)) + withParquetFile(data) { file => + val newData = (11 to 20).map(i => (i, i.toString)) + newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) + checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) } } +} + +class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - run("Parquet data source enabled") + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { - run("Parquet data source disabled") + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index ae606d11a8f6809dae30f5c83ccad5d3e9a099d7..3bf0116c8f7e91eede5d7d95e11d7a56a55e3dd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -19,17 +19,24 @@ package org.apache.spark.sql.parquet import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.parquet.ParquetRelation2._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} -class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { +// The data where the partitioning key exists only in the directory structure. +case class ParquetData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { override val sqlContext: SQLContext = TestSQLContext + import sqlContext._ + val defaultPartitionName = "__NULL__" test("column type inference") { @@ -112,6 +119,17 @@ class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + check(Seq( + s"hdfs://host:9000/path/a=10/b=20", + s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"), + Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello")))) + check(Seq( s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), @@ -123,4 +141,182 @@ class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + parquetFile(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Parquet file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeParquetFile( + (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + val parquetRelation = load( + "org.apache.spark.sql.parquet", + Map( + "path" -> base.getCanonicalPath, + ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) + + parquetRelation.registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index cba06835f9a6126acf91bec4bc1e27c52d8f1505..d0665450cd766db6a9f1c8bce74236c1a4d87a9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,103 +17,120 @@ package org.apache.spark.sql.parquet +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.{QueryTest, SQLConf} /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuite extends QueryTest with ParquetTest { +class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - def run(prefix: String): Unit = { - test(s"$prefix: simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) - } + test("simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) } + } - test(s"$prefix: appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) - } + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) } + } - // This test case will trigger the NPE mentioned in - // https://issues.apache.org/jira/browse/PARQUET-151. - ignore(s"$prefix: overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM t") - checkAnswer(table("t"), data.map(Row.fromTuple)) - } + // This test case will trigger the NPE mentioned in + // https://issues.apache.org/jira/browse/PARQUET-151. + // Update: This also triggers SPARK-5746, should re enable it when we get both fixed. + ignore("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM t") + checkAnswer(table("t"), data.map(Row.fromTuple)) } + } - test(s"$prefix: self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } - - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + assertResult(4, "Field count mismatche")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) } + } - test(s"$prefix: nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) - } + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) } + } - test(s"$prefix: nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) - } + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) } + } - test(s"$prefix: SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) - } + test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } + } - test(s"$prefix: SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) - } + checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) } } +} + +class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi + + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { + val originalConf = sqlContext.conf.parquetUseDataSourceApi - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { - run("Parquet data source enabled") + override protected def beforeAll(): Unit = { + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { - run("Parquet data source disabled") + override protected def afterAll(): Unit = { + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ddc7b181d4d4638aed42d7e44cebd6bafdf58a16..87b380f9509793bdbfa78f41ef5fb3c9ed055bfd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -22,26 +22,24 @@ import java.sql.Timestamp import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.parse.VariableSubstitution +import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException} -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand} -import org.apache.spark.sql.sources.{CreateTableUsing, DataSourceStrategy} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand} +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} +import org.apache.spark.sql.sources.DataSourceStrategy import org.apache.spark.sql.types._ /** @@ -244,6 +242,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] lazy val analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = false) { override val extendedRules = + catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index eb1ee54247bea3c457677781b6f9dc8a661ae9d1..6d794d0e11391e1aaad61eeeeacfecebc8c9b39c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,25 +20,25 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} -import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder} - -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.hive.metastore.{Warehouse, TableType} -import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, FieldSchema} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable} +import org.apache.hadoop.hive.metastore.{TableType, Warehouse} import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.CreateTableDesc import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException} +import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec} import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -101,16 +101,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val caseSensitive: Boolean = false - /** * - * Creates a data source table (a table created with USING clause) in Hive's metastore. - * Returns true when the table has been created. Otherwise, false. - * @param tableName - * @param userSpecifiedSchema - * @param provider - * @param options - * @param isExternal - * @return - */ + /** + * Creates a data source table (a table created with USING clause) in Hive's metastore. + * Returns true when the table has been created. Otherwise, false. + */ def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -141,7 +135,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } def hiveDefaultTableFilePath(tableName: String): String = { - val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase()) + val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase) hiveWarehouse.getTablePath(currentDatabase, tableName).toString } @@ -176,25 +170,41 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Nil } - val relation = MetastoreRelation( - databaseName, tblName, alias)( - table.getTTable, partitions.map(part => part.getTPartition))(hive) - - if (hive.convertMetastoreParquet && - hive.conf.parquetUseDataSourceApi && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet")) { - val metastoreSchema = StructType.fromAttributes(relation.output) - val paths = if (relation.hiveQlTable.isPartitioned) { - relation.hiveQlPartitions.map(p => p.getLocation) - } else { - Seq(relation.hiveQlTable.getDataLocation.toString) - } + MetastoreRelation(databaseName, tblName, alias)( + table.getTTable, partitions.map(part => part.getTPartition))(hive) + } + } - LogicalRelation(ParquetRelation2( - paths, Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) - } else { - relation + private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // serialize the Metastore schema to JSON and pass it as a data source option because of the + // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + if (metastoreRelation.hiveQlTable.isPartitioned) { + val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) + val partitionColumnDataTypes = partitionSchema.map(_.dataType) + val partitions = metastoreRelation.hiveQlPartitions.map { p => + val location = p.getLocation + val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) + }) + ParquetPartition(values, location) } + val partitionSpec = PartitionSpec(partitionSchema, partitions) + val paths = partitions.map(_.path) + LogicalRelation( + ParquetRelation2( + paths, + Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json), + None, + Some(partitionSpec))(hive)) + } else { + val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) + LogicalRelation( + ParquetRelation2( + paths, + Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) } } @@ -261,9 +271,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName") tbl.setSerializationLib(classOf[LazySimpleSerDe].getName()) - import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.io.Text + import org.apache.hadoop.mapred.TextInputFormat tbl.setInputFormatClass(classOf[TextInputFormat]) tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]]) @@ -385,13 +395,56 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } + /** + * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet + * data source relations for better performance. + * + * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. + */ + object ParquetConversions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Collects all `MetastoreRelation`s which should be replaced + val toBeReplaced = plan.collect { + // Write path + case InsertIntoTable(relation: MetastoreRelation, _, _, _) + // Inserting into partitioned table is not supported in Parquet data source (yet). + if !relation.hiveQlTable.isPartitioned && + hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + relation + + // Read path + case p @ PhysicalOperation(_, _, relation: MetastoreRelation) + if hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + relation + } + + // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes + // attribute IDs referenced in other nodes. + toBeReplaced.distinct.foldLeft(plan) { (lastPlan, relation) => + val parquetRelation = convertToParquetRelation(relation) + val attributedRewrites = AttributeMap(relation.output.zip(parquetRelation.output)) + + lastPlan.transformUp { + case r: MetastoreRelation if r == relation => parquetRelation + case other => other.transformExpressions { + case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a) + } + } + } + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. */ object CreateTables extends Rule[LogicalPlan] { import org.apache.hadoop.hive.ql.Context - import org.apache.hadoop.hive.ql.parse.{QB, ASTNode, SemanticAnalyzer} + import org.apache.hadoop.hive.ql.parse.{ASTNode, QB, SemanticAnalyzer} def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Wait until children are resolved. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cb138be90e2e18b67a25cf7bccf4fa50d995d1a1..965d159656d804e462e2aa95a8685b046cee67a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -139,15 +139,19 @@ private[hive] trait HiveStrategies { val partitionLocations = partitions.map(_.getLocation) - hiveContext - .parquetFile(partitionLocations.head, partitionLocations.tail: _*) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil + if (partitionLocations.isEmpty) { + PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil + } else { + hiveContext + .parquetFile(partitionLocations.head, partitionLocations.tail: _*) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection: _*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } } else { hiveContext .parquetFile(relation.hiveQlTable.getDataLocation.toString) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index e246cbb6d77f0d2ab03ebc1eb10587ba9a788763..2acf1a7767c1947bcf74b08209576822f47667fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -40,7 +40,7 @@ case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) * A suite to test the automatic conversion of metastore tables with parquet data to use the * built in parquet support. */ -class ParquetMetastoreSuite extends ParquetPartitioningTest { +class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -97,6 +97,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } override def afterAll(): Unit = { + sql("DROP TABLE partitioned_parquet") + sql("DROP TABLE partitioned_parquet_with_key") + sql("DROP TABLE normal_parquet") setConf("spark.sql.hive.convertMetastoreParquet", "false") } @@ -113,10 +116,38 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { } } +class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + /** * A suite of tests for the Parquet support through the data sources API. */ -class ParquetSourceSuite extends ParquetPartitioningTest { +class ParquetSourceSuiteBase extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -146,6 +177,34 @@ class ParquetSourceSuite extends ParquetPartitioningTest { } } +class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + +class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { + val originalConf = conf.parquetUseDataSourceApi + + override def beforeAll(): Unit = { + super.beforeAll() + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + } + + override def afterAll(): Unit = { + super.afterAll() + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + } +} + /** * A collection of tests for parquet data with various forms of partitioning. */ @@ -191,107 +250,99 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll } } - def run(prefix: String): Unit = { - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"$prefix: ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"$prefix: project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"$prefix: project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"$prefix: simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) - } - - test(s"$prefix: pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), - Row(10)) - } - - test(s"$prefix: non-existent partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"$prefix: multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"$prefix: non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"$prefix: sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"$prefix: hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } + Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) } - test(s"$prefix: $prefix: non-part select(*)") { + test(s"pruned count $table") { checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), Row(10)) } - } - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") - run("Parquet data source enabled") + test(s"non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - run("Parquet data source disabled") + test(s"sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } + } + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + Row(10)) + } }