Skip to content
Snippets Groups Projects
Commit 5d9fa550 authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-5049][SQL] Fix ordering of partition columns in ParquetTableScan

Followup to #3870.  Props to rahulaggarwalguavus for identifying the issue.

Author: Michael Armbrust <michael@databricks.com>

Closes #3990 from marmbrus/SPARK-5049 and squashes the following commits:

dd03e4e [Michael Armbrust] Fill in the partition values of parquet scans instead of using JoinedRow
parent 3aed3051
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,7 @@ import parquet.schema.MessageType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
/**
......@@ -67,6 +67,8 @@ private[sql] case class ParquetRelation(
conf,
sqlContext.isParquetBinaryAsString)
lazy val attributeMap = AttributeMap(output.map(o => o -> o))
override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type]
// Equals must also take into account the output attributes so that we can distinguish between
......
......@@ -64,18 +64,17 @@ case class ParquetTableScan(
// The resolution of Parquet attributes is case sensitive, so we resolve the original attributes
// by exprId. note: output cannot be transient, see
// https://issues.apache.org/jira/browse/SPARK-1367
val normalOutput =
attributes
.filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId))
.flatMap(a => relation.output.find(o => o.exprId == a.exprId))
val output = attributes.map(relation.attributeMap)
val partOutput =
attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId))
// A mapping of ordinals partitionRow -> finalOutput.
val requestedPartitionOrdinals = {
val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex)
def output = partOutput ++ normalOutput
assert(normalOutput.size + partOutput.size == attributes.size,
s"$normalOutput + $partOutput != $attributes, ${relation.output}")
attributes.zipWithIndex.flatMap {
case (attribute, finalOrdinal) =>
partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal)
}
}.toArray
override def execute(): RDD[Row] = {
import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
......@@ -97,7 +96,7 @@ case class ParquetTableScan(
// Store both requested and original schema in `Configuration`
conf.set(
RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
ParquetTypesConverter.convertToString(normalOutput))
ParquetTypesConverter.convertToString(output))
conf.set(
RowWriteSupport.SPARK_ROW_SCHEMA,
ParquetTypesConverter.convertToString(relation.output))
......@@ -125,7 +124,7 @@ case class ParquetTableScan(
classOf[Row],
conf)
if (partOutput.nonEmpty) {
if (requestedPartitionOrdinals.nonEmpty) {
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
......@@ -138,15 +137,25 @@ case class ParquetTableScan(
case _ => None
}.toMap
// Convert the partitioning attributes into the correct types
val partitionRowValues =
partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
relation.partitioningAttributes
.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
new Iterator[Row] {
private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null)
def hasNext = iter.hasNext
def next() = joinedRow.withRight(iter.next()._2)
def next() = {
val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
// Parquet will leave partitioning columns empty, so we fill them in here.
var i = 0
while (i < requestedPartitionOrdinals.size) {
row(requestedPartitionOrdinals(i)._2) =
partitionRowValues(requestedPartitionOrdinals(i)._1)
i += 1
}
row
}
}
}
} else {
......
......@@ -174,6 +174,18 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
}
Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table =>
test(s"ordering of the partitioning columns $table") {
checkAnswer(
sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
Seq.fill(10)((1, "part-1"))
)
checkAnswer(
sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
Seq.fill(10)(("part-1", 1))
)
}
test(s"project the partitioning column $table") {
checkAnswer(
sql(s"SELECT p, count(*) FROM $table group by p"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment