diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ca9c0ed8cec32e01e836a765542876164121cf30..a1f941644f807d9347ffdb3d09738534871c7618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -79,7 +79,17 @@ case class SortMergeJoinExec( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + override def outputOrdering: Seq[SortOrder] = joinType match { + // For left and right outer joins, the output is ordered by the streamed input's join keys. + case LeftOuter => requiredOrders(leftKeys) + case RightOuter => requiredOrders(rightKeys) + // There are null rows in both streams, so there is no order. + case FullOuter => Nil + case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys) + case x => + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $x as the JoinType") + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 375da224aaa7fe7ac67c54a1b29d521ef4da30b5..6df80bca487dfddd5c4b78f8eaf9a0920e9bc3cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation @@ -250,7 +250,9 @@ class PlannerSuite extends SharedSQLContext { } } - // --- Unit tests of EnsureRequirements --------------------------------------------------------- + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Exchange + /////////////////////////////////////////////////////////////////////////// // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, // there two dimensions that need to be considered: are the child partitionings compatible and @@ -383,93 +385,6 @@ class PlannerSuite extends SharedSQLContext { } } - test("EnsureRequirements adds sort when there is no existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq.empty) :: Nil, - requiredChildOrdering = Seq(Seq(orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - test("EnsureRequirements skips sort when required ordering is semantically equal to " + - "existing ordering") { - val exprId: ExprId = NamedExpression.newExprId - val attribute1 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId, - qualifier = Some("col1_qualifier") - ) - - val attribute2 = - AttributeReference( - name = "col1", - dataType = LongType, - nullable = false - ) (exprId = exprId) - - val orderingA1 = SortOrder(attribute1, Ascending) - val orderingA2 = SortOrder(attribute2, Ascending) - - assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") - assert(orderingA1.semanticEquals(orderingA2), - s"$orderingA1 should be semantically equal to $orderingA2") - - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA2)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { - fail(s"No sorts should have been added:\n$outputPlan") - } - } - - // This is a regression test for SPARK-11135 - test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { - val orderingA = SortOrder(Literal(1), Ascending) - val orderingB = SortOrder(Literal(2), Ascending) - assert(orderingA != orderingB) - val inputPlan = DummySparkPlan( - children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, - requiredChildOrdering = Seq(Seq(orderingA, orderingB)), - requiredChildDistribution = Seq(UnspecifiedDistribution) - ) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) - assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case s: SortExec => true }.isEmpty) { - fail(s"Sort should have been added:\n$outputPlan") - } - } - test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -480,7 +395,7 @@ class PlannerSuite extends SharedSQLContext { children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), requiredChildOrdering = Seq(Seq.empty)), - None) + None) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) @@ -509,8 +424,6 @@ class PlannerSuite extends SharedSQLContext { } } - // --------------------------------------------------------------------------------------------- - test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) @@ -524,12 +437,12 @@ class PlannerSuite extends SharedSQLContext { None) val inputPlan = SortMergeJoinExec( - Literal(1) :: Nil, - Literal(1) :: Nil, - Inner, - None, - shuffle, - shuffle) + Literal(1) :: Nil, + Literal(1) :: Nil, + Inner, + None, + shuffle, + shuffle) val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { @@ -556,6 +469,130 @@ class PlannerSuite extends SharedSQLContext { fail(s"Should have only two shuffles:\n$outputPlan") } } + + /////////////////////////////////////////////////////////////////////////// + // Unit tests of EnsureRequirements for Sort + /////////////////////////////////////////////////////////////////////////// + + private val exprA = Literal(1) + private val exprB = Literal(2) + private val orderingA = SortOrder(exprA, Ascending) + private val orderingB = SortOrder(exprB, Ascending) + private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: Nil, 5)) + private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + + assert(orderingA != orderingB) + + private def assertSortRequirementsAreSatisfied( + childPlan: SparkPlan, + requiredOrdering: Seq[SortOrder], + shouldHaveSort: Boolean): Unit = { + val inputPlan = DummySparkPlan( + children = childPlan :: Nil, + requiredChildOrdering = Seq(requiredOrdering), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (shouldHaveSort) { + if (outputPlan.collect { case s: SortExec => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } else { + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + } + + test("EnsureRequirements for sort operator after left outer sort merge join") { + // Only left key is sorted after left outer SMJ (thus doesn't need a sort). + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = leftSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements for sort operator after right outer sort merge join") { + // Only right key is sorted after right outer SMJ (thus doesn't need a sort). + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => + assertSortRequirementsAreSatisfied( + childPlan = rightSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = needSort) + } + } + + test("EnsureRequirements adds sort after full outer sort merge join") { + // Neither keys is sorted after full outer SMJ, so they both need sorts. + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = fullSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = true) + } + } + + test("EnsureRequirements adds sort when there is no existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq.empty), + requiredOrdering = Seq(orderingB), + shouldHaveSort = true) + } + + test("EnsureRequirements skips sort when required ordering is prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB)), + requiredOrdering = Seq(orderingA), + shouldHaveSort = false) + } + + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA1)), + requiredOrdering = Seq(orderingA2), + shouldHaveSort = false) + } + + // This is a regression test for SPARK-11135 + test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + assertSortRequirementsAreSatisfied( + childPlan = DummySparkPlan(outputOrdering = Seq(orderingA)), + requiredOrdering = Seq(orderingA, orderingB), + shouldHaveSort = true) + } } // Used for unit-testing EnsureRequirements