Skip to content
Snippets Groups Projects
Commit 5ff75c74 authored by Michael Armbrust's avatar Michael Armbrust Committed by Reynold Xin
Browse files

[SPARK-2184][SQL] AddExchange isn't idempotent

...redPartitioning.

Author: Michael Armbrust <michael@databricks.com>

Closes #1122 from marmbrus/fixAddExchange and squashes the following commits:

3417537 [Michael Armbrust] Don't bind partitioning expressions as that breaks comparison with requiredPartitioning.
parent 45a95f82
No related branches found
No related tags found
No related merge requests found
...@@ -68,7 +68,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { ...@@ -68,7 +68,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
} }
object BindReferences extends Logging { object BindReferences extends Logging {
def bindReference(expression: Expression, input: Seq[Attribute]): Expression = { def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = {
expression.transform { case a: AttributeReference => expression.transform { case a: AttributeReference =>
attachTree(a, "Binding attribute") { attachTree(a, "Binding attribute") {
val ordinal = input.indexWhere(_.exprId == a.exprId) val ordinal = input.indexWhere(_.exprId == a.exprId)
...@@ -83,6 +83,6 @@ object BindReferences extends Logging { ...@@ -83,6 +83,6 @@ object BindReferences extends Logging {
BoundReference(ordinal, a) BoundReference(ordinal, a)
} }
} }
} }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
} }
} }
...@@ -208,6 +208,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { ...@@ -208,6 +208,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
this(ordering.map(BindReferences.bindReference(_, inputSchema)))
def compare(a: Row, b: Row): Int = { def compare(a: Row, b: Row): Int = {
var i = 0 var i = 0
while (i < ordering.size) { while (i < ordering.size) {
......
...@@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} ...@@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.sql.{SQLConf, SQLContext, Row} import org.apache.spark.sql.{SQLConf, SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering}
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.util.MutablePair import org.apache.spark.util.MutablePair
...@@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair ...@@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair
* :: DeveloperApi :: * :: DeveloperApi ::
*/ */
@DeveloperApi @DeveloperApi
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind {
override def outputPartitioning = newPartitioning override def outputPartitioning = newPartitioning
...@@ -42,7 +42,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una ...@@ -42,7 +42,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
case HashPartitioning(expressions, numPartitions) => case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value. // TODO: Eliminate redundant expressions in grouping key and value.
val rdd = child.execute().mapPartitions { iter => val rdd = child.execute().mapPartitions { iter =>
val hashExpressions = new MutableProjection(expressions) val hashExpressions = new MutableProjection(expressions, child.output)
val mutablePair = new MutablePair[Row, Row]() val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair.update(hashExpressions(r), r)) iter.map(r => mutablePair.update(hashExpressions(r), r))
} }
...@@ -53,7 +53,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una ...@@ -53,7 +53,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
case RangePartitioning(sortingExpressions, numPartitions) => case RangePartitioning(sortingExpressions, numPartitions) =>
// TODO: RangePartitioner should take an Ordering. // TODO: RangePartitioner should take an Ordering.
implicit val ordering = new RowOrdering(sortingExpressions) implicit val ordering = new RowOrdering(sortingExpressions, child.output)
val rdd = child.execute().mapPartitions { iter => val rdd = child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null](null, null) val mutablePair = new MutablePair[Row, Null](null, null)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment