diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ecee750482485e679963b58504a1892d5733469..a1ac93073916cba933732b5c8dec1ac235bbee58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -68,7 +68,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughAggregate, ColumnPruning, // Operator combine - ProjectCollapsing, + CollapseRepartition, + CollapseProject, CombineFilters, CombineLimits, CombineUnions, @@ -322,7 +323,7 @@ object ColumnPruning extends Rule[LogicalPlan] { * Combines two adjacent [[Project]] operators into one and perform alias substitution, * merging the expressions into one single expression. */ -object ProjectCollapsing extends Rule[LogicalPlan] { +object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p @ Project(projectList1, Project(projectList2, child)) => @@ -390,6 +391,16 @@ object ProjectCollapsing extends Rule[LogicalPlan] { } } +/** + * Combines adjacent [[Repartition]] operators by keeping only the last one. + */ +object CollapseRepartition extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) => + Repartition(numPartitions, shuffle, child) + } +} + /** * Simplifies LIKE expressions that do not need full regular expressions to evaluate the condition. * For example, when the expression is just checking to see if a string starts with a given @@ -857,6 +868,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. + * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala similarity index 96% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 85b6530481b0331b9e5dfb51afc3c0f530dc728a..f5fd5ca6beb15d84edc048037209169bd495796d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -25,11 +25,11 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ProjectCollapsingSuite extends PlanTest { +class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: - Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + Batch("CollapseProject", Once, CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index f9f3bd55aa5784400a1a7093cedbb54003514440..b49ca928b62928155a35f34a6e39390106f25e8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -42,7 +42,7 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala index 9b1e16c727647d022cd774435d6fe1dce3b23a40..858a0d8fde3ea0eb01a090142c10132606daca3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOrderSuite.scala @@ -43,7 +43,7 @@ class JoinOrderSuite extends PlanTest { PushPredicateThroughGenerate, PushPredicateThroughAggregate, ColumnPruning, - ProjectCollapsing) :: Nil + CollapseProject) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 8fca5e2167d04e91f11ca67739fbfc8e1af720a6..adaeb513bc1b8094bec80ec06da59eb4fdfbe8a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row, SQLConf} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -223,6 +222,18 @@ class PlannerSuite extends SharedSQLContext { } } + test("collapse adjacent repartitions") { + val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) + def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length + assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + doubleRepartitioned.queryExecution.optimizedPlan match { + case r: Repartition => + assert(r.numPartitions === 5) + assert(r.shuffle === false) + } + } + // --- Unit tests of EnsureRequirements --------------------------------------------------------- // When it comes to testing whether EnsureRequirements properly ensures distribution requirements, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 1654594538366db21111683765ced415015f77e4..fc5725d6915eab9abf1a525654e8a1235950041d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -188,7 +188,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over // `Aggregate`s to perform type casting. This rule merges these `Project`s into // `Aggregate`s. - ProjectCollapsing, + CollapseProject, // Used to handle other auxiliary `Project`s added by analyzer (e.g. // `ResolveAggregateFunctions` rule)