diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index d68854214ef06c07a7774996a492cb620442b65a..03238e9fa00882cbba0768bc2fda3bd2e5134160 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} @@ -191,6 +191,21 @@ class SparkHadoopUtil extends Logging { val method = context.getClass.getMethod("getConfiguration") method.invoke(context).asInstanceOf[Configuration] } + + /** + * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the + * given path points to a file, return a single-element collection containing [[FileStatus]] of + * that file. + */ + def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { + def recurse(path: Path) = { + val (directories, leaves) = fs.listStatus(path).partition(_.isDir) + leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath)) + } + + val baseStatus = fs.getFileStatus(basePath) + if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) + } } object SparkHadoopUtil { diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3ac8ea597e1421561a363decf39ffa40741ff073..e55f285a778c408df30b361de3423fc70314854c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1471,7 +1471,7 @@ class SQLContext(object): else: raise ValueError("Can only register DataFrame as table") - def parquetFile(self, path): + def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil @@ -1483,7 +1483,12 @@ class SQLContext(object): >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jdf = self._ssql_ctx.parquetFile(path) + gateway = self._sc._gateway + jpath = paths[0] + jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1) + for i in range(1, len(paths)): + jpaths[i] = paths[i] + jdf = self._ssql_ctx.parquetFile(jpath, jpaths) return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c84cc95520a194a06ce346de2c31de55d16fa20d..365b1685a8e7147b2dc9163a42ff74a17643881e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BinaryType, BooleanType} object InterpretedPredicate { def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -175,7 +175,10 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison null } else { val r = right.eval(input) - if (r == null) null else l == r + if (r == null) null + else if (left.dataType != BinaryType) l == r + else BinaryType.ordering.compare( + l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index be362be55b56309a4e681497792a7327ff771517..91efe320546a785fa9b18eedd43649f53b8c1e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.types import java.sql.Timestamp +import scala.collection.mutable.ArrayBuffer import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} @@ -29,6 +30,7 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} @@ -159,7 +161,6 @@ object DataType { case failure: NoSuccess => throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") } - } protected[types] def buildFormattedString( @@ -754,6 +755,57 @@ object StructType { def apply(fields: java.util.List[StructField]): StructType = { StructType(fields.toArray.asInstanceOf[Array[StructField]]) } + + private[sql] def merge(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => + ArrayType( + merge(leftElementType, rightElementType), + leftContainsNull || rightContainsNull) + + case (MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => + MapType( + merge(leftKeyType, rightKeyType), + merge(leftValueType, rightValueType), + leftContainsNull || rightContainsNull) + + case (StructType(leftFields), StructType(rightFields)) => + val newFields = ArrayBuffer.empty[StructField] + + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightFields + .find(_.name == leftName) + .map { case rightField @ StructField(_, rightType, rightNullable, _) => + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse(Some(leftField)) + .foreach(newFields += _) + } + + rightFields + .filterNot(f => leftFields.map(_.name).contains(f.name)) + .foreach(newFields += _) + + StructType(newFields) + + case (DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType(leftPrecision.max(rightPrecision), leftScale.max(rightScale)) + + case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) + if leftUdt.userClass == rightUdt.userClass => leftUdt + + case (leftType, rightType) if leftType == rightType => + leftType + + case _ => + throw new SparkException(s"Failed to merge incompatible data types $left and $right") + } } @@ -890,6 +942,20 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") s"struct<${fieldTypes.mkString(",")}>" } + + /** + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field + * B from `that`, + * + * 1. If A and B have the same name and data type, they are merged to a field C with the same name + * and data type. C is nullable if and only if either A or B is nullable. + * 2. If A doesn't exist in `that`, it's included in the result schema. + * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be + * thrown. + */ + private[sql] def merge(that: StructType): StructType = + StructType.merge(this, that).asInstanceOf[StructType] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index d6df927f9d42c97cac3aa745ab96a7c76e9cf4da..58d11751353b31dc5d54a4076733bb7aa03f2b4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -295,7 +295,11 @@ private[sql] class DataFrameImpl protected[sql]( } override def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + if (sqlContext.conf.parquetUseDataSourceApi) { + save("org.apache.spark.sql.parquet", "path" -> path) + } else { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } } override def saveAsTable(tableName: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 7fe17944a734e2ffcffcc584dcc99564e813f99f..5ef3bb022fc5a1cc1b35f0afed7fc432ce655c21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -37,6 +37,7 @@ private[spark] object SQLConf { val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" + val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" @@ -105,6 +106,10 @@ private[sql] class SQLConf extends Serializable { private[spark] def parquetFilterPushDown = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses Parquet implementation based on data source API */ + private[spark] def parquetUseDataSourceApi = + getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 01620aa0acd4944aed0b8c96ae09eccb9a0e7634..706ef6ad4f174bb7af05c9526fda196abd14e173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql import java.beans.Introspector import java.util.Properties -import scala.collection.immutable import scala.collection.JavaConversions._ +import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkContext, Partition} import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ @@ -36,11 +35,12 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.execution._ -import org.apache.spark.sql.json._ import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.json._ +import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation, _} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +import org.apache.spark.{Partition, SparkContext} /** * :: AlphaComponent :: @@ -303,8 +303,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def parquetFile(path: String): DataFrame = - DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + @scala.annotation.varargs + def parquetFile(path: String, paths: String*): DataFrame = + if (conf.parquetUseDataSourceApi) { + baseRelationToDataFrame(parquet.ParquetRelation2(path +: paths, Map.empty)(this)) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } /** * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f06f5fd1fc51135b3f6d1ecca54094b784bd6942..81bcf5a6f32dd114df1079e0dd0d21cb9c46bb65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources.{DescribeCommand => LogicalDescribeCommand} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 8372decbf8aa1dc0f3a9b55bcced89c1af03a4ba..f27585d05a986cc68b99bc36e526e0097b05024d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.StructType private[sql] class DefaultSource - extends RelationProvider with SchemaRelationProvider with CreateableRelationProvider { + extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { /** Returns a new base relation with the parameters. */ override def createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 14c81ae4eba4e70e5c947e623984aa26966ac8b0..19bfba34b8f4a36a57865ad394cd10a53d117cd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -159,7 +159,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -325,7 +325,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { val attributesSize = attributes.size if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") + s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") } var index = 0 @@ -348,10 +348,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray( - record(index).asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index b646109b7c553629050d3df9fe069fa68e3bf657..5209581fa8357b7c105a5530ce8dfc6d9b4e83ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -19,24 +19,23 @@ package org.apache.spark.sql.parquet import java.io.IOException +import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job - import parquet.format.converter.ParquetMetadataConverter -import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} -import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} +import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} +import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition +import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkException} // Implicits import scala.collection.JavaConversions._ @@ -285,7 +284,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ctype: DataType, name: String, nullable: Boolean = true, - inArray: Boolean = false, + inArray: Boolean = false, toThriftSchemaNames: Boolean = false): ParquetType = { val repetition = if (inArray) { @@ -340,7 +339,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } case StructType(structFields) => { val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, + field => fromDataType(field.dataType, field.name, field.nullable, inArray = false, toThriftSchemaNames) } new ParquetGroupType(repetition, name, fields.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 179c0d6b2223985537c647fb97f2cf6a54b90f2a..49d46334b652562cf002037a5543d713c6aaa70a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -16,25 +16,38 @@ */ package org.apache.spark.sql.parquet -import java.util.{List => JList} +import java.io.IOException +import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} +import java.math.{BigDecimal => JBigDecimal} +import java.text.SimpleDateFormat +import java.util.{List => JList, Date} import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Try -import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} import parquet.filter2.predicate.FilterApi -import parquet.hadoop.ParquetInputFormat +import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop.{ParquetInputFormat, _} import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{NewHadoopPartition, RDD} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.sql.{Row, SQLConf, SQLContext} -import org.apache.spark.{Logging, Partition => SparkPartition} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} +import org.apache.spark.sql.types.StructType._ +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} +import org.apache.spark.{Partition => SparkPartition, TaskContext, SerializableWritable, Logging, SparkException} /** @@ -43,19 +56,49 @@ import org.apache.spark.{Logging, Partition => SparkPartition} * required is `path`, which should be the location of a collection of, optionally partitioned, * parquet files. */ -class DefaultSource extends RelationProvider { +class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + private def checkPath(parameters: Map[String, String]): String = { + parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + } + /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val path = - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) + ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) + } - ParquetRelation2(path)(sqlContext) + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = { + ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val path = checkPath(parameters) + ParquetRelation.createEmpty( + path, + data.schema.toAttributes, + false, + sqlContext.sparkContext.hadoopConfiguration, + sqlContext) + + val relation = createRelation(sqlContext, parameters, data.schema) + relation.asInstanceOf[ParquetRelation2].insert(data, true) + relation } } -private[parquet] case class Partition(partitionValues: Map[String, Any], files: Seq[FileStatus]) +private[parquet] case class Partition(values: Row, path: String) + +private[parquet] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) /** * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is @@ -81,117 +124,196 @@ private[parquet] case class Partition(partitionValues: Map[String, Any], files: * discovery. */ @DeveloperApi -case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) - extends CatalystScan with Logging { +case class ParquetRelation2 + (paths: Seq[String], parameters: Map[String, String], maybeSchema: Option[StructType] = None) + (@transient val sqlContext: SQLContext) + extends CatalystScan + with InsertableRelation + with SparkHadoopMapReduceUtil + with Logging { + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean + + // Optional Metastore schema, used when converting Hive Metastore Parquet table + private val maybeMetastoreSchema = + parameters + .get(ParquetRelation2.METASTORE_SCHEMA) + .map(s => DataType.fromJson(s).asInstanceOf[StructType]) + + // Hive uses this as part of the default partition name when the partition column value is null + // or empty string + private val defaultPartitionName = parameters.getOrElse( + ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") + + override def equals(other: Any) = other match { + case relation: ParquetRelation2 => + paths.toSet == relation.paths.toSet && + maybeMetastoreSchema == relation.maybeMetastoreSchema && + (shouldMergeSchemas == relation.shouldMergeSchemas || schema == relation.schema) + } - def sparkContext = sqlContext.sparkContext + private[sql] def sparkContext = sqlContext.sparkContext - // Minor Hack: scala doesnt seem to respect @transient for vals declared via extraction - @transient - private var partitionKeys: Seq[String] = _ - @transient - private var partitions: Seq[Partition] = _ - discoverPartitions() + @transient private val fs = FileSystem.get(sparkContext.hadoopConfiguration) - // TODO: Only finds the first partition, assumes the key is of type Integer... - private def discoverPartitions() = { - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val partValue = "([^=]+)=([^=]+)".r + private class MetadataCache { + private var metadataStatuses: Array[FileStatus] = _ + private var commonMetadataStatuses: Array[FileStatus] = _ + private var footers: Map[FileStatus, Footer] = _ + private var parquetSchema: StructType = _ - val childrenOfPath = fs.listStatus(new Path(path)).filterNot(_.getPath.getName.startsWith("_")) - val childDirs = childrenOfPath.filter(s => s.isDir) + var dataStatuses: Array[FileStatus] = _ + var partitionSpec: PartitionSpec = _ + var schema: StructType = _ + var dataSchemaIncludesPartitionKeys: Boolean = _ - if (childDirs.size > 0) { - val partitionPairs = childDirs.map(_.getPath.getName).map { - case partValue(key, value) => (key, value) + def refresh(): Unit = { + val baseStatuses = { + val statuses = paths.distinct.map(p => fs.getFileStatus(fs.makeQualified(new Path(p)))) + // Support either reading a collection of raw Parquet part-files, or a collection of folders + // containing Parquet files (e.g. partitioned Parquet table). + assert(statuses.forall(!_.isDir) || statuses.forall(_.isDir)) + statuses.toArray } - val foundKeys = partitionPairs.map(_._1).distinct - if (foundKeys.size > 1) { - sys.error(s"Too many distinct partition keys: $foundKeys") + val leaves = baseStatuses.flatMap { f => + val statuses = SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + } + assert(statuses.nonEmpty, s"${f.getPath} is an empty folder.") + statuses } - // Do a parallel lookup of partition metadata. - val partitionFiles = - childDirs.par.map { d => - fs.listStatus(d.getPath) - // TODO: Is there a standard hadoop function for this? - .filterNot(_.getPath.getName.startsWith("_")) - .filterNot(_.getPath.getName.startsWith(".")) - }.seq - - partitionKeys = foundKeys.toSeq - partitions = partitionFiles.zip(partitionPairs).map { case (files, (key, value)) => - Partition(Map(key -> value.toInt), files) - }.toSeq - } else { - partitionKeys = Nil - partitions = Partition(Map.empty, childrenOfPath) :: Nil - } - } + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => + val parquetMetadata = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, f, ParquetMetadataConverter.NO_FILTER) + f -> new Footer(f.getPath, parquetMetadata) + }.seq.toMap + + partitionSpec = { + val partitionDirs = dataStatuses + .filterNot(baseStatuses.contains) + .map(_.getPath.getParent) + .distinct + + if (partitionDirs.nonEmpty) { + ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName) + } else { + // No partition directories found, makes an empty specification + PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + } + } - override val sizeInBytes = partitions.flatMap(_.files).map(_.getLen).sum + parquetSchema = maybeSchema.getOrElse(readSchema()) - val dataSchema = StructType.fromAttributes( // TODO: Parquet code should not deal with attributes. - ParquetTypesConverter.readSchemaFromFile( - partitions.head.files.head.getPath, - Some(sparkContext.hadoopConfiguration), - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp)) + dataSchemaIncludesPartitionKeys = + isPartitioned && + partitionColumns.forall(f => metadataCache.parquetSchema.fieldNames.contains(f.name)) - val dataIncludesKey = - partitionKeys.headOption.map(dataSchema.fieldNames.contains(_)).getOrElse(true) + schema = { + val fullParquetSchema = if (dataSchemaIncludesPartitionKeys) { + metadataCache.parquetSchema + } else { + StructType(metadataCache.parquetSchema.fields ++ partitionColumns.fields) + } - override val schema = - if (dataIncludesKey) { - dataSchema - } else { - StructType(dataSchema.fields :+ StructField(partitionKeys.head, IntegerType)) + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullParquetSchema)) + .getOrElse(fullParquetSchema) + } } - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - // This is mostly a hack so that we can use the existing parquet filter code. - val requiredColumns = output.map(_.name) + private def readSchema(): StructType = { + // Sees which file(s) we need to touch in order to figure out the schema. + val filesToTouch = + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) + ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) + } + } - val requestedSchema = StructType(requiredColumns.map(schema(_))) + @transient private val metadataCache = new MetadataCache + metadataCache.refresh() - val partitionKeySet = partitionKeys.toSet - val rawPredicate = - predicates - .filter(_.references.map(_.name).toSet.subsetOf(partitionKeySet)) - .reduceOption(And) - .getOrElse(Literal(true)) + private def partitionColumns = metadataCache.partitionSpec.partitionColumns - // Translate the predicate so that it reads from the information derived from the - // folder structure - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = partitionKeys.indexWhere(a.name == _) - BoundReference(idx, IntegerType, nullable = true) - } + private def partitions = metadataCache.partitionSpec.partitions - val inputData = new GenericMutableRow(partitionKeys.size) - val pruningCondition = InterpretedPredicate(castedPredicate) + private def isPartitioned = partitionColumns.nonEmpty - val selectedPartitions = - if (partitionKeys.nonEmpty && predicates.nonEmpty) { - partitions.filter { part => - inputData(0) = part.partitionValues.values.head - pruningCondition(inputData) - } - } else { - partitions + private def dataSchemaIncludesPartitionKeys = metadataCache.dataSchemaIncludesPartitionKeys + + override def schema = metadataCache.schema + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + // TODO Should calculate per scan size + // It's common that a query only scans a fraction of a large Parquet file. Returning size of the + // whole Parquet file disables some optimizations in this case (e.g. broadcast join). + override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum + + // This is mostly a hack so that we can use the existing parquet filter code. + override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { + val job = new Job(sparkContext.hadoopConfiguration) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val jobConf: Configuration = ContextUtil.getConfiguration(job) + + val selectedPartitions = prunePartitions(predicates, partitions) + val selectedFiles = if (isPartitioned) { + selectedPartitions.flatMap { p => + metadataCache.dataStatuses.filter(_.getPath.getParent.toString == p.path) } + } else { + metadataCache.dataStatuses.toSeq + } - val fs = FileSystem.get(new java.net.URI(path), sparkContext.hadoopConfiguration) - val selectedFiles = selectedPartitions.flatMap(_.files).map(f => fs.makeQualified(f.getPath)) // FileInputFormat cannot handle empty lists. if (selectedFiles.nonEmpty) { - org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles: _*) + FileInputFormat.setInputPaths(job, selectedFiles.map(_.getPath): _*) } // Push down filters when possible. Notice that not all filters can be converted to Parquet @@ -203,23 +325,28 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) .filter(_ => sqlContext.conf.parquetFilterPushDown) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) - def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - logInfo(s"Reading $percentRead% of $path partitions") + if (isPartitioned) { + def percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 + logInfo(s"Reading $percentRead% of partitions") + } + + val requiredColumns = output.map(_.name) + val requestedSchema = StructType(requiredColumns.map(schema(_))) // Store both requested and original schema in `Configuration` jobConf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedSchema.toAttributes)) + convertToString(requestedSchema.toAttributes)) jobConf.set( RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(schema.toAttributes)) + convertToString(schema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( + new NewHadoopRDD( sparkContext, classOf[FilteringParquetRowInputFormat], classOf[Void], @@ -228,66 +355,400 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) val cacheMetadata = useCache @transient - val cachedStatus = selectedPartitions.flatMap(_.files) + val cachedStatus = selectedFiles // Overridden so we can inject our own cached files statuses. override def getPartitions: Array[SparkPartition] = { - val inputFormat = - if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - } - } else { - new FilteringParquetRowInputFormat + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus } - - inputFormat match { - case configurable: Configurable => - configurable.setConf(getConf) - case _ => + } else { + new FilteringParquetRowInputFormat } + val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } - result } } - // The ordinal for the partition key in the result row, if requested. - val partitionKeyLocation = - partitionKeys - .headOption - .map(requiredColumns.indexOf(_)) - .getOrElse(-1) + // The ordinals for partition keys in the result row, if requested. + val partitionKeyLocations = partitionColumns.fieldNames.zipWithIndex.map { + case (name, index) => index -> requiredColumns.indexOf(name) + }.toMap.filter { + case (_, index) => index >= 0 + } // When the data does not include the key and the key is requested then we must fill it in // based on information from the input split. - if (!dataIncludesKey && partitionKeyLocation != -1) { - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - val currentValue = partValues.values.head.toInt - iter.map { pair => - val res = pair._2.asInstanceOf[SpecificMutableRow] - res.setInt(partitionKeyLocation, currentValue) - res + if (!dataSchemaIncludesPartitionKeys && partitionKeyLocations.nonEmpty) { + baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => + val partValues = selectedPartitions.collectFirst { + case p if split.getPath.getParent.toString == p.path => p.values + }.get + + iterator.map { pair => + val row = pair._2.asInstanceOf[SpecificMutableRow] + var i = 0 + while (i < partValues.size) { + // TODO Avoids boxing cost here! + row.update(partitionKeyLocations(i), partValues(i)) + i += 1 + } + row } } } else { baseRDD.map(_._2) } } + + private def prunePartitions( + predicates: Seq[Expression], + partitions: Seq[Partition]): Seq[Partition] = { + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + val rawPredicate = partitionPruningPredicates.reduceOption(And).getOrElse(Literal(true)) + val boundPredicate = InterpretedPredicate(rawPredicate transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + if (isPartitioned && partitionPruningPredicates.nonEmpty) { + partitions.filter(p => boundPredicate(p.values)) + } else { + partitions + } + } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + // TODO: currently we do not check whether the "schema"s are compatible + // That means if one first creates a table and then INSERTs data with + // and incompatible schema the execution will fail. It would be nice + // to catch this early one, maybe having the planner validate the schema + // before calling execute(). + + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val writeSupport = if (schema.map(_.dataType).forall(_.isPrimitive)) { + log.debug("Initializing MutableRowWriteSupport") + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupport) + + val conf = ContextUtil.getConfiguration(job) + RowWriteSupport.setSchema(schema.toAttributes, conf) + + val destinationPath = new Path(paths.head) + + if (overwrite) { + try { + destinationPath.getFileSystem(conf).delete(destinationPath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${destinationPath.toString} prior" + + s" to writing to Parquet file:\n${e.toString}") + } + } + + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[Row]) + FileOutputFormat.setOutputPath(job, destinationPath) + + val wrappedConf = new SerializableWritable(job.getConfiguration) + val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()) + val stageId = sqlContext.sparkContext.newRddId() + + val taskIdOffset = if (overwrite) { + 1 + } else { + FileSystemHelper.findMaxTaskId( + FileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 + } + + def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { + /* "reduce task" <split #> <attempt # = spark task #> */ + val attemptId = newTaskAttemptID( + jobTrackerId, stageId, isMap = false, context.partitionId(), context.attemptNumber()) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = new AppendingParquetOutputFormat(taskIdOffset) + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext) + try { + while (iterator.hasNext) { + val row = iterator.next() + writer.write(null, row) + } + } finally { + writer.close(hadoopContext) + } + committer.commitTask(hadoopContext) + } + val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) + /* apparently we need a TaskAttemptID to construct an OutputCommitter; + * however we're only going to use this local OutputCommitter for + * setupJob/commitJob, so we just use a dummy "map" task. + */ + val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + jobCommitter.setupJob(jobTaskContext) + sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) + jobCommitter.commitJob(jobTaskContext) + + metadataCache.refresh() + } +} + +object ParquetRelation2 { + // Whether we should merge schemas collected from all Parquet part-files. + val MERGE_SCHEMA = "mergeSchema" + + // Hive Metastore schema, passed in when the Parquet relation is converted from Metastore + val METASTORE_SCHEMA = "metastoreSchema" + + // Default partition name to use when the partition column value is null or empty string + val DEFAULT_PARTITION_NAME = "partition.defaultName" + + // When true, the Parquet data source caches Parquet metadata for performance + val CACHE_METADATA = "cacheMetadata" + + private[parquet] def readSchema(footers: Seq[Footer], sqlContext: SQLContext): StructType = { + footers.map { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val parquetSchema = metadata.getSchema + val maybeSparkSchema = metadata + .getKeyValueMetaData + .toMap + .get(RowReadSupport.SPARK_METADATA_KEY) + .map(DataType.fromJson(_).asInstanceOf[StructType]) + + maybeSparkSchema.getOrElse { + // Falls back to Parquet schema if Spark SQL schema is absent. + StructType.fromAttributes( + // TODO Really no need to use `Attribute` here, we only need to know the data type. + convertToAttributes( + parquetSchema, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetINT96AsTimestamp)) + } + }.reduce { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + assert(metastoreSchema.size == parquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + val reorderedParquetSchema = parquetSchema.sortBy(f => ordinalMap(f.name.toLowerCase)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + // TODO Data source implementations shouldn't touch Catalyst types (`Literal`). + // However, we are already using Catalyst expressions for partition pruning and predicate + // push-down here... + private[parquet] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + require(columnNames.size == literals.size) + } + + /** + * Given a group of qualified paths, tries to parse them and returns a partition specification. + * For example, given: + * {{{ + * hdfs://<host>:<port>/path/to/partition/a=1/b=hello/c=3.14 + * hdfs://<host>:<port>/path/to/partition/a=2/b=world/c=6.28 + * }}} + * it returns: + * {{{ + * PartitionSpec( + * partitionColumns = StructType( + * StructField(name = "a", dataType = IntegerType, nullable = true), + * StructField(name = "b", dataType = StringType, nullable = true), + * StructField(name = "c", dataType = DoubleType, nullable = true)), + * partitions = Seq( + * Partition( + * values = Row(1, "hello", 3.14), + * path = "hdfs://<host>:<port>/path/to/partition/a=1/b=hello/c=3.14"), + * Partition( + * values = Row(2, "world", 6.28), + * path = "hdfs://<host>:<port>/path/to/partition/a=2/b=world/c=6.28"))) + * }}} + */ + private[parquet] def parsePartitions( + paths: Seq[Path], + defaultPartitionName: String): PartitionSpec = { + val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) + val fields = { + val (PartitionValues(columnNames, literals)) = partitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + StructField(name, dataType, nullable = true) + } + } + + val partitions = partitionValues.zip(paths).map { + case (PartitionValues(_, literals), path) => + Partition(Row(literals.map(_.value): _*), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } + + /** + * Parses a single partition, returns column names and values of each partition column. For + * example, given: + * {{{ + * path = hdfs://<host>:<port>/path/to/partition/a=42/b=hello/c=3.14 + * }}} + * it returns: + * {{{ + * PartitionValues( + * Seq("a", "b", "c"), + * Seq( + * Literal(42, IntegerType), + * Literal("hello", StringType), + * Literal(3.14, FloatType))) + * }}} + */ + private[parquet] def parsePartition( + path: Path, + defaultPartitionName: String): PartitionValues = { + val columns = ArrayBuffer.empty[(String, Literal)] + // Old Hadoop versions don't have `Path.isRoot` + var finished = path.getParent == null + var chopped = path + + while (!finished) { + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) + maybeColumn.foreach(columns += _) + chopped = chopped.getParent + finished = maybeColumn.isEmpty || chopped.getParent == null + } + + val (columnNames, values) = columns.reverse.unzip + PartitionValues(columnNames, values) + } + + private def parsePartitionColumn( + columnSpec: String, + defaultPartitionName: String): Option[(String, Literal)] = { + val equalSignIndex = columnSpec.indexOf('=') + if (equalSignIndex == -1) { + None + } else { + val columnName = columnSpec.take(equalSignIndex) + assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") + + val rawColumnValue = columnSpec.drop(equalSignIndex + 1) + assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") + + val literal = inferPartitionColumnValue(rawColumnValue, defaultPartitionName) + Some(columnName -> literal) + } + } + + /** + * Resolves possible type conflicts between partitions by up-casting "lower" types. The up- + * casting order is: + * {{{ + * NullType -> + * IntegerType -> LongType -> + * FloatType -> DoubleType -> DecimalType.Unlimited -> + * StringType + * }}} + */ + private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { + val distinctColNamesOfPartitions = values.map(_.columnNames).distinct + val columnCount = values.head.columnNames.size + + // Column names of all partitions must match + assert(distinctColNamesOfPartitions.size == 1, { + val list = distinctColNamesOfPartitions.mkString("\t", "\n", "") + s"Conflicting partition column names detected:\n$list" + }) + + // Resolves possible type conflicts for each column + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } + + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } + } + + /** + * Converts a string to a `Literal` with automatic type inference. Currently only supports + * [[IntegerType]], [[LongType]], [[FloatType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[StringType]]. + */ + private[parquet] def inferPartitionColumnValue( + raw: String, + defaultPartitionName: String): Literal = { + // First tries integral types + Try(Literal(Integer.parseInt(raw), IntegerType)) + .orElse(Try(Literal(JLong.parseLong(raw), LongType))) + // Then falls back to fractional types + .orElse(Try(Literal(JFloat.parseFloat(raw), FloatType))) + .orElse(Try(Literal(JDouble.parseDouble(raw), DoubleType))) + .orElse(Try(Literal(new JBigDecimal(raw), DecimalType.Unlimited))) + // Then falls back to string + .getOrElse { + if (raw == defaultPartitionName) Literal(null, NullType) else Literal(raw, StringType) + } + } + + private val upCastingOrder: Seq[DataType] = + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + + /** + * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" + * types. + */ + private def resolveTypeConflicts(literals: Seq[Literal]): Seq[Literal] = { + val desiredType = { + val topType = literals.map(_.dataType).maxBy(upCastingOrder.indexOf(_)) + // Falls back to string if all values of this column are null or empty string + if (topType == NullType) StringType else topType + } + + literals.map { case l @ Literal(_, dataType) => + Literal(Cast(l, desiredType).eval(), desiredType) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 386ff2452f1a36cd9c2f67c09f22b5621dc3fa4a..d23ffb8b7a9602cb1fbc29921baee9859a7be429 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, InsertIntoTable => LogicalInsertIntoTable} -import org.apache.spark.sql.execution +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.{Row, Strategy, execution} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -54,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy { case l @ LogicalRelation(t: TableScan) => execution.PhysicalRDD(l.output, t.buildScan()) :: Nil - case i @ LogicalInsertIntoTable( + case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => if (partition.nonEmpty) { sys.error(s"Insert into a partition is not allowed because $l is not partitioned.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 2ef740b3be0bdfd423a5f541aec19cfc31c642c0..9c37e0169ff852c20fadd09b27334064386465db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -241,20 +241,16 @@ object ResolvedDataSource { val relation = userSpecifiedSchema match { case Some(schema: StructType) => { clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.SchemaRelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options), schema) + case dataSource: SchemaRelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) case dataSource: org.apache.spark.sql.sources.RelationProvider => sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.") } } case None => { clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.RelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.RelationProvider] - .createRelation(sqlContext, new CaseInsensitiveMap(options)) + case dataSource: RelationProvider => + dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider => sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.") } @@ -279,10 +275,8 @@ object ResolvedDataSource { } val relation = clazz.newInstance match { - case dataSource: org.apache.spark.sql.sources.CreateableRelationProvider => - dataSource - .asInstanceOf[org.apache.spark.sql.sources.CreateableRelationProvider] - .createRelation(sqlContext, options, data) + case dataSource: CreatableRelationProvider => + dataSource.createRelation(sqlContext, options, data) case _ => sys.error(s"${clazz.getCanonicalName} does not allow create table as select.") } @@ -366,7 +360,7 @@ private [sql] case class CreateTempTableUsingAsSelect( /** * Builds a map in which keys are case insensitive */ -protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] +protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String] with Serializable { val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index ad0a35b91ebc26709212bd6008defa659acad29e..40fc1f2aa272477f1634e3e67a03cfbb0b73177d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -78,7 +78,7 @@ trait SchemaRelationProvider { } @DeveloperApi -trait CreateableRelationProvider { +trait CreatableRelationProvider { def createRelation( sqlContext: SQLContext, parameters: Map[String, String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index ff91a0eb420493c1293b7710d627e58fe11c718a..f8117c21773ae1e361c379f5fdac4757d5b6f426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -22,8 +22,10 @@ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} /** @@ -54,9 +56,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) + val maybeAnalyzedPredicate = { + val forParquetTableScan = query.queryExecution.executedPlan.collect { + case plan: ParquetTableScan => plan.columnPruningPred + }.flatten.reduceOption(_ && _) + + val forParquetDataSource = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters + }.flatten.reduceOption(_ && _) + + forParquetTableScan.orElse(forParquetDataSource) + } assert(maybeAnalyzedPredicate.isDefined) maybeAnalyzedPredicate.foreach { pred => @@ -84,213 +94,228 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } - test("filter pushdown - boolean") { - withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) - - checkFilterPredicate('_1 === true, classOf[Eq [_]], true) - checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) + (implicit rdd: DataFrame): Unit = { + def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + } } + + checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) } - test("filter pushdown - short") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, - classOf[Operators.And], 3) - checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], Seq(Row(1), Row(4))) - } + private def checkBinaryFilterPredicate + (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) + (implicit rdd: DataFrame): Unit = { + checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } - test("filter pushdown - integer") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + def run(prefix: String): Unit = { + test(s"$prefix: filter pushdown - boolean") { + withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + + checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) + } + } + + test(s"$prefix: filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) + checkFilterPredicate( + Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate( + Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) + checkFilterPredicate( + Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], + Seq(Row(1), Row(4))) + } + } + + test(s"$prefix: filter pushdown - integer") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - long") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test(s"$prefix: filter pushdown - long") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - float") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test(s"$prefix: filter pushdown - float") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - double") { - withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) + test(s"$prefix: filter pushdown - double") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) - checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) - checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) + checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) - checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) + checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) + } } - } - test("filter pushdown - string") { - withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") - checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - - checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") - checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") - checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - - checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + test(s"$prefix: filter pushdown - string") { + withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd => + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate( + '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) + + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") + checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") + checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") + + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + } } - } - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: DataFrame): Unit = { - def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { - rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted + test(s"$prefix: filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") } - } - checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected) - } + withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => + checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) - def checkBinaryFilterPredicate - (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: DataFrame): Unit = { - checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) - } + checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkBinaryFilterPredicate( + '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - test("filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes("UTF-8") - } + checkBinaryFilterPredicate( + '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) + + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd => - checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkBinaryFilterPredicate( - '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate( - '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) - - checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) - checkBinaryFilterPredicate( - '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + + checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) + checkBinaryFilterPredicate( + '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) + } } } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 0bc246c645602aa687a7dd0f4b68b5a0289535ad..c8ebbbc7d2eace14b53d21ab77d1f75ef6939813 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -73,218 +73,229 @@ class ParquetIOSuite extends QueryTest with ParquetTest { withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple))) } - test("basic data types (without binary)") { - val data = (1 to 4).map { i => - (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + def run(prefix: String): Unit = { + test(s"$prefix: basic data types (without binary)") { + val data = (1 to 4).map { i => + (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + } + checkParquetFile(data) } - checkParquetFile(data) - } - test("raw binary") { - val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) - withParquetRDD(data) { rdd => - assertResult(data.map(_._1.mkString(",")).sorted) { - rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + test(s"$prefix: raw binary") { + val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte))) + withParquetRDD(data) { rdd => + assertResult(data.map(_._1.mkString(",")).sorted) { + rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted + } } } - } - - test("string") { - val data = (1 to 4).map(i => Tuple1(i.toString)) - // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL - // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) - } - test("fixed-length decimals") { - import org.apache.spark.sql.test.TestSQLContext.implicits._ - - def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext - .parallelize(0 to 1000) - .map(i => Tuple1(i / 100.0)) - .select($"_1" cast decimal as "abcd") + test(s"$prefix: string") { + val data = (1 to 4).map(i => Tuple1(i.toString)) + // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL + // as we store Spark SQL schema in the extra metadata. + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + } - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { - withTempPath { dir => - val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + test(s"$prefix: fixed-length decimals") { + import org.apache.spark.sql.test.TestSQLContext.implicits._ + + def makeDecimalRDD(decimal: DecimalType): DataFrame = + sparkContext + .parallelize(0 to 1000) + .map(i => Tuple1(i / 100.0)) + // Parquet doesn't allow column names with spaces, have to add an alias here + .select($"_1" cast decimal as "dec") + + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + withTempPath { dir => + val data = makeDecimalRDD(DecimalType(precision, scale)) + data.saveAsParquetFile(dir.getCanonicalPath) + checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + } } - } - // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() + } } - } - // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + withTempPath { dir => + makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).collect() + } } } - } - test("map") { - val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) - checkParquetFile(data) - } + test(s"$prefix: map") { + val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i"))) + checkParquetFile(data) + } - test("array") { - val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) - checkParquetFile(data) - } + test(s"$prefix: array") { + val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1))) + checkParquetFile(data) + } - test("struct") { - val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) + test(s"$prefix: struct") { + val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) + } } - } - test("nested struct with array of array as field") { - val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) - withParquetRDD(data) { rdd => - // Structs are converted to `Row`s - checkAnswer(rdd, data.map { case Tuple1(struct) => - Row(Row(struct.productIterator.toSeq: _*)) - }) + test(s"$prefix: nested struct with array of array as field") { + val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i"))))) + withParquetRDD(data) { rdd => + // Structs are converted to `Row`s + checkAnswer(rdd, data.map { case Tuple1(struct) => + Row(Row(struct.productIterator.toSeq: _*)) + }) + } } - } - test("nested map with struct as value type") { - val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) - withParquetRDD(data) { rdd => - checkAnswer(rdd, data.map { case Tuple1(m) => - Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) - }) + test(s"$prefix: nested map with struct as value type") { + val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i")))) + withParquetRDD(data) { rdd => + checkAnswer(rdd, data.map { case Tuple1(m) => + Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*))) + }) + } } - } - test("nulls") { - val allNulls = ( - null.asInstanceOf[java.lang.Boolean], - null.asInstanceOf[Integer], - null.asInstanceOf[java.lang.Long], - null.asInstanceOf[java.lang.Float], - null.asInstanceOf[java.lang.Double]) - - withParquetRDD(allNulls :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(5)(null): _*)) + test(s"$prefix: nulls") { + val allNulls = ( + null.asInstanceOf[java.lang.Boolean], + null.asInstanceOf[Integer], + null.asInstanceOf[java.lang.Long], + null.asInstanceOf[java.lang.Float], + null.asInstanceOf[java.lang.Double]) + + withParquetRDD(allNulls :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(5)(null): _*)) + } } - } - test("nones") { - val allNones = ( - None.asInstanceOf[Option[Int]], - None.asInstanceOf[Option[Long]], - None.asInstanceOf[Option[String]]) + test(s"$prefix: nones") { + val allNones = ( + None.asInstanceOf[Option[Int]], + None.asInstanceOf[Option[Long]], + None.asInstanceOf[Option[String]]) - withParquetRDD(allNones :: Nil) { rdd => - val rows = rdd.collect() - assert(rows.size === 1) - assert(rows.head === Row(Seq.fill(3)(null): _*)) + withParquetRDD(allNones :: Nil) { rdd => + val rows = rdd.collect() + assert(rows.size === 1) + assert(rows.head === Row(Seq.fill(3)(null): _*)) + } } - } - test("compression codec") { - def compressionCodecFor(path: String) = { - val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) - .map(_.getCodec.name()) - .distinct - - assert(codecs.size === 1) - codecs.head - } + test(s"$prefix: compression codec") { + def compressionCodecFor(path: String) = { + val codecs = ParquetTypesConverter + .readMetaData(new Path(path), Some(configuration)) + .getBlocks + .flatMap(_.getColumns) + .map(_.getCodec.name()) + .distinct + + assert(codecs.size === 1) + codecs.head + } - val data = (0 until 10).map(i => (i, i.toString)) + val data = (0 until 10).map(i => (i, i.toString)) - def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { - withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { - compressionCodecFor(path) + def checkCompressionCodec(codec: CompressionCodecName): Unit = { + withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withParquetFile(data) { path => + assertResult(conf.parquetCompressionCodec.toUpperCase) { + compressionCodecFor(path) + } } } } - } - // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + // Checks default compression codec + checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) - checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) - checkCompressionCodec(CompressionCodecName.GZIP) - checkCompressionCodec(CompressionCodecName.SNAPPY) - } + checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) + checkCompressionCodec(CompressionCodecName.GZIP) + checkCompressionCodec(CompressionCodecName.SNAPPY) + } - test("read raw Parquet file") { - def makeRawParquetFile(path: Path): Unit = { - val schema = MessageTypeParser.parseMessageType( - """ - |message root { - | required boolean _1; - | required int32 _2; - | required int64 _3; - | required float _4; - | required double _5; - |} - """.stripMargin) - - val writeSupport = new TestGroupWriteSupport(schema) - val writer = new ParquetWriter[Group](path, writeSupport) - - (0 until 10).foreach { i => - val record = new SimpleGroup(schema) - record.add(0, i % 2 == 0) - record.add(1, i) - record.add(2, i.toLong) - record.add(3, i.toFloat) - record.add(4, i.toDouble) - writer.write(record) - } + test(s"$prefix: read raw Parquet file") { + def makeRawParquetFile(path: Path): Unit = { + val schema = MessageTypeParser.parseMessageType( + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + |} + """.stripMargin) + + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + (0 until 10).foreach { i => + val record = new SimpleGroup(schema) + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + writer.write(record) + } - writer.close() - } + writer.close() + } - withTempDir { dir => - val path = new Path(dir.toURI.toString, "part-r-0.parquet") - makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => - Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) - }) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + makeRawParquetFile(path) + checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) + } } - } - test("write metadata") { - withTempPath { file => - val path = new Path(file.toURI.toString) - val fs = FileSystem.getLocal(configuration) - val attributes = ScalaReflection.attributesFor[(Int, String)] - ParquetTypesConverter.writeMetaData(attributes, path, configuration) + test(s"$prefix: write metadata") { + withTempPath { file => + val path = new Path(file.toURI.toString) + val fs = FileSystem.getLocal(configuration) + val attributes = ScalaReflection.attributesFor[(Int, String)] + ParquetTypesConverter.writeMetaData(attributes, path, configuration) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) - assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) + assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) - val actualSchema = metaData.getFileMetaData.getSchema - val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) + val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration)) + val actualSchema = metaData.getFileMetaData.getSchema + val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes) - actualSchema.checkContains(expectedSchema) - expectedSchema.checkContains(actualSchema) + actualSchema.checkContains(expectedSchema) + expectedSchema.checkContains(actualSchema) + } } } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ae606d11a8f6809dae30f5c83ccad5d3e9a099d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parquet + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.parquet.ParquetRelation2._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest { + override val sqlContext: SQLContext = TestSQLContext + + val defaultPartitionName = "__NULL__" + + test("column type inference") { + def check(raw: String, literal: Literal): Unit = { + assert(inferPartitionColumnValue(raw, defaultPartitionName) === literal) + } + + check("10", Literal(10, IntegerType)) + check("1000000000000000", Literal(1000000000000000L, LongType)) + check("1.5", Literal(1.5, FloatType)) + check("hello", Literal("hello", StringType)) + check(defaultPartitionName, Literal(null, NullType)) + } + + test("parse partition") { + def check(path: String, expected: PartitionValues): Unit = { + assert(expected === parsePartition(new Path(path), defaultPartitionName)) + } + + def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { + val message = intercept[T] { + parsePartition(new Path(path), defaultPartitionName) + }.getMessage + + assert(message.contains(expected)) + } + + check( + "file:///", + PartitionValues( + ArrayBuffer.empty[String], + ArrayBuffer.empty[Literal])) + + check( + "file://path/a=10", + PartitionValues( + ArrayBuffer("a"), + ArrayBuffer(Literal(10, IntegerType)))) + + check( + "file://path/a=10/b=hello/c=1.5", + PartitionValues( + ArrayBuffer("a", "b", "c"), + ArrayBuffer( + Literal(10, IntegerType), + Literal("hello", StringType), + Literal(1.5, FloatType)))) + + check( + "file://path/a=10/b_hello/c=1.5", + PartitionValues( + ArrayBuffer("c"), + ArrayBuffer(Literal(1.5, FloatType)))) + + checkThrows[AssertionError]("file://path/=10", "Empty partition column name") + checkThrows[AssertionError]("file://path/a=", "Empty partition column value") + } + + test("parse partitions") { + def check(paths: Seq[String], spec: PartitionSpec): Unit = { + assert(parsePartitions(paths.map(new Path(_)), defaultPartitionName) === spec) + } + + check(Seq( + "hdfs://host:9000/path/a=10/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", IntegerType), + StructField("b", StringType))), + Seq(Partition(Row(10, "hello"), "hdfs://host:9000/path/a=10/b=hello")))) + + check(Seq( + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + + check(Seq( + s"hdfs://host:9000/path/a=10/b=$defaultPartitionName", + s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), + Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 5ec7a156d935368fed3f7c8fda8aaf3c9eea279a..48c7598343e55159b3e5179d733b5eecf567adc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.parquet -import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{QueryTest, SQLConf} /** * A test suite that tests various Parquet queries. @@ -28,82 +28,93 @@ import org.apache.spark.sql.test.TestSQLContext._ class ParquetQuerySuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext - test("simple projection") { - withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { - checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) + def run(prefix: String): Unit = { + test(s"$prefix: simple projection") { + withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_))) + } } - } - test("appending") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT INTO TABLE t SELECT * FROM t") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + // TODO Re-enable this after data source insertion API is merged + test(s"$prefix: appending") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM t") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } } - } - // This test case will trigger the NPE mentioned in - // https://issues.apache.org/jira/browse/PARQUET-151. - ignore("overwriting") { - val data = (0 until 10).map(i => (i, i.toString)) - withParquetTable(data, "t") { - sql("INSERT OVERWRITE TABLE t SELECT * FROM t") - checkAnswer(table("t"), data.map(Row.fromTuple)) + // This test case will trigger the NPE mentioned in + // https://issues.apache.org/jira/browse/PARQUET-151. + ignore(s"$prefix: overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + withParquetTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM t") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } } - } - test("self-join") { - // 4 rows, cells of column 1 of row 2 and row 4 are null - val data = (1 to 4).map { i => - val maybeInt = if (i % 2 == 0) None else Some(i) - (maybeInt, i.toString) - } + test(s"$prefix: self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } - withParquetTable(data, "t") { - val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") - val queryOutput = selfJoin.queryExecution.analyzed.output + withParquetTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1") + val queryOutput = selfJoin.queryExecution.analyzed.output - assertResult(4, s"Field count mismatches")(queryOutput.size) - assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { - queryOutput.filter(_.name == "_1").map(_.exprId).size - } + assertResult(4, s"Field count mismatches")(queryOutput.size) + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } - checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } } - } - test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { - case Tuple1((_, Seq(string))) => Row(string) - }) + test(s"$prefix: nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } } - } - test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { - case Tuple1(Seq((_, string))) => Row(string) - }) + test(s"$prefix: nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } } - } - test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { - withParquetTable((1 to 10).map(Tuple1.apply), "t") { - checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + test(s"$prefix: SPARK-1913 regression: columns only referenced by pushed down filters should remain") { + withParquetTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) + } } - } - test("SPARK-5309 strings stored using dictionary compression in parquet") { - withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + test(s"$prefix: SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), - (0 until 10).map(i => Row("same", "run_" + i, 100))) + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) - checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), - List(Row("same", "run_5", 100))) + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } } } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + run("Parquet data source enabled") + } + + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + run("Parquet data source disabled") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 5f7f31d395cf7cf048105e39696567651dfd2635..2e6c2d5f9ab558967bec788745ca1f51877ff406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ class ParquetSchemaSuite extends FunSuite with ParquetTest { val sqlContext = TestSQLContext @@ -192,4 +193,40 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest { assert(a.nullable === b.nullable) } } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Conflicting field count + assert(intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting field names + intercept[Throwable] { + ParquetRelation2.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 243310686d08a806a27d086afdc0e50c84e1168e..c78369d12cf55228a53db2c8050b61fe210c9399 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -175,10 +176,25 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Nil } - // Since HiveQL is case insensitive for table names we make them all lowercase. - MetastoreRelation( + val relation = MetastoreRelation( databaseName, tblName, alias)( table.getTTable, partitions.map(part => part.getTPartition))(hive) + + if (hive.convertMetastoreParquet && + hive.conf.parquetUseDataSourceApi && + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet")) { + val metastoreSchema = StructType.fromAttributes(relation.output) + val paths = if (relation.hiveQlTable.isPartitioned) { + relation.hiveQlPartitions.map(p => p.getLocation) + } else { + Seq(relation.hiveQlTable.getDataLocation.toString) + } + + LogicalRelation(ParquetRelation2( + paths, Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive)) + } else { + relation + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 7857a0252ebb34a549b308dcfd91c361571cbce2..95abc363ae767bb6c7ec55449a03c4398647033e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -87,7 +87,8 @@ private[hive] trait HiveStrategies { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet => + hiveContext.convertMetastoreParquet && + !hiveContext.conf.parquetUseDataSourceApi => // Filter out all predicates that only deal with partition keys val partitionsKeys = AttributeSet(relation.partitionKeys) @@ -136,8 +137,10 @@ private[hive] trait HiveStrategies { pruningCondition(inputData) } + val partitionLocations = partitions.map(_.getLocation) + hiveContext - .parquetFile(partitions.map(_.getLocation).mkString(",")) + .parquetFile(partitionLocations.head, partitionLocations.tail: _*) .addPartitioningAttributes(relation.partitionKeys) .lowerCase .where(unresolvedOtherPredicates) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 581f6663994922303082a8e32151b5b85566099f..eae69af5864aa0007cb7e0cf779d0c96d21c77b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -28,53 +28,55 @@ class HiveParquetSuite extends QueryTest with ParquetTest { import sqlContext._ - test("Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) + def run(prefix: String): Unit = { + test(s"$prefix: Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) + } } - } - test("SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } - } - - test("Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { - checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) + test(s"$prefix: SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) + } } - } - test("Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { + test(s"$prefix: Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) } } - } - - test("INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 4).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") + test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) + parquetFile(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + checkAnswer( + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) + } + } + } + + // TODO Re-enable this after data source insertion API is merged + ignore(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) + parquetFile(file.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index 30441bbbdf817333d78b0fcf4dc58576ac67ea94..a7479a5b9586451fa52f47735a4cb13ce0b12cd3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -23,7 +23,8 @@ import java.io.File import org.apache.spark.sql.catalyst.expressions.Row import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ @@ -79,7 +80,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + location '${new File(normalTableDir, "normal").getCanonicalPath}' """) (1 to 10).foreach { p => @@ -97,7 +98,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { setConf("spark.sql.hive.convertMetastoreParquet", "false") } - test("conversion is working") { + test(s"conversion is working") { assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: HiveTableScan => true @@ -105,6 +106,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { case _: ParquetTableScan => true + case _: PhysicalRDD => true }.nonEmpty) } } @@ -147,6 +149,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { */ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll { var partitionedTableDir: File = null + var normalTableDir: File = null var partitionedTableDirWithKey: File = null import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -156,6 +159,10 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll partitionedTableDir.delete() partitionedTableDir.mkdir() + normalTableDir = File.createTempFile("parquettests", "sparksql") + normalTableDir.delete() + normalTableDir.mkdir() + (1 to 10).foreach { p => val partDir = new File(partitionedTableDir, s"p=$p") sparkContext.makeRDD(1 to 10) @@ -163,6 +170,11 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll .saveAsParquetFile(partDir.getCanonicalPath) } + sparkContext + .makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-1")) + .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql") partitionedTableDirWithKey.delete() partitionedTableDirWithKey.mkdir() @@ -175,99 +187,107 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll } } - Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => - test(s"ordering of the partitioning columns $table") { - checkAnswer( - sql(s"SELECT p, stringField FROM $table WHERE p = 1"), - Seq.fill(10)(Row(1, "part-1")) - ) - - checkAnswer( - sql(s"SELECT stringField, p FROM $table WHERE p = 1"), - Seq.fill(10)(Row("part-1", 1)) - ) - } - - test(s"project the partitioning column $table") { - checkAnswer( - sql(s"SELECT p, count(*) FROM $table group by p"), - Row(1, 10) :: - Row(2, 10) :: - Row(3, 10) :: - Row(4, 10) :: - Row(5, 10) :: - Row(6, 10) :: - Row(7, 10) :: - Row(8, 10) :: - Row(9, 10) :: - Row(10, 10) :: Nil - ) - } - - test(s"project partitioning and non-partitioning columns $table") { - checkAnswer( - sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), - Row("part-1", 1, 10) :: - Row("part-2", 2, 10) :: - Row("part-3", 3, 10) :: - Row("part-4", 4, 10) :: - Row("part-5", 5, 10) :: - Row("part-6", 6, 10) :: - Row("part-7", 7, 10) :: - Row("part-8", 8, 10) :: - Row("part-9", 9, 10) :: - Row("part-10", 10, 10) :: Nil - ) - } - - test(s"simple count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table"), - Row(100)) + def run(prefix: String): Unit = { + Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"$prefix: ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)(Row(1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(Row("part-1", 1)) + ) + } + + test(s"$prefix: project the partitioning column $table") { + checkAnswer( + sql(s"SELECT p, count(*) FROM $table group by p"), + Row(1, 10) :: + Row(2, 10) :: + Row(3, 10) :: + Row(4, 10) :: + Row(5, 10) :: + Row(6, 10) :: + Row(7, 10) :: + Row(8, 10) :: + Row(9, 10) :: + Row(10, 10) :: Nil + ) + } + + test(s"$prefix: project partitioning and non-partitioning columns $table") { + checkAnswer( + sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"), + Row("part-1", 1, 10) :: + Row("part-2", 2, 10) :: + Row("part-3", 3, 10) :: + Row("part-4", 4, 10) :: + Row("part-5", 5, 10) :: + Row("part-6", 6, 10) :: + Row("part-7", 7, 10) :: + Row("part-8", 8, 10) :: + Row("part-9", 9, 10) :: + Row("part-10", 10, 10) :: Nil + ) + } + + test(s"$prefix: simple count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table"), + Row(100)) + } + + test(s"$prefix: pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + Row(10)) + } + + test(s"$prefix: non-existent partition $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), + Row(0)) + } + + test(s"$prefix: multi-partition pruned count $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), + Row(30)) + } + + test(s"$prefix: non-partition predicates $table") { + checkAnswer( + sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), + Row(30)) + } + + test(s"$prefix: sum $table") { + checkAnswer( + sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), + Row(1 + 2 + 3)) + } + + test(s"$prefix: hive udfs $table") { + checkAnswer( + sql(s"SELECT concat(stringField, stringField) FROM $table"), + sql(s"SELECT stringField FROM $table").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) + } } - test(s"pruned count $table") { + test(s"$prefix: $prefix: non-part select(*)") { checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"), + sql("SELECT COUNT(*) FROM normal_parquet"), Row(10)) } - - test(s"non-existant partition $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"), - Row(0)) - } - - test(s"multi-partition pruned count $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"), - Row(30)) - } - - test(s"non-partition predicates $table") { - checkAnswer( - sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"), - Row(30)) - } - - test(s"sum $table") { - checkAnswer( - sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"), - Row(1 + 2 + 3)) - } - - test(s"hive udfs $table") { - checkAnswer( - sql(s"SELECT concat(stringField, stringField) FROM $table"), - sql(s"SELECT stringField FROM $table").map { - case Row(s: String) => Row(s + s) - }.collect().toSeq) - } } - test("non-part select(*)") { - checkAnswer( - sql("SELECT COUNT(*) FROM normal_parquet"), - Row(10)) - } + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + run("Parquet data source enabled") + + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + run("Parquet data source disabled") }