Skip to content
Snippets Groups Projects
Commit 6aad02d0 authored by Takeshi Yamamuro's avatar Takeshi Yamamuro Committed by Herman van Hovell
Browse files

[SPARK-18394][SQL] Make an AttributeSet.toSeq output order consistent

## What changes were proposed in this pull request?
This pr sorted output attributes on their name and exprId in `AttributeSet.toSeq` to make the order consistent.  If the order is different, spark possibly generates different code and then misses cache in `CodeGenerator`, e.g., `GenerateColumnAccessor` generates code depending on an input attribute order.

## How was this patch tested?
Added tests in `AttributeSetSuite` and manually checked if the cache worked well in the given query of the JIRA.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #18959 from maropu/SPARK-18394.
parent ae9e4247
No related branches found
No related tags found
No related merge requests found
......@@ -121,7 +121,12 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
override def toSeq: Seq[Attribute] = {
// We need to keep a deterministic output order for `baseSet` because this affects a variable
// order in generated code (e.g., `GenerateColumnAccessor`).
// See SPARK-18394 for details.
baseSet.map(_.a).toSeq.sortBy { a => (a.name, a.exprId.id) }
}
override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"
......
......@@ -78,4 +78,44 @@ class AttributeSetSuite extends SparkFunSuite {
assert(aSet == aSet)
assert(aSet == AttributeSet(aUpper :: Nil))
}
test("SPARK-18394 keep a deterministic output order along with attribute names and exprIds") {
// Checks a simple case
val attrSeqA = {
val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(1098))
val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(107))
val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(838))
val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)
val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(389))
val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(89329))
val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)
(attrSetA ++ attrSetB).toSeq.map(_.name)
}
val attrSeqB = {
val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(392))
val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(92))
val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(87))
val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)
val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(9023920))
val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(522))
val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)
(attrSetA ++ attrSetB).toSeq.map(_.name)
}
assert(attrSeqA === attrSeqB)
// Checks the same column names having different exprIds
val attr1 = AttributeReference("c", IntegerType)(exprId = ExprId(1098))
val attr2 = AttributeReference("c", IntegerType)(exprId = ExprId(107))
val attrSetA = AttributeSet(attr1 :: attr2 :: Nil)
val attr3 = AttributeReference("c", IntegerType)(exprId = ExprId(389))
val attrSetB = AttributeSet(attr3 :: Nil)
assert((attrSetA ++ attrSetB).toSeq === attr2 :: attr3 :: attr1 :: Nil)
}
}
......@@ -162,7 +162,12 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
}.head
assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch")
assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch")
// Scanned columns in `HiveTableScanExec` are generated by the `pruneFilterProject` method
// in `SparkPlanner`. This method internally uses `AttributeSet.toSeq`, in which
// the returned output columns are sorted by the names and expression ids.
assert(actualScannedColumns.sorted === expectedScannedColumns.sorted,
"Scanned columns mismatch")
val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted
val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment