diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b237a07c72d07e46135b7eab3632a54f3231a9ee..2835dc3408b962293943bd31c1f65fd9c58a6cd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -28,7 +28,7 @@ import parquet.schema.MessageType import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} /** @@ -67,6 +67,8 @@ private[sql] case class ParquetRelation( conf, sqlContext.isParquetBinaryAsString) + lazy val attributeMap = AttributeMap(output.map(o => o -> o)) + override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 96bace1769f719831aa06a78ec99480011f6fcb9..f5487740d3af92530c222ab248e7445c0f8c1592 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -64,18 +64,17 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val normalOutput = - attributes - .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) - .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + val output = attributes.map(relation.attributeMap) - val partOutput = - attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + // A mapping of ordinals partitionRow -> finalOutput. + val requestedPartitionOrdinals = { + val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - def output = partOutput ++ normalOutput - - assert(normalOutput.size + partOutput.size == attributes.size, - s"$normalOutput + $partOutput != $attributes, ${relation.output}") + attributes.zipWithIndex.flatMap { + case (attribute, finalOrdinal) => + partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) + } + }.toArray override def execute(): RDD[Row] = { import parquet.filter2.compat.FilterCompat.FilterPredicateCompat @@ -97,7 +96,7 @@ case class ParquetTableScan( // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(normalOutput)) + ParquetTypesConverter.convertToString(output)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -125,7 +124,7 @@ case class ParquetTableScan( classOf[Row], conf) - if (partOutput.nonEmpty) { + if (requestedPartitionOrdinals.nonEmpty) { baseRDD.mapPartitionsWithInputSplit { case (split, iter) => val partValue = "([^=]+)=([^=]+)".r val partValues = @@ -138,15 +137,25 @@ case class ParquetTableScan( case _ => None }.toMap + // Convert the partitioning attributes into the correct types val partitionRowValues = - partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + relation.partitioningAttributes + .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) new Iterator[Row] { - private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null) - def hasNext = iter.hasNext - - def next() = joinedRow.withRight(iter.next()._2) + def next() = { + val row = iter.next()._2.asInstanceOf[SpecificMutableRow] + + // Parquet will leave partitioning columns empty, so we fill them in here. + var i = 0 + while (i < requestedPartitionOrdinals.size) { + row(requestedPartitionOrdinals(i)._2) = + partitionRowValues(requestedPartitionOrdinals(i)._1) + i += 1 + } + row + } } } } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala index fc0e42c201d56e846bf760f2fd9a798eeb21be75..8bbb7f2fdbf48ed75f55b3c2876158bbdaf4cdee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala @@ -174,6 +174,18 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll } Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table => + test(s"ordering of the partitioning columns $table") { + checkAnswer( + sql(s"SELECT p, stringField FROM $table WHERE p = 1"), + Seq.fill(10)((1, "part-1")) + ) + + checkAnswer( + sql(s"SELECT stringField, p FROM $table WHERE p = 1"), + Seq.fill(10)(("part-1", 1)) + ) + } + test(s"project the partitioning column $table") { checkAnswer( sql(s"SELECT p, count(*) FROM $table group by p"),